-
Notifications
You must be signed in to change notification settings - Fork 28
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
4 changed files
with
189 additions
and
3 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,124 @@ | ||
# For unparametrized destination types | ||
generate_copyto!_signature(dest, dest_type::Symbol, Msig) = | ||
:(Base.copyto!($(dest)::$(dest_type), applied_obj::$(Msig))) | ||
|
||
# For parametrized destination types | ||
function generate_copyto!_signature(dest, dest_type::Expr, Msig) | ||
dest_type.head == :curly || | ||
throw(ArgumentError("Invalid destination specification $(dest)::$(dest_type)")) | ||
:(Base.copyto!($(dest)::$(dest_type), applied_obj::$(Msig)) where {$(dest_type.args[2:end]...)}) | ||
end | ||
|
||
function generate_copyto!(body, factor_names, Msig) | ||
body.head == :(->) || | ||
throw(ArgumentError("Invalid copyto! specification")) | ||
body.args[1].head == :(::) || | ||
throw(ArgumentError("Invalid destination specification $(body.args[1])")) | ||
(dest,dest_type) = body.args[1].args | ||
copyto!_signature = generate_copyto!_signature(dest, dest_type, Msig) | ||
f_body = quote | ||
axes($dest) == axes(applied_obj) || throw(DimensionMismatch("axes must be same")) | ||
$(factor_names) = applied_obj.args | ||
$(body.args[2].args...) | ||
$(dest) | ||
end | ||
Expr(:function, copyto!_signature, f_body) | ||
end | ||
|
||
""" | ||
@materialize function op(args...) | ||
This macro simplifies the setup of a few functions necessary for the | ||
materialization of [`Applied`](@ref) objects: | ||
- `ApplyStyle`, used to ensure dispatch of the applied object to the | ||
routines below | ||
- `copyto!(dest::DestType, applied_obj::Applied{...,op})` performs the | ||
actual materialization of `applied_obj` into the destination object | ||
which has been generated by | ||
- `similar` which usually returns a suitable matrix | ||
- `materialize` which makes use of the above functions | ||
# Example | ||
```julia | ||
@materialize function *(Ac::MyAdjointBasis, | ||
O::MyOperator, | ||
B::MyBasis) | ||
MyApplyStyle # An instance of this type will be returned by ApplyStyle | ||
T -> begin # generates similar | ||
A = parent(Ac) | ||
parent(A) == parent(B) || | ||
throw(ArgumentError("Incompatible bases")) | ||
# There may be different matrices best representing different | ||
# situations: | ||
if ... | ||
Diagonal(Vector{T}(undef, size(B,1))) | ||
else | ||
Tridiagonal(Vector{T}(undef, size(B,1)-1), | ||
Vector{T}(undef, size(B,1)), | ||
Vector{T}(undef, size(B,1)-1)) | ||
end | ||
end | ||
dest::Diagonal{T} -> begin # generate copyto!(dest::Diagonal{T}, ...) where T | ||
dest.diag .= 1 | ||
end | ||
dest::Tridiagonal{T} -> begin # generate copyto!(dest::Tridiagonal{T}, ...) where T | ||
dest.dl .= -2 | ||
dest.ev .= 1 | ||
dest.du .= 3 | ||
end | ||
end | ||
``` | ||
""" | ||
macro materialize(expr) | ||
expr.head == :function || expr.head == :(=) || error("Must start with a function") | ||
@assert expr.args[1].head == :call | ||
op = expr.args[1].args[1] | ||
|
||
bodies = filter(e -> !(e isa LineNumberNode), expr.args[2].args) | ||
length(bodies) < 3 && | ||
throw(ArgumentError("At least three blocks required (ApplyStyle, similar, and at least one copyto!)")) | ||
|
||
factor_types = :(<:Tuple{}) | ||
factor_names = :(()) | ||
apply_style = first(bodies) | ||
apply_style_fun = :(LazyArrays.ApplyStyle(::typeof($op)) = $(apply_style)()) | ||
|
||
# Generate Applied signature | ||
for arg in expr.args[1].args[2:end] | ||
arg isa Expr && arg.head == :(::) || | ||
throw(ArgumentError("Invalid argument specification $(arg)")) | ||
arg_name, arg_typ = arg.args | ||
push!(factor_types.args[1].args, :(<:$(arg_typ))) | ||
push!(factor_names.args, arg_name) | ||
push!(apply_style_fun.args[1].args, :(::Type{<:$(arg_typ)})) | ||
end | ||
Msig = :(LazyArrays.Applied{$(apply_style), typeof($op), $(factor_types)}) | ||
|
||
sim_body = bodies[2] | ||
sim_body.head == :(->) || | ||
throw(ArgumentError("Invalid similar specification")) | ||
T = first(sim_body.args) | ||
|
||
copytos! = map(body -> generate_copyto!(body, factor_names, Msig), bodies[3:end]) | ||
|
||
f = quote | ||
$(apply_style_fun) | ||
|
||
function Base.similar(applied_obj::$Msig, ::Type{$T}=eltype(applied_obj)) where $T | ||
$(factor_names) = applied_obj.args | ||
$(sim_body.args[2]) | ||
end | ||
|
||
$(copytos!...) | ||
|
||
LazyArrays.materialize(applied_obj::$Msig) = | ||
copyto!(similar(applied_obj, eltype(applied_obj)), applied_obj) | ||
end | ||
esc(f) | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,59 @@ | ||
struct MyOperator{T} | ||
n::Int | ||
kind::Symbol | ||
end | ||
|
||
Base.axes(O::MyOperator) = (Base.OneTo(O.n),Base.OneTo(O.n)) | ||
Base.axes(O::MyOperator,i) = axes(O)[i] | ||
Base.size(O::MyOperator) = (O.n,O.n) | ||
Base.eltype(::MyOperator{T}) where T = T | ||
|
||
struct MyApplyStyle <: ApplyStyle end | ||
|
||
@materialize function *(Ac::Adjoint{<:Any,<:AbstractMatrix}, | ||
O::MyOperator, | ||
B::AbstractMatrix) | ||
MyApplyStyle | ||
T -> begin | ||
A = parent(Ac) | ||
|
||
if O.kind == :diagonal | ||
Diagonal(Vector{T}(undef, O.n)) | ||
else | ||
Tridiagonal(Vector{T}(undef, O.n-1), | ||
Vector{T}(undef, O.n), | ||
Vector{T}(undef, O.n-1)) | ||
end | ||
end | ||
dest::Diagonal{T} -> begin | ||
dest.diag .= 1 | ||
end | ||
dest::Tridiagonal{T} -> begin | ||
dest.dl .= -2 | ||
dest.d .= 1 | ||
dest.du .= 3 | ||
end | ||
end | ||
|
||
@testset "Materialize DSL" begin | ||
o = ones(10) | ||
M = ones(10,10) | ||
D = MyOperator{Float64}(10, :diagonal) | ||
T = MyOperator{ComplexF64}(10, :tridiagonal) | ||
|
||
@test LazyArrays.ApplyStyle(*, typeof(M'), typeof(D), typeof(M)) == MyApplyStyle() | ||
@test LazyArrays.ApplyStyle(*, typeof(M'), typeof(T), typeof(M)) == MyApplyStyle() | ||
|
||
d = apply(*, M', D, M) | ||
@test d isa Diagonal{Float64} | ||
@test all(d.diag .== 1) | ||
|
||
t = apply(*, M', T, M) | ||
@test t isa Tridiagonal | ||
@test all(t.dl .== -2) | ||
@test all(t.d .== 1) | ||
@test all(t.du .== 3) | ||
|
||
M̃ = ones(11,11) | ||
@test_throws DimensionMismatch apply(*, M̃', D, M̃) | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters