Skip to content

Commit

Permalink
Undo accidental change.
Browse files Browse the repository at this point in the history
  • Loading branch information
ameya98 committed Jan 4, 2024
1 parent 973eedf commit 22c139a
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 9 deletions.
6 changes: 5 additions & 1 deletion e3nn_jax/_src/s2grid.py
Original file line number Diff line number Diff line change
Expand Up @@ -1229,9 +1229,13 @@ def m0_values_to_irrepsarray(m0_values, lmax, p_val, p_arg) -> e3nn.IrrepsArray:
lmax + 1
)
irreps = s2_irreps(lmax, p_val, p_arg)
m0 = jnp.zeros((*m0_values.shape[:-1], (lmax + 1) ** 2))
m0 = m0.at[:, m0_indices].set(m0_values)
return e3nn.IrrepsArray(
irreps,
jnp.zeros((lmax + 1) ** 2).at[m0_indices].set(m0_values),
jnp.zeros((*m0_values.shape[:-1], (lmax + 1) ** 2))
.at[:, m0_indices]
.set(m0_values),
)


Expand Down
11 changes: 3 additions & 8 deletions tests/_src/s2grid_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,17 +76,12 @@ def test_legendre_transforms(keys, lmax, p_val, p_arg, quadrature, fft_to, batch
m0_indices = jnp.cumsum(jnp.repeat(jnp.arange(lmax + 1), 2))[::2] + jnp.arange(
lmax + 1
)
np.testing.assert_allclose(
a.array[m0_indices],
res_m0,
rtol=1e-5,
atol=1e-5,
)
np.testing.assert_allclose(a.array[:, m0_indices], res_m0, rtol=1e-5, atol=1e-5)
irrepsarray_m0 = m0_values_to_irrepsarray(res_m0, lmax, p_val, p_arg)
assert a.irreps == irrepsarray_m0.irreps
np.testing.assert_allclose(
a.array[m0_indices],
irrepsarray_m0.array[m0_indices],
a.array[:, m0_indices],
irrepsarray_m0.array[:, m0_indices],
rtol=1e-5,
atol=1e-5,
)
Expand Down

0 comments on commit 22c139a

Please sign in to comment.