Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add DiagonalTensorMap constructors and converters #212

Merged
merged 11 commits into from
Feb 5, 2025
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ end
TensorKit.jl is a package that provides types and methods to represent and manipulate
tensors with symmetries. The emphasis is on the structure and functionality needed to build
tensor network algorithms for the simulation of quantum many-body systems. Such tensors are
typically invariant under a symmetry group which acts via specific representions on each of
typically invariant under a symmetry group which acts via specific representations on each of
the indices of the tensor. TensorKit.jl provides the functionality for constructing such
tensors and performing typical operations such as tensor contractions and decompositions,
thereby preserving the symmetries and exploiting them for optimal performance.
Expand Down
29 changes: 24 additions & 5 deletions src/tensors/diagonal.jl
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,20 @@
return DiagonalTensorMap{T}(data, V)
end

function DiagonalTensorMap(t::AbstractTensorMap{T,S,1,1}) where {T,S}
isa(t, DiagonalTensorMap) && return t
domain(t) == codomain(t) ||
throw(SpaceMismatch("DiagonalTensorMap requires equal domain and codomain"))
A = storagetype(t)
d = DiagonalTensorMap{T,S,A}(undef, space(t, 1))
for (c, b) in blocks(d)
bt = block(t, c)
# TODO: rewrite in terms of `diagview` from MatrixAlgebraKit.jl
copy!(b.diag, view(bt, LinearAlgebra.diagind(bt)))
end
return d
end

# TODO: more constructors needed?

# Special case adjoint:
Expand All @@ -73,12 +87,17 @@
TensorMap(d::DiagonalTensorMap) = copy!(similar(d), d)
Base.convert(::Type{TensorMap}, d::DiagonalTensorMap) = TensorMap(d)

function Base.convert(::Type{DiagonalTensorMap{T,S,A}},
d::DiagonalTensorMap{T,S,A}) where {T,S,A}
return d
end
function Base.convert(D::Type{<:DiagonalTensorMap}, d::DiagonalTensorMap)
return DiagonalTensorMap(convert(storagetype(D), d.data), d.domain)
return (d isa D) ? d : DiagonalTensorMap(convert(storagetype(D), d.data), d.domain)
end
Base.convert(::Type{DiagonalTensorMap}, t::DiagonalTensorMap) = t
function Base.convert(::Type{DiagonalTensorMap}, t::AbstractTensorMap)
LinearAlgebra.isdiag(t) ||

Check warning on line 95 in src/tensors/diagonal.jl

View check run for this annotation

Codecov / codecov/patch

src/tensors/diagonal.jl#L93-L95

Added lines #L93 - L95 were not covered by tests
throw(ArgumentError("DiagonalTensorMap requires input tensor that is diagonal"))
return DiagonalTensorMap(t)

Check warning on line 97 in src/tensors/diagonal.jl

View check run for this annotation

Codecov / codecov/patch

src/tensors/diagonal.jl#L97

Added line #L97 was not covered by tests
end
function Base.convert(::Type{DiagonalTensorMap}, d::Dict{Symbol,Any})
return convert(DiagonalTensorMap, convert(TensorMap, d))

Check warning on line 100 in src/tensors/diagonal.jl

View check run for this annotation

Codecov / codecov/patch

src/tensors/diagonal.jl#L99-L100

Added lines #L99 - L100 were not covered by tests
end

# Complex, real and imaginary parts
Expand Down
5 changes: 5 additions & 0 deletions test/diagonal.jl
Original file line number Diff line number Diff line change
Expand Up @@ -89,9 +89,14 @@ diagspacelist = ((ℂ^4)', ℂ[Z2Irrep](0 => 2, 1 => 3),
@timedtestset "Tensor conversion" begin
t = @constinferred DiagonalTensorMap(undef, V)
rand!(t.data)
# element type conversion
tc = complex(t)
@test convert(typeof(tc), t) == tc
@test typeof(convert(typeof(tc), t)) == typeof(tc)
# to and from generic TensorMap
td = DiagonalTensorMap(TensorMap(t))
@test t == td
@test typeof(td) == typeof(t)
end
I = sectortype(V)
if BraidingStyle(I) isa SymmetricBraiding
Expand Down
Loading