Skip to content

Commit

Permalink
Handle ctrl special case explicitly in bind().
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 714055763
Change-Id: I020008723af4a590926268a2ea615af54e20aa0b
  • Loading branch information
quagla authored and copybara-github committed Jan 10, 2025
1 parent 3f32cc2 commit da32b0d
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 11 deletions.
19 changes: 10 additions & 9 deletions mjx/mujoco/mjx/_src/support.py
Original file line number Diff line number Diff line change
Expand Up @@ -404,14 +404,17 @@ def __init__(self, data: Data, model: Model, specs: Sequence[Any]):
self.prefix = 'cam_'
ids.append(name2id(model, mujoco.mjtObj.mjOBJ_CAMERA, spec.name))
case mujoco.MjsTendon():
self.prefix = 'tendon_'
self.prefix = 'ten_'
ids.append(name2id(model, mujoco.mjtObj.mjOBJ_TENDON, spec.name))
case mujoco.MjsActuator():
self.prefix = 'actuator_'
ids.append(name2id(model, mujoco.mjtObj.mjOBJ_ACTUATOR, spec.name))
case mujoco.MjsSensor():
self.prefix = 'sensor_'
ids.append(name2id(model, mujoco.mjtObj.mjOBJ_SENSOR, spec.name))
case mujoco.MjsEquality():
self.prefix = 'eq_'
ids.append(name2id(model, mujoco.mjtObj.mjOBJ_EQUALITY, spec.name))
case _:
raise ValueError('invalid spec type')
if len(ids) == 1:
Expand All @@ -420,15 +423,13 @@ def __init__(self, data: Data, model: Model, specs: Sequence[Any]):
self.id = ids

def __getname(self, name: str):
try:
getattr(self.data, self.prefix + name)
return self.prefix + name
except AttributeError:
try:
getattr(self.data, name)
if name == 'ctrl':
if self.prefix == 'actuator_':
return name
except AttributeError as e:
raise ValueError(f'invalid name: {name}') from e
else:
raise AttributeError('ctrl is not available for this type')
else:
return self.prefix + name

def __getattr__(self, name: str):
return getattr(self.data, self.__getname(name))[self.id, ...]
Expand Down
6 changes: 4 additions & 2 deletions mjx/mujoco/mjx/_src/support_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,9 +238,11 @@ def test_bind(self):
np.testing.assert_array_equal(dx.bind(mx, s.actuators).ctrl, [0, 0, 0])

# test invalid name
with self.assertRaises(ValueError):
with self.assertRaises(AttributeError):
print(dx.bind(mx, s.geoms).ctrl)
with self.assertRaises(AttributeError):
print(dx.bind(mx, s.actuators).actuator_ctrl)
with self.assertRaises(ValueError):
with self.assertRaises(AttributeError):
print(dx.bind(mx, s.actuators).set('actuator_ctrl', [1, 2, 3]))

_CONTACTS = """
Expand Down

0 comments on commit da32b0d

Please sign in to comment.