86 lines
2.3 KiB
Julia
86 lines
2.3 KiB
Julia
module OracleFunction
|
|
|
|
using LinearAlgebra: norm
|
|
|
|
export OracleF, LeastSquaresF, tomography
|
|
|
|
@doc """
|
|
```julia
|
|
OracleF{T, F<:Function, G<:Function}
|
|
```
|
|
|
|
Struct that holds a generic function to evaluate.
|
|
`eval` is the function that evaluates a point, `grad` is the gradient of the function and
|
|
`starting_point` is the point from which minimization should start.
|
|
"""
|
|
struct OracleF{T, F<:Function, G<:Function}
|
|
starting_point::AbstractArray{T}
|
|
eval::F
|
|
grad::G
|
|
end
|
|
|
|
@doc """
|
|
```julia
|
|
LeastSquaresF{T, F<:Function, G<:Function}
|
|
```
|
|
|
|
Struct that holds an instance of a least squares problem. The interface is similar to the `OracleF` struct.
|
|
`eval` is the function that evaluates a point, `grad` is the gradient of the function and
|
|
`starting_point` is the point from which minimization should start.
|
|
|
|
See also [`OracleF`](@ref).
|
|
"""
|
|
struct LeastSquaresF{T, F<:Function, G<:Function}
|
|
oracle::OracleF{T, F, G}
|
|
X::AbstractMatrix{T}
|
|
y::AbstractArray{T}
|
|
symm::AbstractMatrix{T}
|
|
yX::AbstractArray{T}
|
|
end
|
|
|
|
function LeastSquaresF(starting_point::AbstractArray{T}, X::AbstractMatrix{T}, y::AbstractArray{T}) where T
|
|
f(x) = norm(X * x - y)^2
|
|
df(x) = 2 * X' * (X * x - y)
|
|
symm = X' * X
|
|
yX = y' * X
|
|
|
|
o = OracleF(starting_point, f, df)
|
|
LeastSquaresF(o, X, y, symm, yX)
|
|
end
|
|
|
|
function LeastSquaresF(t::NamedTuple)
|
|
if [:X_hat, :y_hat, :start] ⊈ keys(t)
|
|
throw(ArgumentError("Input tuple does not contain necessary values, found: " * string(keys(t))))
|
|
end
|
|
starting_point, X, y = t[:start], t[:X_hat], t[:y_hat]
|
|
|
|
LeastSquaresF(starting_point, X, y)
|
|
end
|
|
|
|
@doc """
|
|
```julia
|
|
tomography(l::LeastSquaresF{T, F, G}, w::AbstractArray{T}, p::AbstractArray{T})
|
|
```
|
|
|
|
Function that returns the minimum of the function `l` along the plane in `w` and with direction `p`.
|
|
|
|
See also [`LeastSquaresF`](@ref).
|
|
"""
|
|
function tomography(l::LeastSquaresF{T, F, G}, w::AbstractArray{T}, p::AbstractArray{T}) where {T, F, G}
|
|
(l.yX * p - w' * l.symm * p) * inv(p' * l.symm * p)
|
|
end
|
|
|
|
function Base.getproperty(l::LeastSquaresF{T, F, G}, name::Symbol) where {T, F, G}
|
|
if name === :eval
|
|
return l.oracle.eval
|
|
elseif name === :grad
|
|
return l.oracle.grad
|
|
elseif name === :starting_point
|
|
return l.oracle.starting_point
|
|
else
|
|
getfield(l, name)
|
|
end
|
|
end
|
|
|
|
end ## module OracleFunction
|