diff --git a/src/inspect.jl b/src/inspect.jl index 6f22d0ad..8dc7eaf6 100644 --- a/src/inspect.jl +++ b/src/inspect.jl @@ -12,10 +12,10 @@ function AbstractTrees.nodevalue(x::BasicSymbolic) string(x.impl.val) elseif isadd(x) string(exprtype(x), - (scalar = get_coeff(x), coeffs = Tuple(k => v for (k, v) in x.impl.dict))) + (scalar = get_coeff(x), coeffs = Tuple(k => v for (k, v) in get_dict(x)))) elseif ismul(x) string(exprtype(x), - (scalar = get_coeff(x), powers = Tuple(k => v for (k, v) in x.impl.dict))) + (scalar = get_coeff(x), powers = Tuple(k => v for (k, v) in get_dict(x)))) elseif isdiv(x) || ispow(x) string(exprtype(x)) else diff --git a/src/polyform.jl b/src/polyform.jl index 34989ece..f2630166 100644 --- a/src/polyform.jl +++ b/src/polyform.jl @@ -502,12 +502,12 @@ end # mul, pow case function quick_mulpow(x, y) y.impl.exp isa Number || return (x, y) - if haskey(x.impl.dict, y.impl.base) - d = copy(x.impl.dict) - if x.impl.dict[y.impl.base] > y.impl.exp + if haskey(get_dict(x), y.impl.base) + d = copy(get_dict(x)) + if get_dict(x)[y.impl.base] > y.impl.exp d[y.impl.base] -= y.impl.exp den = 1 - elseif x.impl.dict[y.impl.base] == y.impl.exp + elseif get_dict(x)[y.impl.base] == y.impl.exp delete!(d, y.impl.base) den = 1 else diff --git a/src/types.jl b/src/types.jl index 8cae7251..bfbc20a9 100644 --- a/src/types.jl +++ b/src/types.jl @@ -72,6 +72,10 @@ function get_coeff(x::BasicSymbolic) x.impl.coeff end +function get_dict(x::BasicSymbolic) + x.impl.dict +end + # Same but different error messages @noinline error_on_type() = error("Internal error: unreachable reached!") @noinline error_sym() = error("Sym doesn't have a operation or arguments!") @@ -297,7 +301,7 @@ function _isequal(a, b, E) if E === SYM nameof(a) === nameof(b) elseif E === ADD || E === MUL - coeff_isequal(get_coeff(a), get_coeff(b)) && isequal(a.impl.dict, b.impl.dict) + coeff_isequal(get_coeff(a), get_coeff(b)) && isequal(get_dict(a), get_dict(b)) elseif E === DIV isequal(a.impl.num, b.impl.num) && isequal(a.impl.den, b.impl.den) elseif E === POW @@ -341,7 +345,7 @@ function Base.hash(s::BasicSymbolic, salt::UInt)::UInt h = s.hash[] !iszero(h) && return h hashoffset = isadd(s) ? ADD_SALT : SUB_SALT - h′ = hash(hashoffset, hash(get_coeff(s), hash(s.impl.dict, salt))) + h′ = hash(hashoffset, hash(get_coeff(s), hash(get_dict(s), salt))) s.hash[] = h′ return h′ elseif E === DIV @@ -461,7 +465,7 @@ function maybe_intcoeff(x) if ismul(x) coeff = get_coeff(x) if coeff isa Rational && isone(denominator(coeff)) - _Mul(symtype(x), coeff.num, x.impl.dict; metadata = x.metadata) + _Mul(symtype(x), coeff.num, get_dict(x); metadata = x.metadata) else x end @@ -542,7 +546,7 @@ function toterm(t::BasicSymbolic{T}) where {T} elseif E === ADD || E === MUL args = BasicSymbolic[] push!(args, get_coeff(t)) - for (k, coeff) in t.impl.dict + for (k, coeff) in get_dict(t) push!( args, coeff == 1 ? k : _Term(T, E === MUL ? (^) : (*), [_Const(coeff), k])) end @@ -567,7 +571,7 @@ function makeadd(sign, coeff, xs...) for x in xs if isadd(x) coeff += get_coeff(x) - _merge!(+, d, x.impl.dict, filter = _iszero) + _merge!(+, d, get_dict(x), filter = _iszero) continue end if x isa Number @@ -575,7 +579,7 @@ function makeadd(sign, coeff, xs...) continue end if ismul(x) - k = _Mul(symtype(x), 1, x.impl.dict) + k = _Mul(symtype(x), 1, get_dict(x)) v = sign * get_coeff(x) + get(d, k, 0) else k = x @@ -598,7 +602,7 @@ function makemul(coeff, xs...; d = Dict{BasicSymbolic, Any}()) coeff *= x elseif ismul(x) coeff *= get_coeff(x) - _merge!(+, d, x.impl.dict, filter = _iszero) + _merge!(+, d, get_dict(x), filter = _iszero) else v = 1 + get(d, x, 0) if _iszero(v) @@ -1223,10 +1227,10 @@ function +(a::SN, b::SN) !issafecanon(+, a, b) && return term(+, a, b) # Don't flatten if args have metadata if isadd(a) && isadd(b) return _Add( - add_t(a, b), get_coeff(a) + get_coeff(b), _merge(+, a.impl.dict, b.impl.dict, filter = _iszero)) + add_t(a, b), get_coeff(a) + get_coeff(b), _merge(+, get_dict(a), get_dict(b), filter = _iszero)) elseif isadd(a) coeff, dict = makeadd(1, 0, b) - return _Add(add_t(a, b), get_coeff(a) + coeff, _merge(+, a.impl.dict, dict, filter = _iszero)) + return _Add(add_t(a, b), get_coeff(a) + coeff, _merge(+, get_dict(a), dict, filter = _iszero)) elseif isadd(b) return b + a end @@ -1240,7 +1244,7 @@ function +(a::Number, b::SN) !issafecanon(+, b) && return term(+, a, b) # Don't flatten if args have metadata iszero(a) && return b if isadd(b) - _Add(add_t(a, b), a + get_coeff(b), b.impl.dict) + _Add(add_t(a, b), a + get_coeff(b), get_dict(b)) else _Add(add_t(a, b), makeadd(1, a, b)...) end @@ -1258,7 +1262,7 @@ function -(a::SN) return term(-, a) end if isadd(a) - _Add(sub_t(a), -get_coeff(a), mapvalues((_, v) -> -v, a.impl.dict)) + _Add(sub_t(a), -get_coeff(a), mapvalues((_, v) -> -v, get_dict(a))) else _Add(sub_t(a), makeadd(-1, 0, a)...) end @@ -1266,7 +1270,7 @@ end function -(a::SN, b::SN) (!issafecanon(+, a) || !issafecanon(*, b)) && return term(-, a, b) if isadd(a) && isadd(b) - _Add(sub_t(a, b), get_coeff(a) - get_coeff(b), _merge(-, a.impl.dict, b.impl.dict, filter = _iszero)) + _Add(sub_t(a, b), get_coeff(a) - get_coeff(b), _merge(-, get_dict(a), get_dict(b), filter = _iszero)) else a + (-b) end @@ -1294,16 +1298,16 @@ function *(a::SN, b::SN) _Div(a * b.impl.num, b.impl.den) elseif ismul(a) && ismul(b) _Mul(mul_t(a, b), get_coeff(a) * get_coeff(b), - _merge(+, a.impl.dict, b.impl.dict, filter = _iszero)) + _merge(+, get_dict(a), get_dict(b), filter = _iszero)) elseif ismul(a) && ispow(b) if b.impl.exp isa Number _Mul(mul_t(a, b), get_coeff(a), - _merge(+, a.impl.dict, Base.ImmutableDict(b.impl.base => b.impl.exp), + _merge(+, get_dict(a), Base.ImmutableDict(b.impl.base => b.impl.exp), filter = _iszero)) else _Mul(mul_t(a, b), get_coeff(a), - _merge(+, a.impl.dict, Base.ImmutableDict(b => 1), filter = _iszero)) + _merge(+, get_dict(a), Base.ImmutableDict(b => 1), filter = _iszero)) end elseif ispow(a) && ismul(b) b * a @@ -1326,7 +1330,7 @@ function *(a::Number, b::SN) # -1(a+b) -> -a - b T = promote_symtype(+, typeof(a), symtype(b)) _Add(T, get_coeff(b) * a, - Dict{BasicSymbolic, Any}(k => v * a for (k, v) in b.impl.dict)) + Dict{BasicSymbolic, Any}(k => v * a for (k, v) in get_dict(b))) else _Mul(mul_t(a, b), makemul(a, b)...) end @@ -1352,7 +1356,7 @@ function ^(a::SN, b) elseif ismul(a) && b isa Number coeff = unstable_pow(get_coeff(a), b) _Mul(promote_symtype(^, symtype(a), symtype(b)), - coeff, mapvalues((k, v) -> b * v, a.impl.dict)) + coeff, mapvalues((k, v) -> b * v, get_dict(a))) else _Pow(a, b) end diff --git a/test/basics.jl b/test/basics.jl index 9a70f23b..a392edda 100644 --- a/test/basics.jl +++ b/test/basics.jl @@ -1,5 +1,5 @@ using SymbolicUtils: Symbolic, FnType, symtype, operation, arguments, issym, isterm, - BasicSymbolic, term, get_name, get_coeff + BasicSymbolic, term, get_name, get_coeff, get_dict using SymbolicUtils using IfElse: ifelse using Setfield @@ -234,7 +234,7 @@ end @testset "maketerm" begin @syms a b c - @test isequal(SymbolicUtils.maketerm(typeof(b + c), +, [a, (b+c)], nothing).impl.dict, Dict(a=>1,b=>1,c=>1)) + @test isequal(get_dict(SymbolicUtils.maketerm(typeof(b + c), +, [a, (b+c)], nothing)), Dict(a=>1,b=>1,c=>1)) @test isequal(SymbolicUtils.maketerm(typeof(b^2), ^, [b^2, 1//2], nothing), b) # test that maketerm doesn't hard-code BasicSymbolic subtype