-
Currently I have a piece of code that looks like so (using the e3nn's PyTorch version): class WrappedFCTensorProduct(nn.Module):
def __init__(self, irreps_in1, irreps_in2, irreps_out, **kwargs):
super().__init__()
irreps_in1 = Irreps(irreps_in1)
irreps_in2 = Irreps(irreps_in2)
irreps_out = Irreps(irreps_out)
self.tp = FullyConnectedTensorProduct(irreps_in1, irreps_in2, irreps_out, internal_weights=False, shared_weights=True, **kwargs)
self.weights = nn.Parameter(torch.ones(self.tp.weight_numel))
def forward(self, x1, x2):
return self.tp(x1, x2, self.weights) I'm not aware if there's a better way of doing this, but it's what exists now. I'm trying to move this to e3nn-jax. I don't see an equivalent for My (currently broken and not running) attempt at this has been like so: import torch
from torch import nn
from e3nn.o3 import Irreps, FullyConnectedTensorProduct
class WrappedFCTensorProduct(nn.Module):
irreps_in1 = Irreps('1x0e + 1x1e')
irreps_in2 = Irreps('1x0e + 1x1e')
full_irreps_out = e.tensor_product(irreps_in1, irreps_in2)
irreps_out = Irreps('1x0e + 1x1e')
linear = Linear(irreps_out, irreps_in=full_irreps_out)
def __call__(self, irarray_in1, irarray_in2):
out = e.tensor_product(irarray_in1, irarray_in2, filter_ir_out=self.irreps_out)
return self.linear(out) The call and the corresponding error messages are like so: import jax
import jax.numpy as jnp
import flax.linen as nn
from e3nn_jax.flax import Linear
import e3nn_jax as e
from e3nn_jax import Irreps, IrrepsArray
tp = WrappedFCTensorProduct()
in1 = IrrepsArray('1x0e + 1x1e', jnp.array([-1.5074, 0.8150, -1.8354, -1.2662]))
in2 = IrrepsArray('1x0e + 1x1e', jnp.array([ 0.5179, 0.2734, -1.5536, 0.5348]))
params = tp.init(jax.random.PRNGKey(0), in1, in2) which gives me:
I've looked up the corresponding error in the flax documentation, but unfortunately wasn't able to find a way around it (mostly since I am still getting used to the structure of Flax). I'd appreciate any pointers in this direction. The end goal is an implementation of the FullyConnectedTensorProduct as a flax-based module (or as close as we can get to it with a reasonable amount of effort). EDIT: Just saw #39 so mentioning here that I'm not really stuck to Flax and in principle would be okay discussing this as an Equinox implementation as well. |
Beta Was this translation helpful? Give feedback.
Replies: 2 comments 5 replies
-
Try to add @nn.compact before the call method See https://flax.readthedocs.io/en/latest/guides/flax_fundamentals/setup_or_nncompact.html Yes we could easily add Equinox support for Linear. @ameya98 had a PR for that, I wonder what happened with that |
Beta Was this translation helpful? Give feedback.
-
I think you need something like this: class WrappedFCTensorProduct(nn.Module):
irreps_out: e3nn.Irreps
@nn.compact
def __call__(self, input_1: e3nn.IrrepsArray, input_2: e3nn.IrrepsArray) -> e3nn.IrrepsArray:
output = e3nn.tensor_product(input_1, input_2)
output = e3nn.flax.Linear(irreps_out=self.irreps_out)(output)
return output and you can initialize and call the module like this: input_1 = e3nn.IrrepsArray("0e + 1e + 2e + 3e", jnp.ones(16))
input_2 = e3nn.IrrepsArray("0e + 1e + 2e + 3e", jnp.ones(16))
tp = WrappedFCTensorProduct(irreps_out="2x0e + 5x1e + 8x2e + 11x3e")
params = tp.init(jax.random.PRNGKey(0), input_1, input_2)
output = tp.apply(params, input_1, input_2) Take a look at my notebook here: https://gist.github.com/ameya98/bf21a3b43ed3d6e02526eb8289a9895a |
Beta Was this translation helpful? Give feedback.
I think you need something like this:
and you can initialize and call the module like this: