diff --git a/src/PALC/callback.jl b/src/PALC/callback.jl index 380d4d2..9af1ecb 100644 --- a/src/PALC/callback.jl +++ b/src/PALC/callback.jl @@ -53,6 +53,17 @@ mutable struct InternalTerminateContinuationCallback{FType} <: end end +# Callback sets +mutable struct TerminateContinuationCallbackSet{T<:Tuple} <: RootSolveContinuationCallback + callbacks::T # Each argument should be a RootSolveContinuationCallback, or else errors will occur later + # the user CAN costruct with a tuple of anything, but should use the below constructor with varargs +end + +# don't need to do any handling like in SciMLBase, since we only have 1 type of callback for now +function TerminateContinuationCallbackSet(callbacks::Union{RootSolveContinuationCallback, Nothing}...) + TerminateContinuationCallbackSet(callbacks) +end + # Simple continuation callback for analyzing the status of the continuation process. # Has no zero finding functionality struct AnalysisContinuationCallback{FType} <: AbstractContinuationCallback @@ -76,6 +87,13 @@ function initialize!(cb::InternalRootSolveContinuationCallback, cache::PALCCache cb.val_0 = cb.f(cache.u0, cache.λ0, cache, alg, prob) return nothing end +function initialize!(cb::TerminateContinuationCallbackSet, cache::PALCCache, alg, prob) + # initialize each callback in the set + @inbounds for i in eachindex(cb.callbacks) + initialize!(cb.callbacks[i], cache, alg, prob) + end + return nothing +end # Callback update update!(cb::Nothing, cache::PALCCache, alg, prob) = nothing @@ -87,6 +105,11 @@ function update!(cb::InternalRootSolveContinuationCallback, cache::PALCCache, al cb.val_0 = cb.f(cache.u0, cache.λ0, cache, alg, prob) return nothing end +function update!(cb::TerminateContinuationCallbackSet, cache::PALCCache, alg, prob) + @inbounds for i in eachindex(cb.callbacks) + update!(cb.callbacks[i], cache, alg, prob) + end +end # Check the root solve callback (returns true if we stepped over zero) check(cb::Nothing, uλ0, cache::PALCCache, alg, prob) = false @@ -132,3 +155,8 @@ end function handle_termination_callback(cb, cache, alg, p) return cb end +function handle_termination_callback(cb::TerminateContinuationCallbackSet{T}, cache, alg, p) where {T<:Tuple} + handle_cb_map_fun = cb-> handle_termination_callback(cb, cache, alg, p) + cb = TerminateContinuationCallbackSet(map(handle_cb_map_fun, cb.callbacks)) + return cb +end \ No newline at end of file diff --git a/src/PALC/correction.jl b/src/PALC/correction.jl index 4241669..2c48c21 100644 --- a/src/PALC/correction.jl +++ b/src/PALC/correction.jl @@ -147,6 +147,164 @@ function palc_correction!( return success, terminate_continuation end +## Correction method if we are using a set of callbacks +# Only difference here is that instead of looking for 1 callback being triggered, we look for if any of the set is triggered sequentially, then perform RF. +# Note: should probably later improve this to cover the case that multiple callbacks trigger in a single step, but for now we are looking at them sequentially +function palc_correction!( + cache, + alg, + p::ContinuationProblem, + solvers, + dsmin, + dsmax, + term_callback::TerminateContinuationCallbackSet, + analysis_callback, + trace, +) + # Get cache variables + u0 = cache.u0 + λ0 = cache.λ0 + δu0 = cache.δu0 + δλ0 = cache.δλ0 + uλpred = cache.uλpred + n = length(δu0) + + # Compute inner product of tangent with itself + dotδ = alg.inner_prod(δu0, δλ0) + + # Get problem variables + λmin = p.λ_bounds[1] + λmax = p.λ_bounds[2] + + # Solve nonlinear problem (reducing step-size if necessary) + attempts = 0 + success = true + done = false + hit_bnd = NaN + cb_trig = false + rf_succ = false + triggered_callback = nothing # which (if any) callback was triggered + while !done + # Update attempts + attempts += 1 + + # Compute α for ds + α = cache.ds / dotδ + + # Update uλpred (clamping ds to try and stay in λ bounds) + # If clamped, set hit_bnd to the bound hit and we'll resolve + # if successful with constant λ + + uλpred[end] = λ0 + α * δλ0 + if uλpred[end] < λmin + α = (λmin - λ0) / δλ0 + cache.ds = α * dotδ + uλpred[end] = λmin + hit_bnd = λmin + print_correction_trace(cache, trace, 2) + elseif uλpred[end] > λmax + α = (λmax - λ0) / δλ0 + cache.ds = α * dotδ + uλpred[end] = λmax + hit_bnd = λmax + print_correction_trace(cache, trace, 2) + else + print_correction_trace(cache, trace, 1) + end + + uλpred[1:n] .= u0 .+ α .* δu0 + + # Solve the palc nonlinear problem + uλ, retcode = solve_palc_nlp!(solvers, uλpred, trace) + + # Check if successful + if SciMLBase.successful_retcode(retcode) + + # Check if callback triggered (iterating through in order) + for jj in eachindex(term_callback.callbacks) + cb_trig = check(term_callback.callbacks[jj], uλ, cache, alg, p) + if cb_trig # break once one is triggered + triggered_callback = jj + break + end + end + + # If callback triggered, perform regula falsi root finding method and update uλ + if cb_trig + rf_succ = palc_target_callback_event!( + uλ, cache, alg, p, solvers, term_callback.callbacks[triggered_callback], trace + ) + hit_bnd = NaN # Reset since we're likely not stepping as far and will recheck + end + + # Check if we crossed the boundary + if uλ[end] < λmin + hit_bnd = λmin + elseif uλ[end] > λmax + hit_bnd = λmax + end + + if cb_trig && !rf_succ # Triggered callback but rootfind was unsuccessful + cb_trig = false + hit_bnd = NaN + scale_and_clamp_ds!(cache, 0.5, dsmin, dsmax) + elseif isnan(hit_bnd) + # Push solution and set done + set_successful_iterate!(cache, uλ) + done = true + + # Print trace if desired + print_correction_trace(cache, trace, 3) + else + # Update cache without pushing solution to curve + set_successful_iterate!(cache, uλ, false) + + # Print trace if desired + print_correction_trace(cache, trace, 3) + + # Target solution on boundary + flag = palc_target_solution_on_boundary!(cache, hit_bnd, solvers, trace) + + # If targeting solution on boundary was successful, we're done. Otherwise, reduce ds + if flag + done = true + else + hit_bnd = NaN + scale_and_clamp_ds!(cache, 0.5, dsmin, dsmax) + end + end + else # solve is not successful + if abs(cache.ds) == dsmin + done = true + success = false + cache.ret = :MinimumStepSize # update ret with 'done' condition. This won't be overwritten since success=false (see continuation.jl) + else + # Reduce step-size and reattempt + scale_and_clamp_ds!(cache, 0.5, dsmin, dsmax) + end + end + end + + # Update ds is we were successful + # Consider only updating is successful in < n number of attempts + success && scale_and_clamp_ds!(cache, 1.2, dsmin, dsmax) + + # Update the callback if we were successfull + success && update!(term_callback, cache, alg, p) + + # Call the analysis callback if we were successful + success && call!(analysis_callback, cache) + + # Handle termination flag + terminate_continuation = !isnan(hit_bnd) || cb_trig # hit bound or triggered callback + + if terminate_continuation + set_successful_retcode!(cache, hit_bnd, cb_trig, triggered_callback) + end + + return success, terminate_continuation +end + function print_correction_trace(cache::PALCCache, trace::Silent, stage) return nothing end @@ -400,8 +558,19 @@ function palc_correction_jacobian!(J, uλ, p) end function set_successful_retcode!(cache, hit_bnd, cb_trig) + # recover success condition and set + if isnan(hit_bnd) && cb_trig + cache.ret = :Callback + elseif !cb_trig + cache.ret = :HitBound + end +end + +function set_successful_retcode!(cache, hit_bnd, cb_trig, jj) + # recover success condition and set if isnan(hit_bnd) && cb_trig - cache.ret = :CallbackTermination + retstr = "Callback$jj" + cache.ret = Symbol(retstr) elseif !cb_trig cache.ret = :HitBound end diff --git a/src/PALC/palc.jl b/src/PALC/palc.jl index 09e1a7d..eb8f169 100644 --- a/src/PALC/palc.jl +++ b/src/PALC/palc.jl @@ -82,6 +82,13 @@ mutable struct PALCCache{MT<:Union{Matrix{Float64},SparseMatrixCSC{Float64,Int}} # Return code ret::Symbol # could also do a custom type for pretty viewing like sciml but symbols should work just fine + # current implemented retcodes: + # :HitBound; successful, hit boundary + # :Callback; successful, terminated due to callback + # :Callbacki; successful, terminated due to callback i in callback set + # :Maxiters; unsuccessful (generally), terminated due to max iterations + # :MinimumStepSize; unsuccessful, terminated due to shrinking stepsize + # :None; this is what the retcode is initialized to, so seeing this after running continuation signifies an unhandled exit condition end function PALCCache( diff --git a/src/PALC/special_callbacks.jl b/src/PALC/special_callbacks.jl index 292efac..1fd53fb 100644 --- a/src/PALC/special_callbacks.jl +++ b/src/PALC/special_callbacks.jl @@ -46,3 +46,56 @@ function handle_termination_callback(cb::FoldBifurcationTerminationCallback, cac fold_detection_callback_function, cache, alg, p; tol=cb.tol ) end + +# Incomplete - just testing. Remove later if not used +# This callback terminates the continuation process when the slope of the predicted tangent in a prescried dimension rises above a certain threshold. +# For example, I'm using this to detect rapid increases in the time of flight. This may prematurely trigger as we approach +# fold bifurcations (though in an initial test this did not happen, I think because the folds in κ-tf space are generally cusps rather than smooth) +# so a better implementation is needed. But it works for the problem I'm using it on right now, where I'm avoiding degeneracy +# due to a rapid increase in tf rather than a fold +struct MaxSlopeTerminationCallback <: RootSolveContinuationCallback + indexes::Vector{Int} + threshold::Float64 + tol::Float64 + function MaxSlopeTerminationCallback(indexes,threshold; tol=1e-12) + return new(indexes, threshold, tol) + end +end + + +function max_slope_callback_function(u, λ, cache, alg, prob, threshold, indexes) + + # Calculate the bordered prediction in the same manner as the FoldBifurcationTerminationCallback + n = length(u) + eval_J!(cache.Jfun, cache.Ffun, u, λ, prob.f) + cache.bordered_mat[1:n, 1:(n + 1)] .= cache.Jfun + ddotdu1!(view(cache.bordered_mat, n + 1, 1:n), cache.δu0, alg.inner_prod) + cache.bordered_mat[n + 1, n + 1] = ddotdλ1(cache.δλ0, alg.inner_prod) + + x = begin + ls_sol = solve(LinearProblem(cache.bordered_mat, cache.bordered_b), alg.linsolve) + ls_sol.u + end + + # Scale the predicted tangent + #scale_predicted_tangent!(x, cache, alg) + xn = sqrt(alg.inner_prod(view(x, 1:n), x[n + 1])) + α = sign(alg.inner_prod(view(x, 1:n), cache.δu0, x[n + 1], cache.δλ0)) / xn + x .*= α + + un = view(x, 1:n) + λn = view(x, n+1) + + # Now check if the slope (norm) for the desired indexes is greater than some threshold + ret = norm(abs.(un[indexes])./λn)-threshold + return ret + +end + +# Handle slope termination callback +function handle_termination_callback(cb::MaxSlopeTerminationCallback, cache, alg, p) + fn(u, λ, cache, alg, prob) = max_slope_callback_function(u, λ, cache, alg, prob, cb.threshold, cb.indexes) + return InternalTerminateContinuationCallback( + fn, cache, alg, p; tol=cb.tol + ) +end \ No newline at end of file diff --git a/src/SimpleContinuation.jl b/src/SimpleContinuation.jl index c7c42ef..4e19fbb 100644 --- a/src/SimpleContinuation.jl +++ b/src/SimpleContinuation.jl @@ -35,7 +35,7 @@ export Bordered, Secant export ContinuationFunction, SparseContinuationFunction export ContinuationProblem -export TerminateContinuationCallback, AnalysisContinuationCallback +export TerminateContinuationCallback, AnalysisContinuationCallback, TerminateContinuationCallbackSet export FoldBifurcationTerminationCallback export PALC diff --git a/test/runtests.jl b/test/runtests.jl index 530e6bd..02718ff 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -5,5 +5,6 @@ using SafeTestsets @time begin @time @safetestset "Functions" begin include("test_functions.jl") + include("test_retcodes.jl") end end diff --git a/test/test_callbacks.jl b/test/test_callbacks.jl new file mode 100644 index 0000000..09de012 --- /dev/null +++ b/test/test_callbacks.jl @@ -0,0 +1,97 @@ +using SimpleContinuation +using Test + +const SC = SimpleContinuation + +# The scalar test function is from https://bifurcationkit.github.io/BifurcationKitDocs.jl/stable/gettingstarted/ + +function scalar_test_fun!(F, x, λ) + F[1] = (λ + x[1] - x[1]^3 / 3) +end + +function scalar_fun_ujac!(J, F, x, λ) + scalar_test_fun!(F, x, λ) + J[1,1] = 1 - x[1]^2 +end + +function scalar_fun_jac!(J, F, x, λ) + scalar_test_fun!(F, x, λ) + J[1,1] = 1 - x[1]^2 + J[1,2] = 1 +end + +f_min_time = (F, u, s) -> scalar_test_fun!(F, u, s) +Jz_min_time = (J, F, u, s) -> scalar_fun_ujac!(J, F, u, s) +J_min_time = (J, F, u, s) -> scalar_fun_jac!(J, F, u, s) + + +f1 = (x,λ) -> -1.5-x[1] # this should trigger BEFORE first fold bifurcation +f2 = (x,λ) -> x[1] # this would trigger AFTER first fold bifurcation +f3 = (x,λ) -> x[1] - 100 # this should never trigger + +fold_bifurcation_cb = SC.FoldBifurcationTerminationCallback() +cb1 = SC.TerminateContinuationCallback(f1) +cb2 = SC.TerminateContinuationCallback(f2) +cb3 = SC.TerminateContinuationCallback(f3) + +# now, lets test a couple sets + +u0 = [-2.0] +λ0 = -1.0 + +# Set 1: should terminate at fold and have retcode :Callback1 +set_1 = SC.TerminateContinuationCallbackSet(fold_bifurcation_cb, cb2) +cache = continuation( + ContinuationProblem( + ContinuationFunction{Val{true}}(f_min_time, Jz_min_time, J_min_time), + u0, + λ0, + (λ0, 1.0), + ), + PALC(; predicter=Bordered()); + both_sides=false, + ds0=1e-2, + dsmin=1e-2, + dsmax=0.1, + max_cont_steps=1000, + term_callback=set_1, +) + + +# set 2 should have retcode :Callback2 +set_2 = SC.TerminateContinuationCallbackSet(cb3, cb2) +cache = continuation( + ContinuationProblem( + ContinuationFunction{Val{true}}(f_min_time, Jz_min_time, J_min_time), + u0, + λ0, + (λ0, 1.0), + ), + PALC(; predicter=Bordered()); + both_sides=false, + ds0=1e-2, + dsmin=1e-2, + dsmax=0.0, + max_cont_steps=1000, + term_callback=set_2, +) + +# Should terminate at cb4, retcode :Callback1 +cb4 = SC.MaxSlopeTerminationCallback([1], 50) +set_3 = SC.TerminateContinuationCallbackSet(cb4, fold_bifurcation_cb) +cache = continuation( + ContinuationProblem( + ContinuationFunction{Val{true}}(f_min_time, Jz_min_time, J_min_time), + u0, + λ0, + (λ0, 1.0), + ), + PALC(; predicter=Bordered()); + both_sides=false, + ds0=1e-4, + dsmin=1e-4, + dsmax=1e-3, # needed to reduce step size to force the slope to reach a high value (larger steps step over high-slope regions and trigger the fold bifurcation cb first) + max_cont_steps=10e3, + term_callback=set_3, + trace=ContinuationAndNewtonSteps() +) \ No newline at end of file diff --git a/test/test_retcodes.jl b/test/test_retcodes.jl index 92f12aa..2a47b75 100644 --- a/test/test_retcodes.jl +++ b/test/test_retcodes.jl @@ -42,7 +42,7 @@ cache = continuation( #trace=ContinuationSteps(), term_callback=FoldBifurcationTerminationCallback(), ) -@test cache.ret == :CallbackTermination +@test cache.ret == :Callback # Should term due to hitting boundary cache = continuation(