Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Performance comparison between e3nn and cuEquivariance of TensorProduct Interface #45

Open
LHJ1098826475 opened this issue Dec 12, 2024 · 8 comments

Comments

@LHJ1098826475
Copy link

Hi,

I'm glad you shared this technical blog (https://developer.nvidia.com/blog/accelerate-drug-and-material-discovery-with-new-math-library-nvidia-cuequivariance/).

I compared the performance of e3nn and cuEquivariance based on the test case provided in (https://docs.nvidia.com/cuda/cuequivariance/tutorials/stp.html), and the results showed that cuEquivariance had no performance benefits. Is there any problem with my test script?

# cuEquivariance
import time
import itertools
import torch
import numpy as np

import cuequivariance as cue
import cuequivariance_torch as cuet
import cuequivariance.segmented_tensor_product as stp

device = torch.device("cuda:0")


# init irreps、descriptor
irreps_in = cue.Irreps("O3", "3000x0e")
irreps_out = cue.Irreps("O3", "3000x0e")
descriptor = stp.SegmentedTensorProduct.from_subscripts("uv,iu,iv")

# add segment
for mul, ir in irreps_in:
   descriptor.add_segment(1, (ir.dim, mul))
for mul, ir in irreps_out:
   descriptor.add_segment(2, (ir.dim, mul))

# add path
for (i1, (mul1, ir1)), (i2, (mul2, ir2)) in itertools.product(
   enumerate(irreps_in), enumerate(irreps_out)
):
   if ir1 == ir2:
      descriptor.add_path(None, i1, i2, c=1.0)

descriptor = descriptor.normalize_paths_for_operand(-1)

# compute
linear_torch = cuet.TensorProduct(descriptor, device=device)
w = torch.randn(descriptor.operands[0].size).to(device)
x1 = torch.randn((50000, irreps_in.dim)).to(device)
x2 = linear_torch(w, x1)  # warmup

# time statistics
loop = 10
torch.cuda.synchronize()
start = time.time()
for i in range(loop):
    x2 = linear_torch(w, x1)
torch.cuda.synchronize()
end = time.time()
print((end-start) / loop)

assert x2.shape == (50000, irreps_out.dim)
# e3nn
import time
import torch
from e3nn.o3 import Irreps, Linear

device = torch.device("cuda:0")

# init Linear
irreps_in = Irreps("3000x0e")
irreps_out = Irreps("3000x0e")
tp = Linear(irreps_in, irreps_out).to(device)

# compute
x1 = torch.randn((50000, irreps_in.dim)).to(device)
x2 = tp(x1)  # warmup

# time statistics
loop = 10
torch.cuda.synchronize()
start = time.time()
for i in range(loop):
    x2 = tp(x1)
torch.cuda.synchronize()
end = time.time()
print( (end-start) / loop)

assert x2.shape == (50000, irreps_out.dim)

Image

Besides the performance acceleration of cuEquivariance in the MACE and DiffDock models, have you conducted a performance comparison of TensorProduct interface? Can you provide test scripts at TensorProduct interface level?

Thanks

@lmcl90
Copy link

lmcl90 commented Dec 16, 2024

I encountered the same issue too. I installed cuequivariance from source (commit id: 3557c31) and ran a series of tests on an A100 GPU. Here are the relevant details about my environment and test script:

# GPU info 
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 525.147.05   Driver Version: 525.147.05   CUDA Version: 12.0     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|===============================+======================+======================|
|   0  NVIDIA A100-SXM...  On   | 00000000:53:00.0 Off |                    0 |
| N/A   29C    P0    63W / 400W |      2MiB / 81920MiB |      0%      Default |
|                               |                      |             Disabled |
+-------------------------------+----------------------+----------------------+
# python packages

Package                       Version    Editable project location
----------------------------- ---------- ----------------------------------------------
cuequivariance                0.1.0      /workspace/cuequivariance/cuequivariance
cuequivariance-ops-torch-cu12 0.1.0
cuequivariance-torch          0.1.0      /workspace/cuequivariance/cuequivariance_torch
e3nn                          0.5.4
filelock                      3.16.1
fsspec                        2024.10.0
Jinja2                        3.1.4
MarkupSafe                    3.0.2
mpmath                        1.3.0
networkx                      3.4.2
numpy                         2.2.0
nvidia-cublas-cu12            12.4.5.8
nvidia-cuda-cupti-cu12        12.4.127
nvidia-cuda-nvrtc-cu12        12.4.127
nvidia-cuda-runtime-cu12      12.4.127
nvidia-cudnn-cu12             9.1.0.70
nvidia-cufft-cu12             11.2.1.3
nvidia-curand-cu12            10.3.5.147
nvidia-cusolver-cu12          11.6.1.9
nvidia-cusparse-cu12          12.3.1.170
nvidia-nccl-cu12              2.21.5
nvidia-nvjitlink-cu12         12.4.127
nvidia-nvtx-cu12              12.4.127
opt_einsum                    3.4.0
opt-einsum-fx                 0.1.4
packaging                     24.2
pip                           24.2
scipy                         1.14.1
setuptools                    75.1.0
sympy                         1.13.1
torch                         2.5.1
triton                        3.1.0
typing_extensions             4.12.2
wheel                         0.44.0
import logging
import statistics
from typing import Callable, Optional
import time
import sys

import cuequivariance as cue
import cuequivariance_torch as cuet
from e3nn import o3
import numpy as np
import torch
from torch.profiler import profile, record_function, ProfilerActivity

logging.basicConfig(level=logging.INFO)

@torch.no_grad()
def timing(repeat: int, warmup: int, fn: Callable, *args):
    warmup = 1 if warmup <= 0 else warmup
    for _ in range(warmup):   
        fn(*args)
    torch.cuda.synchronize()

    repeat = 1 if repeat <= 0 else repeat
    times = []
    for _ in range(repeat):
        start = time.monotonic_ns()
        fn(*args)
        torch.cuda.synchronize()
        end = time.monotonic_ns()
        times.append(end - start)
    ns_to_ms = 1e6
    print(f'min = {min(times) / ns_to_ms:.3f} ms')   
    print(f'max = {max(times) / ns_to_ms:.3f} ms')
    avg = statistics.mean(times)
    print(f'avg = { avg / ns_to_ms:.3f} ms')
    print(f'median = {statistics.median(times) / ns_to_ms:.3f} ms')
    print(f'stdev = {statistics.pstdev(times, avg) / ns_to_ms:.3f}')

# compare FullyConnectedTensorProduct
in_irreps_str = '32x0e'
out_irreps_str = '32x0e+8x1o'
sh_irreps_str = '1x0e+1x1o+1x2e'

in_size = [15496, 32]
sh_size = [15496, 9]
weight_size = [15496, 1280]
in_tensor = torch.randn(in_size).to('cuda')
sh_tensor = torch.randn(sh_size).to('cuda')
weight_tensor = torch.randn(weight_size).to('cuda')

warmup = 10
repeat = 1000

print('-----cue fc-----')
in_irreps = cue.Irreps(cue.O3, in_irreps_str)
sh_irreps = cue.Irreps(cue.O3, sh_irreps_str)
out_irreps = cue.Irreps(cue.O3, out_irreps_str)
in_layout = cuet.primitives.transpose.TransposeIrrepsLayout(
    in_irreps,
    source=cue.mul_ir,
    target=cue.ir_mul,
    device=in_tensor.device
)
sh_layout = cuet.primitives.transpose.TransposeIrrepsLayout(
    sh_irreps,
    source=cue.mul_ir,
    target=cue.ir_mul,
    device=sh_tensor.device
)
out_layout = cuet.primitives.transpose.TransposeIrrepsLayout(
    out_irreps,
    source=cue.ir_mul,
    target=cue.mul_ir,
    device=in_tensor.device
)
cu_tp = cuet.FullyConnectedTensorProduct(
    in_irreps, 
    sh_irreps, 
    out_irreps, 
    layout=cue.IrrepsLayout.ir_mul,
    shared_weights=False,
    device=in_tensor.device
    )
timing(repeat, warmup, cu_tp, in_tensor, sh_tensor, weight_tensor)
cu_out = cu_tp(in_layout(in_tensor), sh_layout(sh_tensor), weight_tensor)
cu_out = out_layout(cu_out)

print('-----e3nn fc-----')
e3nn_tp = o3.FullyConnectedTensorProduct(
    in_irreps_str, 
    sh_irreps_str, 
    out_irreps_str, 
    shared_weights=False)
timing(repeat, warmup, e3nn_tp, in_tensor, sh_tensor, weight_tensor)
e3nn_out = e3nn_tp(in_tensor, sh_tensor, weight_tensor)
assert torch.allclose(cu_out, e3nn_out, atol=1e-5)

# compare spherical_harmonics
lmax = 2
sh_irreps = o3.Irreps.spherical_harmonics(lmax)
ls = [0, 1, 2]

vec_size = [15496, 3]
vec_tensor = torch.randn(vec_size).to('cuda') * 10

repeat = 10
print('-----cue sh-----')
def spherical_harmonics(
    ls: list[int],
    vectors: torch.Tensor,
    normalize: bool = True,
    use_fallback: Optional[bool] = None,
    optimize_fallback: Optional[bool] = None,
) -> torch.Tensor:
    if isinstance(ls, int):
        ls = [ls]
    assert ls == sorted(set(ls))
    assert vectors.shape[-1] == 3

    if normalize:
        vectors = torch.nn.functional.normalize(vectors, dim=-1)

    x = vectors.reshape(-1, 3)
    m = cuet.EquivariantTensorProduct(
        cue.descriptors.spherical_harmonics(cue.O3(1, -1), ls),
        layout=cue.ir_mul,
        device=x.device,
        math_dtype=x.dtype,
        use_fallback=use_fallback,
        optimize_fallback=optimize_fallback,
    )
    y = m([x])
    y = y.reshape(vectors.shape[:-1] + (y.shape[-1],))
    return y

sh_in_layout = cuet.primitives.transpose.TransposeIrrepsLayout(
    cue.Irreps(cue.O3, '1o'),
    source=cue.mul_ir,
    target=cue.ir_mul,
    device=vec_tensor.device
)
sh_out_layout = cuet.primitives.transpose.TransposeIrrepsLayout(
    cue.Irreps(cue.O3, '1x0e+1x1o+1x2e'),
    source=cue.ir_mul,
    target=cue.mul_ir,
    device=vec_tensor.device
)
timing(repeat, warmup, spherical_harmonics, ls, vec_tensor, True, True, True)
cu_out = spherical_harmonics(ls, sh_in_layout(vec_tensor), True, True, True)
cu_out = sh_out_layout(cu_out)

print('-----e3nn sh-----')
timing(repeat, warmup, o3.spherical_harmonics, ls, vec_tensor, True, 'component')

e3nn_out = o3.spherical_harmonics(sh_irreps, vec_tensor, normalization='component', normalize=True)
e3nn_out_ls = o3.spherical_harmonics(ls, vec_tensor, normalization='component', normalize=True)

if not torch.allclose(cu_out, e3nn_out, atol=1e-5):
    print(cu_out)
    print(e3nn_out)
    print(cu_out - e3nn_out)

Upon executing the provided script, I observed the following results:

-----cue fc-----
min = 0.328 ms
max = 2.594 ms
avg = 1.313 ms
median = 0.578 ms
stdev = 0.994
-----e3nn fc-----
min = 0.225 ms
max = 2.853 ms
avg = 1.307 ms
median = 0.496 ms
stdev = 1.037
-----cue sh-----
min = 131.831 ms
max = 244.550 ms
avg = 147.421 ms
median = 137.989 ms
stdev = 32.505
-----e3nn sh-----
min = 0.329 ms
max = 2.470 ms
avg = 0.905 ms
median = 0.337 ms
stdev = 0.890

tensor([[ 1.0000,  2.3527,  1.1063,  ..., -0.1054,  1.1477,  0.5670],
        [ 1.0000,  1.6815, -0.4092,  ...,  2.1022, -0.3489,  1.0551],
        [ 1.0000,  1.1068, -0.2071,  ...,  1.5111, -0.9285,  1.9812],
        ...,
        [ 1.0000,  0.6210,  0.9093,  ..., -0.1088,  0.8024,  2.7458],
        [ 1.0000, -0.6842,  0.6020,  ...,  0.0590,  1.0362, -0.8278],
        [ 1.0000,  1.4467,  0.8359,  ..., -0.0879,  0.6471,  2.6615]],
       device='cuda:0')
tensor([[ 1.0000,  1.3527,  0.1063,  ..., -1.1054,  0.1477, -0.4330],
        [ 1.0000,  0.6815, -1.4092,  ...,  1.1022, -1.3489,  0.0551],
        [ 1.0000,  0.1068, -1.2071,  ...,  0.5111, -1.9285,  0.9812],
        ...,
        [ 1.0000, -0.3790, -0.0907,  ..., -1.1088, -0.1976,  1.7458],
        [ 1.0000, -1.6842, -0.3980,  ..., -0.9410,  0.0362, -1.8278],
        [ 1.0000,  0.4467, -0.1641,  ..., -1.0879, -0.3529,  1.6615]],
       device='cuda:0')
tensor([[0.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
        [0.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
        [0.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
        ...,
        [0.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
        [0.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
        [0.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000]],
       device='cuda:0')

It appears that there are no CUDA kernels being utilized for the spherical harmonics operation in cuEquivariance, leading to significantly longer execution times compared to e3nn. Is there any planned update or fix to address this inefficiency?

Additionally, I noticed that the results from cuet.spherical_harmonics do not match those from e3nn.o3.spherical_harmonics. Could you provide insight into why these discrepancies occur? Ensuring consistency between implementations is crucial for our applications.

@mariogeiger
Copy link
Collaborator

Hi,

We are only providing some custom CUDA kernels for somehow limited scope of tensor products. We aim to progressively add more custom kernels.

To better answer you I'm working on adding a page to the documentation to list the custom kernels we have and which cases they are useful for. See #48

When our kernels are not providing good performance please use the argument use_fallback=True to use a similar implementation as e3nn.

@hiranumn
Copy link

Hi @mariogeiger and others,

First of all, thank you for implementing and sharing cu-equivariance. The flexibility it has on different sorts of irreps is amazing.

I am mainly interested in using this for fully connected TP (inputs are both 32x0e+32x1e+32x2e) and, I am also not seeing performance improvement vs e3nn, similar to others. My guess is that this is expected for now because underlying custom kernels are only meant for tensor products which have small operand sizes?

Also for the "baseline" implementations of DiffDock and MACE, is it referring to the original implementation by the authors?

Thanks and looking forward to seeing more custom kernels!

@mariogeiger
Copy link
Collaborator

You mean this operation, right?

irreps = cue.Irreps("O3", "32x0e + 32x1e + 32x2e")
e = cue.descriptors.fully_connected_tensor_product(
    irreps, irreps, irreps
).flatten_coefficient_modes()

And yes I think this operation has too big inputs to fit in the shared memory and that's why our FusedTensorProductOp4 is slow.
@stadlmax can you confirm we are right? I obtain this size: sum(ope.size for ope in e.d.operands) * torch.float32.itemsize) = 1969536

Yes by "baseline" implementation we mean the original implementation by the authors.

@hiranumn
Copy link

Here is the dimension of the irreps that I am tinkering with.

in_irreps1 = cue.Irreps("SO3", "32x0 + 32x1 + 32x2")
in_irreps2 = cue.Irreps("SO3", "32x0 + 32x1 + 32x2")
out_irreps = cue.Irreps("SO3", "16x0 + 16x1 + 16x2")

If you could also let me know the upper bound of sum(ope.size for ope in e.d.operands) * torch.float32.itemsize) for FusedTensorProductOp4 to work well, that would be fantastic!

@mariogeiger
Copy link
Collaborator

The upper bound is the shared memory. For instance it's about 256KB for Hopper (source).

For instance

in_irreps1 = cue.Irreps("SO3", "16x0 + 16x1 + 16x2")
in_irreps2 = cue.Irreps("SO3", "16x0 + 16x1 + 16x2")
out_irreps = cue.Irreps("SO3", "16x0 + 16x1 + 16x2")
d = cue.descriptors.fully_connected_tensor_product(in_irreps1, in_irreps2, out_irreps).d
dtype = torch.float32
sum(ope.size for ope in d.operands) * dtype.itemsize / 1024
# 241.6875

should just fit on a Hopper GPU. (I didn't test it but that's my understanding)

@becca
Copy link
Collaborator

becca commented Jan 15, 2025

@LHJ1098826475 👋 Hello from the Product team on cuEquivariance. Can you please share what you are using cuEquivariance for?

@LHJ1098826475
Copy link
Author

Thank you for your team's concern.

We are conducting our research based on the Allegro model. We hope to understand cuEqu's optimization to accelerate the model.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

5 participants