Skip to content

Commit

Permalink
fix: handle non-standard indices in linear_expansion
Browse files Browse the repository at this point in the history
  • Loading branch information
AayushSabharwal committed Jan 20, 2025
1 parent 23b5604 commit d67cf8d
Show file tree
Hide file tree
Showing 4 changed files with 9 additions and 1 deletion.
2 changes: 2 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ LogExpFunctions = "2ab3a3ac-af41-5b50-aa03-7779005ae688"
MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
Markdown = "d6f4376e-aef5-505a-96c1-9c027394607a"
NaNMath = "77ba4419-2d1f-58cd-9bb1-8ffee604a2e3"
OffsetArrays = "6fe1bfb0-de20-5000-8ca7-80f57d26f881"
PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a"
Primes = "27ebfcd6-29c5-5fa9-bf4b-fb8fc14df3ae"
RecipesBase = "3cdcf5f2-1ef4-517c-9805-6587b60abb01"
Expand Down Expand Up @@ -79,6 +80,7 @@ Lux = "1"
MacroTools = "0.5"
NaNMath = "1"
Nemo = "0.46, 0.47, 0.48"
OffsetArrays = "1.15.0"
PreallocationTools = "0.4"
PrecompileTools = "1"
Primes = "0.5"
Expand Down
2 changes: 2 additions & 0 deletions src/Symbolics.jl
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,8 @@ import SymbolicLimits

using ADTypes: ADTypes

import OffsetArrays

@reexport using SymbolicUtils
RuntimeGeneratedFunctions.init(@__MODULE__)

Expand Down
2 changes: 1 addition & 1 deletion src/linear_algebra.jl
Original file line number Diff line number Diff line change
Expand Up @@ -329,7 +329,7 @@ function _linear_expansion(t, x)
arrt, idxst... = arguments(t)
isequal(arrt, arrx) && return (0, t, true)

indexed_t = Symbolics.scalarize(arrt)[idxst...]
indexed_t = OffsetArrays.Origin(map(first, axes(arrt)))(Symbolics.scalarize(arrt))[idxst...]
# when indexing a registered function/callable symbolic
# scalarizing and indexing leads to the same symbolic variable
# which causes a StackOverflowError without this
Expand Down
4 changes: 4 additions & 0 deletions test/linear_solver.jl
Original file line number Diff line number Diff line change
Expand Up @@ -73,4 +73,8 @@ a, b, islinear = Symbolics.linear_expansion(D(x) - x, x)
@test !Symbolics.linear_expansion(z([x...]), x[1])[3]
@test !Symbolics.linear_expansion(z(collect(Symbolics.unwrap(x))), x[1])[3]
@test !Symbolics.linear_expansion(z([x, 2x]), x[1])[3]

@variables x[0:2]
a, b, islin = Symbolics.linear_expansion(x[0] - z(x[1]), z(x[1]))
@test islin && isequal(a, -1) && isequal(b, x[0])
end

0 comments on commit d67cf8d

Please sign in to comment.