-
Notifications
You must be signed in to change notification settings - Fork 8
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Performance comparison between e3nn and cuEquivariance of TensorProduct Interface #45
Comments
I encountered the same issue too. I installed cuequivariance from source (commit id: 3557c31) and ran a series of tests on an A100 GPU. Here are the relevant details about my environment and test script: # GPU info
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 525.147.05 Driver Version: 525.147.05 CUDA Version: 12.0 |
|-------------------------------+----------------------+----------------------+
| GPU Name Persistence-M| Bus-Id Disp.A | Volatile Uncorr. ECC |
| Fan Temp Perf Pwr:Usage/Cap| Memory-Usage | GPU-Util Compute M. |
| | | MIG M. |
|===============================+======================+======================|
| 0 NVIDIA A100-SXM... On | 00000000:53:00.0 Off | 0 |
| N/A 29C P0 63W / 400W | 2MiB / 81920MiB | 0% Default |
| | | Disabled |
+-------------------------------+----------------------+----------------------+ # python packages
Package Version Editable project location
----------------------------- ---------- ----------------------------------------------
cuequivariance 0.1.0 /workspace/cuequivariance/cuequivariance
cuequivariance-ops-torch-cu12 0.1.0
cuequivariance-torch 0.1.0 /workspace/cuequivariance/cuequivariance_torch
e3nn 0.5.4
filelock 3.16.1
fsspec 2024.10.0
Jinja2 3.1.4
MarkupSafe 3.0.2
mpmath 1.3.0
networkx 3.4.2
numpy 2.2.0
nvidia-cublas-cu12 12.4.5.8
nvidia-cuda-cupti-cu12 12.4.127
nvidia-cuda-nvrtc-cu12 12.4.127
nvidia-cuda-runtime-cu12 12.4.127
nvidia-cudnn-cu12 9.1.0.70
nvidia-cufft-cu12 11.2.1.3
nvidia-curand-cu12 10.3.5.147
nvidia-cusolver-cu12 11.6.1.9
nvidia-cusparse-cu12 12.3.1.170
nvidia-nccl-cu12 2.21.5
nvidia-nvjitlink-cu12 12.4.127
nvidia-nvtx-cu12 12.4.127
opt_einsum 3.4.0
opt-einsum-fx 0.1.4
packaging 24.2
pip 24.2
scipy 1.14.1
setuptools 75.1.0
sympy 1.13.1
torch 2.5.1
triton 3.1.0
typing_extensions 4.12.2
wheel 0.44.0 import logging
import statistics
from typing import Callable, Optional
import time
import sys
import cuequivariance as cue
import cuequivariance_torch as cuet
from e3nn import o3
import numpy as np
import torch
from torch.profiler import profile, record_function, ProfilerActivity
logging.basicConfig(level=logging.INFO)
@torch.no_grad()
def timing(repeat: int, warmup: int, fn: Callable, *args):
warmup = 1 if warmup <= 0 else warmup
for _ in range(warmup):
fn(*args)
torch.cuda.synchronize()
repeat = 1 if repeat <= 0 else repeat
times = []
for _ in range(repeat):
start = time.monotonic_ns()
fn(*args)
torch.cuda.synchronize()
end = time.monotonic_ns()
times.append(end - start)
ns_to_ms = 1e6
print(f'min = {min(times) / ns_to_ms:.3f} ms')
print(f'max = {max(times) / ns_to_ms:.3f} ms')
avg = statistics.mean(times)
print(f'avg = { avg / ns_to_ms:.3f} ms')
print(f'median = {statistics.median(times) / ns_to_ms:.3f} ms')
print(f'stdev = {statistics.pstdev(times, avg) / ns_to_ms:.3f}')
# compare FullyConnectedTensorProduct
in_irreps_str = '32x0e'
out_irreps_str = '32x0e+8x1o'
sh_irreps_str = '1x0e+1x1o+1x2e'
in_size = [15496, 32]
sh_size = [15496, 9]
weight_size = [15496, 1280]
in_tensor = torch.randn(in_size).to('cuda')
sh_tensor = torch.randn(sh_size).to('cuda')
weight_tensor = torch.randn(weight_size).to('cuda')
warmup = 10
repeat = 1000
print('-----cue fc-----')
in_irreps = cue.Irreps(cue.O3, in_irreps_str)
sh_irreps = cue.Irreps(cue.O3, sh_irreps_str)
out_irreps = cue.Irreps(cue.O3, out_irreps_str)
in_layout = cuet.primitives.transpose.TransposeIrrepsLayout(
in_irreps,
source=cue.mul_ir,
target=cue.ir_mul,
device=in_tensor.device
)
sh_layout = cuet.primitives.transpose.TransposeIrrepsLayout(
sh_irreps,
source=cue.mul_ir,
target=cue.ir_mul,
device=sh_tensor.device
)
out_layout = cuet.primitives.transpose.TransposeIrrepsLayout(
out_irreps,
source=cue.ir_mul,
target=cue.mul_ir,
device=in_tensor.device
)
cu_tp = cuet.FullyConnectedTensorProduct(
in_irreps,
sh_irreps,
out_irreps,
layout=cue.IrrepsLayout.ir_mul,
shared_weights=False,
device=in_tensor.device
)
timing(repeat, warmup, cu_tp, in_tensor, sh_tensor, weight_tensor)
cu_out = cu_tp(in_layout(in_tensor), sh_layout(sh_tensor), weight_tensor)
cu_out = out_layout(cu_out)
print('-----e3nn fc-----')
e3nn_tp = o3.FullyConnectedTensorProduct(
in_irreps_str,
sh_irreps_str,
out_irreps_str,
shared_weights=False)
timing(repeat, warmup, e3nn_tp, in_tensor, sh_tensor, weight_tensor)
e3nn_out = e3nn_tp(in_tensor, sh_tensor, weight_tensor)
assert torch.allclose(cu_out, e3nn_out, atol=1e-5)
# compare spherical_harmonics
lmax = 2
sh_irreps = o3.Irreps.spherical_harmonics(lmax)
ls = [0, 1, 2]
vec_size = [15496, 3]
vec_tensor = torch.randn(vec_size).to('cuda') * 10
repeat = 10
print('-----cue sh-----')
def spherical_harmonics(
ls: list[int],
vectors: torch.Tensor,
normalize: bool = True,
use_fallback: Optional[bool] = None,
optimize_fallback: Optional[bool] = None,
) -> torch.Tensor:
if isinstance(ls, int):
ls = [ls]
assert ls == sorted(set(ls))
assert vectors.shape[-1] == 3
if normalize:
vectors = torch.nn.functional.normalize(vectors, dim=-1)
x = vectors.reshape(-1, 3)
m = cuet.EquivariantTensorProduct(
cue.descriptors.spherical_harmonics(cue.O3(1, -1), ls),
layout=cue.ir_mul,
device=x.device,
math_dtype=x.dtype,
use_fallback=use_fallback,
optimize_fallback=optimize_fallback,
)
y = m([x])
y = y.reshape(vectors.shape[:-1] + (y.shape[-1],))
return y
sh_in_layout = cuet.primitives.transpose.TransposeIrrepsLayout(
cue.Irreps(cue.O3, '1o'),
source=cue.mul_ir,
target=cue.ir_mul,
device=vec_tensor.device
)
sh_out_layout = cuet.primitives.transpose.TransposeIrrepsLayout(
cue.Irreps(cue.O3, '1x0e+1x1o+1x2e'),
source=cue.ir_mul,
target=cue.mul_ir,
device=vec_tensor.device
)
timing(repeat, warmup, spherical_harmonics, ls, vec_tensor, True, True, True)
cu_out = spherical_harmonics(ls, sh_in_layout(vec_tensor), True, True, True)
cu_out = sh_out_layout(cu_out)
print('-----e3nn sh-----')
timing(repeat, warmup, o3.spherical_harmonics, ls, vec_tensor, True, 'component')
e3nn_out = o3.spherical_harmonics(sh_irreps, vec_tensor, normalization='component', normalize=True)
e3nn_out_ls = o3.spherical_harmonics(ls, vec_tensor, normalization='component', normalize=True)
if not torch.allclose(cu_out, e3nn_out, atol=1e-5):
print(cu_out)
print(e3nn_out)
print(cu_out - e3nn_out) Upon executing the provided script, I observed the following results:
It appears that there are no CUDA kernels being utilized for the spherical harmonics operation in cuEquivariance, leading to significantly longer execution times compared to e3nn. Is there any planned update or fix to address this inefficiency? Additionally, I noticed that the results from cuet.spherical_harmonics do not match those from e3nn.o3.spherical_harmonics. Could you provide insight into why these discrepancies occur? Ensuring consistency between implementations is crucial for our applications. |
Hi, We are only providing some custom CUDA kernels for somehow limited scope of tensor products. We aim to progressively add more custom kernels. To better answer you I'm working on adding a page to the documentation to list the custom kernels we have and which cases they are useful for. See #48 When our kernels are not providing good performance please use the argument |
Hi @mariogeiger and others, First of all, thank you for implementing and sharing cu-equivariance. The flexibility it has on different sorts of irreps is amazing. I am mainly interested in using this for fully connected TP (inputs are both 32x0e+32x1e+32x2e) and, I am also not seeing performance improvement vs e3nn, similar to others. My guess is that this is expected for now because underlying custom kernels are only meant for tensor products which have small operand sizes? Also for the "baseline" implementations of DiffDock and MACE, is it referring to the original implementation by the authors? Thanks and looking forward to seeing more custom kernels! |
You mean this operation, right? irreps = cue.Irreps("O3", "32x0e + 32x1e + 32x2e")
e = cue.descriptors.fully_connected_tensor_product(
irreps, irreps, irreps
).flatten_coefficient_modes() And yes I think this operation has too big inputs to fit in the shared memory and that's why our Yes by "baseline" implementation we mean the original implementation by the authors. |
Here is the dimension of the irreps that I am tinkering with.
If you could also let me know the upper bound of |
The upper bound is the shared memory. For instance it's about 256KB for Hopper (source). For instance in_irreps1 = cue.Irreps("SO3", "16x0 + 16x1 + 16x2")
in_irreps2 = cue.Irreps("SO3", "16x0 + 16x1 + 16x2")
out_irreps = cue.Irreps("SO3", "16x0 + 16x1 + 16x2")
d = cue.descriptors.fully_connected_tensor_product(in_irreps1, in_irreps2, out_irreps).d
dtype = torch.float32
sum(ope.size for ope in d.operands) * dtype.itemsize / 1024
# 241.6875 should just fit on a Hopper GPU. (I didn't test it but that's my understanding) |
@LHJ1098826475 👋 Hello from the Product team on cuEquivariance. Can you please share what you are using cuEquivariance for? |
Thank you for your team's concern. We are conducting our research based on the Allegro model. We hope to understand cuEqu's optimization to accelerate the model. |
Hi,
I'm glad you shared this technical blog (https://developer.nvidia.com/blog/accelerate-drug-and-material-discovery-with-new-math-library-nvidia-cuequivariance/).
I compared the performance of e3nn and cuEquivariance based on the test case provided in (https://docs.nvidia.com/cuda/cuequivariance/tutorials/stp.html), and the results showed that cuEquivariance had no performance benefits. Is there any problem with my test script?
Besides the performance acceleration of cuEquivariance in the MACE and DiffDock models, have you conducted a performance comparison of TensorProduct interface? Can you provide test scripts at TensorProduct interface level?
Thanks
The text was updated successfully, but these errors were encountered: