Skip to content

Commit

Permalink
edited tensor contractions
Browse files Browse the repository at this point in the history
  • Loading branch information
chakravala committed Jun 27, 2024
1 parent 2493c4d commit ed961c6
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 3 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "Grassmann"
uuid = "4df31cd9-4c27-5bea-88d0-e6a7146666d8"
authors = ["Michael Reed"]
version = "0.8.21"
version = "0.8.22"

[deps]
AbstractTensors = "a8e43f4a-99b7-5565-8bf1-0165161caaea"
Expand Down
13 changes: 11 additions & 2 deletions src/forms.jl
Original file line number Diff line number Diff line change
Expand Up @@ -300,6 +300,7 @@ transpose_row(t::Chain{V,1,<:Chain},i) where V = transpose_row(value(t),i,V)
@generated _transpose(t::FixedVector{N,<:Chain{V,1}},W=V) where {N,V} = :(Chain{V,1}(transpose_row.(Ref(t),$(list(1,mdims(V))),W)))
Base.transpose(t::Chain{V,1,<:Chain{V,1}}) where V = _transpose(value(t))
Base.transpose(t::Chain{V,1,<:Chain{W,1}}) where {V,W} = _transpose(value(t),V)
Base.inv(t::TensorNested,g) = inv(t)

Base.Matrix(t::TensorAlgebra) = matrix(t)

Expand Down Expand Up @@ -415,7 +416,7 @@ end

value(t::DiagonalOperator) = t.v
matrix(m::DiagonalOperator) = matrix(TensorOperator(m))
getindex(t::DiagonalOperator,i::Int,j::Int) = ij ? zero(valuetype(value(t))) : value(t)[i]
getindex(t::DiagonalOperator,i::Int,j::Int) = ij ? zero(valuetype(value(t))) : value(value(t))[i]
getindex(t::DiagonalOperator,i::Int) = value(t)(i)

compound(m::DiagonalOperator{V,<:Chain{V,1}},::Val{0}) where V = DiagonalOperator(Chain{V,0}(1))
Expand Down Expand Up @@ -638,7 +639,8 @@ contraction(a::Single{W},b::Chain{V,G,<:Chain}) where {W,G,V} = Chain{V,G}(colum
contraction(x::Chain{V,G,<:Chain},y::Single{V,G}) where {V,G} = value(y)*x[bladeindex(mdims(V),UInt(basis(y)))]
contraction(x::Chain{V,G,<:Chain},y::Submanifold{V,G}) where {V,G} = x[bladeindex(mdims(V),UInt(y))]
#contraction(a::Chain{V,L,<:Chain{V,G},N},b::Chain{V,G,<:Chain{V},M}) where {V,G,L,N,M} = Chain{V,G}(contraction.(Ref(a),value(b)))
contraction(x::Chain{V,L,<:Chain{V,G},N},y::Chain{V,G,<:Chain{V,L},N}) where {L,N,V,G} = Chain{V,G}(contraction.(Ref(x),value(y)))
contraction(x::Chain{V,L,<:Chain{V,G},N},y::Chain{V,G,<:Chain{V,L},N}) where {L,N,V,G} = Chain{V,G}(contraction_mat.(Ref(x),value(y)))
contraction_mat(x::Chain{W,L,<:Chain{V,G},N},y::Chain{V,G,T,N}) where {W,L,N,V,G,T} = Chain{V,G}(matmul(value(x),value(y)))
contraction(x::Chain{W,L,<:Chain{V,G},N},y::Chain{V,G,T,N}) where {W,L,N,V,G,T} = Chain{V,G}(matmul(value(x),value(y)))
contraction(x::Chain{W,L,<:Multivector{V},N},y::Chain{V,G,T,N}) where {W,L,N,V,G,T} = Multivector{V}(matmul(value(x),value(y)))
contraction(x::Multivector{W,<:Chain{V,G},N},y::Multivector{V,T,N}) where {W,N,V,G,T} = Chain{V,G}(matmul(value(x),value(y)))
Expand Down Expand Up @@ -730,6 +732,13 @@ contraction(a::DiagonalOperator{V,<:Multivector{V}},b::Multivector{V}) where V =
contraction(a::Multivector{V},b::DiagonalOperator{V,<:Multivector{V}}) where V = Multivector{V}(value(a).*value(value(b)))
contraction(a::DiagonalOperator{V,<:Multivector{V}},b::DiagonalOperator{V,<:Multivector{V}}) where V = DiagonalOperator(Multivector{V}(value(value(a)).*value(value(b))))

contraction(a::Outermorphism{V},b::Endomorphism{V,<:Chain{V,G}}) where {V,G} = contraction(TensorOperator(a[G]),b)
contraction(a::Endomorphism{V,<:Chain{V,G}},b::Outermorphism{V}) where {V,G} = contraction(a,TensorOperator(b[G]))
contraction(a::DiagonalOperator{V,<:Multivector},b::Endomorphism{V,<:Chain{V,G}}) where {V,G} = contraction(DiagonalOperator(value(a)(Val(G))),b)
contraction(a::Endomorphism{V,<:Chain{V,G}},b::DiagonalOperator{V,<:Multivector}) where {V,G} = contraction(a,DiagonalOperator(value(b)(Val(G))))
contraction(a::DiagonalOperator{V,<:Chain{V,G}},b::Endomorphism{V,<:Chain{V,G}}) where {V,G} = TensorOperator(Chain{V,G}(value(value(a)).*value(value(b))))
contraction(a::Endomorphism{V,<:Chain{V,G}},b::DiagonalOperator{V,<:Chain{V,G}}) where {V,G} = TensorOperator(Chain{V,G}(value(value(a)).*value(value(b))))

contraction(a::Outermorphism{V},b::TensorGraded{V,G}) where {V,G} = contraction(a[G],b)
contraction(a::TensorGraded{V,G},b::Outermorphism{V}) where {V,G} = contraction(a,b[G])
contraction(a::Outermorphism{V},b::Outermorphism{V}) where V = Outermorphism(contraction.(a.v,b.v))
Expand Down

0 comments on commit ed961c6

Please sign in to comment.