diff --git a/Project.toml b/Project.toml index f2fdcc1..cc6cb28 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "ITensorBase" uuid = "4795dd04-0d67-49bb-8f44-b89c448a1dc7" authors = ["ITensor developers and contributors"] -version = "0.1.7" +version = "0.1.8" [deps] Accessors = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697" diff --git a/src/ITensorBase.jl b/src/ITensorBase.jl index 669d95e..0ee5d2b 100644 --- a/src/ITensorBase.jl +++ b/src/ITensorBase.jl @@ -165,6 +165,9 @@ using UnspecifiedTypes: UnspecifiedZero function specify_eltype(a::Zeros{UnspecifiedZero}, elt::Type) return Zeros{elt}(axes(a)) end +function specify_eltype(a::AbstractArray, elt::Type) + return a +end # TODO: Use `adapt` to reach down into the storage. function specify_eltype!(a::AbstractITensor, elt::Type) diff --git a/src/quirks.jl b/src/quirks.jl index 3b4179b..58140fe 100644 --- a/src/quirks.jl +++ b/src/quirks.jl @@ -36,34 +36,48 @@ function onehot(iv::Pair{<:Index,<:Int}) return a end +# TODO: This is just a stand-in for truncated SVD +# that only makes use of `maxdim`, just to get some +# functionality running in `ITensorMPS.jl`. +# Define a proper truncated SVD in +# `MatrixAlgebra.jl`/`TensorAlgebra.jl`. +function svd_truncated(a::AbstractITensor, codomain_inds; maxdim) + U, S, V = svd(a, codomain_inds) + r = Base.OneTo(min(maxdim, minimum(Int.(size(S))))) + u = commonind(U, S) + v = commonind(V, S) + us = uniqueinds(U, S) + vs = uniqueinds(V, S) + U′ = U[(us .=> :)..., u => r] + S′ = S[u => r, v => r] + V′ = V[v => r, (vs .=> :)...] + return U′, S′, V′ +end + using LinearAlgebra: qr, svd # TODO: Define this in `MatrixAlgebra.jl`/`TensorAlgebra.jl`. function factorize( - a::AbstractITensor, codomain_inds; maxdim=nothing, cutoff=nothing, kwargs... + a::AbstractITensor, codomain_inds; maxdim=nothing, cutoff=nothing, ortho="left", kwargs... ) # TODO: Perform this intersection in `TensorAlgebra.qr`/`TensorAlgebra.svd`? # See https://github.com/ITensor/NamedDimsArrays.jl/issues/22. - codomain_inds′ = intersect(inds(a), codomain_inds) - if isnothing(maxdim) && isnothing(cutoff) - Q, R = qr(a, codomain_inds′) - return Q, R, (; truncerr=zero(Bool),) + codomain_inds′ = if ortho == "left" + intersect(inds(a), codomain_inds) + elseif ortho == "right" + setdiff(inds(a), codomain_inds) else - U, S, V = svd(a, codomain_inds′) - # TODO: This is just a stand-in for truncated SVD - # that only makes use of `maxdim`, just to get some - # functionality running in `ITensorMPS.jl`. - # Define a proper truncated SVD in - # `MatrixAlgebra.jl`/`TensorAlgebra.jl`. - r = Base.OneTo(min(maxdim, minimum(Int.(size(S))))) - u = commonind(U, S) - v = commonind(V, S) - us = uniqueinds(U, S) - vs = uniqueinds(V, S) - U′ = U[(us .=> :)..., u => r] - S′ = S[u => r, v => r] - V′ = V[v => r, (vs .=> :)...] - return U′, S′ * V′, (; truncerr=zero(Bool),) + error("Bad `ortho` input.") + end + F1, F2 = if isnothing(maxdim) && isnothing(cutoff) + qr(a, codomain_inds′) + else + U, S, V = svd_truncated(a, codomain_inds′; maxdim) + U, S * V + end + if ortho == "right" + F2, F1 = F1, F2 end + return F1, F2, (; truncerr=zero(Bool),) end # TODO: Used in `ITensorMPS.jl`, decide where or if to define it.