first commit with current lessons
This commit is contained in:
247
10-04/GMQ.jl
Normal file
247
10-04/GMQ.jl
Normal file
@ -0,0 +1,247 @@
|
||||
using LinearAlgebra
|
||||
using Printf
|
||||
using Plots
|
||||
|
||||
function gmq(Q::Matrix, q::Vector; x::Union{Vector, Nothing}=nothing, fStar::Real=-Inf, alpha::Real=0 , MaxIter::Int=1000 , eps::Real=1e-6, plt::Union{Plots.Plot, Nothing}=nothing, Plotf::Int=2, printing::Bool=true)::Tuple{Vector, String}
|
||||
# Plotf
|
||||
# 0 = nothing is plotted
|
||||
# 1 = the function value / gap are plotted
|
||||
# 2 = the level sets of f and the trajectory are plotted (when n = 2)
|
||||
|
||||
Interactive = true # if we pause at every iteration
|
||||
|
||||
Streamlined = true # if the streamlined version of the algorithm, with
|
||||
# only one O( n^2 ) operation per iteration, is used
|
||||
|
||||
# reading and checking input- - - - - - - - - - - - - - - - - - - - - - - -
|
||||
# - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
|
||||
|
||||
if !isreal(Q)
|
||||
throw(ArgumentError(Q, "Q not a real matrix"))
|
||||
end
|
||||
|
||||
n = size(Q, 1)
|
||||
|
||||
if n <= 1
|
||||
throw(ArgumentError(Q, "Q is too small"))
|
||||
end
|
||||
|
||||
if n != size(Q, 2)
|
||||
throw(ArgumentError(Q, "Q is not square"))
|
||||
end
|
||||
|
||||
if !isreal(q)
|
||||
throw(ArgumentError(q, "q not a real vector"))
|
||||
end
|
||||
|
||||
if size(q, 1) != n
|
||||
throw(ArgumentError(q, "q size does not match with Q"))
|
||||
end
|
||||
|
||||
if x == nothing
|
||||
x = zeros(n, 1)
|
||||
end
|
||||
|
||||
if !isreal(x)
|
||||
throw(ArgumentError(x, "x not a real vector"))
|
||||
end
|
||||
|
||||
if size(x, 1) != n
|
||||
throw(ArgumentError(x, "x size does not match with Q"))
|
||||
end
|
||||
|
||||
if MaxIter < 1
|
||||
throw(ArgumentError(MaxIter, "MaxIter too small"))
|
||||
end
|
||||
|
||||
if eps < 0
|
||||
throw(ArgumentError(eps, "eps can not be negative"))
|
||||
end
|
||||
|
||||
# initializations - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
|
||||
# - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
|
||||
|
||||
if printing
|
||||
print("Gradient method for quadratic functions ")
|
||||
if alpha == 0
|
||||
print("(optimal stepsize)\n")
|
||||
else
|
||||
print("(fixed stepsize)\n")
|
||||
end
|
||||
|
||||
print("iter\tf(x)\t\t\t||g||")
|
||||
end
|
||||
|
||||
if fStar > - Inf
|
||||
if printing
|
||||
print("\t\tgap\t\trate")
|
||||
end
|
||||
prevf = Inf
|
||||
end
|
||||
if printing
|
||||
if alpha == 0
|
||||
print("\t\talpha")
|
||||
end
|
||||
|
||||
print("\n\n")
|
||||
end
|
||||
|
||||
i = 0;
|
||||
if Plotf == 1
|
||||
gap = []
|
||||
end
|
||||
|
||||
if Streamlined
|
||||
g = Q * x + q
|
||||
end
|
||||
|
||||
if Plotf == 1 && plt == nothing
|
||||
plt = plot(yscale = :log,
|
||||
xlims=(0, MaxIter),
|
||||
ylims=(1e-15, Inf),
|
||||
guidefontsize=16)
|
||||
elseif Plotf == 2 && plt == nothing
|
||||
plt = plot()
|
||||
end
|
||||
|
||||
status = ""
|
||||
|
||||
# main loop - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
|
||||
# - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
|
||||
|
||||
while true
|
||||
if !Streamlined
|
||||
g = Q * x + q
|
||||
end
|
||||
|
||||
ng = norm(g)
|
||||
|
||||
f = dot((g + q)', x) / 2 # 1/2 x^T Q x + q x
|
||||
# = 1/2 ( x^T Q x + 2 q x )
|
||||
# = 1/2 x^T ( Q x + q + q )
|
||||
# = 1/2 ( q + g ) x
|
||||
i += 1
|
||||
|
||||
if printing
|
||||
@printf("%4d\t%1.8e\t\t%1.4e", i, f, ng)
|
||||
end
|
||||
if fStar > -Inf
|
||||
gapk = (f - fStar)/maximum([abs(fStar), 1])
|
||||
if printing
|
||||
@printf("\t%1.4e", gapk)
|
||||
|
||||
if prevf < Inf
|
||||
@printf("\t%1.4e", (f - fStar)/(prevf - fStar))
|
||||
else
|
||||
@printf("\t\t")
|
||||
end
|
||||
end
|
||||
|
||||
prevf = f
|
||||
|
||||
if Plotf == 1
|
||||
push!(gap, gapk)
|
||||
end
|
||||
end
|
||||
|
||||
# stopping criteria - - - - - - - - - - - - - - - - - - - - - - - - - -
|
||||
if ng <= eps
|
||||
status = "optimal"
|
||||
if alpha == 0 && printing
|
||||
print("\n")
|
||||
end
|
||||
break
|
||||
end
|
||||
|
||||
if i > MaxIter
|
||||
status = "stopped"
|
||||
if alpha == 0 && printing
|
||||
print("\n")
|
||||
end
|
||||
break
|
||||
end
|
||||
|
||||
# compute step size - - - - - - - - - - - - - - - - - - - - - - - - - -
|
||||
# meanwhile, check if f is unbounded below
|
||||
# note that if alpha > 0 this is only used for the unboundedness check
|
||||
# which is a bit of a waste, but there you go; anyway, in the
|
||||
# streamlined version this only costs O( n )
|
||||
|
||||
if Streamlined
|
||||
v = Q * g;
|
||||
den = dot(g', v)
|
||||
else
|
||||
den = dot(g', Q * g)
|
||||
end
|
||||
|
||||
if den <= 1e-14
|
||||
# this is actually two different cases:
|
||||
# - g' * Q * g = 0, i.e., f is linear along g, and since the
|
||||
# gradient is not zero, it is unbounded below
|
||||
#
|
||||
# - g' * Q * g < 0, i.e., g is a direction of negative curvature for
|
||||
# f, which is then necessarily unbounded below
|
||||
if printing
|
||||
if alpha == 0
|
||||
print("\n")
|
||||
end
|
||||
@printf("g' * Q * g = %1.4e ==> unbounded\n", den)
|
||||
end
|
||||
status = "unbounded"
|
||||
break
|
||||
end
|
||||
|
||||
if alpha > 0
|
||||
t = alpha
|
||||
else
|
||||
t = ng^2 / den # stepsize
|
||||
if printing
|
||||
@printf("\t%1.2e", t)
|
||||
end
|
||||
end
|
||||
|
||||
if printing
|
||||
print("\n")
|
||||
end
|
||||
|
||||
# compute new point - - - - - - - - - - - - - - - - - - - - - - - - - -
|
||||
|
||||
# possibly plot the trajectory
|
||||
if n == 2 && Plotf == 2
|
||||
PXY = hcat(vec(x), vec(x - t * g))
|
||||
plot!(PXY[1,:],
|
||||
PXY[2,:],
|
||||
linestyle=:solid,
|
||||
linewidth=2,
|
||||
markershape=:circle,
|
||||
seriescolor=colorant"black",
|
||||
label="")
|
||||
end
|
||||
|
||||
x = x - t * g
|
||||
|
||||
if Streamlined
|
||||
g = g - t * v
|
||||
end
|
||||
|
||||
if Interactive
|
||||
#readline()
|
||||
end
|
||||
|
||||
if Plotf != 0
|
||||
#IJulia.clear_output(true)
|
||||
#display(plt)
|
||||
end
|
||||
end
|
||||
|
||||
if Plotf == 1
|
||||
plot!(plt,
|
||||
gap,
|
||||
linewidth=2,
|
||||
seriescolor=colorant"black")
|
||||
display(plt)
|
||||
elseif Plotf == 2
|
||||
display(plt)
|
||||
end
|
||||
(vec(x), status)
|
||||
end
|
||||
Reference in New Issue
Block a user