-
Notifications
You must be signed in to change notification settings - Fork 30
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Introduce vectorized parallel welford batchnorm for channels last path #1317
base: main
Are you sure you want to change the base?
Conversation
Depend on #1306 |
Test case: import torch
import torch.nn as nn
N, C, H, W = 4, 24, 160, 256
x = torch.randn(N, C, H, W).to(memory_format=torch.channels_last).bfloat16().xpu()
bn = nn.BatchNorm2d(C).xpu()
prof_xpu = torch.profiler.profile(
activities=[
torch.profiler.ProfilerActivity.CPU,
torch.profiler.ProfilerActivity.XPU],
)
with prof_xpu:
for i in range(120):
output = bn(x)
print(prof_xpu.key_averages(group_by_input_shape=True).table(sort_by="self_xpu_time_total", row_limit=100000))
print(output.dtype) Old: ------------------------------------------------------- ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------
Name Self CPU % Self CPU CPU total % CPU total CPU time avg Self XPU Self XPU % XPU total XPU time avg # of Calls
------------------------------------------------------- ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------
aten::native_batch_norm 32.46% 36.835ms 55.59% 63.086ms 525.716us 9.036ms 94.62% 9.284ms 77.364us 120
at::native::xpu::BatchNormCollectStatisticsChannelsL... 0.00% 0.000us 0.00% 0.000us 0.000us 5.028ms 52.65% 5.028ms 41.899us 120
at::native::xpu::BatchNormTransformInputChannelsLast... 0.00% 0.000us 0.00% 0.000us 0.000us 3.707ms 38.81% 3.707ms 30.889us 120
at::native::xpu::UnrolledElementwiseForMultiOutputsK... 0.00% 0.000us 0.00% 0.000us 0.000us 301.760us 3.16% 301.760us 2.515us 120
aten::add_ 39.55% 44.881ms 43.19% 49.021ms 408.507us 266.400us 2.79% 266.400us 2.220us 120
at::native::xpu::VectorizedElementwiseKernel<2, at::... 0.00% 0.000us 0.00% 0.000us 0.000us 266.400us 2.79% 266.400us 2.220us 120
aten::fill_ 10.76% 12.215ms 13.91% 15.782ms 131.518us 247.360us 2.59% 247.360us 2.061us 120
at::native::xpu::VectorizedElementwiseKernel<4, at::... 0.00% 0.000us 0.00% 0.000us 0.000us 247.360us 2.59% 247.360us 2.061us 120
urEnqueueKernelLaunch 14.10% 16.002ms 14.10% 16.002ms 26.670us 0.000us 0.00% 0.000us 0.000us 600
aten::batch_norm 0.32% 358.626us 56.81% 64.473ms 537.271us 0.000us 0.00% 9.284ms 77.364us 120
aten::_batch_norm_impl_index 0.59% 672.126us 56.49% 64.114ms 534.282us 0.000us 0.00% 9.284ms 77.364us 120
aten::empty 0.77% 877.312us 0.77% 877.312us 1.462us 0.000us 0.00% 0.000us 0.000us 600
aten::empty_like 0.26% 291.281us 0.90% 1.016ms 8.469us 0.000us 0.00% 0.000us 0.000us 120
aten::empty_strided 0.48% 543.179us 0.64% 725.000us 6.042us 0.000us 0.00% 0.000us 0.000us 120
aten::zeros 0.31% 347.857us 14.58% 16.543ms 137.856us 0.000us 0.00% 247.360us 2.061us 120
aten::zero_ 0.25% 287.867us 14.16% 16.070ms 133.917us 0.000us 0.00% 247.360us 2.061us 120
urUSMDeviceAlloc 0.16% 181.821us 0.16% 181.821us 181.821us 0.000us 0.00% 0.000us 0.000us 1
------------------------------------------------------- ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------
Self CPU time total: 113.493ms
Self XPU time total: 9.550ms Optimized: ------------------------------------------------------- ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------
Name Self CPU % Self CPU CPU total % CPU total CPU time avg Self XPU Self XPU % XPU total XPU time avg # of Calls
------------------------------------------------------- ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------
aten::native_batch_norm 31.41% 36.805ms 57.46% 67.325ms 561.045us 7.634ms 93.41% 7.904ms 65.865us 120
at::native::xpu::WelfordBatchNormStatChannelsLastVec... 0.00% 0.000us 0.00% 0.000us 0.000us 4.044ms 49.48% 4.044ms 33.696us 120
at::native::xpu::BatchNormTransformInputChannelsLast... 0.00% 0.000us 0.00% 0.000us 0.000us 3.279ms 40.12% 3.279ms 27.325us 120
at::native::xpu::UnrolledElementwiseForMultiOutputsK... 0.00% 0.000us 0.00% 0.000us 0.000us 311.360us 3.81% 311.360us 2.595us 120
aten::fill_ 10.60% 12.422ms 13.69% 16.040ms 133.669us 269.920us 3.30% 269.920us 2.249us 120
at::native::xpu::VectorizedElementwiseKernel<4, at::... 0.00% 0.000us 0.00% 0.000us 0.000us 269.920us 3.30% 269.920us 2.249us 120
aten::add_ 37.56% 44.005ms 41.29% 48.382ms 403.184us 268.800us 3.29% 268.800us 2.240us 120
at::native::xpu::VectorizedElementwiseKernel<2, at::... 0.00% 0.000us 0.00% 0.000us 0.000us 268.800us 3.29% 268.800us 2.240us 120
urEnqueueKernelLaunch 17.36% 20.341ms 17.36% 20.341ms 33.902us 0.000us 0.00% 0.000us 0.000us 600
aten::batch_norm 0.32% 375.749us 58.71% 68.788ms 573.231us 0.000us 0.00% 7.904ms 65.865us 120
aten::_batch_norm_impl_index 0.60% 701.913us 58.39% 68.412ms 570.100us 0.000us 0.00% 7.904ms 65.865us 120
aten::empty 0.77% 897.311us 0.77% 897.311us 1.496us 0.000us 0.00% 0.000us 0.000us 600
aten::empty_like 0.26% 308.416us 0.85% 994.647us 8.289us 0.000us 0.00% 0.000us 0.000us 120
aten::empty_strided 0.47% 554.852us 0.59% 686.231us 5.719us 0.000us 0.00% 0.000us 0.000us 120
aten::zeros 0.32% 372.505us 14.33% 16.792ms 139.929us 0.000us 0.00% 269.920us 2.249us 120
aten::zero_ 0.22% 255.474us 13.91% 16.296ms 135.798us 0.000us 0.00% 269.920us 2.249us 120
urUSMDeviceAlloc 0.11% 131.379us 0.11% 131.379us 131.379us 0.000us 0.00% 0.000us 0.000us 1
------------------------------------------------------- ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------
Self CPU time total: 117.170ms
Self XPU time total: 8.173ms |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@xytintel : I am waiting for Welford algorithm implementation for XPU, but not from the performance aspect of the story, but from the functional and accuracy side. We have the known pytorch/pytorch#141642 affecting Transformers tests and related to the missing Welford for LayerNorm. Can you, please, extend the fix to cover LayerNorm to address the known accuracy problems?
@dvrogozh Yes, we will apply Welford to all Norm operators. This PR focuses solely on performance, while @min-jean-cho is working on LayerNorm. |
Mainly added vectorization and optimized the workload, achieving a performance improvement of up to 1.169x in certain cases.