lesson 9/11
This commit is contained in:
423
11-09/SDG.jl
Normal file
423
11-09/SDG.jl
Normal file
@ -0,0 +1,423 @@
|
||||
using LinearAlgebra, Printf, Plots
|
||||
|
||||
function SDG(f;
|
||||
x::Union{Nothing, Vector}=nothing,
|
||||
astart::Real=1,
|
||||
eps::Real=1e-6,
|
||||
MaxFeval::Integer=1000,
|
||||
m1::Real=1e-3,
|
||||
m2::Real=0.9,
|
||||
tau::Real=0.9,
|
||||
sfgrd::Real=0.01,
|
||||
MInf::Real=-Inf,
|
||||
mina::Real=1e-16,
|
||||
plt::Union{Plots.Plot, Nothing}=nothing,
|
||||
plotatend::Bool=true,
|
||||
Plotf::Integer=0,
|
||||
printing::Bool=true)::Tuple{AbstractArray, String}
|
||||
|
||||
# - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
|
||||
# local functions - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
|
||||
# - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
|
||||
|
||||
function f2phi(alpha, derivate=false)
|
||||
lastx = x .- alpha .* g
|
||||
(phi, lastg, _) = f(lastx)
|
||||
|
||||
if (Plotf > 2)
|
||||
if fStar > -Inf
|
||||
push!(gap, (phi - fStar) / max(abs(fStar), 1))
|
||||
else
|
||||
push!(gap, phi)
|
||||
end
|
||||
end
|
||||
feval += 1
|
||||
|
||||
if derivate
|
||||
return phi, dot(-g, lastg)
|
||||
end
|
||||
return phi, nothing
|
||||
end
|
||||
|
||||
# - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
|
||||
# - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
|
||||
|
||||
function ArmijoWolfeLS(phi0, phip0, as, m1, m2, tau)
|
||||
# performs an Armijo-Wolfe Line Search.
|
||||
#
|
||||
# Inputs:
|
||||
#
|
||||
# - phi0 = phi( 0 )
|
||||
#
|
||||
# - phip0 = phi'( 0 ) (< 0)
|
||||
#
|
||||
# - as (> 0) is the first value to be tested: if the Armijo condition
|
||||
#
|
||||
# phi( as ) <= phi0 + m1 * as * phip0
|
||||
#
|
||||
# is satisfied but the Wolfe condition is not, which means that the
|
||||
# derivative in as is still negative, which means that longer steps
|
||||
# might be possible), then as is divided by tau < 1 (hence it is
|
||||
# increased) until this does not happen any longer
|
||||
#
|
||||
# - m1 (> 0 and < 1, typically small, like 0.01) is the parameter of
|
||||
# the Armijo condition
|
||||
#
|
||||
# - m2 (> m1 > 0, typically large, like 0.9) is the parameter of the
|
||||
# Wolfe condition
|
||||
#
|
||||
# - tau (> 0 and < 1) is the increasing coefficient for the first phase
|
||||
# (extrapolation)
|
||||
#
|
||||
# Outputs:
|
||||
#
|
||||
# - a is the "optimal" step
|
||||
#
|
||||
# - phia = phi( a ) (the "optimal" f-value)
|
||||
|
||||
lsiter = 1 # count iterations of first phase
|
||||
local phips, phia
|
||||
while feval ≤ MaxFeval
|
||||
(phia, phips) = f2phi(as, true) # compute phi( a ) and phi'( a )
|
||||
|
||||
if phia > phi0 + m1 * as * phip0 # Armijo not satisfied
|
||||
break
|
||||
end
|
||||
|
||||
if phips ≥ m2 * phip0 # Wolfe satisfied
|
||||
if printing
|
||||
@printf("%2d ", lsiter)
|
||||
end
|
||||
a = as
|
||||
return (a, phia) # Armijo + Wolfe satisfied, done
|
||||
end
|
||||
if phips ≥ 0 # derivative is positive, break
|
||||
break
|
||||
end
|
||||
as = as / tau
|
||||
lsiter += 1
|
||||
end
|
||||
|
||||
if printing
|
||||
@printf("%2d ", lsiter)
|
||||
end
|
||||
lsiter = 1 # count iterations of second phase
|
||||
|
||||
am = 0
|
||||
a = as
|
||||
phipm = phip0
|
||||
while (feval ≤ MaxFeval) && ((as - am) > abs(mina)) && (abs(phips) > 1e-12)
|
||||
|
||||
if (phipm < 0) && (phips > 0)
|
||||
# if the derivative in as is positive and that in am is negative,
|
||||
# then compute the new step by safeguarded quadratic interpolation
|
||||
a = (am * phips - as * phipm) / (phips - phipm)
|
||||
a = max(am + ( as - am ) * sfgrd, min(as - ( as - am ) * sfgrd, a))
|
||||
else
|
||||
a = (as - am) / 2 # else just use dumb binary search
|
||||
end
|
||||
|
||||
phia, phipa = f2phi(a, true) # compute phi( a ) and phi'( a )
|
||||
|
||||
if phia ≤ phi0 + m1 * as * phip0 # Armijo satisfied
|
||||
if phipa ≥ m2 * phip0 # Wolfe satisfied
|
||||
break # Armijo + Wolfe satisfied, done
|
||||
end
|
||||
|
||||
am = a # Armijo is satisfied but Wolfe is not, i.e., the
|
||||
phipm = phipa # derivative is still negative: move the left
|
||||
# endpoint of the interval to a
|
||||
else # Armijo not satisfied
|
||||
as = a # move the right endpoint of the interval to a
|
||||
phips = phipa
|
||||
end
|
||||
|
||||
lsiter += 1
|
||||
end
|
||||
|
||||
if printing
|
||||
@printf("%2d", lsiter)
|
||||
end
|
||||
|
||||
return (a, phia)
|
||||
end
|
||||
|
||||
# - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
|
||||
|
||||
function BacktrackingLS(phi0, phip0, as, m1, tau)
|
||||
# performs a Backtracking Line Search.
|
||||
#
|
||||
# phi0 = phi( 0 ), phip0 = phi'( 0 ) < 0
|
||||
#
|
||||
# as > 0 is the first value to be tested, which is decreased by
|
||||
# multiplying it by tau < 1 until the Armijo condition with parameter
|
||||
# m1 is satisfied
|
||||
#
|
||||
# returns the optimal step and the optimal f-value
|
||||
|
||||
local phia
|
||||
|
||||
lsiter = 1 # count ls iterations
|
||||
while feval ≤ MaxFeval && as > mina
|
||||
(phia, _) = f2phi(as)
|
||||
if phia ≤ phi0 + m1 * as * phip0 # Armijo satisfied
|
||||
break # we are done
|
||||
end
|
||||
as = as * tau
|
||||
lsiter += 1
|
||||
end
|
||||
|
||||
if printing
|
||||
@printf("\t%2d", lsiter)
|
||||
end
|
||||
return (as, phia)
|
||||
end
|
||||
|
||||
# Plotf = 1
|
||||
# 0 = nothing is plotted
|
||||
# 1 = the level sets of f and the trajectory are plotted (when n = 2)
|
||||
# 2 = the function value / gap are plotted, iteration-wise
|
||||
# 3 = the function value / gap are plotted, function-evaluation-wise
|
||||
Interactive = false
|
||||
|
||||
local gap
|
||||
PXY = Matrix{Real}(undef, 2, 0)
|
||||
|
||||
status = "error"
|
||||
|
||||
if Plotf > 1
|
||||
if Plotf == 2
|
||||
MaxIter = 200 # expected number of iterations for the gap plot
|
||||
else
|
||||
MaxIter = 1000 # expected number of iterations for the gap plot
|
||||
end
|
||||
gap = []
|
||||
end
|
||||
|
||||
if x == nothing
|
||||
(fStar, x, _) = f(nothing)
|
||||
else
|
||||
(fStar, _, _) = f(nothing)
|
||||
end
|
||||
|
||||
n = size(x, 1)
|
||||
|
||||
if astart == 0
|
||||
throw(ArgumentError("astart must be ≠ 0"))
|
||||
end
|
||||
|
||||
if m1 ≤ 0 || m1 ≥ 1
|
||||
throw(ArgumentError("m1: ($m1) is not in (0, 1)"))
|
||||
end
|
||||
|
||||
AWLS = (m2 > 0 && m2 < 1)
|
||||
|
||||
if tau ≤ 0 || tau ≥ 1
|
||||
throw(ArgumentError("tau: ($tau) is not in (0, 1)"))
|
||||
end
|
||||
|
||||
if sfgrd ≤ 0 || sfgrd ≥ 1
|
||||
throw(ArgumentError("sfgrd: ($sfgrd) is not in (0, 1)"))
|
||||
end
|
||||
|
||||
if mina < 0
|
||||
throw(ArgumentError("mina: ($mina) must be ≥ 0"))
|
||||
end
|
||||
|
||||
if Plotf > 1 && plt == nothing
|
||||
plt = plot(xlims=(0, MaxIter))
|
||||
elseif plt == nothing
|
||||
plt = plot()
|
||||
end
|
||||
|
||||
# "global" variables- - - - - - - - - - - - - - - - - - - - - - - - - - - -
|
||||
# - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
|
||||
|
||||
lastx = zeros(n) # last point visited in the line search
|
||||
lastg = zeros(n) # gradient of lastx
|
||||
feval = 1 # f() evaluations count ("common" with LSs)
|
||||
|
||||
# initializations - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
|
||||
# - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
|
||||
|
||||
if printing
|
||||
println("Gradient method")
|
||||
end
|
||||
if fStar > -Inf
|
||||
if printing
|
||||
print("feval\trel gap\t\t|| g(x) ||\trate\t")
|
||||
end
|
||||
prevv = Inf
|
||||
else
|
||||
if printing
|
||||
print("feval\tf(x)\t\t\t|| g(x) ||")
|
||||
end
|
||||
end
|
||||
if astart > 0
|
||||
if printing
|
||||
print("\tls feval\ta*")
|
||||
end
|
||||
end
|
||||
if printing
|
||||
print("\n\n")
|
||||
end
|
||||
|
||||
# compute first f-value and gradient in x^0 - - - - - - - - - - - - - - - -
|
||||
|
||||
g = zeros(2, 1)
|
||||
v, _ = f2phi(0)
|
||||
g = lastg
|
||||
|
||||
# compute norm of the (first) gradient- - - - - - - - - - - - - - - - - - -
|
||||
|
||||
ng = norm(g)
|
||||
if eps < 0
|
||||
ng0 = -ng # norm of first subgradient: why is there a "-"? ;-)
|
||||
else
|
||||
ng0 = 1 # un-scaled stopping criterion
|
||||
end
|
||||
|
||||
# main loop - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
|
||||
# - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
|
||||
|
||||
while true
|
||||
# output statistics & plot gap/f-values - - - - - - - - - - - - - - - -
|
||||
|
||||
if fStar > -Inf
|
||||
gapk = (v .- fStar) / max(abs(fStar), 1)
|
||||
|
||||
if printing
|
||||
@printf("%4d\t%1.4e\t%1.4e", feval, gapk, ng)
|
||||
end
|
||||
if prevv < Inf
|
||||
if printing
|
||||
@printf("\t%1.4e", (v .- fStar) / (prevv - fStar))
|
||||
end
|
||||
else
|
||||
if printing
|
||||
print(" \t ")
|
||||
end
|
||||
end
|
||||
prevv = v
|
||||
|
||||
if Plotf > 1
|
||||
if Plotf ≥ 2
|
||||
push!(gap, gapk)
|
||||
end
|
||||
plot!(plt, yscale=:log)
|
||||
if Plotf == 2
|
||||
plot!(plt, ylims=(1e-15, 1e+1))
|
||||
else
|
||||
plot!(plt, ylims=(1e-15, 1e+4))
|
||||
end
|
||||
end
|
||||
else
|
||||
if printing
|
||||
@printf("%4d\t%1.8e\t\t%1.4e", feval, v, ng)
|
||||
end
|
||||
|
||||
if Plotf ≥ 2
|
||||
push!(gap, v)
|
||||
end
|
||||
end
|
||||
|
||||
# stopping criteria - - - - - - - - - - - - - - - - - - - - - - - - - -
|
||||
|
||||
if ng ≤ (eps * ng0)
|
||||
status = "optimal"
|
||||
if printing
|
||||
print("\n")
|
||||
end
|
||||
break
|
||||
end
|
||||
|
||||
if feval > MaxFeval
|
||||
status = "stopped"
|
||||
if printing
|
||||
print("\n")
|
||||
end
|
||||
break
|
||||
end
|
||||
|
||||
# compute step size - - - - - - - - - - - - - - - - - - - - - - - - - -
|
||||
|
||||
phip0 = -ng * ng
|
||||
|
||||
if astart < 0
|
||||
# fixed-step approach
|
||||
lastx = x .+ astart .* g
|
||||
(v, lastg, _) = f(lastx)
|
||||
feval = feval + 1
|
||||
else
|
||||
# line-search approach, either Armijo-Wolfe or Backtracking
|
||||
|
||||
if AWLS
|
||||
a, v = ArmijoWolfeLS(v, phip0, astart, m1, m2, tau)
|
||||
else
|
||||
a, v = BacktrackingLS(v, phip0, astart, m1, tau)
|
||||
end
|
||||
end
|
||||
|
||||
# output statistics - - - - - - - - - - - - - - - - - - - - - - - - - -
|
||||
|
||||
if astart > 0
|
||||
if printing
|
||||
@printf("\t%1.4e\n", a)
|
||||
end
|
||||
|
||||
if a ≤ mina
|
||||
status = "error"
|
||||
if printing
|
||||
print("\n")
|
||||
end
|
||||
break
|
||||
end
|
||||
else
|
||||
if printing
|
||||
print("\n")
|
||||
end
|
||||
end
|
||||
|
||||
if v ≤ MInf
|
||||
status = "unbounded"
|
||||
if printing
|
||||
print("\n")
|
||||
end
|
||||
break
|
||||
end
|
||||
|
||||
# compute new point - - - - - - - - - - - - - - - - - - - - - - - - - -
|
||||
|
||||
# possibly plot the trajectory
|
||||
if n == 2 && Plotf == 1
|
||||
PXY = hcat(PXY, hcat(x, lastx))
|
||||
end
|
||||
|
||||
x = lastx
|
||||
|
||||
# update gradient - - - - - - - - - - - - - - - - - - - - - - - - - - -
|
||||
|
||||
g = lastg
|
||||
ng = norm(g)
|
||||
|
||||
# iterate - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
|
||||
|
||||
if Interactive
|
||||
readline()
|
||||
end
|
||||
end
|
||||
|
||||
if plotatend
|
||||
if Plotf ≥ 2
|
||||
plot!(plt, gap)
|
||||
elseif Plotf == 1 && n == 2
|
||||
plot!(plt, PXY[1, :], PXY[2, :])
|
||||
end
|
||||
display(plt)
|
||||
end
|
||||
|
||||
# end of main loop- - - - - - - - - - - - - - - - - - - - - - - - - - - - -
|
||||
# - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
|
||||
|
||||
return (x, status)
|
||||
end
|
||||
Reference in New Issue
Block a user