Skip to content

Commit

Permalink
Add ones() and ones_like() with tests.
Browse files Browse the repository at this point in the history
  • Loading branch information
ameya98 committed Jan 5, 2024
1 parent af22024 commit 5605097
Show file tree
Hide file tree
Showing 3 changed files with 54 additions and 2 deletions.
4 changes: 4 additions & 0 deletions e3nn_jax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,8 @@
as_irreps_array,
zeros,
zeros_like,
ones,
ones_like,
concatenate,
stack,
mean,
Expand Down Expand Up @@ -176,6 +178,8 @@
"as_irreps_array",
"zeros",
"zeros_like",
"ones",
"ones_like",
"concatenate",
"stack",
"mean",
Expand Down
18 changes: 17 additions & 1 deletion e3nn_jax/_src/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,9 @@ def as_irreps_array(array: Union[jnp.ndarray, e3nn.IrrepsArray], *, backend=None
return e3nn.IrrepsArray(f"{array.shape[-1]}x0e", array)


def zeros(irreps: IntoIrreps, leading_shape, dtype=None) -> e3nn.IrrepsArray:
def zeros(
irreps: IntoIrreps, leading_shape: Tuple = (), dtype: jnp.dtype = None
) -> e3nn.IrrepsArray:
r"""Create an IrrepsArray of zeros."""
irreps = e3nn.Irreps(irreps)
array = jnp.zeros(leading_shape + (irreps.dim,), dtype=dtype)
Expand All @@ -113,6 +115,20 @@ def zeros_like(irreps_array: e3nn.IrrepsArray) -> e3nn.IrrepsArray:
return e3nn.zeros(irreps_array.irreps, irreps_array.shape[:-1], irreps_array.dtype)


def ones(
irreps: IntoIrreps, leading_shape: Tuple = (), dtype: jnp.dtype = None
) -> e3nn.IrrepsArray:
r"""Create an IrrepsArray of ones."""
irreps = e3nn.Irreps(irreps)
array = jnp.ones(leading_shape + (irreps.dim,), dtype=dtype)
return e3nn.IrrepsArray(irreps, array, zero_flags=(False,) * len(irreps))


def ones_like(irreps_array: e3nn.IrrepsArray) -> e3nn.IrrepsArray:
r"""Create an IrrepsArray of ones with the same shape as another IrrepsArray."""
return e3nn.ones(irreps_array.irreps, irreps_array.shape[:-1], irreps_array.dtype)


def _align_two_irreps(
irreps1: e3nn.Irreps, irreps2: e3nn.Irreps
) -> Tuple[e3nn.Irreps, e3nn.Irreps]:
Expand Down
34 changes: 33 additions & 1 deletion tests/_src/basic_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,42 @@


def assert_array_equals_chunks(x: e3nn.IrrepsArray):
y = e3nn.from_chunks(x.irreps, x.chunks, x.shape[:-1])
y = e3nn.from_chunks(x.irreps, x.chunks, x.shape[:-1], x.dtype)
np.testing.assert_array_equal(x.array, y.array)


def test_zeros():
x = e3nn.zeros("0e + 1e", leading_shape=(3, 5))
assert jnp.all(x.array == 0)
assert x.shape == (3, 5, 4)
assert x.irreps == "0e + 1e"
assert_array_equals_chunks(x)


def test_zeros_like():
x = e3nn.ones("0e + 1e", leading_shape=(3, 5))
y = e3nn.zeros_like(x)
assert jnp.all(y.array == 0)
assert y.shape == x.shape
assert y.irreps == x.irreps


def test_ones():
x = e3nn.ones("0e + 1e", leading_shape=(3, 5))
assert jnp.all(x.array == 1)
assert x.shape == (3, 5, 4)
assert x.irreps == "0e + 1e"
assert_array_equals_chunks(x)


def test_ones_like():
x = e3nn.zeros("0e + 1e", leading_shape=(3, 5))
y = e3nn.ones_like(x)
assert jnp.all(y.array == 1)
assert y.shape == x.shape
assert y.irreps == x.irreps


def test_concatenate1(keys):
x1 = e3nn.normal("0e + 1e", keys[0], (3,))
x2 = e3nn.normal("0e + 1e", keys[0], (2,))
Expand Down

0 comments on commit 5605097

Please sign in to comment.