diff --git a/docs/make.jl b/docs/make.jl index caf010d0b..58ea83467 100644 --- a/docs/make.jl +++ b/docs/make.jl @@ -53,6 +53,7 @@ makedocs( "Algebra" => [ "manual/solver.md", "manual/groebner.md", + "manual/partial_fractions.md", ], "Calculus" => [ diff --git a/docs/src/manual/ode.md b/docs/src/manual/ode.md index e7c38cfbb..53fac8f45 100644 --- a/docs/src/manual/ode.md +++ b/docs/src/manual/ode.md @@ -43,6 +43,16 @@ Symbolics.solve_symbolic_IVP Symbolics.solve_linear_ode_system ``` +### Laplace Transform + +The Laplace transform can be used to solve ODEs by transforming the whole equation, solving algebraically, then applying the inverse transform. The Laplace transform and inverse transform functionality is currently based on a rule table and applying linearity, so this method is limited in what expressions are able to be transformed and inverse transformed. + +```@docs +Symbolics.laplace +Symbolics.inverse_laplace +Symbolics.laplace_solve_ode +``` + ### SymPy ```@docs diff --git a/docs/src/manual/partial_fractions.md b/docs/src/manual/partial_fractions.md new file mode 100644 index 000000000..a7a95581b --- /dev/null +++ b/docs/src/manual/partial_fractions.md @@ -0,0 +1,9 @@ +# Partial Fraction Decomposition + +Partial fraction decomposition is performed using the cover-up method. This involves "covering up" a factor in the denominator and substituting the root into the remaining expression. When the denominator can be completely factored into non-repeated linear factors, this produces the desired result. When there are repeated or irreducible quadratic factors, it produces terms with unknown coefficients in the numerator that is solved as a system of equations. + +It is often used when solving integrals or performing an inverse Laplace transform (see [`inverse_laplace`](@ref)). + +```docs +Symbolics.partial_frac_decomposition +``` \ No newline at end of file diff --git a/src/Symbolics.jl b/src/Symbolics.jl index bfa4ce37c..f12ceb17f 100644 --- a/src/Symbolics.jl +++ b/src/Symbolics.jl @@ -175,6 +175,9 @@ include("operators.jl") include("limits.jl") export limit +include("partialfractions.jl") +export partial_frac_decomposition + # Hacks to make wrappers "nicer" const NumberTypes = Union{AbstractFloat,Integer,Complex{<:AbstractFloat},Complex{<:Integer}} (::Type{T})(x::SymbolicUtils.Symbolic) where {T<:NumberTypes} = throw(ArgumentError("Cannot convert Sym to $T since Sym is symbolic and $T is concrete. Use `substitute` to replace the symbolic unwraps.")) @@ -223,8 +226,9 @@ export symbolic_solve # Diff Eq Solver include("diffeqs/diffeqs.jl") include("diffeqs/systems.jl") +include("diffeqs/laplace.jl") include("diffeqs/diffeq_helpers.jl") -export SymbolicLinearODE, symbolic_solve_ode, solve_linear_ode_system, solve_symbolic_IVP +export SymbolicLinearODE, symbolic_solve_ode, solve_linear_ode_system, solve_symbolic_IVP, laplace, inverse_laplace, laplace_solve_ode # Sympy Functions diff --git a/src/diffeqs/diffeq_helpers.jl b/src/diffeqs/diffeq_helpers.jl index 26e34823a..3dabd98b5 100644 --- a/src/diffeqs/diffeq_helpers.jl +++ b/src/diffeqs/diffeq_helpers.jl @@ -12,25 +12,37 @@ function _get_der_order(expr, x, t) return maximum(_get_der_order.(factors(expr), Ref(x), Ref(t))) end - return _get_der_order(substitute(expr, Dict(Differential(t)(x) => x)), x, t) + 1 + return _get_der_order(fast_substitute(expr, Dict(Differential(t)(x) => x)), x, t) + 1 +end + +function reduce_rule(expr, Dt) + iscall(expr) && isequal(operation(expr), Dt) ? wrap(arguments(expr)[1]) : nothing +end + +""" + unwrap_der(expr, Dt) + +Helper function to unwrap derivatives of `f(t)` in `expr` with respect to the differential operator `Dt = Differential(t)`. Returns a tuple `(n, base_expr)`, where `n` is the order of the derivative and `base_expr` is the expression with the derivatives removed. If `expr` does not contain `f(t)` or its derivatives, returns `(0, expr)`. +""" +function unwrap_der(expr, Dt) + + if reduce_rule(unwrap(expr), Dt) === nothing + return 0, expr + end + + order, expr = unwrap_der(reduce_rule(unwrap(expr), Dt), Dt) + return order + 1, expr end # takes into account fractions function _true_factors(expr) - facs = factors(expr) - true_facs::Vector{Number} = [] - frac_rule = @rule (~x)/(~y) => [~x, 1/~y] - for fac in facs - frac = frac_rule(fac) - if frac !== nothing && !isequal(frac[1], 1) - append!(true_facs, _true_factors(frac[1])) - append!(true_facs, _true_factors(frac[2])) - else - push!(true_facs, fac) - end - end + expr = flatten_fractions(unwrap(expr)) # flatten nested fractions + + numerator_factors = SymbolicUtils.numerators(unwrap(expr)) + denominator_factors = SymbolicUtils.denominators(unwrap(expr)) - return convert(Vector{Num}, true_facs) + facs = filter(fac -> !isequal(fac, 1), [numerator_factors; 1 ./ denominator_factors]) + return isempty(facs) ? [1] : facs end """ @@ -45,7 +57,7 @@ function reduce_order(eq, x, t, ys) # reduction of order y_sub = Dict([[(Dt^i)(x) => ys[i+1] for i=0:n-1]; (Dt^n)(x) => variable(:𝒴)]) - eq = substitute(eq, y_sub) + eq = fast_substitute(eq, y_sub) # isolate (Dt^n)(x) f = symbolic_linear_solve(eq, variable(:𝒴), check=false) @@ -60,7 +72,7 @@ function unreduce_order(expr, x, t, ys) Dt = Differential(t) rev_y_sub = Dict(ys[i] => (Dt^(i-1))(x) for i in eachindex(ys)) - return substitute(expr, rev_y_sub) + return fast_substitute(expr, rev_y_sub) end function is_solution(solution, eq::Equation, x, t) @@ -82,15 +94,12 @@ function is_solution(solution, eq, x, t) end function _parse_trig(expr, t) - parse_sin = Symbolics.Chain([(@rule sin(t) => 1), (@rule sin(~x * t) => ~x)]) - parse_cos = Symbolics.Chain([(@rule cos(t) => 1), (@rule cos(~x * t) => ~x)]) - - if !isequal(parse_sin(expr), expr) - return parse_sin(expr), true + if iscall(expr) && isequal(operation(expr), sin) && any(isequal.(t, factors(arguments(expr)[1]))) + return arguments(expr)[1]/t, true end - if !isequal(parse_cos(expr), expr) - return parse_cos(expr), false + if iscall(expr) && isequal(operation(expr), cos) && any(isequal.(t, factors(arguments(expr)[1]))) + return arguments(expr)[1]/t, false end return nothing diff --git a/src/diffeqs/diffeqs.jl b/src/diffeqs/diffeqs.jl index 0308f87ec..523d7906b 100644 --- a/src/diffeqs/diffeqs.jl +++ b/src/diffeqs/diffeqs.jl @@ -59,7 +59,7 @@ function is_linear_ode(expr, x, t) @assert n >= 1 "ODE must have at least one derivative" y_sub = Dict([[(Dt^i)(x) => ys[i+1] for i=0:n-1]; (Dt^n)(x) => variable(:𝒴)]) - expr = substitute(expr, y_sub) + expr = fast_substitute(expr, y_sub) # isolate (Dt^n)(x) f = symbolic_linear_solve(expr, variable(:𝒴), check=false) @@ -362,7 +362,7 @@ function get_rrf_coeff(q, t) return a, r_re + r_im * im end - a = prod(filter(fac -> isempty(Symbolics.get_variables(fac, [t])), facs)) + a = prod([1; filter(fac -> isempty(Symbolics.get_variables(fac, [t])), facs)]) not_a = filter(fac -> !isempty(Symbolics.get_variables(fac, [t])), facs) # should just be e^(rt) if length(not_a) != 1 @@ -384,7 +384,7 @@ For finding particular solution when q(t) = a*e^(rt)*cos(bt) (or sin(bt)) function exp_trig_particular_solution(eq::SymbolicLinearODE) facs = _true_factors(eq.q) - a = prod(filter(fac -> isempty(Symbolics.get_variables(fac, [eq.t])), facs)) + a = prod([1; filter(fac -> isempty(Symbolics.get_variables(fac, [eq.t])), facs)]) not_a = filter(fac -> !isempty(Symbolics.get_variables(fac, [eq.t])), facs) @@ -415,12 +415,12 @@ function exp_trig_particular_solution(eq::SymbolicLinearODE) @variables π“ˆ p = characteristic_polynomial(eq, π“ˆ) Ds = Differential(π“ˆ) - while isequal(substitute(expand_derivatives((Ds^k)(p)), Dict(π“ˆ => r+b*im)), 0) + while isequal(fast_substitute(expand_derivatives((Ds^k)(p)), Dict(π“ˆ => r+b*im)), 0) k += 1 end rrf = expand(simplify(a * exp((r + b * im) * eq.t) * eq.t^k / - (substitute(expand_derivatives((Ds^k)(p)), Dict(π“ˆ => r+b*im))))) + (fast_substitute(expand_derivatives((Ds^k)(p)), Dict(π“ˆ => r+b*im))))) return is_sin ? imag(rrf) : real(rrf) end @@ -447,12 +447,12 @@ function resonant_response_formula(eq::SymbolicLinearODE) @variables π“ˆ p = characteristic_polynomial(eq, π“ˆ) Ds = Differential(π“ˆ) - while isequal(substitute(expand_derivatives((Ds^k)(p)), Dict(π“ˆ => r)), 0) + while isequal(fast_substitute(expand_derivatives((Ds^k)(p)), Dict(π“ˆ => r)), 0) k += 1 end return expand(simplify(a * exp(r * eq.t) * eq.t^k / - (substitute(expand_derivatives((Ds^k)(p)), Dict(π“ˆ => r))))) + (fast_substitute(expand_derivatives((Ds^k)(p)), Dict(π“ˆ => r))))) end function method_of_undetermined_coefficients(eq::SymbolicLinearODE) @@ -476,8 +476,8 @@ function method_of_undetermined_coefficients(eq::SymbolicLinearODE) coeff_solution = nothing end - if degree > 0 && coeff_solution !== nothing && !isempty(coeff_solution) && isequal(expand(substitute(eq_subbed, coeff_solution[1])), 0) - return substitute(form, coeff_solution[1]) + if degree > 0 && coeff_solution !== nothing && !isempty(coeff_solution) && isequal(expand(fast_substitute(eq_subbed, coeff_solution[1])), 0) + return fast_substitute(form, coeff_solution[1]) end # exponential @@ -487,13 +487,13 @@ function method_of_undetermined_coefficients(eq::SymbolicLinearODE) r = coeff[2] form = a_form*exp(r*eq.t) - eq_subbed = substitute(get_expression(eq), Dict(eq.x => form)) + eq_subbed = fast_substitute(get_expression(eq), Dict(eq.x => form)) eq_subbed = expand_derivatives(eq_subbed) eq_subbed = simplify(expand((eq_subbed.lhs - eq_subbed.rhs) / exp(r*eq.t))) coeff_solution = solve_interms_ofvar(eq_subbed, eq.t) if coeff_solution !== nothing && !isempty(coeff_solution) - return substitute(form, coeff_solution[1]) + return fast_substitute(form, coeff_solution[1]) end end @@ -505,9 +505,9 @@ function method_of_undetermined_coefficients(eq::SymbolicLinearODE) if parsed !== nothing Ο‰ = parsed[1] form = 𝒢*cos(Ο‰*eq.t) + 𝒷*sin(Ο‰*eq.t) - eq_subbed = substitute(get_expression(eq), Dict(eq.x => form)) + eq_subbed = fast_substitute(get_expression(eq), Dict(eq.x => form)) eq_subbed = expand_derivatives(eq_subbed) - eq_subbed = expand(substitute(eq_subbed.lhs - eq_subbed.rhs, Dict(cos(Ο‰*eq.t)=>π’Έπ“ˆ, sin(Ο‰*eq.t)=>π“ˆπ“ƒ))) + eq_subbed = expand(fast_substitute(eq_subbed.lhs - eq_subbed.rhs, Dict(cos(Ο‰*eq.t)=>π’Έπ“ˆ, sin(Ο‰*eq.t)=>π“ˆπ“ƒ))) cos_eq = simplify(sum(filter(term -> !isempty(Symbolics.get_variables(term, π’Έπ“ˆ)), terms(eq_subbed)))/π’Έπ“ˆ) sin_eq = simplify(sum(filter(term -> !isempty(Symbolics.get_variables(term, π“ˆπ“ƒ)), terms(eq_subbed)))/π“ˆπ“ƒ) if !isempty(Symbolics.get_variables(cos_eq, [eq.t,π“ˆπ“ƒ,π’Έπ“ˆ])) || !isempty(Symbolics.get_variables(sin_eq, [eq.t,π“ˆπ“ƒ,π’Έπ“ˆ])) @@ -517,7 +517,7 @@ function method_of_undetermined_coefficients(eq::SymbolicLinearODE) end if coeff_solution !== nothing && !isempty(coeff_solution) - return substitute(form, coeff_solution[1]) + return fast_substitute(form, coeff_solution[1]) end end end @@ -556,7 +556,7 @@ function solve_symbolic_IVP(ivp::IVP) push!(eqs, eq) end - return expand(simplify(substitute(general_solution, symbolic_solve(eqs, ivp.eq.C)[1]))) + return expand(simplify(fast_substitute(general_solution, symbolic_solve(eqs, ivp.eq.C)[1]))) end """ @@ -622,7 +622,7 @@ function solve_clairaut(expr, x, t) end C = Symbolics.variable(:C, 1) # constant of integration - f = substitute(f, Dict(Dt(x) => C)) + f = fast_substitute(f, Dict(Dt(x) => C)) if !isempty(Symbolics.get_variables(f, [x])) return nothing end @@ -649,7 +649,7 @@ function linearize_bernoulli(expr, x, t, v) for term in terms if Symbolics.hasderiv(Symbolics.value(term)) facs = _true_factors(term) - leading_coeff = prod(filter(fac -> !Symbolics.hasderiv(Symbolics.value(fac)), facs)) + leading_coeff = prod([1; filter(fac -> !Symbolics.hasderiv(Symbolics.value(fac)), facs)]) if !isequal(term//leading_coeff, Dt(x)) return nothing end @@ -661,10 +661,10 @@ function linearize_bernoulli(expr, x, t, v) end if isequal(x_fac[1], x) - p = prod(filter(fac -> isempty(Symbolics.get_variables(fac, [x])), facs)) + p = prod([1; filter(fac -> isempty(Symbolics.get_variables(fac, [x])), facs)]) else n = degree(x_fac[1]) - q = -prod(filter(fac -> isempty(Symbolics.get_variables(fac, [x])), facs)) + q = -prod([1; filter(fac -> isempty(Symbolics.get_variables(fac, [x])), facs)]) end end end diff --git a/src/diffeqs/laplace.jl b/src/diffeqs/laplace.jl new file mode 100644 index 000000000..c50fd359b --- /dev/null +++ b/src/diffeqs/laplace.jl @@ -0,0 +1,304 @@ +import DomainSets.ClosedInterval + +# from https://tutorial.math.lamar.edu/Classes/DE/Laplace_Table.aspx +transform_rules(f, t, F, s) = Symbolics.Chain([ + @rule 1 => 1/s + @rule exp(t) => 1/(s - 1) + @rule exp(~a * t) => 1/(-~a + s) + @rule t => 1/s^2 + @rule t^~n => factorial(~n)/s^(~n + 1) + @rule sqrt(t) => term(sqrt, pi)/(2 * s^(3/2)) + @rule sin(t) => 1/(1 + s^2) + @rule sin(~a * t) => ~a/((~a)^2 + s^2) + @rule cos(t) => s/(1 + s^2) + @rule cos(~a * t) => s/((~a)^2 + s^2) + @rule t*sin(t) => 1/(1 + s^2)^2 + @rule t*sin(~a * t) => 2*~a*s / ((~a)^2 + s^2)^2 + @rule t*cos(t) => (s^2 - 1) / (1 + s^2)^2 + @rule t*cos(~a * t) => (-(~a)^2 + s^2) / ((~a)^2 + s^2)^2 + @rule sin(t) - t*cos(t) => 2 / (1 + s^2)^2 + @rule sin(~a*t) - ~a*t*cos(~a*t) => 2*(~a)^3 / ((~a)^2 + s^2)^2 + @rule sin(t) + t*cos(t) => 2s^2 / (1 + s^2)^2 + @rule sin(~a*t) + ~a*t*cos(~a*t) => 2*~a*s^2 / ((~a)^2 + s^2)^2 + @rule cos(~a*t) - ~a*t*sin(~a*t) => s*((~a)^2 + s^2) / ((~a)^2 + s^2)^2 + @rule cos(~a*t) + ~a*t*sin(~a*t) => s*(s^2 + 3*(~a)^2) / ((~a)^2 + s^2)^2 + @rule sin(~b + ~a*t) => (s*sin(~b) + ~a*cos(~b)) / ((~a)^2 + s^2) + @rule cos(~b + ~a*t) => (s*cos(~b) - ~a*sin(~b)) / ((~a)^2 + s^2) + @rule sinh(~a * t) => ~a/(-(~a)^2 + s^2) + @rule cosh(~a * t) => s/(-(~a)^2 + s^2) + @rule exp(~a*t) * sin(~b * t) => ~b / ((~b)^2 + (-~a+s)^2) + @rule exp(~a*t) * cos(~b * t) => (-~a+s) / ((~b)^2 + (-~a+s)^2) + @rule exp(~a*t) * sinh(~b * t) => ~b / (-(~b)^2 + (-~a+s)^2) + @rule exp(~a*t) * cosh(~b * t) => (-~a+s) / (-(~b)^2 + (-~a+s)^2) + @rule t^~n * exp(~a * t) => factorial(~n) / (-~a + s)^(~n + 1) + @rule t*exp(~a * t) => 1 / (-~a + s)^(2) + @rule t^~n * exp(t) => factorial(~n) / (s)^(~n + 1) + @rule t*exp(t) => 1 / (s)^(2) + @rule exp(~c*t) * ~g => laplace(~g, f, t, F, s - ~c) # s-shift rule + @rule f(t)*t => -Differential(s)(F(s)) # s-derivative rule + @rule f(t)*t^(~n) => (-1)^(~n) * (Differential(s)^~n)(F(s)) # s-derivative rule + @rule f(~a + t) => exp(~a*s)*F(s) # t-shift rule + @rule f(t) => F(s) +]) + +""" + laplace(expr, f, t, F, s) + +Performs the Laplace transform of `expr` with respect to the variable `t`, where `f(t)` is a function in `expr` being transformed, and `F(s)` is the Laplace transform of `f(t)`. Returns the transformed expression in terms of `s`. + +Note that `f(t)` and `F(s)` should be defined using `@syms` + +Currently relies mostly on linearity and a rules table. When the rules table does not apply, it falls back to the integral definition of the Laplace transform. + +# Examples + +```jldoctest +julia> @variables t, s +2-element Vector{Num}: + t + s + +julia> @syms f(t)::Real F(s)::Real +(f, F) + +julia> laplace(exp(4t) + 5, f, t, F, s) +5 / s + 1 / (-4 + s) + +julia> laplace(10 + 4t - t^2, f, t, F, s) +10 / s + 4 / (s^2) + -2 / (s^3) + +julia> laplace(exp(-2t)*cos(3t) + 5exp(-2t)*sin(3t), f, t, F, s) +(2 + s) / (9 + (2 + s)^2) + 15 / (9 + (2 + s)^2) + +julia> laplace(t^2 * f(t), f, t, F, s) # s-derivative rule +Differential(s)(Differential(s)(F(s))) + +julia> laplace(5f(t-4), f, t, F, s) # t-shift rule +5F(s)*exp(-4s) + +julia> laplace(log(t), f, t, F, s) # fallback to definition +Integral(t, 0.0 .. Inf)(exp(-s*t)*log(t)) +``` +""" +function laplace(expr, f, t, F, s; rules=nothing) + expr = expand(expr) + + if isequal(expr, 0) + return 0 + end + + Dt = Differential(t) + + if rules === nothing + rules = transform_rules(f, t, F, s) + end + + transformed = rules(expr) + + # Check if transformation was successful + if !isequal(transformed, expr) + return transformed + end + + # t-derivative rule ((Dt^n)(f(t)) -> s^n*F(s) - s^(n-1)*f(0) - s^(n-2)*f'(0) - ... - f^(n-1)(0)) + n, expr = unwrap_der(expr, Dt) + if n != 0 && isequal(expr, f(t)) + f0 = Symbolics.variables(:𝒻0, 0:(n-1)) + transformed = s^n*F(s) + for i = 1:n + transformed -= s^(n-i)*f0[i] + end + + return transformed + end + + terms = Symbolics.terms(expr) + result = 0 + + # unable to apply linearity, so return based on definition + if length(terms) == 1 && length(filter(x->isempty(Symbolics.get_variables(x)), _true_factors(terms[1]))) == 0 + return Integral(t in ClosedInterval(0, Inf))(expr*exp(-s*t)) + end + + # apply linearity by splitting into terms and factoring out constants + for term in terms + factors = _true_factors(wrap(term)) + constant = filter(x -> isempty(Symbolics.get_variables(x)), factors) + if !isempty(constant) + result += laplace(term / constant[1], f, t, F, s, rules=rules) * constant[1] + else + result += laplace(term, f, t, F, s, rules=rules) + end + end + + return result +end + +function laplace(expr::Equation, f, t, F, s) + return laplace(expr.lhs, f, t, F, s) ~ laplace(expr.rhs, f, t, F, s) +end + +# postprocess_root prevents automatic evaluation of sqrt to its floating point value +function processed_sqrt(x) + return postprocess_root(term(sqrt, x)) +end + +# F and f aren't used here, but are here for future-proofing +inverse_transform_rules(F, s, f, t) = Symbolics.Chain([ + @rule 1/s => 1 + @rule 1/(~a + s) => exp(-~a * t) + @rule 1/s^(~n) => t^(~n-1) / factorial(~n-1) + @rule 1/(2 * s^(3/2)) => sqrt(t)/term(term(sqrt, pi)) + @rule 1/(~a + s^2) => sin(processed_sqrt(~a) * t)/processed_sqrt(~a) + @rule s/(~a + s^2) => cos(processed_sqrt(~a) * t) + @rule s / (~a + s^2)^2 => t*sin(processed_sqrt(~a) * t)/(2*processed_sqrt(~a)) + @rule (-~a + s^2) / (~a + s^2)^2 => t*cos(processed_sqrt(~a) * t) + @rule 1 / (~a + s^2)^2 => (sin(processed_sqrt(~a)*t) - processed_sqrt(~a)*t*cos(processed_sqrt(~a)*t))/ (2*processed_sqrt(~a)^3) + @rule s^2 / (~a + s^2)^2 => (sin(processed_sqrt(~a)*t) + processed_sqrt(~a)*t*cos(processed_sqrt(~a)*t)) / (2*processed_sqrt(~a)) + @rule s*(~a + s^2) / (~a + s^2)^2 => cos(processed_sqrt(~a)*t) - processed_sqrt(~a)*t*sin(processed_sqrt(~a)*t) + @rule s*(3*~a + s^2) / (~a + s^2)^2 => cos(processed_sqrt(~a)*t) + processed_sqrt(~a)*t*sin(processed_sqrt(~a)*t) + @rule (s*sin(~b) + ~a*cos(~b)) / (~a + s^2) => sin(~b + processed_sqrt(~a)*t) + @rule (s*cos(~b) - ~a*sin(~b)) / ((~a)^2 + s^2) => cos(~b + ~a*t) + @rule 1/(s^2 - (~b)^2) => sinh(~b * t)/~b + @rule s/(s^2 - (~b)^2) => cosh(~b * t) + @rule 1 / ((~c+s)^2 + (~b)^2) => exp(-~c*t) * sin(~b * t) / ~b + @rule (~c+s) / ((~c+s)^2 + (~b)^2) => exp(-~c*t) * cos(~b * t) + @rule 1 / ((~c+s)^2 - (~b)^2) => exp(-~c*t) * sinh(~b * t) / ~b + @rule (~c+s) / ((~c+s)^2 - (~b)^2) => exp(-~c*t) * cosh(~b * t) + @rule 1 / (~a + s)^(~n) => t^(~n-1) * exp(-~a * t) / factorial(~n-1) +]) + +""" + inverse_laplace(expr, F, s, f, t) + +Performs the inverse Laplace transform of `expr` with respect to the variable `s`, where `F(s)` is the Laplace transform of `f(t)`. Returns the transformed expression in terms of `t`. + +Note that `f(t)` and `F(s)` should be defined using `@syms`. + +Will perform partial fraction decomposition and linearity before applying the inverse Laplace transform rules. When unable to find a result, returns `nothing`. + +# Examples + +```jldoctest +julia> @variables t, s +2-element Vector{Num}: + t + s + +julia> @syms f(t)::Real F(s)::Real +(f, F) + +julia> inverse_laplace(7/(s+3)^3, F, s, f, t) +(7//2)*(t^2)*exp(-3t) + +julia> inverse_laplace((s+2)/(s^2 - 3s - 4), F, s, f, t) # using partial fraction decomposition +-(1//5)*exp(-t) + (6//5)*exp(4t) + +julia> inverse_laplace(1/s^4, F, s, f, t) +(1//6)*(t^3) +``` +""" +function inverse_laplace(expr, F, s, f, t; rules=nothing) + if isequal(expr, 0) + return 0 + end + + # check for partial fractions + partial_fractions = partial_frac_decomposition(expr, s) + if partial_fractions !== nothing && !isequal(partial_fractions, expr) + return inverse_laplace(partial_fractions, F, s, f, t) + end + + if rules === nothing + rules = inverse_transform_rules(F, s, f, t) + end + + transformed = rules(expr) + + # Check if transformation was successful + if !isequal(transformed, expr) + return transformed + end + + _terms = terms(numerator(expr)) ./ denominator(expr) + + result = 0 + if length(_terms) == 1 && length(filter(x -> isempty(get_variables(x)), _true_factors(_terms[1]))) == 0 + @warn "Inverse laplace failed: $expr" + return nothing # no result + end + + # apply linearity by splitting into terms and factoring out constants + for term in _terms + factors = _true_factors(term) + constant = filter(x -> isempty(Symbolics.get_variables(x)), factors) + if !isempty(constant) + result += inverse_laplace(term / constant[1], F, s, f, t, rules=rules) * constant[1] + else + result += inverse_laplace(term, F, s, f, t, rules=rules) + end + end + + return result +end + +function inverse_laplace(expr::Equation, F, s, f, t) + return inverse_laplace(expr.lhs, F, s, f, t) ~ inverse_laplace(expr.rhs, F, s, f, t) +end + +""" + laplace_solve_ode(eq, f, t, f0) + +Solves the ordinary differential equation `eq` for the function `f(t)` using the Laplace transform method. + +`f0` is a vector of initial conditions evaluated at `t=0` (`[f(0), f'(0), f''(0), ...]`, must be same length as order of `eq`). + +# Examples + +```jldoctest +@variables t, s +@syms f(t)::Real F(s)::Real + +Dt = Differential(t) + +julia> laplace_solve_ode(Dt(f(t)) + 3f(t) ~ t^2*exp(-3t) + t*exp(-2t) + t, f, t, [1]) +-(1//9) + (1//3)*t + (19//9)*exp(-3t) - exp(-2t) + t*exp(-2t) + (1//3)*(t^3)*exp(-3t) + +julia> laplace_solve_ode((Dt^2)(f(t)) + f(t) ~ 2 + 2cos(t), f, t, [0, 0]) +(2//1) - (2//1)*cos(t) + t*sin(t) + +julia> laplace_solve_ode((Dt^3)(f(t)) - Dt(f(t)) ~ 6 - 3t^2, f, t, [1, 1, 1]) +exp(t) + t^3 +``` +""" +function laplace_solve_ode(eq, f, t, f0) + s = variable(:π“ˆ) + @syms 𝓕(s) + + # transform equation + transformed_eq = laplace(eq, f, t, 𝓕, s) + # substitute in initial conditions + transformed_eq = fast_substitute(transformed_eq, Dict(𝓕(s) => variable(:𝓕), [variable(:𝒻0, i-1) => f0[i] for i=1:length(f0)]...)) + transformed_eq = expand(transformed_eq.lhs - transformed_eq.rhs) + + # solve for/isolate F(s) + F_terms = 0 + other_terms = [] + for term in terms(transformed_eq) + if isempty(get_variables(term, [variable(:𝓕)])) + push!(other_terms, -1*term) + else + F_terms += term/variable(:𝓕) # assumes term is something times F + end + end + + if isempty(other_terms) + other_terms = 0 + end + + # (a + b + ...)*F(s) = (c + d + ...) -> F(s) = (c + d + ...) / (a + b + ...) + transformed_soln = simplify(sum(other_terms ./ F_terms)) + + # perform inverse laplace transform to get f(t) + return expand(inverse_laplace(transformed_soln, 𝓕, s, f, t)) +end \ No newline at end of file diff --git a/src/diffeqs/systems.jl b/src/diffeqs/systems.jl index 3f586a875..7316cb03e 100644 --- a/src/diffeqs/systems.jl +++ b/src/diffeqs/systems.jl @@ -87,8 +87,8 @@ end Replacement for `LinearAlgebra.eigen` function that uses symbolic functions to avoid floating-point inaccuracies """ function symbolic_eigen(A::Matrix{<:Number}) - @variables Ξ» # eigenvalue - v = variables(:v, 1:size(A, 1)) # vector of subscripted variables to represent eigenvector + Ξ» = variable(:β„°) # eigenvalue + v = variables(:𝓋, 1:size(A, 1)) # vector of subscripted variables to represent eigenvector # find eigenvalues first p = det(Ξ»*I - A) ~ 0 # polynomial to solve @@ -99,7 +99,7 @@ function symbolic_eigen(A::Matrix{<:Number}) for value in values eqs = (value*I - A) * v# .~ zeros(size(A, 1)) # equations to give eigenvectors - eqs = substitute(eqs, Dict(v[1] => 1)) # set first element to 1 to constrain solution space + eqs = fast_substitute(eqs, Dict(v[1] => 1)) # set first element to 1 to constrain solution space sol = symbolic_solve(eqs[1:end-1], v[2:end]) # solve all but one equation (because of constraining solutions above) diff --git a/src/partialfractions.jl b/src/partialfractions.jl new file mode 100644 index 000000000..31ba00171 --- /dev/null +++ b/src/partialfractions.jl @@ -0,0 +1,175 @@ +# used to represent linear or irreducible quadratic factors +struct Factor + expr + root + multiplicity + x +end + +function Factor(expr, multiplicity, x) + fac_rule = @rule x + ~r => -~r + return Factor(expr, fac_rule(expr), multiplicity, x) +end + +function Base.isequal(a::Factor, b::Factor) + return isequal(a.expr, b.expr) && a.multiplicity == b.multiplicity +end + +# https://math.mit.edu/~hrm/18.031/pf-coverup.pdf +""" + partial_frac_decomposition(expr, x) + +Performs partial fraction decomposition for expressions with linear, repeated, or irreducible quadratic factors in the denominator. Can't currently handle irrational roots. + +When leading coefficient of the denominator is not 1, it will be factored out and then put back in at the end, often leading to non-integer coefficients in the result. Will return `nothing` if the expression is not a valid polynomial fraction, or if it has irrational roots. + +# Examples + +```jldoctest +julia> @variables x +1-element Vector{Num}: + x + +julia> partial_frac_decomposition((3x-1) / (x^2 + x - 6), x) +(1//1) / (-2 + x) + (2//1) / (3 + x) + +julia> partial_frac_decomposition((4x^3 + 16x + 7)/(x^2 + 4)^2, x) # repeated irreducible quadratic factor +(4x) / (4 + x^2) + 7 / ((4 + x^2)^2) + +julia> partial_frac_decomposition((4x^2 - 22x + 7)/((2x+3)*(x-2)^2), x) # non-one leading coefficient +(-3//1) / ((-2 + x)^2) + (2//1) / ((3//2) + x) +``` + +!!! note that irreducible quadratic and repeated linear factors require the `Groebner` package to solve a system of equations +""" +function partial_frac_decomposition(expr, x) + A, B = numerator(expr), denominator(expr) + + # check if both numerator and denominator are polynomials + if !isequal(polynomial_coeffs(A, [x])[2], 0) || !isequal(polynomial_coeffs(B, [x])[2], 0) + return nothing + end + + if degree(A) >= degree(B) + return nothing + end + + facs = factorize(B, x) + if facs === nothing + return nothing + end + + leading_coeff = coeff_vector(expand(B), x)[end] # of denominator + + # already in partial fraction form + if length(facs) == 1 && only(facs).multiplicity == 1 && degree(A) <= 1 + return expr + end + + result = [] + c_idx = 0 # index to keep track of which C subscript to use + if length(facs) == 1 + fac = only(facs) + + if fac.root === nothing # irreducible quadratic factor + for i = 1:fac.multiplicity + push!(result, (variable(:π’ž, c_idx+=1)*x + variable(:π’ž, c_idx+=1))/(fac.expr^i)) # (Ax + B)/(x-r)^i + end + else + append!(result, variables(:π’ž, (c_idx+1):(c_idx+=fac.multiplicity)) ./ fac.expr.^(1:fac.multiplicity)) # C1/(x-r) + C2/(x-2)^2 ... + end + else + for fac in facs + if fac.root === nothing # irreducible quadratic factor + for i = 1:fac.multiplicity + push!(result, (variable(:π’ž, c_idx+=1)*x + variable(:π’ž, c_idx+=1))/(fac.expr^i)) # (Ax + B)/(x-r)^i + end + continue + end + + # cover up method + other_facs = filter(f -> !isequal(f, fac), facs) + + numerator = rationalize(unwrap(fast_substitute(A / prod((f -> f.expr^f.multiplicity).(other_facs)), Dict(x => fac.root)))) # plug in root to expression without its factor in denominator + push!(result, numerator / fac.expr^fac.multiplicity) + + if fac.multiplicity > 1 + append!(result, variables(:π’ž, (c_idx+1):(c_idx+=fac.multiplicity-1)) ./ fac.expr.^(1:fac.multiplicity-1)) # C1/(x-r) + C2/(x-2)^2 ... + end + end + end + + # no unknowns, so just return + if isequal(get_variables(sum(result)), [x]) + return sum(result ./ leading_coeff) + end + + lhs = numerator(expr) + rhs = expand(sum(simplify.(numerator.(result) .* ((B/leading_coeff) ./ denominator.(result))))) # multiply each numerator by the common denominator/its denominator, and sum to get numerator of whole expression + + solution = solve_interms_ofvar(lhs - rhs, x)[1] # solve for unknowns (C's) by looking at coefficients of the polynomial + + # single unknown + if !(solution isa Dict) + solution = Dict(variable(:π’ž, 1) => solution) + end + + return sum(fast_substitute.(result, Ref(solution)) ./ leading_coeff) # fast_substitute solutions back in and sum +end + +# increasing from 0 to degree n. doesn't skip powers of x like polynomial_coeffs +function coeff_vector(poly, x, n) + coeff_dict = polynomial_coeffs(poly, [x])[1] + vec = [] + for i = 0:n + if x^i in keys(coeff_dict) + push!(vec, coeff_dict[x^i]) + else + push!(vec, 0) + end + end + + return vec +end + +# increasing from 0 to degree of poly +function coeff_vector(poly, x) + return coeff_vector(poly, x, degree(poly)) +end + +function count_multiplicities(facs) + counts = Dict() + for fac in facs + if haskey(counts, fac) + counts[fac] += 1 + else + counts[fac] = 1 + end + end + + return counts +end + +# for partial fractions, into linear and irreducible quadratic factors +function factorize(expr, x) + roots = symbolic_solve(expr, x, dropmultiplicity=false) + + counts = count_multiplicities(roots) + facs = Set() + + for root in keys(counts) + if !isequal(abs(imag(root)), 0) + fac_expr = expand((x - root)*(x - conj(root))) + if !isequal(imag(fac_expr), 0) + @warn "Encountered issue with complex irrational roots. Returning nothing." + return nothing + end + push!(facs, Factor(real(fac_expr), counts[root], x)) + continue + end + + push!(facs, Factor(x - root, root, counts[root], x)) + end + + return facs +end \ No newline at end of file diff --git a/test/laplace.jl b/test/laplace.jl new file mode 100644 index 000000000..798e45a07 --- /dev/null +++ b/test/laplace.jl @@ -0,0 +1,37 @@ +using Test +using Symbolics +using Symbolics: laplace, inverse_laplace +import Nemo, Groebner + +@variables t, s +@syms f(t)::Real F(s)::Real + +# https://sites.math.washington.edu/~aloveles/Math307Fall2019/m307LaplacePractice.pdf +@test isequal(laplace(exp(4t) + 5, f, t, F, s), 1/(s-4) + 5/s) +@test isequal(laplace(cos(2t) + 7sin(2t), f, t, F, s), s/(s^2 + 4) + 14/(s^2 + 4)) +@test isequal(laplace(exp(-2t)*cos(3t) + 5exp(-2t)*sin(3t), f, t, F, s), (s+2)/((s+2)^2 + 9) + 15/((s+2)^2 + 9)) +@test isequal(laplace(10 + 5t + t^2 - 4t^3, f, t, F, s), expand(10/s + 5/s^2 + 2/s^3 - 24/s^4)) +@test isequal(laplace(exp(3t)*(t^2 + 4t + 2), f, t, F, s), 2/(s-3)^3 + 4/(s-3)^2 + 2/(s-3)) +@test isequal(laplace(6exp(5t)*cos(2t) - exp(7t), f, t, F, s), 6(s-5)/((s-5)^2 + 4) + expand(-1/(s-7))) + +# https://www.math.lsu.edu/~adkins/m2065/2065s08review2a.pdf +@test isequal(inverse_laplace(7/(s+3)^3, F, s, f, t), (7//2)t^2 * exp(-3t)) +@test isequal(inverse_laplace((s-9)/(s^2 + 9), F, s, f, t), cos(3t) - 3sin(3t)) +# partial fraction decomposition +@test isequal(inverse_laplace((s+2)/(s^2 - 3s - 4), F, s, f, t), (6//5)*exp(4t) - (1//5)*exp(-t)) +@test isequal(inverse_laplace(1/(s^2 - 10s + 9), F, s, f, t), (1//8)*exp(9t) - (1//8)*exp(t)) + +Dt = Differential(t) +@test isequal(laplace_solve_ode(Dt(f(t)) + 3f(t) ~ t^2*exp(-3t) + t*exp(-2t) + t, f, t, [1]), (1//3)*t^3*exp(-3t) + t*exp(-2t) + (1//3)*t + (19//9)*exp(-3t) - exp(-2t) - 1//9) +@test isequal(laplace_solve_ode((Dt^2)(f(t)) - 3Dt(f(t)) + 2f(t) ~ 4, f, t, [2, 3]), 2 - 3exp(t) + 3exp(2t)) +@test isequal(laplace_solve_ode((Dt^3)(f(t)) - Dt(f(t)) ~ 2, f, t, [4,4,4]), 5exp(t) - exp(-t) - 2t) +@test isequal(laplace_solve_ode((Dt^3)(f(t)) - Dt(f(t)) ~ 6 - 3t^2, f, t, [1, 1, 1]), exp(t) + t^3) +@test isequal(laplace_solve_ode((Dt^2)(f(t)) - f(t) ~ 2sin(t), f, t, [0, 0]), (1//2)exp(t) - (1//2)exp(-t) - sin(t)) +@test isequal(laplace_solve_ode((Dt^2)(f(t)) + 2Dt(f(t)) ~ 5f(t), f, t, [0, 0]), 0) +@test isequal(laplace_solve_ode((Dt^2)(f(t)) + f(t) ~ sin(4t), f, t, [0, 0]), (4//15)sin(t) - (1//15)sin(4t)) +@test isequal(laplace_solve_ode((Dt^2)(f(t)) + Dt(f(t)) ~ 1 + 2t, f, t, [0, 0]), 1 - exp(-t) + t^2 - t) +@test isequal(laplace_solve_ode((Dt^2)(f(t)) + 4Dt(f(t)) + 3f(t) ~ 6, f, t, [0, 0]), exp(-3t) - 3exp(-t) + 2) +@test isequal(laplace_solve_ode((Dt^2)(f(t)) - 2Dt(f(t)) ~ 3*(t + exp(2t)), f, t, [0, 0]), (3//8) - (3//4)t - (3//4)t^2 - (3//8)exp(2t) + (3//2)t*exp(2t)) +@test_broken isequal(laplace_solve_ode((Dt^2)(f(t)) - 2Dt(f(t)) ~ 20*exp(-t)*cos(t), f, t, [0, 0]), 3exp(2t) - 5 + 2exp(-t)*cos(t) - 4exp(-t)*sin(t)) # irreducible quadratic in inverse laplace +@test isequal(laplace_solve_ode((Dt^2)(f(t)) + f(t) ~ 2 + 2cos(t), f, t, [0, 0]), 2 - 2cos(t) + t*sin(t)) +@test isequal(laplace_solve_ode((Dt^2)(f(t)) - Dt(f(t)) ~ 30cos(3t), f, t, [0, 0]), 3exp(t) - 3cos(3t) - sin(3t)) \ No newline at end of file diff --git a/test/partialfractions.jl b/test/partialfractions.jl new file mode 100644 index 000000000..d6c75e814 --- /dev/null +++ b/test/partialfractions.jl @@ -0,0 +1,26 @@ +using Test +using Symbolics +import Nemo, Groebner +import Symbolics: partial_frac_decomposition + +@variables x + +# https://en.neurochispas.com/algebra/4-types-of-partial-fractions-decomposition-with-examples/ +@test isequal(partial_frac_decomposition((3x-1) / (x^2 + x - 6), x), 2/(x+3) + 1/(x-2)) +@test isequal(partial_frac_decomposition((9x^2 + 34x + 14) / ((x+2)*(x^2 - x - 12)), x), expand(3/(x+2) + 7/(x-4) - 1/(x+3))) + +# https://tutorial.math.lamar.edu/Problems/Alg/PartialFractions.aspx +@test isequal(partial_frac_decomposition((17x-53)/(x^2 - 2x - 15), x), expand(4/(x-5) + 13/(x+3))) +@test isequal(partial_frac_decomposition((34-12x)/(3x^2 - 10x - 8), x), (-3)/(2//3 + x) + -1/(-4 + x)) +@test isequal(partial_frac_decomposition((125 + 4x - 9x^2)/((x-1)*(x+3)*(x+4)), x), expand(6/(x-1) - 8/(x+3) - 7/(x+4))) +@test isequal(partial_frac_decomposition((10x+35)/((x+4)^2), x), 10/(x+4) + -5/(x+4)^2) +@test isequal(partial_frac_decomposition((6x+5)/((2x-1)^2), x), (3//2)/(x-1//2) + 2/(x-1//2)^2) +@test isequal(partial_frac_decomposition((7x^2-17x+38)/((x+6)*(x-1)^2), x), 8/(x+6) + -1/(x-1) + 4/(x-1)^2) +@test isequal(partial_frac_decomposition((4x^2 - 22x + 7)/((2x+3)*(x-2)^2), x), 2/(x+3//2) + -3/(x-2)^2) +@test_broken isequal(partial_frac_decomposition((3x^2 + 7x + 28)/(x*(x^2 + x + 7)), x), expand(4/x + (3-x)/(x^2+x+7))) # irrational roots +@test isequal(partial_frac_decomposition((4x^3 + 16x + 7)/(x^2 + 4)^2, x), 4x/(x^2+4) + 7/(x^2+4)^2) + +# check valid expressions +@test partial_frac_decomposition(sin(x), x) === nothing +@test partial_frac_decomposition(x^2/(x-1), x) === nothing +@test partial_frac_decomposition(1/(x^2 + 2), x) === nothing # irrational roots, should eventually be fixed \ No newline at end of file diff --git a/test/runtests.jl b/test/runtests.jl index d5725fe40..9e7271716 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -71,6 +71,8 @@ if GROUP == "All" || GROUP == "Core" @safetestset "Taylor Series Test" begin include("taylor.jl") end @safetestset "Discontinuity registration test" begin include("discontinuities.jl") end @safetestset "ODE solver test" begin include("diffeqs.jl") end + @safetestset "Laplace transform test" begin include("laplace.jl") end + @safetestset "Partial Fraction Decomposition Test" begin include("partialfractions.jl") end end end diff --git a/test/sympy.jl b/test/sympy.jl index 3617b8520..bd5d873b4 100644 --- a/test/sympy.jl +++ b/test/sympy.jl @@ -85,11 +85,11 @@ canonical_sol_ode = Symbolics.substitute(sol_ode, Dict(const_sym => C1)) Dt = Symbolics.Differential(t) C = Symbolics.variables(:C, 1:5) -@test isequal(symbolic_solve_ode(LinearODE(x, t, [5/t], 7t)), Symbolics.sympy_simplify(C[1]*t^(-5) + t^2)) -@test isequal(symbolic_solve_ode(LinearODE(x, t, [cos(t)], cos(t))), 1 + C[1]*exp(-sin(t))) -@test isequal(symbolic_solve_ode(LinearODE(x, t, [-(1+t)], 1+t)), Symbolics.expand(Symbolics.sympy_simplify(C[1]*exp((1//2)t^2 + t) - 1))) +@test isequal(symbolic_solve_ode(SymbolicLinearODE(x, t, [5/t], 7t)), Symbolics.sympy_simplify(C[1]*t^(-5) + t^2)) +@test isequal(symbolic_solve_ode(SymbolicLinearODE(x, t, [cos(t)], cos(t))), 1 + C[1]*exp(-sin(t))) +@test isequal(symbolic_solve_ode(SymbolicLinearODE(x, t, [-(1+t)], 1+t)), Symbolics.expand(Symbolics.sympy_simplify(C[1]*exp((1//2)t^2 + t) - 1))) # SymPy is being weird and not simplifying correctly (and some symbols are wrong, like pi and erf being syms), but these otherwise work -@test_broken isequal(symbolic_solve_ode(LinearODE(x, t, [-2t], 1)), Symbolics.sympy_simplify(exp(t^2)*sqrt(Symbolics.variable(:pi))*erf(t)/2 + C[1]*exp(t^2))) +@test_broken isequal(symbolic_solve_ode(SymbolicLinearODE(x, t, [-2t], 1)), Symbolics.sympy_simplify(exp(t^2)*sqrt(Symbolics.variable(:pi))*erf(t)/2 + C[1]*exp(t^2))) ## Bernoulli equations @test isequal(symbolic_solve_ode(Dt(x) + (4//t)*x ~ t^3 * x^2, x, t), 1/(C[1]t^4 - t^4 * log(t)))