Skip to content

Commit

Permalink
unary operator bug fix
Browse files Browse the repository at this point in the history
  • Loading branch information
wardvermeulen authored and thomasfaingnaert committed Jun 26, 2023
1 parent be47324 commit 31156dc
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 5 deletions.
4 changes: 3 additions & 1 deletion src/tensors/contraction.jl
Original file line number Diff line number Diff line change
Expand Up @@ -45,16 +45,18 @@ function contraction!(plan::ContractionPlan, α, a, b, β, c, d)
unaryOpA = plan.desc.descA.unaryOp
unaryOpB = plan.desc.descB.unaryOp
unaryOpC = plan.desc.descC.unaryOp
unaryOpD = plan.desc.descD.unaryOp

α = plan.desc.computeType(α)
β = plan.desc.computeType(β)

if plan.algo == ALGO_GETT
GemmKernels.matmul(
a, b, c, d, plan.algorithmPlan.gemmConf,
transform_shared_to_regs_a = Transform.Elementwise(x -> α * unaryOpA(x)),
transform_shared_to_regs_a = Transform.Elementwise(x -> unaryOpA(α * x)),
transform_shared_to_regs_b = Transform.Elementwise(x -> unaryOpB(x)),
transform_shared_to_regs_c = Transform.Elementwise(x -> β * unaryOpC(x)),
transform_regs_to_shared_d = Transform.Elementwise(x -> unaryOpD(x)),
kernel = Kernel.matmul_pipelined,
)
else
Expand Down
26 changes: 22 additions & 4 deletions src/tensors/descriptor.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
using CUDA
using GemmKernels

ModeType = AbstractVector{<:Union{Char,Integer}}
Expand Down Expand Up @@ -55,10 +56,27 @@ mutable struct ContractionDescriptor
dataType::DataType

function ContractionDescriptor(
a, modeA::ModeType,
b, modeB::ModeType,
c, modeC::ModeType,
d, modeD::ModeType;
descA::TensorDescriptor, modeA::ModeType,
descB::TensorDescriptor, modeB::ModeType,
descC::TensorDescriptor, modeC::ModeType,
descD::TensorDescriptor, modeD::ModeType,
computeType,
dataType
)
return new(
descA, modeA,
descB, modeB,
descC, modeC,
descD, modeD,
computeType, dataType
)
end

function ContractionDescriptor(
a::CuArray, modeA::ModeType,
b::CuArray, modeB::ModeType,
c::CuArray, modeC::ModeType,
d::CuArray, modeD::ModeType;
computeType=eltype(a),
dataType=eltype(c)
)
Expand Down

0 comments on commit 31156dc

Please sign in to comment.