diff --git a/src/tensors/contraction.jl b/src/tensors/contraction.jl index 03b8e1c9..21df02ab 100644 --- a/src/tensors/contraction.jl +++ b/src/tensors/contraction.jl @@ -45,6 +45,7 @@ function contraction!(plan::ContractionPlan, α, a, b, β, c, d) unaryOpA = plan.desc.descA.unaryOp unaryOpB = plan.desc.descB.unaryOp unaryOpC = plan.desc.descC.unaryOp + unaryOpD = plan.desc.descD.unaryOp α = plan.desc.computeType(α) β = plan.desc.computeType(β) @@ -52,9 +53,10 @@ function contraction!(plan::ContractionPlan, α, a, b, β, c, d) if plan.algo == ALGO_GETT GemmKernels.matmul( a, b, c, d, plan.algorithmPlan.gemmConf, - transform_shared_to_regs_a = Transform.Elementwise(x -> α * unaryOpA(x)), + transform_shared_to_regs_a = Transform.Elementwise(x -> unaryOpA(α * x)), transform_shared_to_regs_b = Transform.Elementwise(x -> unaryOpB(x)), transform_shared_to_regs_c = Transform.Elementwise(x -> β * unaryOpC(x)), + transform_regs_to_shared_d = Transform.Elementwise(x -> unaryOpD(x)), kernel = Kernel.matmul_pipelined, ) else diff --git a/src/tensors/descriptor.jl b/src/tensors/descriptor.jl index b8487e86..2565156b 100644 --- a/src/tensors/descriptor.jl +++ b/src/tensors/descriptor.jl @@ -1,3 +1,4 @@ +using CUDA using GemmKernels ModeType = AbstractVector{<:Union{Char,Integer}} @@ -55,10 +56,27 @@ mutable struct ContractionDescriptor dataType::DataType function ContractionDescriptor( - a, modeA::ModeType, - b, modeB::ModeType, - c, modeC::ModeType, - d, modeD::ModeType; + descA::TensorDescriptor, modeA::ModeType, + descB::TensorDescriptor, modeB::ModeType, + descC::TensorDescriptor, modeC::ModeType, + descD::TensorDescriptor, modeD::ModeType, + computeType, + dataType + ) + return new( + descA, modeA, + descB, modeB, + descC, modeC, + descD, modeD, + computeType, dataType + ) + end + + function ContractionDescriptor( + a::CuArray, modeA::ModeType, + b::CuArray, modeB::ModeType, + c::CuArray, modeC::ModeType, + d::CuArray, modeD::ModeType; computeType=eltype(a), dataType=eltype(c) )