diff --git a/src/LazyArrays.jl b/src/LazyArrays.jl index d622239e..27571df2 100644 --- a/src/LazyArrays.jl +++ b/src/LazyArrays.jl @@ -60,10 +60,11 @@ end export Mul, Applied, MulArray, MulVector, MulMatrix, InvMatrix, PInvMatrix, Hcat, Vcat, Kron, BroadcastArray, BroadcastMatrix, BroadcastVector, cache, Ldiv, Inv, PInv, Diff, Cumsum, - applied, materialize, materialize!, ApplyArray, ApplyMatrix, ApplyVector, apply, ⋆, @~, LazyArray + applied, materialize, materialize!, @materialize, ApplyArray, ApplyMatrix, ApplyVector, apply, ⋆, @~, LazyArray include("lazyapplying.jl") +include("materialize_dsl.jl") include("lazybroadcasting.jl") include("linalg/linalg.jl") include("cache.jl") diff --git a/src/materialize_dsl.jl b/src/materialize_dsl.jl new file mode 100644 index 00000000..36ff8e32 --- /dev/null +++ b/src/materialize_dsl.jl @@ -0,0 +1,124 @@ +# For unparametrized destination types +generate_copyto!_signature(dest, dest_type::Symbol, Msig) = + :(Base.copyto!($(dest)::$(dest_type), applied_obj::$(Msig))) + +# For parametrized destination types +function generate_copyto!_signature(dest, dest_type::Expr, Msig) + dest_type.head == :curly || + throw(ArgumentError("Invalid destination specification $(dest)::$(dest_type)")) + :(Base.copyto!($(dest)::$(dest_type), applied_obj::$(Msig)) where {$(dest_type.args[2:end]...)}) +end + +function generate_copyto!(body, factor_names, Msig) + body.head == :(->) || + throw(ArgumentError("Invalid copyto! specification")) + body.args[1].head == :(::) || + throw(ArgumentError("Invalid destination specification $(body.args[1])")) + (dest,dest_type) = body.args[1].args + copyto!_signature = generate_copyto!_signature(dest, dest_type, Msig) + f_body = quote + axes($dest) == axes(applied_obj) || throw(DimensionMismatch("axes must be same")) + $(factor_names) = applied_obj.args + $(body.args[2].args...) + $(dest) + end + Expr(:function, copyto!_signature, f_body) +end + +""" + @materialize function op(args...) + +This macro simplifies the setup of a few functions necessary for the +materialization of [`Applied`](@ref) objects: + +- `ApplyStyle`, used to ensure dispatch of the applied object to the + routines below + +- `copyto!(dest::DestType, applied_obj::Applied{...,op})` performs the + actual materialization of `applied_obj` into the destination object + which has been generated by + +- `similar` which usually returns a suitable matrix + +- `materialize` which makes use of the above functions + +# Example + +```julia +@materialize function *(Ac::MyAdjointBasis, + O::MyOperator, + B::MyBasis) + MyApplyStyle # An instance of this type will be returned by ApplyStyle + T -> begin # generates similar + A = parent(Ac) + parent(A) == parent(B) || + throw(ArgumentError("Incompatible bases")) + + # There may be different matrices best representing different + # situations: + if ... + Diagonal(Vector{T}(undef, size(B,1))) + else + Tridiagonal(Vector{T}(undef, size(B,1)-1), + Vector{T}(undef, size(B,1)), + Vector{T}(undef, size(B,1)-1)) + end + end + dest::Diagonal{T} -> begin # generate copyto!(dest::Diagonal{T}, ...) where T + dest.diag .= 1 + end + dest::Tridiagonal{T} -> begin # generate copyto!(dest::Tridiagonal{T}, ...) where T + dest.dl .= -2 + dest.ev .= 1 + dest.du .= 3 + end +end +``` +""" +macro materialize(expr) + expr.head == :function || expr.head == :(=) || error("Must start with a function") + @assert expr.args[1].head == :call + op = expr.args[1].args[1] + + bodies = filter(e -> !(e isa LineNumberNode), expr.args[2].args) + length(bodies) < 3 && + throw(ArgumentError("At least three blocks required (ApplyStyle, similar, and at least one copyto!)")) + + factor_types = :(<:Tuple{}) + factor_names = :(()) + apply_style = first(bodies) + apply_style_fun = :(LazyArrays.ApplyStyle(::typeof($op)) = $(apply_style)()) + + # Generate Applied signature + for arg in expr.args[1].args[2:end] + arg isa Expr && arg.head == :(::) || + throw(ArgumentError("Invalid argument specification $(arg)")) + arg_name, arg_typ = arg.args + push!(factor_types.args[1].args, :(<:$(arg_typ))) + push!(factor_names.args, arg_name) + push!(apply_style_fun.args[1].args, :(::Type{<:$(arg_typ)})) + end + Msig = :(LazyArrays.Applied{$(apply_style), typeof($op), $(factor_types)}) + + sim_body = bodies[2] + sim_body.head == :(->) || + throw(ArgumentError("Invalid similar specification")) + T = first(sim_body.args) + + copytos! = map(body -> generate_copyto!(body, factor_names, Msig), bodies[3:end]) + + f = quote + $(apply_style_fun) + + function Base.similar(applied_obj::$Msig, ::Type{$T}=eltype(applied_obj)) where $T + $(factor_names) = applied_obj.args + $(sim_body.args[2]) + end + + $(copytos!...) + + LazyArrays.materialize(applied_obj::$Msig) = + copyto!(similar(applied_obj, eltype(applied_obj)), applied_obj) + end + esc(f) +end diff --git a/test/materialize_dsl.jl b/test/materialize_dsl.jl new file mode 100644 index 00000000..6c5b3cb4 --- /dev/null +++ b/test/materialize_dsl.jl @@ -0,0 +1,59 @@ +struct MyOperator{T} + n::Int + kind::Symbol +end + +Base.axes(O::MyOperator) = (Base.OneTo(O.n),Base.OneTo(O.n)) +Base.axes(O::MyOperator,i) = axes(O)[i] +Base.size(O::MyOperator) = (O.n,O.n) +Base.eltype(::MyOperator{T}) where T = T + +struct MyApplyStyle <: ApplyStyle end + +@materialize function *(Ac::Adjoint{<:Any,<:AbstractMatrix}, + O::MyOperator, + B::AbstractMatrix) + MyApplyStyle + T -> begin + A = parent(Ac) + + if O.kind == :diagonal + Diagonal(Vector{T}(undef, O.n)) + else + Tridiagonal(Vector{T}(undef, O.n-1), + Vector{T}(undef, O.n), + Vector{T}(undef, O.n-1)) + end + end + dest::Diagonal -> begin + dest.diag .= 1 + end + dest::Tridiagonal{T} -> begin + dest.dl .= -2 + dest.d .= 1 + dest.du .= 3 + end +end + +@testset "Materialize DSL" begin + o = ones(10) + M = ones(10,10) + D = MyOperator{Float64}(10, :diagonal) + T = MyOperator{ComplexF64}(10, :tridiagonal) + + @test LazyArrays.ApplyStyle(*, typeof(M'), typeof(D), typeof(M)) == MyApplyStyle() + @test LazyArrays.ApplyStyle(*, typeof(M'), typeof(T), typeof(M)) == MyApplyStyle() + + d = apply(*, M', D, M) + @test d isa Diagonal{Float64} + @test all(d.diag .== 1) + + t = apply(*, M', T, M) + @test t isa Tridiagonal + @test all(t.dl .== -2) + @test all(t.d .== 1) + @test all(t.du .== 3) + + M̃ = ones(11,11) + @test_throws DimensionMismatch apply(*, M̃', D, M̃) +end diff --git a/test/runtests.jl b/test/runtests.jl index c4b7b1d1..0bf48562 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,6 +1,7 @@ using Test, LinearAlgebra, LazyArrays, StaticArrays, FillArrays, ArrayLayouts import LazyArrays: CachedArray, colsupport, rowsupport, LazyArrayStyle, broadcasted, - PaddedLayout, ApplyLayout, BroadcastLayout, AddArray, LazyLayout + PaddedLayout, ApplyLayout, BroadcastLayout, AddArray, LazyLayout, + ApplyStyle @testset "Lazy MemoryLayout" begin @testset "ApplyArray" begin @@ -25,6 +26,7 @@ import LazyArrays: CachedArray, colsupport, rowsupport, LazyArrayStyle, broadcas end end include("applytests.jl") +include("materialize_dsl.jl") include("multests.jl") include("ldivtests.jl") include("addtests.jl") @@ -341,4 +343,4 @@ end @test exp.(transpose(v)) isa BroadcastMatrix @test exp.(M') isa BroadcastMatrix @test exp.(transpose(M)) isa BroadcastMatrix -end \ No newline at end of file +end