Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add DiagonalTensorMap constructors and converters #212

Merged
merged 11 commits into from
Feb 5, 2025
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ end
TensorKit.jl is a package that provides types and methods to represent and manipulate
tensors with symmetries. The emphasis is on the structure and functionality needed to build
tensor network algorithms for the simulation of quantum many-body systems. Such tensors are
typically invariant under a symmetry group which acts via specific representions on each of
typically invariant under a symmetry group which acts via specific representations on each of
the indices of the tensor. TensorKit.jl provides the functionality for constructing such
tensors and performing typical operations such as tensor contractions and decompositions,
thereby preserving the symmetries and exploiting them for optimal performance.
Expand Down
22 changes: 22 additions & 0 deletions src/tensors/diagonal.jl
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,20 @@ function DiagonalTensorMap(data::DenseVector{T}, V::IndexSpace) where {T}
return DiagonalTensorMap{T}(data, V)
end

function DiagonalTensorMap(t::AbstractTensorMap{T,S,1,1}) where {T,S}
isa(t, DiagonalTensorMap) && return t
domain(t) == codomain(t) ||
throw(SpaceMismatch("DiagonalTensorMap requires equal domain and codomain"))
A = storagetype(t)
d = DiagonalTensorMap{T,S,A}(undef, space(t, 1))
for (c, b) in blocks(d)
bt = block(t, c)
# TODO: rewrite in terms of `diagview` from MatrixAlgebraKit.jl
copy!(b.diag, view(bt, LinearAlgebra.diagind(bt)))
end
return t
end

# TODO: more constructors needed?

# Special case adjoint:
Expand Down Expand Up @@ -80,6 +94,14 @@ end
function Base.convert(D::Type{<:DiagonalTensorMap}, d::DiagonalTensorMap)
return DiagonalTensorMap(convert(storagetype(D), d.data), d.domain)
end
function Base.convert(::Type{DiagonalTensorMap}, t::AbstractTensorMap)
all(LinearAlgebra.isdiag ∘ last, blocks(t)) ||
Jutho marked this conversation as resolved.
Show resolved Hide resolved
throw(ArgumentError("DiagonalTensorMap requires input tensor that is diagonal"))
return DiagonalTensorMap(t)
end
function Base.convert(::Type{DiagonalTensorMap}, d::Dict{Symbol,Any})
return convert(DiagonalTensorMap, convert(TensorMap, d))
end

# Complex, real and imaginary parts
#-----------------------------------
Expand Down
5 changes: 5 additions & 0 deletions test/diagonal.jl
Original file line number Diff line number Diff line change
Expand Up @@ -89,9 +89,14 @@ diagspacelist = ((ℂ^4)', ℂ[Z2Irrep](0 => 2, 1 => 3),
@timedtestset "Tensor conversion" begin
t = @constinferred DiagonalTensorMap(undef, V)
rand!(t.data)
# element type conversion
tc = complex(t)
@test convert(typeof(tc), t) == tc
@test typeof(convert(typeof(tc), t)) == typeof(tc)
# to and from generic TensorMap
td = DiagonalTensorMap(TensorMap(t))
@test t == td
@test typeof(td) == typeof(t)
end
I = sectortype(V)
if BraidingStyle(I) isa SymmetricBraiding
Expand Down
Loading