Skip to content

Commit

Permalink
Merge branch 'main' into minjean/nested_tensor_dispatch
Browse files Browse the repository at this point in the history
  • Loading branch information
xytintel authored Jan 9, 2025
2 parents 4511b17 + 8988335 commit 3533fc8
Show file tree
Hide file tree
Showing 12 changed files with 471 additions and 8 deletions.
2 changes: 2 additions & 0 deletions .github/scripts/apply_torch_pr.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
"https://github.com/pytorch/pytorch/pull/126516",
# Modify the tolerance level in TIMM benchmark
"https://github.com/pytorch/pytorch/pull/143739",
# Fix build error caused by incorrect namespace change by #144014
"https://github.com/pytorch/pytorch/pull/144450",
]
)
parser.add_argument('--extra-pr-list', '-e', nargs='+',default=[])
Expand Down
47 changes: 47 additions & 0 deletions src/ATen/native/xpu/TensorAdvancedIndexing.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,17 +7,24 @@
#include <ATen/core/op_registration/adaption.h>
#include <ATen/native/DispatchStub.h>
#include <ATen/native/IndexKernel.h>
#include <ATen/native/ReductionType.h>
#include <ATen/native/TensorAdvancedIndexing.h>
#include <ATen/native/TensorAdvancedIndexingUtils.h>
#include <ATen/native/TensorIterator.h>
//#include <ATen/native/TensorFactories.cpp>
#include <ATen/native/xpu/sycl/IndexingKernels.h>
#include <ATen/native/xpu/sycl/ScatterGatherKernels.h>
#include <ATen/ops/ones_like.h>
#include <ATen/ops/zeros_like.h>
#include <comm/RegisterUtils.h>
#include <comm/xpu_aten.h>
#include <torch/library.h>

#include <ATen/ops/index_add_meta.h>
#include <ATen/ops/index_reduce_meta.h>
#include <xpu/ATen/ops/index_add_native.h>
#include <xpu/ATen/ops/index_reduce_native.h> //generated
//#include <xpu/ATen/ops/index_reduce_prod_native.h> //generated

namespace at {

Expand All @@ -42,6 +49,7 @@ REGISTER_XPU_DISPATCH(index_fill_stub, &xpu::index_fill_kernel);
REGISTER_XPU_DISPATCH(index_copy_stub, &xpu::index_copy_kernel);
REGISTER_XPU_DISPATCH(put_stub, &xpu::put_kernel);
REGISTER_XPU_DISPATCH(take_stub, &xpu::take_kernel);
// REGISTER_XPU_DISPATCH(index_reduce_stub, &xpu::index_reduce_kernel);

TORCH_IMPL_FUNC(index_add_xpu_out)
(const Tensor& self,
Expand Down Expand Up @@ -126,5 +134,44 @@ Tensor count_nonzero_xpu(const Tensor& self, IntArrayRef dims) {
return (self != 0).sum(dims);
}

TORCH_IMPL_FUNC(index_reduce_xpu_out)
(const Tensor& self,
int64_t dim,
const Tensor& index,
const Tensor& source,
const c10::string_view reduce,
bool include_self,
const Tensor& result) {
TORCH_WARN_ONCE(
"index_reduce() is in beta and the API may change at any time.");
if (reduce == "prod") {
xpu::index_reduce_prod_kernel(
self, dim, index, source, include_self, ReductionType::PROD, result);
} else if (reduce == "mean") {
xpu::index_reduce_mean_kernel(
self, dim, index, source, include_self, ReductionType::MEAN, result);
auto counts = include_self ? ones_like(result) : zeros_like(result);
counts.index_add_(dim, index, ones_like(source));
counts.masked_fill_(counts == 0, 1);
if (result.is_floating_point() || result.is_complex()) {
result.div_(counts);
} else {
result.div_(counts, "floor");
}
} else if (reduce == "amax") {
xpu::index_reduce_amax_kernel(
self, dim, index, source, include_self, ReductionType::MAX, result);
} else if (reduce == "amin") {
xpu::index_reduce_amin_kernel(
self, dim, index, source, include_self, ReductionType::MIN, result);
} else {
TORCH_CHECK(
false,
"Only support prod, mean, amax or amin reduce operator. Input was ",
reduce,
".");
}
}

} // namespace native
} // namespace at
1 change: 0 additions & 1 deletion src/ATen/native/xpu/XPUFallback.template
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,6 @@ TORCH_LIBRARY_IMPL(aten, XPU, m) {
"_fft_r2c",
"_flash_attention_forward",
"geqrf",
"index_reduce.out",
"linalg_cholesky_ex.L",
"_linalg_det.result",
"linalg_eig",
Expand Down
6 changes: 6 additions & 0 deletions src/ATen/native/xpu/sycl/Atomics.h
Original file line number Diff line number Diff line change
Expand Up @@ -360,6 +360,8 @@ SYCL_ATOMIC_INTEGER(Mul, std::multiplies<int8_t>()(a, b), int8_t)
SYCL_ATOMIC_INTEGER(Mul, std::multiplies<int16_t>()(a, b), int16_t)
SYCL_ATOMIC_INTEGER(Mul, std::multiplies<int32_t>()(a, b), int32_t)
SYCL_ATOMIC_INTEGER(Mul, std::multiplies<int64_t>()(a, b), int64_t)
SYCL_ATOMIC_INTEGER(Mul, std::multiplies<uint32_t>()(a, b), uint32_t)
SYCL_ATOMIC_INTEGER(Mul, std::multiplies<uint64_t>()(a, b), uint64_t)

SYCL_ATOMIC_FP(Mul, std::multiplies<float>()(a, b), float)
SYCL_ATOMIC_FP(Mul, std::multiplies<double>()(a, b), double)
Expand Down Expand Up @@ -391,6 +393,8 @@ SYCL_ATOMIC_INTEGER(Max, safe_max<int8_t>(a, b), int8_t)
SYCL_ATOMIC_INTEGER(Max, safe_max<int16_t>(a, b), int16_t)
SYCL_ATOMIC_INTEGER(Max, safe_max<int32_t>(a, b), int32_t)
SYCL_ATOMIC_INTEGER(Max, safe_max<int64_t>(a, b), int64_t)
SYCL_ATOMIC_INTEGER(Max, safe_max<uint32_t>(a, b), uint32_t)
SYCL_ATOMIC_INTEGER(Max, safe_max<uint64_t>(a, b), uint64_t)

SYCL_ATOMIC_FP(Max, safe_max<float>(a, b), float)
SYCL_ATOMIC_FP(Max, safe_max<double>(a, b), double)
Expand All @@ -403,6 +407,8 @@ SYCL_ATOMIC_INTEGER(Min, safe_min<int8_t>(a, b), int8_t)
SYCL_ATOMIC_INTEGER(Min, safe_min<int16_t>(a, b), int16_t)
SYCL_ATOMIC_INTEGER(Min, safe_min<int32_t>(a, b), int32_t)
SYCL_ATOMIC_INTEGER(Min, safe_min<int64_t>(a, b), int64_t)
SYCL_ATOMIC_INTEGER(Min, safe_min<uint32_t>(a, b), uint32_t)
SYCL_ATOMIC_INTEGER(Min, safe_min<uint64_t>(a, b), uint64_t)

SYCL_ATOMIC_FP(Min, safe_min<float>(a, b), float)
SYCL_ATOMIC_FP(Min, safe_min<double>(a, b), double)
Expand Down
Loading

0 comments on commit 3533fc8

Please sign in to comment.