Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Apply vectorization for batch_norm channels last kernel #1306

Closed
wants to merge 2 commits into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
235 changes: 227 additions & 8 deletions src/ATen/native/xpu/sycl/BatchNormKernels.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1558,6 +1558,151 @@ struct BatchNormTransformInputChannelsLastKernelFunctor {
const bool fuse_relu_;
};

template <
typename scalar_t,
typename accscalar_t,
typename layerscalar_t,
int VEC_SIZE>
struct BatchNormTransformInputChannelsLastVectorizedKernelFunctor {
void operator()(sycl::nd_item<1> item) const {
constexpr bool WEIGHT_VEC = std::is_same<scalar_t, layerscalar_t>::value;
int global_id = item.get_global_id(0);
int num_threads_in_c = stride_ / VEC_SIZE;
int b_id = global_id / num_threads_in_c;
int c_vec_begin = global_id % num_threads_in_c * VEC_SIZE;
int m_stride =
item.get_local_range(0) * item.get_group_range(0) / num_threads_in_c;

for (int m_offset = b_id; m_offset < reduction_size_;
m_offset += m_stride) {
using vec_t = memory::aligned_vector<scalar_t, VEC_SIZE>;
auto address_base = m_offset * stride_ + c_vec_begin;
vec_t input_vec = *reinterpret_cast<vec_t*>(
const_cast<scalar_t*>(&input_[address_base]));
vec_t output_vec, z_vec;
if (z_ != nullptr) {
z_vec =
*reinterpret_cast<vec_t*>(const_cast<scalar_t*>(&z_[address_base]));
}

vec_t weight_vec, shift_vec;
if constexpr (WEIGHT_VEC) {
if (weight_ != nullptr) {
weight_vec = *reinterpret_cast<vec_t*>(
const_cast<scalar_t*>(&weight_[c_vec_begin]));
}
if (shift_ != nullptr) {
shift_vec = *reinterpret_cast<vec_t*>(
const_cast<scalar_t*>(&shift_[c_vec_begin]));
}
}

#pragma
for (int i = 0; i < VEC_SIZE; ++i) {
auto c_offset = c_vec_begin + i;
auto m_c = mean_[c_offset];
auto inv_std_c = static_cast<accscalar_t>(inv_std_[c_offset]);

accscalar_t w_c, s_c;

if constexpr (WEIGHT_VEC) {
w_c = weight_ == nullptr
? accscalar_t(1.0)
: static_cast<accscalar_t>(weight_vec.val[i]);
s_c = shift_ == nullptr ? accscalar_t(0.0)
: static_cast<accscalar_t>(shift_vec.val[i]);
} else {
w_c = weight_ == nullptr
? accscalar_t(1.0)
: static_cast<accscalar_t>(weight_[c_offset]);
s_c = shift_ == nullptr ? accscalar_t(0.0)
: static_cast<accscalar_t>(shift_[c_offset]);
}

auto tmp = w_c * (static_cast<accscalar_t>(input_vec.val[i]) - m_c) *
inv_std_c +
s_c;

if (z_ != nullptr) {
tmp += z_vec.val[i];
}

output_vec.val[i] =
(fuse_relu_ && tmp <= accscalar_t(0.0)
? scalar_t(0.0)
: static_cast<scalar_t>(tmp));
}

*reinterpret_cast<vec_t*>(const_cast<scalar_t*>(
&out_[m_offset * stride_ + c_vec_begin])) = output_vec;
}
}

BatchNormTransformInputChannelsLastVectorizedKernelFunctor(
const scalar_t* RESTRICT input,
const scalar_t* RESTRICT z,
const accscalar_t* RESTRICT mean,
const accscalar_t* RESTRICT inv_std,
const layerscalar_t* RESTRICT weight,
const layerscalar_t* RESTRICT shift,
scalar_t* RESTRICT out,
const int reduction_size,
const int stride,
const bool fuse_relu)
: input_(input),
z_(z),
mean_(mean),
inv_std_(inv_std),
weight_(weight),
shift_(shift),
out_(out),
reduction_size_(reduction_size),
stride_(stride),
fuse_relu_(fuse_relu) {}

private:
const scalar_t* RESTRICT input_;
const scalar_t* RESTRICT z_;
const accscalar_t* RESTRICT mean_;
const accscalar_t* RESTRICT inv_std_;
const layerscalar_t* RESTRICT weight_;
const layerscalar_t* RESTRICT shift_;
scalar_t* RESTRICT out_;
const int reduction_size_;
const int stride_;
const bool fuse_relu_;
};

template <typename scalar_t, int VEC_SIZE>
bool can_use_batch_norm_cnl_vec_kernel(
char* input,
char* output,
char* z,
char* weight,
char* shift,
int stride) {
return memory::can_vectorize_up_to<scalar_t>(input) >= VEC_SIZE &&
memory::can_vectorize_up_to<scalar_t>(output) >= VEC_SIZE &&
(z == nullptr || memory::can_vectorize_up_to<scalar_t>(z) >= VEC_SIZE) &&
(weight == nullptr ||
memory::can_vectorize_up_to<scalar_t>(weight) >= VEC_SIZE) &&
(shift == nullptr ||
memory::can_vectorize_up_to<scalar_t>(shift) >= VEC_SIZE) &&
(stride % VEC_SIZE == 0);
}

template <typename scalar_t, int VEC_SIZE>
bool can_use_batch_norm_cnl_vec_kernel(
char* input,
char* output,
char* z,
int stride) {
return memory::can_vectorize_up_to<scalar_t>(input) >= VEC_SIZE &&
memory::can_vectorize_up_to<scalar_t>(output) >= VEC_SIZE &&
(z == nullptr || memory::can_vectorize_up_to<scalar_t>(z) >= VEC_SIZE) &&
(stride % VEC_SIZE == 0);
}

void batch_norm_elemt_channels_last_template(
const at::Tensor& output,
const at::Tensor& input,
Expand All @@ -1567,8 +1712,12 @@ void batch_norm_elemt_channels_last_template(
const at::Tensor& inv_std,
const at::optional<at::Tensor>& z = c10::nullopt, // bias after BN
const bool fuse_relu = false) {
constexpr int VEC_SIZE = 4;
const auto stride = input.sizes()[1];
const auto reduction_size = input.numel() / stride;
int total_vecs = reduction_size * stride / VEC_SIZE;
auto wg_sz = syclMaxWorkItemsPerEU();
int num_wg_for_vec = (total_vecs + wg_sz - 1) / wg_sz;
auto config = get_adaptive_launch_config(
reduction_size, stride, false, ELEMENTS_PER_WORK_ITEM);
auto global_range = std::get<0>(config);
Expand All @@ -1582,18 +1731,51 @@ void batch_norm_elemt_channels_last_template(
AT_DISPATCH_FLOATING_TYPES_AND2(
kHalf, kBFloat16, input.scalar_type(), "batchnorm_forward_xpu", [&] {
using accscalar_t = at::acc_type_device<scalar_t, kXPU>;

auto input_data_ptr = input.const_data_ptr<scalar_t>();
auto output_data_ptr = output.mutable_data_ptr<scalar_t>();
auto z_data_ptr =
z.has_value() ? z.value().const_data_ptr<scalar_t>() : nullptr;

if (can_use_batch_norm_cnl_vec_kernel<scalar_t, VEC_SIZE>(
(char*)input_data_ptr,
(char*)output_data_ptr,
(char*)z_data_ptr,
stride)) {
auto kfn =
BatchNormTransformInputChannelsLastVectorizedKernelFunctor<
scalar_t,
accscalar_t,
accscalar_t,
VEC_SIZE>(
input_data_ptr,
z_data_ptr,
mean.const_data_ptr<accscalar_t>(),
inv_std.const_data_ptr<accscalar_t>(),
weight.defined() ? weight.const_data_ptr<accscalar_t>()
: nullptr,
shift.defined() ? shift.const_data_ptr<accscalar_t>()
: nullptr,
output_data_ptr,
reduction_size,
stride,
fuse_relu);
sycl_kernel_submit(wg_sz * num_wg_for_vec, wg_sz, queue, kfn);
return;
}

auto kfn = BatchNormTransformInputChannelsLastKernelFunctor<
scalar_t,
accscalar_t,
accscalar_t,
ELEMENTS_PER_ITER>(
input.const_data_ptr<scalar_t>(),
z.has_value() ? z.value().const_data_ptr<scalar_t>() : nullptr,
input_data_ptr,
z_data_ptr,
mean.const_data_ptr<accscalar_t>(),
inv_std.const_data_ptr<accscalar_t>(),
weight.defined() ? weight.const_data_ptr<accscalar_t>() : nullptr,
shift.defined() ? shift.const_data_ptr<accscalar_t>() : nullptr,
output.mutable_data_ptr<scalar_t>(),
output_data_ptr,
reduction_size,
stride,
fuse_relu);
Expand All @@ -1611,18 +1793,55 @@ void batch_norm_elemt_channels_last_template(
AT_DISPATCH_FLOATING_TYPES_AND2(
kHalf, kBFloat16, input.scalar_type(), "batchnorm_forward_xpu", [&] {
using accscalar_t = at::acc_type_device<scalar_t, kXPU>;

auto input_data_ptr = input.const_data_ptr<scalar_t>();
auto output_data_ptr = output.mutable_data_ptr<scalar_t>();
auto z_data_ptr =
z.has_value() ? z.value().const_data_ptr<scalar_t>() : nullptr;
auto weight_data_ptr =
weight.defined() ? weight.const_data_ptr<scalar_t>() : nullptr;
auto shift_data_ptr =
shift.defined() ? shift.const_data_ptr<scalar_t>() : nullptr;

if (can_use_batch_norm_cnl_vec_kernel<scalar_t, VEC_SIZE>(
(char*)input_data_ptr,
(char*)output_data_ptr,
(char*)z_data_ptr,
(char*)weight_data_ptr,
(char*)shift_data_ptr,
stride)) {
auto kfn =
BatchNormTransformInputChannelsLastVectorizedKernelFunctor<
scalar_t,
accscalar_t,
scalar_t,
VEC_SIZE>(
input_data_ptr,
z_data_ptr,
mean.const_data_ptr<accscalar_t>(),
inv_std.const_data_ptr<accscalar_t>(),
weight_data_ptr,
shift_data_ptr,
output_data_ptr,
reduction_size,
stride,
fuse_relu);
sycl_kernel_submit(wg_sz * num_wg_for_vec, wg_sz, queue, kfn);
return;
}

auto kfn = BatchNormTransformInputChannelsLastKernelFunctor<
scalar_t,
accscalar_t,
scalar_t,
ELEMENTS_PER_ITER>(
input.const_data_ptr<scalar_t>(),
z.has_value() ? z.value().const_data_ptr<scalar_t>() : nullptr,
input_data_ptr,
z_data_ptr,
mean.const_data_ptr<accscalar_t>(),
inv_std.const_data_ptr<accscalar_t>(),
weight.defined() ? weight.const_data_ptr<scalar_t>() : nullptr,
shift.defined() ? shift.const_data_ptr<scalar_t>() : nullptr,
output.mutable_data_ptr<scalar_t>(),
weight_data_ptr,
shift_data_ptr,
output_data_ptr,
reduction_size,
stride,
fuse_relu);
Expand Down