Files
cmdla/10-04/GMQ.jl

247 lines
6.2 KiB
Julia
Raw Normal View History

2023-10-29 02:06:02 +01:00
using LinearAlgebra
using Printf
using Plots
function gmq(Q::Matrix, q::Vector; x::Union{Vector, Nothing}=nothing, fStar::Real=-Inf, alpha::Real=0 , MaxIter::Int=1000 , eps::Real=1e-6, plt::Union{Plots.Plot, Nothing}=nothing, Plotf::Int=2, printing::Bool=true)::Tuple{Vector, String}
# Plotf
# 0 = nothing is plotted
# 1 = the function value / gap are plotted
# 2 = the level sets of f and the trajectory are plotted (when n = 2)
Interactive = true # if we pause at every iteration
Streamlined = true # if the streamlined version of the algorithm, with
# only one O( n^2 ) operation per iteration, is used
# reading and checking input- - - - - - - - - - - - - - - - - - - - - - - -
# - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
if !isreal(Q)
throw(ArgumentError(Q, "Q not a real matrix"))
end
n = size(Q, 1)
if n <= 1
throw(ArgumentError(Q, "Q is too small"))
end
if n != size(Q, 2)
throw(ArgumentError(Q, "Q is not square"))
end
if !isreal(q)
throw(ArgumentError(q, "q not a real vector"))
end
if size(q, 1) != n
throw(ArgumentError(q, "q size does not match with Q"))
end
if x == nothing
x = zeros(n, 1)
end
if !isreal(x)
throw(ArgumentError(x, "x not a real vector"))
end
if size(x, 1) != n
throw(ArgumentError(x, "x size does not match with Q"))
end
if MaxIter < 1
throw(ArgumentError(MaxIter, "MaxIter too small"))
end
if eps < 0
throw(ArgumentError(eps, "eps can not be negative"))
end
# initializations - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
# - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
if printing
print("Gradient method for quadratic functions ")
if alpha == 0
print("(optimal stepsize)\n")
else
print("(fixed stepsize)\n")
end
print("iter\tf(x)\t\t\t||g||")
end
if fStar > - Inf
if printing
print("\t\tgap\t\trate")
end
prevf = Inf
end
if printing
if alpha == 0
print("\t\talpha")
end
print("\n\n")
end
i = 0;
if Plotf == 1
gap = []
end
if Streamlined
g = Q * x + q
end
if Plotf == 1 && plt == nothing
plt = plot(yscale = :log,
xlims=(0, MaxIter),
ylims=(1e-15, Inf),
guidefontsize=16)
elseif Plotf == 2 && plt == nothing
plt = plot()
end
status = ""
# main loop - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
# - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
while true
if !Streamlined
g = Q * x + q
end
ng = norm(g)
f = dot((g + q)', x) / 2 # 1/2 x^T Q x + q x
# = 1/2 ( x^T Q x + 2 q x )
# = 1/2 x^T ( Q x + q + q )
# = 1/2 ( q + g ) x
i += 1
if printing
@printf("%4d\t%1.8e\t\t%1.4e", i, f, ng)
end
if fStar > -Inf
gapk = (f - fStar)/maximum([abs(fStar), 1])
if printing
@printf("\t%1.4e", gapk)
if prevf < Inf
@printf("\t%1.4e", (f - fStar)/(prevf - fStar))
else
@printf("\t\t")
end
end
prevf = f
if Plotf == 1
push!(gap, gapk)
end
end
# stopping criteria - - - - - - - - - - - - - - - - - - - - - - - - - -
if ng <= eps
status = "optimal"
if alpha == 0 && printing
print("\n")
end
break
end
if i > MaxIter
status = "stopped"
if alpha == 0 && printing
print("\n")
end
break
end
# compute step size - - - - - - - - - - - - - - - - - - - - - - - - - -
# meanwhile, check if f is unbounded below
# note that if alpha > 0 this is only used for the unboundedness check
# which is a bit of a waste, but there you go; anyway, in the
# streamlined version this only costs O( n )
if Streamlined
v = Q * g;
den = dot(g', v)
else
den = dot(g', Q * g)
end
if den <= 1e-14
# this is actually two different cases:
# - g' * Q * g = 0, i.e., f is linear along g, and since the
# gradient is not zero, it is unbounded below
#
# - g' * Q * g < 0, i.e., g is a direction of negative curvature for
# f, which is then necessarily unbounded below
if printing
if alpha == 0
print("\n")
end
@printf("g' * Q * g = %1.4e ==> unbounded\n", den)
end
status = "unbounded"
break
end
if alpha > 0
t = alpha
else
t = ng^2 / den # stepsize
if printing
@printf("\t%1.2e", t)
end
end
if printing
print("\n")
end
# compute new point - - - - - - - - - - - - - - - - - - - - - - - - - -
# possibly plot the trajectory
if n == 2 && Plotf == 2
PXY = hcat(vec(x), vec(x - t * g))
plot!(PXY[1,:],
PXY[2,:],
linestyle=:solid,
linewidth=2,
markershape=:circle,
seriescolor=colorant"black",
label="")
end
x = x - t * g
if Streamlined
g = g - t * v
end
if Interactive
#readline()
end
if Plotf != 0
#IJulia.clear_output(true)
#display(plt)
end
end
if Plotf == 1
plot!(plt,
gap,
linewidth=2,
seriescolor=colorant"black")
display(plt)
elseif Plotf == 2
display(plt)
end
(vec(x), status)
end