diff --git a/src/Symbolics.jl b/src/Symbolics.jl index 32f94cf7e..c34c33d40 100644 --- a/src/Symbolics.jl +++ b/src/Symbolics.jl @@ -210,7 +210,7 @@ include("solver/polynomialization.jl") include("solver/attract.jl") include("solver/ia_main.jl") include("solver/main.jl") -include("solver/ia_rules.jl") +include("solver/special_cases.jl") export symbolic_solve function symbolics_to_sympy end diff --git a/src/solver/ia_main.jl b/src/solver/ia_main.jl index 6f9329cce..7d00e06a1 100644 --- a/src/solver/ia_main.jl +++ b/src/solver/ia_main.jl @@ -8,10 +8,11 @@ function isolate(lhs, var; warns=true, conditions=[], complex_roots = true, peri lhs = unwrap(lhs) old_lhs = nothing + while !isequal(lhs, var) subs, poly = filter_poly(lhs, var) - if check_poly_inunivar(poly, var) + if check_polynomial(poly, strict=false) roots = [] new_var = gensym() new_var = (@variables $new_var)[1] @@ -20,7 +21,7 @@ function isolate(lhs, var; warns=true, conditions=[], complex_roots = true, peri else a, b, islin = linear_expansion(lhs - new_var, var) if islin - lhs_roots = [-b / a] + lhs_roots = [-b // a] else lhs_roots = [RootsOf(lhs - new_var, var)] if warns @@ -31,7 +32,12 @@ function isolate(lhs, var; warns=true, conditions=[], complex_roots = true, peri for i in eachindex(lhs_roots) for j in eachindex(rhs) - push!(roots, substitute(lhs_roots[i], Dict(new_var=>rhs[j]), fold=false)) + if iscall(lhs_roots[i]) && operation(lhs_roots[i]) == RootsOf + lhs_roots[i].arguments[1] = substitute(lhs_roots[i].arguments[1], Dict(new_var=>rhs[j]), fold=false) + push!(roots, lhs_roots[i]) + else + push!(roots, substitute(lhs_roots[i], Dict(new_var=>rhs[j]), fold=false)) + end end end return roots, conditions @@ -39,7 +45,7 @@ function isolate(lhs, var; warns=true, conditions=[], complex_roots = true, peri if isequal(old_lhs, lhs) warns && @warn("This expression cannot be solved with the methods available to ia_solve. Try a numerical method instead.") - return nothing + return nothing, conditions end old_lhs = deepcopy(lhs) @@ -76,7 +82,7 @@ function isolate(lhs, var; warns=true, conditions=[], complex_roots = true, peri else # 2 / x = y lhs = args[2] - rhs = map(sol -> args[1] // sol, rhs) + rhs = map(sol -> term(/, args[1], sol), rhs) end elseif oper === (^) @@ -108,6 +114,7 @@ function isolate(lhs, var; warns=true, conditions=[], complex_roots = true, peri elseif any(isequal(x, var) for x in get_variables(args[1])) && n_occurrences(args[2], var) == 0 lhs = args[1] + s, args[2] = filter_stuff(args[2]) rhs = map(sol -> term(^, sol, 1 // args[2]), rhs) else lhs = args[2] @@ -169,7 +176,7 @@ function attract(lhs, var; warns = true, complex_roots = true, periodic_roots = return nothing, conditions end end - + new_var = collect(keys(sub))[1] new_var_val = collect(values(sub))[1] @@ -178,6 +185,7 @@ function attract(lhs, var; warns = true, complex_roots = true, periodic_roots = new_roots = [] for root in roots + iscall(root) && operation(root) == RootsOf && continue new_sol, new_conds = isolate(new_var_val - root, var; warns = warns, complex_roots, periodic_roots) append!(conditions, new_conds) push!(new_roots, new_sol) @@ -273,9 +281,9 @@ function ia_solve(lhs, var; warns = true, complex_roots = true, periodic_roots = conditions = [] if nx == 0 warns && @warn("Var not present in given expression") - return [] + return nothing elseif nx == 1 - sols, conditions = isolate(lhs, var; warns = warns, complex_roots, periodic_roots) + sols, conditions = isolate(lhs, var; warns = warns, complex_roots, periodic_roots) elseif nx > 1 sols, conditions = attract(lhs, var; warns = warns, complex_roots, periodic_roots) end diff --git a/src/solver/main.jl b/src/solver/main.jl index a1c523aea..9082a9518 100644 --- a/src/solver/main.jl +++ b/src/solver/main.jl @@ -173,10 +173,6 @@ function symbolic_solve(expr, x::T; dropmultiplicity = true, warns = true) where expr = Vector{Num}(expr) end - if expr_univar && !x_univar - expr = [expr] - expr_univar = false - end if !expr_univar && x_univar x = [x] x_univar = false @@ -189,8 +185,17 @@ function symbolic_solve(expr, x::T; dropmultiplicity = true, warns = true) where isequal(sols, nothing) && return nothing sols = map(postprocess_root, sols) return sols + elseif expr_univar + all_vars = get_variables(expr) + diff_vars = setdiff(wrap.(all_vars), x) + if length(diff_vars) == 1 + return solve_interms_ofvar(expr, diff_vars[1], dropmultiplicity=dropmultiplicity, warns=warns) + end + + expr = [expr] end + if !x_univar for e in expr for var in x @@ -247,6 +252,7 @@ function symbolic_solve(expr; x...) return symbolic_solve(expr, vars; x...) end + """ solve_univar(expression, x; dropmultiplicity=true) This solver uses analytic solutions up to degree 4 to solve univariate polynomials. @@ -266,10 +272,12 @@ implemented in the function `get_roots` and its children. - dropmultiplicity (optional): Print repeated roots or not? +- strict (optional): Bool that enables/disables strict assert if input expression is a univariate polynomial or not. If strict=true and expression is not a polynomial, `solve_univar` throws an assertion error. + # Examples """ -function solve_univar(expression, x; dropmultiplicity=true) +function solve_univar(expression, x; dropmultiplicity=true, strict=true) args = [] mult_n = 1 expression = unwrap(expression) @@ -287,6 +295,9 @@ function solve_univar(expression, x; dropmultiplicity=true) end subs, filtered_expr, assumptions = filter_poly(expression, x, assumptions=true) + if !strict && !check_polynomial(filtered_expr, strict=false) + return [RootsOf(wrap(expression), wrap(x))] + end coeffs, constant = polynomial_coeffs(filtered_expr, [x]) degree = sdegree(coeffs, x) @@ -325,7 +336,6 @@ function solve_univar(expression, x; dropmultiplicity=true) end if isequal(arr_roots, []) - @assert check_polynomial(expression) "This expression could not be solved by `symbolic_solve`." return [RootsOf(wrap(expression), wrap(x))] end diff --git a/src/solver/nemo_stuff.jl b/src/solver/nemo_stuff.jl index 52d3da578..4cbef51b2 100644 --- a/src/solver/nemo_stuff.jl +++ b/src/solver/nemo_stuff.jl @@ -1,12 +1,16 @@ # Checks that the expression is a polynomial with integer or rational # coefficients -function check_polynomial(poly) +function check_polynomial(poly; strict=true) poly = wrap(poly) vars = get_variables(poly) distr, rem = polynomial_coeffs(poly, vars) - @assert isequal(rem, 0) "Not a polynomial" - @assert all(c -> c isa Integer || c isa Rational, collect(values(distr))) "Coefficients must be integer or rational" - return true + if strict + @assert isequal(rem, 0) "Not a polynomial" + @assert all(c -> c isa Integer || c isa Rational, collect(values(distr))) "Coefficients must be integer or rational" + return true + else + return isequal(rem, 0) + end end # factor(x^2*y + b*x*y - a*x - a*b) -> (x*y - a)*(x + b) diff --git a/src/solver/postprocess.jl b/src/solver/postprocess.jl index 4764690aa..b4f12f749 100644 --- a/src/solver/postprocess.jl +++ b/src/solver/postprocess.jl @@ -43,14 +43,26 @@ function _postprocess_root(x::SymbolicUtils.BasicSymbolic) end end + args = arguments(x) + # (X)^0 => 1 - if oper === (^) && isequal(arguments(x)[2], 0) + if oper === (^) && isequal(args[2], 0) && !isequal(args[1], 0) return 1 end # (X)^1 => X - if oper === (^) && isequal(arguments(x)[2], 1) - return arguments(x)[1] + if oper === (^) && isequal(args[2], 1) + return args[1] + end + + # (0)^X => 0 + if oper === (^) && isequal(args[1], 0) && !isequal(args[2], 0) + return 0 + end + + # y / 0 => Inf + if oper === (/) && !isequal(args[1], 0) && isequal(args[2], 0) + return Inf end # sqrt((N / D)^2 * M) => N / D * sqrt(M) diff --git a/src/solver/preprocess.jl b/src/solver/preprocess.jl index 8b0dc5f81..7d6a21c6e 100644 --- a/src/solver/preprocess.jl +++ b/src/solver/preprocess.jl @@ -44,13 +44,10 @@ function clean_f(filtered_expr, var, subs) if oper === (/) args = arguments(unwrapped_f) - if any(isequal(var, x) for x in get_variables(args[2])) - filtered_expr = expand(args[1] * args[2]) + if !all(isequal(var, x) for x in get_variables(args[2])) + filtered_expr = args[1] push!(assumptions, substitute(args[2], subs, fold=false)) - return filtered_expr, assumptions end - filtered_expr = args[1] - @info "Assuming $(substitute(args[2], subs, fold=false) != 0)" end return filtered_expr, assumptions end diff --git a/src/solver/ia_rules.jl b/src/solver/special_cases.jl similarity index 55% rename from src/solver/ia_rules.jl rename to src/solver/special_cases.jl index 33351ae64..cb5360d66 100644 --- a/src/solver/ia_rules.jl +++ b/src/solver/special_cases.jl @@ -37,7 +37,47 @@ function cross_multiply(eq) return cross_multiply(eq) end end +""" + solve_interms_ofvar(eq, s; dropmultiplicity=true, warns=true) +This special case solver expects a single equation in multiple variables and a +variable `s` (this can be any Num, `s` is used for convenience). The function generates +a system of equations to by observing the coefficients of the powers of `s` present in `eq`. +E.g. a system would look like `a+b = 1`, `a-2b = 3` for the eq `(a+b)s + (a-2b)s^2 - (1)s - (3)s^2 = 0`. +After generating this system, it calls `symbolic_solve`, which uses `solve_multivar`. `symbolic_solve` was chosen +instead of `solve_multivar` because it postprocesses the roots in order to simplify them and make them more user friendly. +Generation of system uses cross multiplication in order to simplify the equation and convert it +to a polynomial like shape. + + +# Arguments +- eq: Single symbolics Num or SymbolicUtils.BasicSymbolic. This is equated to 0 and then solved. E.g. `expr = x+2`, we solve `x+2 = 0` + +- s: Variable to "isolate", i.e. ignore and generate the system of equations based on this variable's coefficients. + +- dropmultiplicity (optional): Print repeated roots or not? + +- warns (optional, this is not used currently): Warn user when something is wrong or not. + +# Examples +```jldoctest +julia> @variables a b x s; + +julia> eq = (a*x^2+b)*s^2 - 2s^2 + 2*b*s - 3*s + 2(x^2)*(s^3) + 10*s^3; + +julia> Symbolics.solve_interms_ofvar(eq, s) +2-element Vector{Any}: + Dict{Num, Any}(a => -1//10, b => 3//2, x => (0 - 1im)*√(5)) + Dict{Num, Any}(a => -1//10, b => 3//2, x => (0 + 1im)*√(5)) +``` +```jldoctest +julia> eq = ((s^2 + 1)/(s^2 + 2*s + 1)) - ((s^2 + a)/(b*c*s^2 + (b+c)*s + d)); + +julia> Symbolics.solve_interms_ofvar(eq, s) +1-element Vector{Any}: + Dict{Num, Any}(a => 1, d => 1, b => 1, c => 1) +``` +""" function solve_interms_ofvar(eq, s; dropmultiplicity=true, warns=true) @assert iscall(unwrap(eq)) vars = Symbolics.get_variables(eq) diff --git a/test/solver.jl b/test/solver.jl index b366f7d57..7dd92195a 100644 --- a/test/solver.jl +++ b/test/solver.jl @@ -5,7 +5,6 @@ import Symbolics: ssqrt, slog, scbrt, symbolic_solve, ia_solve, postprocess_root @test Base.get_extension(Symbolics, :SymbolicsNemoExt) === nothing @variables x roots = ia_solve(log(2 + x), x) - @test substitute(roots[1], Dict()) == -1.0 roots = @test_warn ["Nemo", "required"] ia_solve(log(2 + x^2), x) @test operation(roots[1]) == Symbolics.RootsOf end @@ -69,23 +68,28 @@ end @testset "Solving in terms of a constant var" begin eq = ((s^2 + 1)/(s^2 + 2*s + 1)) - ((s^2 + a)/(b*c*s^2 + (b+c)*s + d)) calcd_roots = sort_arr(Symbolics.solve_interms_ofvar(eq, s), [a,b,c,d]) + solve_roots = sort_arr(symbolic_solve(eq, [a,b,c,d]), [a,b,c,d]) known_roots = sort_arr([Dict(a=>1, b=>1, c=>1, d=>1)], [a,b,c,d]) @test check_approx(calcd_roots, known_roots) + @test check_approx(solve_roots, known_roots) eq = (a+b)*s^2 - 2s^2 + 2*b*s - 3*s calcd_roots = sort_arr(Symbolics.solve_interms_ofvar(eq, s), [a,b]) + solve_roots = sort_arr(symbolic_solve(eq, [a,b]), [a,b]) known_roots = sort_arr([Dict(a=>1/2, b=>3/2)], [a,b]) @test check_approx(calcd_roots, known_roots) + @test check_approx(solve_roots, known_roots) eq = (a*x^2+b)*s^2 - 2s^2 + 2*b*s - 3*s + 2(x^2)*(s^3) + 10*s^3 - calcd_roots = sort_arr(Symbolics.solve_interms_ofvar(eq, s), [a,b]) + calcd_roots = sort_arr(Symbolics.solve_interms_ofvar(eq, s), [a,b,x]) + solve_roots = sort_arr(symbolic_solve(eq, [a,b,x]), [a,b,x]) known_roots = sort_arr([Dict(a=>-1/10, b=>3/2, x=>-im*sqrt(5)), Dict(a=>-1/10, b=>3/2, x=>im*sqrt(5))], [a,b,x]) @test check_approx(calcd_roots, known_roots) + @test check_approx(solve_roots, known_roots) end @testset "Invalid input" begin @test_throws AssertionError symbolic_solve(x, x^2) - @test_throws AssertionError symbolic_solve(1/x, x) end @testset "Nice univar cases" begin @@ -355,14 +359,18 @@ end @testset "Post Process roots" begin SymbolicUtils.@syms __x __symsqrt(x) = SymbolicUtils.term(ssqrt, x) + term = SymbolicUtils.term @test Symbolics.postprocess_root(2 // 1) == 2 && Symbolics.postprocess_root(2 + 0*im) == 2 @test Symbolics.postprocess_root(__symsqrt(4)) == 2 @test isequal(Symbolics.postprocess_root(__symsqrt(__x)^2), __x) - @test Symbolics.postprocess_root( SymbolicUtils.term(^, __x, 0) ) == 1 - @test Symbolics.postprocess_root( SymbolicUtils.term(^, Base.MathConstants.e, 0) ) == 1 - @test Symbolics.postprocess_root( SymbolicUtils.term(^, Base.MathConstants.pi, 1) ) == Base.MathConstants.pi - @test isequal(Symbolics.postprocess_root( SymbolicUtils.term(^, __x, 1) ), __x) + + @test isequal(Symbolics.postprocess_root(term(^, 0, __x)), 0) + @test_broken isequal(Symbolics.postprocess_root(term(/, __x, 0)), Inf) + @test Symbolics.postprocess_root(term(^, __x, 0) ) == 1 + @test Symbolics.postprocess_root(term(^, Base.MathConstants.e, 0) ) == 1 + @test Symbolics.postprocess_root(term(^, Base.MathConstants.pi, 1) ) == Base.MathConstants.pi + @test isequal(Symbolics.postprocess_root(term(^, __x, 1) ), __x) x = Symbolics.term(sqrt, 2) @test isequal(Symbolics.postprocess_root( expand((x + 1)^4) ), 17 + 12x) @@ -426,7 +434,10 @@ end lhs = ia_solve(a*x^b + c, x)[1] lhs2 = symbolic_solve(a*x^b + c, x)[1] rhs = Symbolics.term(^, -c.val/a.val, 1/b.val) - #@test isequal(lhs, rhs) + @test_broken isequal(lhs, rhs) + + @test isequal(symbolic_solve(2/x, x)[1], Inf) + @test isequal(symbolic_solve(x^1.5, x)[1], 0) lhs = symbolic_solve(log(a*x)-b,x)[1] @test isequal(Symbolics.unwrap(Symbolics.ssubs(lhs, Dict(a=>1, b=>1))), 1E)