diff --git a/mjx/mujoco/mjx/_src/support.py b/mjx/mujoco/mjx/_src/support.py index 748be165fa..aeb0f7ed85 100644 --- a/mjx/mujoco/mjx/_src/support.py +++ b/mjx/mujoco/mjx/_src/support.py @@ -296,71 +296,70 @@ def __init__(self, model: Model, specs: Sequence[Any]): specs = [specs] ids = [] for spec in specs: - match spec: - case mujoco.MjsBody(): - self.prefix = 'body_' - idx = name2id(model, mujoco.mjtObj.mjOBJ_BODY, spec.name) - case mujoco.MjsJoint(): - self.prefix = 'jnt_' - idx = name2id(model, mujoco.mjtObj.mjOBJ_JOINT, spec.name) - case mujoco.MjsGeom(): - self.prefix = 'geom_' - idx = name2id(model, mujoco.mjtObj.mjOBJ_GEOM, spec.name) - case mujoco.MjsSite(): - self.prefix = 'site_' - idx = name2id(model, mujoco.mjtObj.mjOBJ_SITE, spec.name) - case mujoco.MjsLight(): - self.prefix = 'light_' - idx = name2id(model, mujoco.mjtObj.mjOBJ_LIGHT, spec.name) - case mujoco.MjsCamera(): - self.prefix = 'cam_' - idx = name2id(model, mujoco.mjtObj.mjOBJ_CAMERA, spec.name) - case mujoco.MjsMesh(): - self.prefix = 'mesh_' - idx = name2id(model, mujoco.mjtObj.mjOBJ_MESH, spec.name) - case mujoco.MjsHField(): - self.prefix = 'hfield_' - idx = name2id(model, mujoco.mjtObj.mjOBJ_HFIELD, spec.name) - case mujoco.MjsPair(): - self.prefix = 'pair_' - idx = name2id(model, mujoco.mjtObj.mjOBJ_PAIR, spec.name) - case mujoco.MjsTendon(): - self.prefix = 'tendon_' - idx = name2id(model, mujoco.mjtObj.mjOBJ_TENDON, spec.name) - case mujoco.MjsActuator(): - self.prefix = 'actuator_' - idx = name2id(model, mujoco.mjtObj.mjOBJ_ACTUATOR, spec.name) - case mujoco.MjsSensor(): - self.prefix = 'sensor_' - idx = name2id(model, mujoco.mjtObj.mjOBJ_SENSOR, spec.name) - case mujoco.MjsNumeric(): - self.prefix = 'numeric_' - idx = name2id(model, mujoco.mjtObj.mjOBJ_NUMERIC, spec.name) - case mujoco.MjsText(): - self.prefix = 'text_' - idx = name2id(model, mujoco.mjtObj.mjOBJ_TEXT, spec.name) - case mujoco.MjsTuple(): - self.prefix = 'tuple_' - idx = name2id(model, mujoco.mjtObj.mjOBJ_TUPLE, spec.name) - case mujoco.MjsKey(): - self.prefix = 'key_' - idx = name2id(model, mujoco.mjtObj.mjOBJ_KEY, spec.name) - case mujoco.MjsEquality(): - self.prefix = 'eq_' - idx = name2id(model, mujoco.mjtObj.mjOBJ_EQUALITY, spec.name) - case mujoco.MjsExclude(): - self.prefix = 'exclude_' - idx = name2id(model, mujoco.mjtObj.mjOBJ_EXCLUDE, spec.name) - case mujoco.MjsSkin(): - self.prefix = 'skin_' - idx = name2id(model, mujoco.mjtObj.mjOBJ_SKIN, spec.name) - case mujoco.MjsMaterial(): - self.prefix = 'material_' - idx = name2id(model, mujoco.mjtObj.mjOBJ_MATERIAL, spec.name) - case _: - raise ValueError('invalid spec type') + if isinstance(spec, mujoco.MjsBody): + self.prefix = 'body_' + idx = name2id(model, mujoco.mjtObj.mjOBJ_BODY, spec.name) + elif isinstance(spec, mujoco.MjsJoint): + self.prefix = 'jnt_' + idx = name2id(model, mujoco.mjtObj.mjOBJ_JOINT, spec.name) + elif isinstance(spec, mujoco.MjsGeom): + self.prefix = 'geom_' + idx = name2id(model, mujoco.mjtObj.mjOBJ_GEOM, spec.name) + elif isinstance(spec, mujoco.MjsSite): + self.prefix = 'site_' + idx = name2id(model, mujoco.mjtObj.mjOBJ_SITE, spec.name) + elif isinstance(spec, mujoco.MjsLight): + self.prefix = 'light_' + idx = name2id(model, mujoco.mjtObj.mjOBJ_LIGHT, spec.name) + elif isinstance(spec, mujoco.MjsCamera): + self.prefix = 'cam_' + idx = name2id(model, mujoco.mjtObj.mjOBJ_CAMERA, spec.name) + elif isinstance(spec, mujoco.MjsMesh): + self.prefix = 'mesh_' + idx = name2id(model, mujoco.mjtObj.mjOBJ_MESH, spec.name) + elif isinstance(spec, mujoco.MjsHField): + self.prefix = 'hfield_' + idx = name2id(model, mujoco.mjtObj.mjOBJ_HFIELD, spec.name) + elif isinstance(spec, mujoco.MjsPair): + self.prefix = 'pair_' + idx = name2id(model, mujoco.mjtObj.mjOBJ_PAIR, spec.name) + elif isinstance(spec, mujoco.MjsTendon): + self.prefix = 'tendon_' + idx = name2id(model, mujoco.mjtObj.mjOBJ_TENDON, spec.name) + elif isinstance(spec, mujoco.MjsActuator): + self.prefix = 'actuator_' + idx = name2id(model, mujoco.mjtObj.mjOBJ_ACTUATOR, spec.name) + elif isinstance(spec, mujoco.MjsSensor): + self.prefix = 'sensor_' + idx = name2id(model, mujoco.mjtObj.mjOBJ_SENSOR, spec.name) + elif isinstance(spec, mujoco.MjsNumeric): + self.prefix = 'numeric_' + idx = name2id(model, mujoco.mjtObj.mjOBJ_NUMERIC, spec.name) + elif isinstance(spec, mujoco.MjsText): + self.prefix = 'text_' + idx = name2id(model, mujoco.mjtObj.mjOBJ_TEXT, spec.name) + elif isinstance(spec, mujoco.MjsTuple): + self.prefix = 'tuple_' + idx = name2id(model, mujoco.mjtObj.mjOBJ_TUPLE, spec.name) + elif isinstance(spec, mujoco.MjsKey): + self.prefix = 'key_' + idx = name2id(model, mujoco.mjtObj.mjOBJ_KEY, spec.name) + elif isinstance(spec, mujoco.MjsEquality): + self.prefix = 'eq_' + idx = name2id(model, mujoco.mjtObj.mjOBJ_EQUALITY, spec.name) + elif isinstance(spec, mujoco.MjsExclude): + self.prefix = 'exclude_' + idx = name2id(model, mujoco.mjtObj.mjOBJ_EXCLUDE, spec.name) + elif isinstance(spec, mujoco.MjsSkin): + self.prefix = 'skin_' + idx = name2id(model, mujoco.mjtObj.mjOBJ_SKIN, spec.name) + elif isinstance(spec, mujoco.MjsMaterial): + self.prefix = 'material_' + idx = name2id(model, mujoco.mjtObj.mjOBJ_MATERIAL, spec.name) + else: + raise ValueError('invalid spec type') if idx < 0: - raise KeyError(f'invalid name: {spec.name}') + raise KeyError(f'invalid name: {spec.name}') # pytype: disable=attribute-error ids.append(idx) if len(ids) == 1: self.id = ids[0] @@ -388,41 +387,40 @@ def __init__(self, data: Data, model: Model, specs: Sequence[Any]): specs = [specs] ids = [] for spec in specs: - match spec: - case mujoco.MjsBody(): - self.prefix = '' - idx = name2id(model, mujoco.mjtObj.mjOBJ_BODY, spec.name) - case mujoco.MjsJoint(): - self.prefix = 'jnt_' - idx = name2id(model, mujoco.mjtObj.mjOBJ_JOINT, spec.name) - case mujoco.MjsGeom(): - self.prefix = 'geom_' - idx = name2id(model, mujoco.mjtObj.mjOBJ_GEOM, spec.name) - case mujoco.MjsSite(): - self.prefix = 'site_' - idx = name2id(model, mujoco.mjtObj.mjOBJ_SITE, spec.name) - case mujoco.MjsLight(): - self.prefix = 'light_' - idx = name2id(model, mujoco.mjtObj.mjOBJ_LIGHT, spec.name) - case mujoco.MjsCamera(): - self.prefix = 'cam_' - idx = name2id(model, mujoco.mjtObj.mjOBJ_CAMERA, spec.name) - case mujoco.MjsTendon(): - self.prefix = 'ten_' - idx = name2id(model, mujoco.mjtObj.mjOBJ_TENDON, spec.name) - case mujoco.MjsActuator(): - self.prefix = 'actuator_' - idx = name2id(model, mujoco.mjtObj.mjOBJ_ACTUATOR, spec.name) - case mujoco.MjsSensor(): - self.prefix = 'sensor_' - idx = name2id(model, mujoco.mjtObj.mjOBJ_SENSOR, spec.name) - case mujoco.MjsEquality(): - self.prefix = 'eq_' - idx = name2id(model, mujoco.mjtObj.mjOBJ_EQUALITY, spec.name) - case _: - raise ValueError('invalid spec type') + if isinstance(spec, mujoco.MjsBody): + self.prefix = '' + idx = name2id(model, mujoco.mjtObj.mjOBJ_BODY, spec.name) + elif isinstance(spec, mujoco.MjsJoint): + self.prefix = 'jnt_' + idx = name2id(model, mujoco.mjtObj.mjOBJ_JOINT, spec.name) + elif isinstance(spec, mujoco.MjsGeom): + self.prefix = 'geom_' + idx = name2id(model, mujoco.mjtObj.mjOBJ_GEOM, spec.name) + elif isinstance(spec, mujoco.MjsSite): + self.prefix = 'site_' + idx = name2id(model, mujoco.mjtObj.mjOBJ_SITE, spec.name) + elif isinstance(spec, mujoco.MjsLight): + self.prefix = 'light_' + idx = name2id(model, mujoco.mjtObj.mjOBJ_LIGHT, spec.name) + elif isinstance(spec, mujoco.MjsCamera): + self.prefix = 'cam_' + idx = name2id(model, mujoco.mjtObj.mjOBJ_CAMERA, spec.name) + elif isinstance(spec, mujoco.MjsTendon): + self.prefix = 'ten_' + idx = name2id(model, mujoco.mjtObj.mjOBJ_TENDON, spec.name) + elif isinstance(spec, mujoco.MjsActuator): + self.prefix = 'actuator_' + idx = name2id(model, mujoco.mjtObj.mjOBJ_ACTUATOR, spec.name) + elif isinstance(spec, mujoco.MjsSensor): + self.prefix = 'sensor_' + idx = name2id(model, mujoco.mjtObj.mjOBJ_SENSOR, spec.name) + elif isinstance(spec, mujoco.MjsEquality): + self.prefix = 'eq_' + idx = name2id(model, mujoco.mjtObj.mjOBJ_EQUALITY, spec.name) + else: + raise ValueError('invalid spec type') if idx < 0: - raise KeyError(f'invalid name: {spec.name}') + raise KeyError(f'invalid name: {spec.name}') # pytype: disable=attribute-error ids.append(idx) if len(ids) == 1: self.id = ids[0]