Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 28 additions & 0 deletions src/PALC/callback.jl
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,17 @@ mutable struct InternalTerminateContinuationCallback{FType} <:
end
end

# Callback sets
mutable struct TerminateContinuationCallbackSet{T<:Tuple} <: RootSolveContinuationCallback
callbacks::T # Each argument should be a RootSolveContinuationCallback, or else errors will occur later
# the user CAN costruct with a tuple of anything, but should use the below constructor with varargs
end

# don't need to do any handling like in SciMLBase, since we only have 1 type of callback for now
function TerminateContinuationCallbackSet(callbacks::Union{RootSolveContinuationCallback, Nothing}...)
TerminateContinuationCallbackSet(callbacks)
end

# Simple continuation callback for analyzing the status of the continuation process.
# Has no zero finding functionality
struct AnalysisContinuationCallback{FType} <: AbstractContinuationCallback
Expand All @@ -76,6 +87,13 @@ function initialize!(cb::InternalRootSolveContinuationCallback, cache::PALCCache
cb.val_0 = cb.f(cache.u0, cache.λ0, cache, alg, prob)
return nothing
end
function initialize!(cb::TerminateContinuationCallbackSet, cache::PALCCache, alg, prob)
# initialize each callback in the set
@inbounds for i in eachindex(cb.callbacks)
initialize!(cb.callbacks[i], cache, alg, prob)
end
return nothing
end

# Callback update
update!(cb::Nothing, cache::PALCCache, alg, prob) = nothing
Expand All @@ -87,6 +105,11 @@ function update!(cb::InternalRootSolveContinuationCallback, cache::PALCCache, al
cb.val_0 = cb.f(cache.u0, cache.λ0, cache, alg, prob)
return nothing
end
function update!(cb::TerminateContinuationCallbackSet, cache::PALCCache, alg, prob)
@inbounds for i in eachindex(cb.callbacks)
update!(cb.callbacks[i], cache, alg, prob)
end
end

# Check the root solve callback (returns true if we stepped over zero)
check(cb::Nothing, uλ0, cache::PALCCache, alg, prob) = false
Expand Down Expand Up @@ -132,3 +155,8 @@ end
function handle_termination_callback(cb, cache, alg, p)
return cb
end
function handle_termination_callback(cb::TerminateContinuationCallbackSet{T}, cache, alg, p) where {T<:Tuple}
handle_cb_map_fun = cb-> handle_termination_callback(cb, cache, alg, p)
cb = TerminateContinuationCallbackSet(map(handle_cb_map_fun, cb.callbacks))
return cb
end
171 changes: 170 additions & 1 deletion src/PALC/correction.jl
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,164 @@ function palc_correction!(
return success, terminate_continuation
end

## Correction method if we are using a set of callbacks
# Only difference here is that instead of looking for 1 callback being triggered, we look for if any of the set is triggered sequentially, then perform RF.
# Note: should probably later improve this to cover the case that multiple callbacks trigger in a single step, but for now we are looking at them sequentially
function palc_correction!(
cache,
alg,
p::ContinuationProblem,
solvers,
dsmin,
dsmax,
term_callback::TerminateContinuationCallbackSet,
analysis_callback,
trace,
)
# Get cache variables
u0 = cache.u0
λ0 = cache.λ0
δu0 = cache.δu0
δλ0 = cache.δλ0
uλpred = cache.uλpred
n = length(δu0)

# Compute inner product of tangent with itself
dotδ = alg.inner_prod(δu0, δλ0)

# Get problem variables
λmin = p.λ_bounds[1]
λmax = p.λ_bounds[2]

# Solve nonlinear problem (reducing step-size if necessary)
attempts = 0
success = true
done = false
hit_bnd = NaN
cb_trig = false
rf_succ = false
triggered_callback = nothing # which (if any) callback was triggered
while !done
# Update attempts
attempts += 1

# Compute α for ds
α = cache.ds / dotδ

# Update uλpred (clamping ds to try and stay in λ bounds)
# If clamped, set hit_bnd to the bound hit and we'll resolve
# if successful with constant λ

uλpred[end] = λ0 + α * δλ0
if uλpred[end] < λmin
α = (λmin - λ0) / δλ0
cache.ds = α * dotδ
uλpred[end] = λmin
hit_bnd = λmin
print_correction_trace(cache, trace, 2)
elseif uλpred[end] > λmax
α = (λmax - λ0) / δλ0
cache.ds = α * dotδ
uλpred[end] = λmax
hit_bnd = λmax
print_correction_trace(cache, trace, 2)
else
print_correction_trace(cache, trace, 1)
end

uλpred[1:n] .= u0 .+ α .* δu0

# Solve the palc nonlinear problem
uλ, retcode = solve_palc_nlp!(solvers, uλpred, trace)

# Check if successful
if SciMLBase.successful_retcode(retcode)

# Check if callback triggered (iterating through in order)
for jj in eachindex(term_callback.callbacks)
cb_trig = check(term_callback.callbacks[jj], uλ, cache, alg, p)
if cb_trig # break once one is triggered
triggered_callback = jj
break
end
end

# If callback triggered, perform regula falsi root finding method and update uλ
if cb_trig
rf_succ = palc_target_callback_event!(
uλ, cache, alg, p, solvers, term_callback.callbacks[triggered_callback], trace
)
hit_bnd = NaN # Reset since we're likely not stepping as far and will recheck
end

# Check if we crossed the boundary
if uλ[end] < λmin
hit_bnd = λmin
elseif uλ[end] > λmax
hit_bnd = λmax
end

if cb_trig && !rf_succ # Triggered callback but rootfind was unsuccessful
cb_trig = false
hit_bnd = NaN
scale_and_clamp_ds!(cache, 0.5, dsmin, dsmax)
elseif isnan(hit_bnd)
# Push solution and set done
set_successful_iterate!(cache, uλ)
done = true

# Print trace if desired
print_correction_trace(cache, trace, 3)
else
# Update cache without pushing solution to curve
set_successful_iterate!(cache, uλ, false)

# Print trace if desired
print_correction_trace(cache, trace, 3)

# Target solution on boundary
flag = palc_target_solution_on_boundary!(cache, hit_bnd, solvers, trace)

# If targeting solution on boundary was successful, we're done. Otherwise, reduce ds
if flag
done = true
else
hit_bnd = NaN
scale_and_clamp_ds!(cache, 0.5, dsmin, dsmax)
end
end
else # solve is not successful
if abs(cache.ds) == dsmin
done = true
success = false
cache.ret = :MinimumStepSize # update ret with 'done' condition. This won't be overwritten since success=false (see continuation.jl)
else
# Reduce step-size and reattempt
scale_and_clamp_ds!(cache, 0.5, dsmin, dsmax)
end
end
end

# Update ds is we were successful
# Consider only updating is successful in < n number of attempts
success && scale_and_clamp_ds!(cache, 1.2, dsmin, dsmax)

# Update the callback if we were successfull
success && update!(term_callback, cache, alg, p)

# Call the analysis callback if we were successful
success && call!(analysis_callback, cache)

# Handle termination flag
terminate_continuation = !isnan(hit_bnd) || cb_trig # hit bound or triggered callback

if terminate_continuation
set_successful_retcode!(cache, hit_bnd, cb_trig, triggered_callback)
end

return success, terminate_continuation
end

function print_correction_trace(cache::PALCCache, trace::Silent, stage)
return nothing
end
Expand Down Expand Up @@ -400,8 +558,19 @@ function palc_correction_jacobian!(J, uλ, p)
end

function set_successful_retcode!(cache, hit_bnd, cb_trig)
# recover success condition and set
if isnan(hit_bnd) && cb_trig
cache.ret = :Callback
elseif !cb_trig
cache.ret = :HitBound
end
end

function set_successful_retcode!(cache, hit_bnd, cb_trig, jj)
# recover success condition and set
if isnan(hit_bnd) && cb_trig
cache.ret = :CallbackTermination
retstr = "Callback$jj"
cache.ret = Symbol(retstr)
elseif !cb_trig
cache.ret = :HitBound
end
Expand Down
7 changes: 7 additions & 0 deletions src/PALC/palc.jl
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,13 @@ mutable struct PALCCache{MT<:Union{Matrix{Float64},SparseMatrixCSC{Float64,Int}}

# Return code
ret::Symbol # could also do a custom type for pretty viewing like sciml but symbols should work just fine
# current implemented retcodes:
# :HitBound; successful, hit boundary
# :Callback; successful, terminated due to callback
# :Callbacki; successful, terminated due to callback i in callback set
# :Maxiters; unsuccessful (generally), terminated due to max iterations
# :MinimumStepSize; unsuccessful, terminated due to shrinking stepsize
# :None; this is what the retcode is initialized to, so seeing this after running continuation signifies an unhandled exit condition
end

function PALCCache(
Expand Down
53 changes: 53 additions & 0 deletions src/PALC/special_callbacks.jl
Original file line number Diff line number Diff line change
Expand Up @@ -46,3 +46,56 @@ function handle_termination_callback(cb::FoldBifurcationTerminationCallback, cac
fold_detection_callback_function, cache, alg, p; tol=cb.tol
)
end

# Incomplete - just testing. Remove later if not used
# This callback terminates the continuation process when the slope of the predicted tangent in a prescried dimension rises above a certain threshold.
# For example, I'm using this to detect rapid increases in the time of flight. This may prematurely trigger as we approach
# fold bifurcations (though in an initial test this did not happen, I think because the folds in κ-tf space are generally cusps rather than smooth)
# so a better implementation is needed. But it works for the problem I'm using it on right now, where I'm avoiding degeneracy
# due to a rapid increase in tf rather than a fold
struct MaxSlopeTerminationCallback <: RootSolveContinuationCallback
indexes::Vector{Int}
threshold::Float64
tol::Float64
function MaxSlopeTerminationCallback(indexes,threshold; tol=1e-12)
return new(indexes, threshold, tol)
end
end


function max_slope_callback_function(u, λ, cache, alg, prob, threshold, indexes)

# Calculate the bordered prediction in the same manner as the FoldBifurcationTerminationCallback
n = length(u)
eval_J!(cache.Jfun, cache.Ffun, u, λ, prob.f)
cache.bordered_mat[1:n, 1:(n + 1)] .= cache.Jfun
ddotdu1!(view(cache.bordered_mat, n + 1, 1:n), cache.δu0, alg.inner_prod)
cache.bordered_mat[n + 1, n + 1] = ddotdλ1(cache.δλ0, alg.inner_prod)

x = begin
ls_sol = solve(LinearProblem(cache.bordered_mat, cache.bordered_b), alg.linsolve)
ls_sol.u
end

# Scale the predicted tangent
#scale_predicted_tangent!(x, cache, alg)
xn = sqrt(alg.inner_prod(view(x, 1:n), x[n + 1]))
α = sign(alg.inner_prod(view(x, 1:n), cache.δu0, x[n + 1], cache.δλ0)) / xn
x .*= α

un = view(x, 1:n)
λn = view(x, n+1)

# Now check if the slope (norm) for the desired indexes is greater than some threshold
ret = norm(abs.(un[indexes])./λn)-threshold
return ret

end

# Handle slope termination callback
function handle_termination_callback(cb::MaxSlopeTerminationCallback, cache, alg, p)
fn(u, λ, cache, alg, prob) = max_slope_callback_function(u, λ, cache, alg, prob, cb.threshold, cb.indexes)
return InternalTerminateContinuationCallback(
fn, cache, alg, p; tol=cb.tol
)
end
2 changes: 1 addition & 1 deletion src/SimpleContinuation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ export Bordered, Secant
export ContinuationFunction, SparseContinuationFunction
export ContinuationProblem

export TerminateContinuationCallback, AnalysisContinuationCallback
export TerminateContinuationCallback, AnalysisContinuationCallback, TerminateContinuationCallbackSet
export FoldBifurcationTerminationCallback

export PALC
Expand Down
1 change: 1 addition & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,5 +5,6 @@ using SafeTestsets
@time begin
@time @safetestset "Functions" begin
include("test_functions.jl")
include("test_retcodes.jl")
end
end
Loading
Loading