Skip to content

Commit

Permalink
expose StoreT parameter for potential speed (#838)
Browse files Browse the repository at this point in the history
* expose StoreT parameter for potential speed

* add storeT to more elementwise

---------

Co-authored-by: Haicheng Wu <[email protected]>
  • Loading branch information
erees1 and hwu36 authored Mar 10, 2023
1 parent 29801e3 commit 86cae03
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ template <
int ElementsPerAccess,
typename ElementwiseOp_ = Identity<ElementCompute_>,
typename BinaryOp_ = plus<ElementCompute_>,
bool StoreT_ = true,
typename ElementVector_ = ElementC_
>
class LinearCombinationBiasElementwise {
Expand Down Expand Up @@ -97,7 +98,7 @@ class LinearCombinationBiasElementwise {
static bool const kStoreZ = true;

/// If true, the 'T' tensor is stored
static bool const kStoreT = true;
static bool const kStoreT = StoreT_;

/// Host-constructable parameters structure
struct Params {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,7 @@ template <
typename ElementCompute_,
typename ElementZ_,
int ElementsPerAccess,
bool StoreT = true,
bool StoreT_ = true,
typename ElementVector_ = ElementC_
>
class LinearCombinationBiasRelu {
Expand Down Expand Up @@ -238,7 +238,7 @@ class LinearCombinationBiasRelu {
static bool const kStoreZ = true;

/// If true, the 'T' tensor is stored
static bool const kStoreT = StoreT;
static bool const kStoreT = StoreT_;

/// Host-constructable parameters structure
struct Params {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ template <typename ElementOutput_, typename ElementAccumulator_,
template <typename T> class BinaryOp1_,
template <typename T> class UnaryOp_,
template <typename T> class BinaryOp2_ = detail::NoOp,
bool StoreT_ = false,
typename ElementVector_ = ElementC_>
class LinearCombinationResidualBlock {
public:
Expand Down Expand Up @@ -90,7 +91,7 @@ class LinearCombinationResidualBlock {

static bool const kIsHeavy = true;
static bool const kStoreZ = true;
static bool const kStoreT = false;
static bool const kStoreT = StoreT_;

/// Host-constructable parameters structure
struct Params {
Expand Down Expand Up @@ -182,11 +183,12 @@ template <typename ElementOutput_, typename ElementAccumulator_,
template <typename T> class ActivationOp_,
template <typename T> class BinaryOp1_,
template <typename T> class UnaryOp_,
bool StoreT_,
typename ElementVector_>
class LinearCombinationResidualBlock<ElementOutput_, ElementAccumulator_,
ElementCompute_, ElementC_, ElementsPerAccess,
ActivationOp_, BinaryOp1_, UnaryOp_,
detail::NoOp, ElementVector_> {
detail::NoOp, StoreT_, ElementVector_> {
public:
static bool const kIsSingleSource = true;

Expand Down Expand Up @@ -214,7 +216,7 @@ class LinearCombinationResidualBlock<ElementOutput_, ElementAccumulator_,

static bool const kIsHeavy = true;
static bool const kStoreZ = true;
static bool const kStoreT = false;
static bool const kStoreT = StoreT_;

/// Host-constructable parameters structure
struct Params {
Expand Down

0 comments on commit 86cae03

Please sign in to comment.