Skip to content

Commit

Permalink
Simplify pendula test data. Fix bug in scan.body_tree that led to…
Browse files Browse the repository at this point in the history
… incorrect smooth dynamics for some kinematic tree layouts.

PiperOrigin-RevId: 582125477
Change-Id: I47505ceb4be4c88052bc670f401b8c45ff320bbd
  • Loading branch information
erikfrey authored and copybara-github committed Nov 14, 2023
1 parent 97ad543 commit c814637
Show file tree
Hide file tree
Showing 16 changed files with 266 additions and 257 deletions.
1 change: 1 addition & 0 deletions doc/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ MJX
- Fixed bug where mixed ``jnt_limited`` joints were not being constrained correctly.
- Made ``device_put`` type validation more verbose (fixes :github:issue:`1113`).
- Removed empty EFC rows from ``MJX``, for joints with no limits (fixes :github:issue:`1117`).
- Fixed bug in ``scan.body_tree`` that led to incorrect smooth dynamics for some kinematic tree layouts.

Python bindings
^^^^^^^^^^^^^^^
Expand Down
4 changes: 2 additions & 2 deletions mjx/mujoco/mjx/_src/constraint.py
Original file line number Diff line number Diff line change
Expand Up @@ -321,10 +321,10 @@ def count_constraints(m: Model, d: Data) -> Tuple[int, int, int, int]:
if m.opt.disableflags & DisableBit.EQUALITY:
ne = 0
else:
ne_weld = (m.eq_type == EqType.WELD).sum()
ne_connect = (m.eq_type == EqType.CONNECT).sum()
ne_weld = (m.eq_type == EqType.WELD).sum()
ne_joint = (m.eq_type == EqType.JOINT).sum()
ne = ne_weld * 6 + ne_connect * 3 + ne_joint
ne = ne_connect * 3 + ne_weld * 6 + ne_joint

nf = 0

Expand Down
2 changes: 1 addition & 1 deletion mjx/mujoco/mjx/_src/constraint_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def _assert_eq(a, b, name, step, fname, atol=5e-3, rtol=5e-3):
class ConstraintTest(parameterized.TestCase):

@parameterized.parameters(enumerate(test_util.TEST_FILES))
def testconstraints(self, seed, fname):
def test_constraints(self, seed, fname):
"""Test constraints."""
np.random.seed(seed)

Expand Down
88 changes: 55 additions & 33 deletions mjx/mujoco/mjx/_src/forward_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,6 @@
# ==============================================================================
"""Tests for forward functions."""

import itertools

from absl.testing import absltest
from absl.testing import parameterized
import jax
Expand All @@ -38,13 +36,12 @@ def _assert_attr_eq(a, b, attr, step, fname, atol=1e-3, rtol=1e-3):

class ForwardTest(parameterized.TestCase):

@parameterized.parameters(enumerate(test_util.TEST_FILES))
def test_forward(self, seed, fname):
@parameterized.parameters(
filter(lambda s: s not in ('equality.xml',), test_util.TEST_FILES)
)
def test_forward(self, fname):
"""Test mujoco mj forward function matches mujoco_mjx forward function."""
if fname in ('equality.xml',):
return

np.random.seed(seed)
np.random.seed(test_util.TEST_FILES.index(fname))

m = test_util.load_test_file(fname)
d = mujoco.MjData(m)
Expand All @@ -62,36 +59,61 @@ def test_forward(self, seed, fname):
_assert_attr_eq(d, dx, 'qfrc_smooth', i, fname)
_assert_attr_eq(d, dx, 'qacc_smooth', i, fname)

@parameterized.parameters(itertools.product(test_util.TEST_FILES, (0, 1)))
def test_step(self, fname, integrator_type):
@parameterized.parameters(
filter(lambda s: s not in ('equality.xml',), test_util.TEST_FILES)
)
def test_step(self, fname):
"""Test mujoco mj step matches mujoco_mjx step."""
if fname in (
'mixed_joint_pendulum.xml',
'ball_pendulum.xml',
'convex.xml',
'humanoid.xml',
'triple_pendulum.xml', # TODO(b/301485081)
'equality.xml',
):
# skip models with big constraint violations at step 0 or too slow to run
return

np.random.seed(integrator_type)
np.random.seed(test_util.TEST_FILES.index(fname))
m = test_util.load_test_file(fname)
step_jit_fn = jax.jit(forward.step)

m.opt.integrator = integrator_type
int_typ = 'euler' if integrator_type == 0 else 'rk4'
test_name = f'{fname} - {int_typ}'
steps = 100 if int_typ == 'euler' else 30
dt = m.opt.timestep
m.opt.timestep = dt if int_typ == 'euler' else dt * 3
mx = mjx.device_put(m)
d = mujoco.MjData(m)
# give the system a little kick to ensure we have non-identity rotations
d.qvel = np.random.normal(m.nv) * 0.05
for i in range(100):
# in order to avoid re-jitting, reuse the same mj_data shape
qpos, qvel = d.qpos, d.qvel
d = mujoco.MjData(m)
d.qpos, d.qvel = qpos, qvel
dx = mjx.device_put(d)

mujoco.mj_step(m, d)
dx = step_jit_fn(mx, dx)

_assert_attr_eq(d, dx, 'qvel', i, fname, atol=1e-2)
_assert_attr_eq(d, dx, 'qpos', i, fname, atol=1e-2)
_assert_attr_eq(d, dx, 'act', i, fname)
_assert_attr_eq(d, dx, 'time', i, fname)

def test_rk4(self):
m = mujoco.MjModel.from_xml_string("""
<mujoco>
<option integrator="RK4">
<flag constraint="disable"/>
</option>
<worldbody>
<light pos="0 0 1"/>
<geom type="plane" size="1 1 .01" pos="0 0 -1"/>
<body pos="0.15 0 0">
<joint type="hinge" axis="0 1 0"/>
<geom type="capsule" size="0.02" fromto="0 0 0 .1 0 0"/>
<body pos="0.1 0 0">
<joint type="slide" axis="1 0 0" stiffness="200"/>
<geom type="capsule" size="0.015" fromto="-.1 0 0 .1 0 0"/>
</body>
</body>
</worldbody>
</mujoco>
""")
step_jit_fn = jax.jit(forward.step)

mx = mjx.device_put(m)
d = mujoco.MjData(m)
# give the system a little kick to ensure we have non-identity rotations
d.qvel = np.random.normal(m.nv) * 0.05
for i in range(steps):
for i in range(100):
# in order to avoid re-jitting, reuse the same mj_data shape
qpos, qvel = d.qpos, d.qvel
d = mujoco.MjData(m)
Expand All @@ -101,10 +123,10 @@ def test_step(self, fname, integrator_type):
mujoco.mj_step(m, d)
dx = step_jit_fn(mx, dx)

_assert_attr_eq(d, dx, 'qvel', i, test_name, atol=1e-2)
_assert_attr_eq(d, dx, 'qpos', i, test_name, atol=1e-2)
_assert_attr_eq(d, dx, 'act', i, test_name)
_assert_attr_eq(d, dx, 'time', i, test_name)
_assert_attr_eq(d, dx, 'qvel', i, 'test_rk4', atol=1e-2)
_assert_attr_eq(d, dx, 'qpos', i, 'test_rk4', atol=1e-2)
_assert_attr_eq(d, dx, 'act', i, 'test_rk4')
_assert_attr_eq(d, dx, 'time', i, 'test_rk4')

def test_disable_eulerdamp(self):
m = test_util.load_test_file('ant.xml')
Expand Down
6 changes: 3 additions & 3 deletions mjx/mujoco/mjx/_src/passive_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,15 +26,15 @@
import numpy as np


def _assert_attr_eq(a, b, attr, step, fname, atol=1e-5, rtol=1e-5):
def _assert_attr_eq(a, b, attr, step, fname, atol=1e-4, rtol=1e-4):
err_msg = f'mismatch: {attr} at step {step} in {fname}'
a, b = getattr(a, attr), getattr(b, attr)
np.testing.assert_allclose(a, b, err_msg=err_msg, atol=atol, rtol=rtol)


class PassiveTest(parameterized.TestCase):

@parameterized.parameters(enumerate(('ant.xml', 'mixed_joint_pendulum.xml')))
@parameterized.parameters(enumerate(('ant.xml', 'pendula.xml')))
def test_stiffness_damping(self, seed, fname):
"""Tests stiffness and damping on Ant."""
np.random.seed(seed)
Expand All @@ -60,7 +60,7 @@ def test_stiffness_damping(self, seed, fname):
_assert_attr_eq(d, dx, 'qfrc_passive', i, fname)

@parameterized.parameters(
itertools.product(range(3), ('triple_pendulum.xml',))
itertools.product(range(3), ('pendula.xml',))
)
def test_fluid(self, seed, fname):
np.random.seed(seed)
Expand Down
158 changes: 96 additions & 62 deletions mjx/mujoco/mjx/_src/scan.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@ def flat(
) -> Y:
r"""Scan a function across bodies or actuators.
Scan group data according to type and batch shape then calls vmap(f) on it.\
Scan group data according to type and batch shape then calls vmap(f) on it.
Args:
m: an mjx model
Expand Down Expand Up @@ -340,48 +340,88 @@ def f(y, *args) -> y
IndexError: if function output shape does not match out_types shape
"""
_check_input(m, args, in_types)
depth_fn = lambda i, p=m.body_parentid: int(i > 0) and 1 + depth_fn(p[i])
typ_body_id = {
'j': m.jnt_bodyid,
'v': m.dof_bodyid,
'q': _q_bodyid(m),
}
key_parents = {}

# build up groupings of bodies and type ids using (level, (jnt_type,)) keys
key_typ_ids, key_body_ids = {}, {}
for body_id in np.arange(m.nbody, dtype=np.int32):
depth = depth_fn(body_id)

# create grouping key
if any(t in 'jqv' for t in in_types + out_types):
jnts = np.nonzero(typ_body_id['j'] == body_id)[0]
jnts_p = np.nonzero(typ_body_id['j'] == m.body_parentid[body_id])[0]
key = depth, tuple(m.jnt_type[jnts])
parent_key = depth - 1, tuple(m.jnt_type[jnts_p])
else:
key, parent_key = (depth, ()), (depth - 1, ())
# group together bodies that will be processed together. grouping key:
# 1) the tree depth: parent bodies are processed first, so that they are
# available as carry input to child bodies (or reverse if reverse=True)
# 2) the types of arguments passed to f, both carry and *args:
# * for 'b' arguments, there is no extra grouping
# * for 'j' arguments, we group by joint type
# * for 'q' arguments, we group by q width
# * for 'v' arguments, we group by dof width
depths = np.zeros(m.nbody, dtype=np.int32)

# map key => body id
key_body_ids = {}
for body_id in range(m.nbody):
parent_id = -1
if body_id > 0:
parent_id = m.body_parentid[body_id]
depths[body_id] = 1 + depths[parent_id]

# create grouping key: depth, carry, args
key = (depths[body_id],)

for i, t in enumerate(out_types + in_types):
id_ = parent_id if i < len(out_types) else body_id
if t == 'b':
continue
elif t == 'j':
key += (tuple(m.jnt_type[np.nonzero(m.jnt_bodyid == id_)[0]]))
elif t == 'v':
key += (len(np.nonzero(m.dof_bodyid == id_)[0]),)
elif t == 'q':
key += (len(np.nonzero(_q_bodyid(m) == id_)[0]),)

key_parents[key] = parent_key
body_ids = key_body_ids.get(key, np.array([], dtype=np.int32))
key_body_ids[key] = np.append(body_ids, body_id)

# add ids per type
for t in set(in_types + out_types):
out = key_typ_ids.setdefault(key, {})
id_ = body_id if t == 'b' else np.nonzero(typ_body_id[t] == body_id)[0]
id_ = np.expand_dims(id_, axis=0)
out[t] = np.concatenate((out[t], id_)) if t in out else id_
# find parent keys of each key. a key may have multiple parents if the
# carry output keys of distinct parents are the same. e.g.:
# - depth 0 body 1 (slide joint)
# -- depth 1 body 1 (hinge joint)
# - depth 0 body 2 (ball joint)
# -- depth 1 body 2 (hinge joint)
# given a scan with 'j' in the in_types, we would group depth 0 bodies
# separately but we may group depth 1 bodies together
key_parents = {}

key_typ_ids = list(sorted(key_typ_ids.items(), reverse=reverse))
for key, body_ids in key_body_ids.items():
body_ids = body_ids[body_ids != 0] # ignore worldbody, has no parent
if body_ids.size == 0:
continue
# find any key which has a body id that is a parent of these body_ids
pids = m.body_parentid[body_ids]
parents = {k for k, v in key_body_ids.items() if np.isin(v, pids).any()}
key_parents[key] = list(sorted(parents))

# key => take indices
key_in_take, key_y_take = {}, {}
for key, body_ids in key_body_ids.items():
for i, typ in enumerate(in_types + out_types):
if typ == 'b':
ids = body_ids
elif typ == 'j':
ids = np.stack([np.nonzero(m.jnt_bodyid == b)[0] for b in body_ids])
elif typ == 'v':
ids = np.stack([np.nonzero(m.dof_bodyid == b)[0] for b in body_ids])
elif typ == 'q':
ids = np.stack([np.nonzero(_q_bodyid(m) == b)[0] for b in body_ids])
else:
raise ValueError(f'Unknown in_type: {typ}')
if i < len(in_types):
key_in_take.setdefault(key, []).append(ids)
else:
key_y_take.setdefault(key, []).append(np.hstack(ids))

# use this grouping to take the right data subsets and call vmap(f)
keys = sorted(key_body_ids, reverse=reverse)
key_y = {}
for key, typ_ids in key_typ_ids:
for key in keys:
carry = None

if reverse:
child_keys = [k for k, v in key_parents.items() if v == key]
child_keys = [k for k, v in key_parents.items() if key in v]

for child_key in child_keys:
y = key_y[child_key]
Expand All @@ -394,39 +434,33 @@ def index_sum(x, i=id_map, s=body_ids.size):

y = jax.tree_map(index_sum, y)
carry = y if carry is None else jax.tree_map(jp.add, carry, y)
else:
parent_key = key_parents[key]
y = key_y.get(parent_key)

if y is not None:
body_ids = key_body_ids[parent_key]
parent_ids = m.body_parentid[key_body_ids[key]]
take_fn = lambda x, i=_index(body_ids, parent_ids): _take(x, i)
carry = jax.tree_map(take_fn, y)

f_args = [_take(arg, typ_ids[typ]) for arg, typ in zip(args, in_types)]
elif key in key_parents:
ys = [key_y[p] for p in key_parents[key]]
y = jax.tree_map(lambda *x: jp.concatenate(x), *ys)
body_ids = np.concatenate([key_body_ids[p] for p in key_parents[key]])
parent_ids = m.body_parentid[key_body_ids[key]]
take_fn = lambda x, i=_index(body_ids, parent_ids): _take(x, i)
carry = jax.tree_map(take_fn, y)

f_args = [_take(arg, ids) for arg, ids in zip(args, key_in_take[key])]
key_y[key] = _nvmap(f, carry, *f_args)

# slice None results from the final output
key_typ_ids = [(k, v) for k, v in key_typ_ids if key_y[k] is not None]

# concatenate back to a single tree and drop the grouping dimension
ys = [key_y[key] for key, _ in key_typ_ids]
f_ret_is_seq = isinstance(ys[0], (list, tuple))
ys = ys if f_ret_is_seq else [[y] for y in ys]
ys = [
[v if typ == 'b' else jp.concatenate(v) for v, typ in zip(y, out_types)]
for y in ys
]
ys = jax.tree_map(lambda *x: jp.concatenate(x), *ys)

# put concatenated results back into body order
reordered_ys = []
for i, (y, typ) in enumerate(zip(ys, out_types)):
ids = np.concatenate([np.hstack(v[typ]) for _, v in key_typ_ids])
take_ids = _index(ids, np.sort(ids))
_check_output(y, take_ids, typ, i)
reordered_ys.append(_take(y, take_ids))
y = reordered_ys if f_ret_is_seq else reordered_ys[0]
keys = [k for k in keys if key_y[k] is not None]

# concatenate ys, drop grouping dimensions, put back in order
y = []
for i, typ in enumerate(out_types):
y_typ = [key_y[key] for key in keys]
if len(out_types) > 1:
y_typ = [y_[i] for y_ in y_typ]
if typ != 'b':
y_typ = jax.tree_map(jp.concatenate, y_typ)
y_typ = jax.tree_map(lambda *x: jp.concatenate(x), *y_typ)
y_take = np.argsort(np.concatenate([key_y_take[key][i] for key in keys]))
_check_output(y_typ, y_take, typ, i)
y.append(_take(y_typ, y_take))

y = y[0] if len(out_types) == 1 else y

return y
Loading

0 comments on commit c814637

Please sign in to comment.