diff --git a/python/mujoco/codegen/generate_spec_bindings.py b/python/mujoco/codegen/generate_spec_bindings.py index c6764a3de3..d37fc8ccca 100644 --- a/python/mujoco/codegen/generate_spec_bindings.py +++ b/python/mujoco/codegen/generate_spec_bindings.py @@ -70,6 +70,19 @@ def _value_binding_code( return f'{classname}.def_property({",".join(def_property_args)});' +def _struct_binding_code( + field: ast_nodes.AnonymousStructDecl, classname: str = '', varname: str = '' +) -> str: + code = '' + name = classname + varname.title() + # explicitly generate for nested fields with arrays + if any(isinstance(f.type, ast_nodes.ArrayType) for f in field.fields): + for subfield in field.fields: + code += _binding_code(subfield, name) + # generate for the struct itself + field = ast_nodes.ValueType(name=name) + code += _value_binding_code(field, classname, varname) + return code def _array_binding_code( field: ast_nodes.ArrayType, classname: str = '', varname: str = '' @@ -227,13 +240,7 @@ def _binding_code(field: ast_nodes.StructFieldDecl, key: str) -> str: if isinstance(field.type, ast_nodes.ValueType): return _value_binding_code(field.type, key, field.name) elif isinstance(field.type, ast_nodes.AnonymousStructDecl): - code = "" - if field.name in ['headlight', 'rgba']: - for subfield in field.type.fields: - code += _binding_code(subfield, 'mjVisual'+field.name.title()) - field.type = ast_nodes.ValueType(name='mjVisual'+field.name.title()) - code += _value_binding_code(field.type, key, field.name) - return code + return _struct_binding_code(field.type, key, field.name) elif isinstance(field.type, ast_nodes.PointerType): return _ptr_binding_code(field.type, key, field.name) elif isinstance(field.type, ast_nodes.ArrayType):