Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Introduce vectorized parallel welford batchnorm for channels last path #1317

Open
wants to merge 9 commits into
base: main
Choose a base branch
from

Conversation

xytintel
Copy link
Contributor

@xytintel xytintel commented Jan 22, 2025

Mainly added vectorization and optimized the workload, achieving a performance improvement of up to 1.169x in certain cases.

@xytintel
Copy link
Contributor Author

Depend on #1306

@xytintel xytintel changed the title Introduce parallel fused welford batchnorm for channels last path Introduce parallel welford batchnorm for channels last path Jan 24, 2025
@xytintel
Copy link
Contributor Author

Test case:

import torch
import torch.nn as nn

N, C, H, W = 4, 24, 160, 256
x = torch.randn(N, C, H, W).to(memory_format=torch.channels_last).bfloat16().xpu()
bn = nn.BatchNorm2d(C).xpu()

prof_xpu = torch.profiler.profile(
    activities=[
    torch.profiler.ProfilerActivity.CPU,
    torch.profiler.ProfilerActivity.XPU],
)
with prof_xpu:
    for i in range(120):
        output = bn(x)
print(prof_xpu.key_averages(group_by_input_shape=True).table(sort_by="self_xpu_time_total", row_limit=100000))
print(output.dtype)

Old:

-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg      Self XPU    Self XPU %     XPU total  XPU time avg    # of Calls  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                aten::native_batch_norm        32.46%      36.835ms        55.59%      63.086ms     525.716us       9.036ms        94.62%       9.284ms      77.364us           120  
at::native::xpu::BatchNormCollectStatisticsChannelsL...         0.00%       0.000us         0.00%       0.000us       0.000us       5.028ms        52.65%       5.028ms      41.899us           120  
at::native::xpu::BatchNormTransformInputChannelsLast...         0.00%       0.000us         0.00%       0.000us       0.000us       3.707ms        38.81%       3.707ms      30.889us           120  
at::native::xpu::UnrolledElementwiseForMultiOutputsK...         0.00%       0.000us         0.00%       0.000us       0.000us     301.760us         3.16%     301.760us       2.515us           120  
                                             aten::add_        39.55%      44.881ms        43.19%      49.021ms     408.507us     266.400us         2.79%     266.400us       2.220us           120  
at::native::xpu::VectorizedElementwiseKernel<2, at::...         0.00%       0.000us         0.00%       0.000us       0.000us     266.400us         2.79%     266.400us       2.220us           120  
                                            aten::fill_        10.76%      12.215ms        13.91%      15.782ms     131.518us     247.360us         2.59%     247.360us       2.061us           120  
at::native::xpu::VectorizedElementwiseKernel<4, at::...         0.00%       0.000us         0.00%       0.000us       0.000us     247.360us         2.59%     247.360us       2.061us           120  
                                  urEnqueueKernelLaunch        14.10%      16.002ms        14.10%      16.002ms      26.670us       0.000us         0.00%       0.000us       0.000us           600  
                                       aten::batch_norm         0.32%     358.626us        56.81%      64.473ms     537.271us       0.000us         0.00%       9.284ms      77.364us           120  
                           aten::_batch_norm_impl_index         0.59%     672.126us        56.49%      64.114ms     534.282us       0.000us         0.00%       9.284ms      77.364us           120  
                                            aten::empty         0.77%     877.312us         0.77%     877.312us       1.462us       0.000us         0.00%       0.000us       0.000us           600  
                                       aten::empty_like         0.26%     291.281us         0.90%       1.016ms       8.469us       0.000us         0.00%       0.000us       0.000us           120  
                                    aten::empty_strided         0.48%     543.179us         0.64%     725.000us       6.042us       0.000us         0.00%       0.000us       0.000us           120  
                                            aten::zeros         0.31%     347.857us        14.58%      16.543ms     137.856us       0.000us         0.00%     247.360us       2.061us           120  
                                            aten::zero_         0.25%     287.867us        14.16%      16.070ms     133.917us       0.000us         0.00%     247.360us       2.061us           120  
                                       urUSMDeviceAlloc         0.16%     181.821us         0.16%     181.821us     181.821us       0.000us         0.00%       0.000us       0.000us             1  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
Self CPU time total: 113.493ms
Self XPU time total: 9.550ms

Optimized:

-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg      Self XPU    Self XPU %     XPU total  XPU time avg    # of Calls  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                aten::native_batch_norm        31.41%      36.805ms        57.46%      67.325ms     561.045us       7.634ms        93.41%       7.904ms      65.865us           120  
at::native::xpu::WelfordBatchNormStatChannelsLastVec...         0.00%       0.000us         0.00%       0.000us       0.000us       4.044ms        49.48%       4.044ms      33.696us           120  
at::native::xpu::BatchNormTransformInputChannelsLast...         0.00%       0.000us         0.00%       0.000us       0.000us       3.279ms        40.12%       3.279ms      27.325us           120  
at::native::xpu::UnrolledElementwiseForMultiOutputsK...         0.00%       0.000us         0.00%       0.000us       0.000us     311.360us         3.81%     311.360us       2.595us           120  
                                            aten::fill_        10.60%      12.422ms        13.69%      16.040ms     133.669us     269.920us         3.30%     269.920us       2.249us           120  
at::native::xpu::VectorizedElementwiseKernel<4, at::...         0.00%       0.000us         0.00%       0.000us       0.000us     269.920us         3.30%     269.920us       2.249us           120  
                                             aten::add_        37.56%      44.005ms        41.29%      48.382ms     403.184us     268.800us         3.29%     268.800us       2.240us           120  
at::native::xpu::VectorizedElementwiseKernel<2, at::...         0.00%       0.000us         0.00%       0.000us       0.000us     268.800us         3.29%     268.800us       2.240us           120  
                                  urEnqueueKernelLaunch        17.36%      20.341ms        17.36%      20.341ms      33.902us       0.000us         0.00%       0.000us       0.000us           600  
                                       aten::batch_norm         0.32%     375.749us        58.71%      68.788ms     573.231us       0.000us         0.00%       7.904ms      65.865us           120  
                           aten::_batch_norm_impl_index         0.60%     701.913us        58.39%      68.412ms     570.100us       0.000us         0.00%       7.904ms      65.865us           120  
                                            aten::empty         0.77%     897.311us         0.77%     897.311us       1.496us       0.000us         0.00%       0.000us       0.000us           600  
                                       aten::empty_like         0.26%     308.416us         0.85%     994.647us       8.289us       0.000us         0.00%       0.000us       0.000us           120  
                                    aten::empty_strided         0.47%     554.852us         0.59%     686.231us       5.719us       0.000us         0.00%       0.000us       0.000us           120  
                                            aten::zeros         0.32%     372.505us        14.33%      16.792ms     139.929us       0.000us         0.00%     269.920us       2.249us           120  
                                            aten::zero_         0.22%     255.474us        13.91%      16.296ms     135.798us       0.000us         0.00%     269.920us       2.249us           120  
                                       urUSMDeviceAlloc         0.11%     131.379us         0.11%     131.379us     131.379us       0.000us         0.00%       0.000us       0.000us             1  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
Self CPU time total: 117.170ms
Self XPU time total: 8.173ms

@xytintel xytintel changed the title Introduce parallel welford batchnorm for channels last path Introduce vectorized parallel welford batchnorm for channels last path Jan 24, 2025
Copy link
Contributor

@dvrogozh dvrogozh left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@xytintel : I am waiting for Welford algorithm implementation for XPU, but not from the performance aspect of the story, but from the functional and accuracy side. We have the known pytorch/pytorch#141642 affecting Transformers tests and related to the missing Welford for LayerNorm. Can you, please, extend the fix to cover LayerNorm to address the known accuracy problems?

@xytintel
Copy link
Contributor Author

@xytintel : I am waiting for Welford algorithm implementation for XPU, but not from the performance aspect of the story, but from the functional and accuracy side. We have the known pytorch/pytorch#141642 affecting Transformers tests and related to the missing Welford for LayerNorm. Can you, please, extend the fix to cover LayerNorm to address the known accuracy problems?

@dvrogozh Yes, we will apply Welford to all Norm operators. This PR focuses solely on performance, while @min-jean-cho is working on LayerNorm.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants