From 93ff33d8187ecabd51afcc02d03ff891d0bf362d Mon Sep 17 00:00:00 2001 From: Willow Ahrens Date: Fri, 2 Feb 2024 13:33:42 -0500 Subject: [PATCH 1/2] swizzle(swizzle(A, dims), dims_2) --- src/tensors/combinators/swizzle.jl | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/tensors/combinators/swizzle.jl b/src/tensors/combinators/swizzle.jl index 950d4e88d..be5f263f4 100644 --- a/src/tensors/combinators/swizzle.jl +++ b/src/tensors/combinators/swizzle.jl @@ -41,6 +41,8 @@ function virtualize(ex, ::Type{SwizzleArray{dims, Body}}, ctx) where {dims, Body end swizzle(body, dims::Int...) = SwizzleArray(body, dims) +swizzle(body::SwizzleArray{dims}, dims_2::Int...) where {dims} = SwizzleArray(body.body, dims[dims_2]) + function virtual_call(::typeof(swizzle), ctx, body, dims...) @assert All(isliteral)(dims) VirtualSwizzleArray(body, map(dim -> dim.val, collect(dims))) From b9e788ff4547bfe8ac1c2a320a324ea64831f4d8 Mon Sep 17 00:00:00 2001 From: Willow Ahrens Date: Fri, 2 Feb 2024 13:40:03 -0500 Subject: [PATCH 2/2] fix --- Project.toml | 2 +- src/tensors/combinators/swizzle.jl | 2 +- test/test_issues.jl | 2 ++ 3 files changed, 4 insertions(+), 2 deletions(-) diff --git a/Project.toml b/Project.toml index 870761472..9f0df08d9 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "Finch" uuid = "9177782c-1635-4eb9-9bfb-d9dfa25e6bce" authors = ["Willow Ahrens"] -version = "0.6.9" +version = "0.6.10" [deps] AbstractTrees = "1520ce14-60c1-5f80-bbc7-55ef81b5835c" diff --git a/src/tensors/combinators/swizzle.jl b/src/tensors/combinators/swizzle.jl index be5f263f4..c056d4904 100644 --- a/src/tensors/combinators/swizzle.jl +++ b/src/tensors/combinators/swizzle.jl @@ -41,7 +41,7 @@ function virtualize(ex, ::Type{SwizzleArray{dims, Body}}, ctx) where {dims, Body end swizzle(body, dims::Int...) = SwizzleArray(body, dims) -swizzle(body::SwizzleArray{dims}, dims_2::Int...) where {dims} = SwizzleArray(body.body, dims[dims_2]) +swizzle(body::SwizzleArray{dims}, dims_2::Int...) where {dims} = SwizzleArray(body.body, ntuple(n-> dims[dims_2[n]], ndims(body))) function virtual_call(::typeof(swizzle), ctx, body, dims...) @assert All(isliteral)(dims) diff --git a/test/test_issues.jl b/test/test_issues.jl index 376adb7db..e3223b9b1 100644 --- a/test/test_issues.jl +++ b/test/test_issues.jl @@ -610,4 +610,6 @@ using CIndices new_shape_2 = size(Tensor(Dense(SparseList(SparseList(Element(0.0)))), st)) @test new_shape_1 == new_shape_2 + + @test swizzle(swizzle(zeros(3, 3, 3), 3, 1, 2), 3, 2, 1) isa Finch.SwizzleArray{(2, 1, 3), <:Array} end \ No newline at end of file