Skip to content

Commit

Permalink
Added @materialize convenience DSL
Browse files Browse the repository at this point in the history
  • Loading branch information
jagot committed Jan 7, 2020
1 parent 7336bc5 commit e006a95
Show file tree
Hide file tree
Showing 4 changed files with 189 additions and 3 deletions.
3 changes: 2 additions & 1 deletion src/LazyArrays.jl
Original file line number Diff line number Diff line change
Expand Up @@ -60,10 +60,11 @@ end

export Mul, Applied, MulArray, MulVector, MulMatrix, InvMatrix, PInvMatrix,
Hcat, Vcat, Kron, BroadcastArray, BroadcastMatrix, BroadcastVector, cache, Ldiv, Inv, PInv, Diff, Cumsum,
applied, materialize, materialize!, ApplyArray, ApplyMatrix, ApplyVector, apply, , @~, LazyArray
applied, materialize, materialize!, @materialize, ApplyArray, ApplyMatrix, ApplyVector, apply, , @~, LazyArray


include("lazyapplying.jl")
include("materialize_dsl.jl")
include("lazybroadcasting.jl")
include("linalg/linalg.jl")
include("cache.jl")
Expand Down
124 changes: 124 additions & 0 deletions src/materialize_dsl.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
# For unparametrized destination types
generate_copyto!_signature(dest, dest_type::Symbol, Msig) =
:(Base.copyto!($(dest)::$(dest_type), applied_obj::$(Msig)))

# For parametrized destination types
function generate_copyto!_signature(dest, dest_type::Expr, Msig)
dest_type.head == :curly ||
throw(ArgumentError("Invalid destination specification $(dest)::$(dest_type)"))
:(Base.copyto!($(dest)::$(dest_type), applied_obj::$(Msig)) where {$(dest_type.args[2:end]...)})
end

function generate_copyto!(body, factor_names, Msig)
body.head == :(->) ||
throw(ArgumentError("Invalid copyto! specification"))
body.args[1].head == :(::) ||
throw(ArgumentError("Invalid destination specification $(body.args[1])"))
(dest,dest_type) = body.args[1].args
copyto!_signature = generate_copyto!_signature(dest, dest_type, Msig)
f_body = quote
axes($dest) == axes(applied_obj) || throw(DimensionMismatch("axes must be same"))
$(factor_names) = applied_obj.args
$(body.args[2].args...)
$(dest)
end
Expr(:function, copyto!_signature, f_body)
end

"""
@materialize function op(args...)
This macro simplifies the setup of a few functions necessary for the
materialization of [`Applied`](@ref) objects:
- `ApplyStyle`, used to ensure dispatch of the applied object to the
routines below
- `copyto!(dest::DestType, applied_obj::Applied{...,op})` performs the
actual materialization of `applied_obj` into the destination object
which has been generated by
- `similar` which usually returns a suitable matrix
- `materialize` which makes use of the above functions
# Example
```julia
@materialize function *(Ac::MyAdjointBasis,
O::MyOperator,
B::MyBasis)
MyApplyStyle # An instance of this type will be returned by ApplyStyle
T -> begin # generates similar
A = parent(Ac)
parent(A) == parent(B) ||
throw(ArgumentError("Incompatible bases"))
# There may be different matrices best representing different
# situations:
if ...
Diagonal(Vector{T}(undef, size(B,1)))
else
Tridiagonal(Vector{T}(undef, size(B,1)-1),
Vector{T}(undef, size(B,1)),
Vector{T}(undef, size(B,1)-1))
end
end
dest::Diagonal{T} -> begin # generate copyto!(dest::Diagonal{T}, ...) where T
dest.diag .= 1
end
dest::Tridiagonal{T} -> begin # generate copyto!(dest::Tridiagonal{T}, ...) where T
dest.dl .= -2
dest.ev .= 1
dest.du .= 3
end
end
```
"""
macro materialize(expr)
expr.head == :function || expr.head == :(=) || error("Must start with a function")
@assert expr.args[1].head == :call
op = expr.args[1].args[1]

bodies = filter(e -> !(e isa LineNumberNode), expr.args[2].args)
length(bodies) < 3 &&
throw(ArgumentError("At least three blocks required (ApplyStyle, similar, and at least one copyto!)"))

factor_types = :(<:Tuple{})
factor_names = :(())
apply_style = first(bodies)
apply_style_fun = :(LazyArrays.ApplyStyle(::typeof($op)) = $(apply_style)())

# Generate Applied signature
for arg in expr.args[1].args[2:end]
arg isa Expr && arg.head == :(::) ||
throw(ArgumentError("Invalid argument specification $(arg)"))
arg_name, arg_typ = arg.args
push!(factor_types.args[1].args, :(<:$(arg_typ)))
push!(factor_names.args, arg_name)
push!(apply_style_fun.args[1].args, :(::Type{<:$(arg_typ)}))
end
Msig = :(LazyArrays.Applied{$(apply_style), typeof($op), $(factor_types)})

sim_body = bodies[2]
sim_body.head == :(->) ||
throw(ArgumentError("Invalid similar specification"))
T = first(sim_body.args)

copytos! = map(body -> generate_copyto!(body, factor_names, Msig), bodies[3:end])

f = quote
$(apply_style_fun)

function Base.similar(applied_obj::$Msig, ::Type{$T}=eltype(applied_obj)) where $T
$(factor_names) = applied_obj.args
$(sim_body.args[2])
end

$(copytos!...)

LazyArrays.materialize(applied_obj::$Msig) =
copyto!(similar(applied_obj, eltype(applied_obj)), applied_obj)
end
esc(f)
end
59 changes: 59 additions & 0 deletions test/materialize_dsl.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
struct MyOperator{T}
n::Int
kind::Symbol
end

Base.axes(O::MyOperator) = (Base.OneTo(O.n),Base.OneTo(O.n))
Base.axes(O::MyOperator,i) = axes(O)[i]
Base.size(O::MyOperator) = (O.n,O.n)
Base.eltype(::MyOperator{T}) where T = T

struct MyApplyStyle <: ApplyStyle end

@materialize function *(Ac::Adjoint{<:Any,<:AbstractMatrix},
O::MyOperator,
B::AbstractMatrix)
MyApplyStyle
T -> begin
A = parent(Ac)

if O.kind == :diagonal
Diagonal(Vector{T}(undef, O.n))
else
Tridiagonal(Vector{T}(undef, O.n-1),
Vector{T}(undef, O.n),
Vector{T}(undef, O.n-1))
end
end
dest::Diagonal{T} -> begin
dest.diag .= 1
end
dest::Tridiagonal{T} -> begin
dest.dl .= -2
dest.d .= 1
dest.du .= 3
end
end

@testset "Materialize DSL" begin
o = ones(10)
M = ones(10,10)
D = MyOperator{Float64}(10, :diagonal)
T = MyOperator{ComplexF64}(10, :tridiagonal)

@test LazyArrays.ApplyStyle(*, typeof(M'), typeof(D), typeof(M)) == MyApplyStyle()
@test LazyArrays.ApplyStyle(*, typeof(M'), typeof(T), typeof(M)) == MyApplyStyle()

d = apply(*, M', D, M)
@test d isa Diagonal{Float64}
@test all(d.diag .== 1)

t = apply(*, M', T, M)
@test t isa Tridiagonal
@test all(t.dl .== -2)
@test all(t.d .== 1)
@test all(t.du .== 3)

= ones(11,11)
@test_throws DimensionMismatch apply(*, M̃', D, M̃)
end
6 changes: 4 additions & 2 deletions test/runtests.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
using Test, LinearAlgebra, LazyArrays, StaticArrays, FillArrays, ArrayLayouts
import LazyArrays: CachedArray, colsupport, rowsupport, LazyArrayStyle, broadcasted,
PaddedLayout, ApplyLayout, BroadcastLayout, AddArray, LazyLayout
PaddedLayout, ApplyLayout, BroadcastLayout, AddArray, LazyLayout,
ApplyStyle

@testset "Lazy MemoryLayout" begin
@testset "ApplyArray" begin
Expand All @@ -25,6 +26,7 @@ import LazyArrays: CachedArray, colsupport, rowsupport, LazyArrayStyle, broadcas
end
end
include("applytests.jl")
include("materialize_dsl.jl")
include("multests.jl")
include("ldivtests.jl")
include("addtests.jl")
Expand Down Expand Up @@ -341,4 +343,4 @@ end
@test exp.(transpose(v)) isa BroadcastMatrix
@test exp.(M') isa BroadcastMatrix
@test exp.(transpose(M)) isa BroadcastMatrix
end
end

0 comments on commit e006a95

Please sign in to comment.