Skip to content

Commit

Permalink
Support different ortho directions in factorize (#18)
Browse files Browse the repository at this point in the history
  • Loading branch information
mtfishman authored Jan 18, 2025
1 parent e37a397 commit e1b953a
Show file tree
Hide file tree
Showing 3 changed files with 38 additions and 21 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "ITensorBase"
uuid = "4795dd04-0d67-49bb-8f44-b89c448a1dc7"
authors = ["ITensor developers <[email protected]> and contributors"]
version = "0.1.7"
version = "0.1.8"

[deps]
Accessors = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697"
Expand Down
3 changes: 3 additions & 0 deletions src/ITensorBase.jl
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,9 @@ using UnspecifiedTypes: UnspecifiedZero
function specify_eltype(a::Zeros{UnspecifiedZero}, elt::Type)
return Zeros{elt}(axes(a))
end
function specify_eltype(a::AbstractArray, elt::Type)
return a
end

# TODO: Use `adapt` to reach down into the storage.
function specify_eltype!(a::AbstractITensor, elt::Type)
Expand Down
54 changes: 34 additions & 20 deletions src/quirks.jl
Original file line number Diff line number Diff line change
Expand Up @@ -36,34 +36,48 @@ function onehot(iv::Pair{<:Index,<:Int})
return a
end

# TODO: This is just a stand-in for truncated SVD
# that only makes use of `maxdim`, just to get some
# functionality running in `ITensorMPS.jl`.
# Define a proper truncated SVD in
# `MatrixAlgebra.jl`/`TensorAlgebra.jl`.
function svd_truncated(a::AbstractITensor, codomain_inds; maxdim)
U, S, V = svd(a, codomain_inds)
r = Base.OneTo(min(maxdim, minimum(Int.(size(S)))))
u = commonind(U, S)
v = commonind(V, S)
us = uniqueinds(U, S)
vs = uniqueinds(V, S)
U′ = U[(us .=> :)..., u => r]
S′ = S[u => r, v => r]
V′ = V[v => r, (vs .=> :)...]
return U′, S′, V′
end

using LinearAlgebra: qr, svd
# TODO: Define this in `MatrixAlgebra.jl`/`TensorAlgebra.jl`.
function factorize(
a::AbstractITensor, codomain_inds; maxdim=nothing, cutoff=nothing, kwargs...
a::AbstractITensor, codomain_inds; maxdim=nothing, cutoff=nothing, ortho="left", kwargs...
)
# TODO: Perform this intersection in `TensorAlgebra.qr`/`TensorAlgebra.svd`?
# See https://github.com/ITensor/NamedDimsArrays.jl/issues/22.
codomain_inds′ = intersect(inds(a), codomain_inds)
if isnothing(maxdim) && isnothing(cutoff)
Q, R = qr(a, codomain_inds′)
return Q, R, (; truncerr=zero(Bool),)
codomain_inds′ = if ortho == "left"
intersect(inds(a), codomain_inds)
elseif ortho == "right"
setdiff(inds(a), codomain_inds)
else
U, S, V = svd(a, codomain_inds′)
# TODO: This is just a stand-in for truncated SVD
# that only makes use of `maxdim`, just to get some
# functionality running in `ITensorMPS.jl`.
# Define a proper truncated SVD in
# `MatrixAlgebra.jl`/`TensorAlgebra.jl`.
r = Base.OneTo(min(maxdim, minimum(Int.(size(S)))))
u = commonind(U, S)
v = commonind(V, S)
us = uniqueinds(U, S)
vs = uniqueinds(V, S)
U′ = U[(us .=> :)..., u => r]
S′ = S[u => r, v => r]
V′ = V[v => r, (vs .=> :)...]
return U′, S′ * V′, (; truncerr=zero(Bool),)
error("Bad `ortho` input.")
end
F1, F2 = if isnothing(maxdim) && isnothing(cutoff)
qr(a, codomain_inds′)
else
U, S, V = svd_truncated(a, codomain_inds′; maxdim)
U, S * V
end
if ortho == "right"
F2, F1 = F1, F2
end
return F1, F2, (; truncerr=zero(Bool),)
end

# TODO: Used in `ITensorMPS.jl`, decide where or if to define it.
Expand Down

0 comments on commit e1b953a

Please sign in to comment.