using LinearAlgebra, Printf, Plots function SDG(f; x::Union{Nothing, Vector}=nothing, astart::Real=1, eps::Real=1e-6, MaxFeval::Integer=1000, m1::Real=1e-3, m2::Real=0.9, tau::Real=0.9, sfgrd::Real=0.01, MInf::Real=-Inf, mina::Real=1e-16, plt::Union{Plots.Plot, Nothing}=nothing, plotatend::Bool=true, Plotf::Integer=0, printing::Bool=true)::Tuple{AbstractArray, String} # - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - # local functions - - - - - - - - - - - - - - - - - - - - - - - - - - - - - # - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - function f2phi(alpha, derivate=false) lastx = x .- alpha .* g (phi, lastg, _) = f(lastx) if (Plotf > 2) if fStar > -Inf push!(gap, (phi - fStar) / max(abs(fStar), 1)) else push!(gap, phi) end end feval += 1 if derivate return phi, dot(-g, lastg) end return phi, nothing end # - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - # - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - function ArmijoWolfeLS(phi0, phip0, as, m1, m2, tau) # performs an Armijo-Wolfe Line Search. # # Inputs: # # - phi0 = phi( 0 ) # # - phip0 = phi'( 0 ) (< 0) # # - as (> 0) is the first value to be tested: if the Armijo condition # # phi( as ) <= phi0 + m1 * as * phip0 # # is satisfied but the Wolfe condition is not, which means that the # derivative in as is still negative, which means that longer steps # might be possible), then as is divided by tau < 1 (hence it is # increased) until this does not happen any longer # # - m1 (> 0 and < 1, typically small, like 0.01) is the parameter of # the Armijo condition # # - m2 (> m1 > 0, typically large, like 0.9) is the parameter of the # Wolfe condition # # - tau (> 0 and < 1) is the increasing coefficient for the first phase # (extrapolation) # # Outputs: # # - a is the "optimal" step # # - phia = phi( a ) (the "optimal" f-value) lsiter = 1 # count iterations of first phase local phips, phia while feval ≤ MaxFeval (phia, phips) = f2phi(as, true) # compute phi( a ) and phi'( a ) if phia > phi0 + m1 * as * phip0 # Armijo not satisfied break end if phips ≥ m2 * phip0 # Wolfe satisfied if printing @printf("%2d ", lsiter) end a = as return (a, phia) # Armijo + Wolfe satisfied, done end if phips ≥ 0 # derivative is positive, break break end as = as / tau lsiter += 1 end if printing @printf("%2d ", lsiter) end lsiter = 1 # count iterations of second phase am = 0 a = as phipm = phip0 while (feval ≤ MaxFeval) && ((as - am) > abs(mina)) && (abs(phips) > 1e-12) if (phipm < 0) && (phips > 0) # if the derivative in as is positive and that in am is negative, # then compute the new step by safeguarded quadratic interpolation a = (am * phips - as * phipm) / (phips - phipm) a = max(am + ( as - am ) * sfgrd, min(as - ( as - am ) * sfgrd, a)) else a = (as - am) / 2 # else just use dumb binary search end phia, phipa = f2phi(a, true) # compute phi( a ) and phi'( a ) if phia ≤ phi0 + m1 * as * phip0 # Armijo satisfied if phipa ≥ m2 * phip0 # Wolfe satisfied break # Armijo + Wolfe satisfied, done end am = a # Armijo is satisfied but Wolfe is not, i.e., the phipm = phipa # derivative is still negative: move the left # endpoint of the interval to a else # Armijo not satisfied as = a # move the right endpoint of the interval to a phips = phipa end lsiter += 1 end if printing @printf("%2d", lsiter) end return (a, phia) end # - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - function BacktrackingLS(phi0, phip0, as, m1, tau) # performs a Backtracking Line Search. # # phi0 = phi( 0 ), phip0 = phi'( 0 ) < 0 # # as > 0 is the first value to be tested, which is decreased by # multiplying it by tau < 1 until the Armijo condition with parameter # m1 is satisfied # # returns the optimal step and the optimal f-value local phia lsiter = 1 # count ls iterations while feval ≤ MaxFeval && as > mina (phia, _) = f2phi(as) if phia ≤ phi0 + m1 * as * phip0 # Armijo satisfied break # we are done end as = as * tau lsiter += 1 end if printing @printf("\t%2d", lsiter) end return (as, phia) end # Plotf = 1 # 0 = nothing is plotted # 1 = the level sets of f and the trajectory are plotted (when n = 2) # 2 = the function value / gap are plotted, iteration-wise # 3 = the function value / gap are plotted, function-evaluation-wise Interactive = false local gap PXY = Matrix{Real}(undef, 2, 0) status = "error" if Plotf > 1 if Plotf == 2 MaxIter = 200 # expected number of iterations for the gap plot else MaxIter = 1000 # expected number of iterations for the gap plot end gap = [] end if x == nothing (fStar, x, _) = f(nothing) else (fStar, _, _) = f(nothing) end n = size(x, 1) if astart == 0 throw(ArgumentError("astart must be ≠ 0")) end if m1 ≤ 0 || m1 ≥ 1 throw(ArgumentError("m1: ($m1) is not in (0, 1)")) end AWLS = (m2 > 0 && m2 < 1) if tau ≤ 0 || tau ≥ 1 throw(ArgumentError("tau: ($tau) is not in (0, 1)")) end if sfgrd ≤ 0 || sfgrd ≥ 1 throw(ArgumentError("sfgrd: ($sfgrd) is not in (0, 1)")) end if mina < 0 throw(ArgumentError("mina: ($mina) must be ≥ 0")) end if Plotf > 1 && plt == nothing plt = plot(xlims=(0, MaxIter)) elseif plt == nothing plt = plot() end # "global" variables- - - - - - - - - - - - - - - - - - - - - - - - - - - - # - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - lastx = zeros(n) # last point visited in the line search lastg = zeros(n) # gradient of lastx feval = 1 # f() evaluations count ("common" with LSs) # initializations - - - - - - - - - - - - - - - - - - - - - - - - - - - - - # - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - if printing println("Gradient method") end if fStar > -Inf if printing print("feval\trel gap\t\t|| g(x) ||\trate\t") end prevv = Inf else if printing print("feval\tf(x)\t\t\t|| g(x) ||") end end if astart > 0 if printing print("\tls feval\ta*") end end if printing print("\n\n") end # compute first f-value and gradient in x^0 - - - - - - - - - - - - - - - - g = zeros(2, 1) v, _ = f2phi(0) g = lastg # compute norm of the (first) gradient- - - - - - - - - - - - - - - - - - - ng = norm(g) if eps < 0 ng0 = -ng # norm of first subgradient: why is there a "-"? ;-) else ng0 = 1 # un-scaled stopping criterion end # main loop - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - # - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - while true # output statistics & plot gap/f-values - - - - - - - - - - - - - - - - if fStar > -Inf gapk = (v .- fStar) / max(abs(fStar), 1) if printing @printf("%4d\t%1.4e\t%1.4e", feval, gapk, ng) end if prevv < Inf if printing @printf("\t%1.4e", (v .- fStar) / (prevv - fStar)) end else if printing print(" \t ") end end prevv = v if Plotf > 1 if Plotf ≥ 2 push!(gap, gapk) end plot!(plt, yscale=:log) if Plotf == 2 plot!(plt, ylims=(1e-15, 1e+1)) else plot!(plt, ylims=(1e-15, 1e+4)) end end else if printing @printf("%4d\t%1.8e\t\t%1.4e", feval, v, ng) end if Plotf ≥ 2 push!(gap, v) end end # stopping criteria - - - - - - - - - - - - - - - - - - - - - - - - - - if ng ≤ (eps * ng0) status = "optimal" if printing print("\n") end break end if feval > MaxFeval status = "stopped" if printing print("\n") end break end # compute step size - - - - - - - - - - - - - - - - - - - - - - - - - - phip0 = -ng * ng if astart < 0 # fixed-step approach lastx = x .+ astart .* g (v, lastg, _) = f(lastx) feval = feval + 1 else # line-search approach, either Armijo-Wolfe or Backtracking if AWLS a, v = ArmijoWolfeLS(v, phip0, astart, m1, m2, tau) else a, v = BacktrackingLS(v, phip0, astart, m1, tau) end end # output statistics - - - - - - - - - - - - - - - - - - - - - - - - - - if astart > 0 if printing @printf("\t%1.4e\n", a) end if a ≤ mina status = "error" if printing print("\n") end break end else if printing print("\n") end end if v ≤ MInf status = "unbounded" if printing print("\n") end break end # compute new point - - - - - - - - - - - - - - - - - - - - - - - - - - # possibly plot the trajectory if n == 2 && Plotf == 1 PXY = hcat(PXY, hcat(x, lastx)) end x = lastx # update gradient - - - - - - - - - - - - - - - - - - - - - - - - - - - g = lastg ng = norm(g) # iterate - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - if Interactive readline() end end if plotatend if Plotf ≥ 2 plot!(plt, gap) elseif Plotf == 1 && n == 2 plot!(plt, PXY[1, :], PXY[2, :]) end display(plt) end # end of main loop- - - - - - - - - - - - - - - - - - - - - - - - - - - - - # - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - return (x, status) end