Skip to content

Commit

Permalink
Python bindings: Store a reference to MjModel in MjDataWrapper.
Browse files Browse the repository at this point in the history
Before this change, MjDataWrapper contained a custom data structure called MjDataMetadata, which contained the sizes of various arrays needed to create MjData. This was used for serialization and deserialization, as well as copying MjData instances.

This worked fine for native MuJoCo models, but as soon as plugins were used, there was information in the model that was needed and not available in MjDataMetadata. Since plugins are so general, keeping MjDataMetadata was untenable.

PiperOrigin-RevId: 578205355
Change-Id: I11b9ee797d1da5aaf1fecbbe170b3b0f1f1a3e6c
  • Loading branch information
nimrod-gileadi authored and copybara-github committed Oct 31, 2023
1 parent 64a59bb commit 084facc
Show file tree
Hide file tree
Showing 9 changed files with 175 additions and 230 deletions.
10 changes: 6 additions & 4 deletions doc/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -23,20 +23,22 @@ Python bindings

6. Fix the macOS ``mjpython`` launcher to work with the Python interpreter from Apple Command Line
Tools.
7. Fixed a crash when copying instances of ``mujoco.MjData`` for models that use plugins. Introduced a ``model``
attribute to ``MjData`` which is reference to the model that was used to create that ``MjData`` instance.

Simulate
^^^^^^^^
7. :ref:`simulate<saSimulate>`: correct handling of "Pause update", "Fullscreen" and "VSync" buttons.
8. :ref:`simulate<saSimulate>`: correct handling of "Pause update", "Fullscreen" and "VSync" buttons.

Documentation
^^^^^^^^^^^^^
8. Added documentation for the :ref:`UI` framework.
9. Fixed typos and supported fields in docs (fixes :github:issue:`1105` and :github:issue:`1106`).
9. Added documentation for the :ref:`UI` framework.
10. Fixed typos and supported fields in docs (fixes :github:issue:`1105` and :github:issue:`1106`).


Bug fixes
^^^^^^^^^
10. Fixed bug relating to welds modified with :ref:`torquescale<equality-weld-torquescale>`.
11. Fixed bug relating to welds modified with :ref:`torquescale<equality-weld-torquescale>`.

Version 3.0.0 (October 18, 2023)
--------------------------------
Expand Down
1 change: 0 additions & 1 deletion python/mujoco/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -265,7 +265,6 @@ target_sources(
structs_header
INTERFACE indexer_xmacro.h
indexers.h
mjdata_meta.h
structs.h
)
set_target_properties(structs_header PROPERTIES PUBLIC_HEADER structs.h)
Expand Down
58 changes: 54 additions & 4 deletions python/mujoco/bindings_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,10 +74,21 @@
"""

TEST_XML_PLUGIN = r"""
<mujoco model="test">
<mujoco>
<option gravity="0 0 0"/>
<extension>
<plugin plugin="mujoco.elasticity.cable"/>
</extension>
<worldbody>
<composite type="cable" curve="s" count="41 1 1" size="1" offset="0 0 1" initial="none">
<plugin plugin="mujoco.elasticity.cable">
<config key="twist" value="1e6"/>
<config key="bend" value="1e9"/>
</plugin>
<joint kind="main" damping="2"/>
<geom type="capsule" size=".005" density="1"/>
</composite>
</worldbody>
</mujoco>
"""

Expand Down Expand Up @@ -1048,10 +1059,24 @@ def test_mjcb_control_not_leak_memory(self):
while data_instances:
d = data_instances.pop()
self.assertEqual(sys.getrefcount(d), 2)
del d
while model_instances:
m = model_instances.pop()
self.assertEqual(sys.getrefcount(m), 2)

# This test is disabled on PyPy as it uses sys.getrefcount
# However PyPy is not officially supported by MuJoCo
@absltest.skipIf(sys.implementation.name == 'pypy',
reason='requires sys.getrefcount')
def test_mjdata_holds_ref_to_model(self):
data = mujoco.MjData(mujoco.MjModel.from_xml_string('<mujoco/>'))
model = data.model
# references: one in `data.model, one in `model`, one in the temporary ref
# passed to getrefcount.
self.assertEqual(sys.getrefcount(data.model), 3)
del data
self.assertEqual(sys.getrefcount(model), 2)

def test_can_initialize_mjv_structs(self):
self.assertIsInstance(mujoco.MjvScene(), mujoco.MjvScene)
self.assertIsInstance(mujoco.MjvCamera(), mujoco.MjvCamera)
Expand Down Expand Up @@ -1287,6 +1312,34 @@ def test_indexer_name_id(self):
self.assertEqual(data.geom(3).xpos[2], 4)
self.assertEqual(data.geom(4).xpos[2], 5)

def test_load_plugin(self):
model = mujoco.MjModel.from_xml_string(TEST_XML_PLUGIN)
data = mujoco.MjData(model)
mujoco.mj_forward(model, data)

def test_copy_mjdata_with_plugin(self):
model = mujoco.MjModel.from_xml_string(TEST_XML_PLUGIN)
data1 = mujoco.MjData(model)
self.assertIs(data1.model, model)
mujoco.mj_step(model, data1)
data2 = copy.copy(data1)
mujoco.mj_step(model, data1)
mujoco.mj_step(model, data2)
np.testing.assert_array_equal(data1.qpos, data2.qpos)
self.assertIs(data1.model, data2.model)

def test_deepcopy_mjdata_with_plugin(self):
model = mujoco.MjModel.from_xml_string(TEST_XML_PLUGIN)
data1 = mujoco.MjData(model)
self.assertIs(data1.model, model)
mujoco.mj_step(model, data1)
data2 = copy.deepcopy(data1)
mujoco.mj_step(model, data1)
mujoco.mj_step(model, data2)
np.testing.assert_array_equal(data1.qpos, data2.qpos)
self.assertIsNot(data1.model, data2.model)
self.assertNotEqual(data1.model._address, data2.model._address)

def _assert_attributes_equal(self, actual_obj, expected_obj, attr_to_compare):
for name in attr_to_compare:
actual_value = getattr(actual_obj, name)
Expand All @@ -1300,9 +1353,6 @@ def _assert_attributes_equal(self, actual_obj, expected_obj, attr_to_compare):
self.fail("Attribute '{}' differs from expected value: {}".format(
name, str(e)))

def test_load_plugin(self):
mujoco.MjModel.from_xml_string(TEST_XML_PLUGIN)


if __name__ == '__main__':
absltest.main()
5 changes: 3 additions & 2 deletions python/mujoco/functions.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@

#include <Eigen/Core>
#include <mujoco/mjxmacro.h>
#include <mujoco/mujoco.h>
#include "function_traits.h"
#include "functions.h"
#include "private.h"
Expand Down Expand Up @@ -1367,7 +1368,7 @@ PYBIND11_MODULE(_functions, pymodule) {
}

#undef MJ_M
#define MJ_M(x) d.metadata().x
#define MJ_M(x) d.model().get()->x
#undef MJ_D
#define MJ_D(x) data->x
#define X(type, name, nr, nc) \
Expand All @@ -1379,7 +1380,7 @@ PYBIND11_MODULE(_functions, pymodule) {
}

MJDATA_ARENA_POINTERS_PRIMAL
if (d.metadata().is_dual) {
if (mj_isDual(d.model().get())) {
MJDATA_ARENA_POINTERS_DUAL
}
#undef X
Expand Down
34 changes: 13 additions & 21 deletions python/mujoco/indexers.cc
Original file line number Diff line number Diff line change
Expand Up @@ -13,16 +13,13 @@
// limitations under the License.

#include <algorithm>
#include <memory>
#include <optional>
#include <string>
#include <string_view>
#include <utility>
#include <variant>
#include <vector>

#include "indexers.h"
#include "mjdata_meta.h"
#include "raw.h"
#include "util/crossplatform.h"

Expand Down Expand Up @@ -76,7 +73,6 @@ IDToName MakeIDToName(int count, IntPtr name_offsets, CharPtr names) {
// the MuJoCo category of the entity itself, e.g. nbody indicates that the
// field belongs to a body.
// T: Scalar data type of the field.
// M: Either raw::MjModel or MjDataMetadata.
//
// Args:
// base_ptr: Pointer to the first entry in the entire field.
Expand All @@ -87,35 +83,35 @@ IDToName MakeIDToName(int count, IntPtr name_offsets, CharPtr names) {
// additional dimension of the size len(qvel) of the particular joint.
// m: Used for dereferencing MjSize.
// owner: The base object whose lifetime is tied to the returned array.
template <auto MjSize, typename T, typename M>
template <auto MjSize, typename T>
py::array_t<T> MakeArray(T* base_ptr, int index, std::vector<int>&& shape,
const M& m, py::handle owner) {
const raw::MjModel& m, py::handle owner) {
int offset;
if (MjSize == &M::nq) {
if (MjSize == &raw::MjModel::nq) {
offset = m.jnt_qposadr[index];
shape.insert(
shape.begin(),
((index < m.njnt-1) ? m.jnt_qposadr[index+1] : m.nq) - offset);
} else if (MjSize == &M::nv) {
} else if (MjSize == &raw::MjModel::nv) {
offset = m.jnt_dofadr[index];
shape.insert(
shape.begin(),
((index < m.njnt-1) ? m.jnt_dofadr[index+1] : m.nv) - offset);
} else if (MjSize == &M::nhfielddata) {
} else if (MjSize == &raw::MjModel::nhfielddata) {
offset = m.hfield_adr[index];
shape.insert(shape.begin(), m.hfield_ncol[index]);
shape.insert(shape.begin(), m.hfield_nrow[index]);
} else if (MjSize == &M::ntexdata) {
} else if (MjSize == &raw::MjModel::ntexdata) {
offset = m.tex_adr[index];
shape.insert(shape.begin(), m.tex_width[index]);
shape.insert(shape.begin(), m.tex_height[index]);
} else if (MjSize == &M::nsensordata) {
} else if (MjSize == &raw::MjModel::nsensordata) {
offset = m.sensor_adr[index];
shape.insert(shape.begin(), m.sensor_dim[index]);
} else if (MjSize == &M::nnumericdata) {
} else if (MjSize == &raw::MjModel::nnumericdata) {
offset = m.numeric_adr[index];
shape.insert(shape.begin(), m.numeric_size[index]);
} else if (MjSize == &M::ntupledata) {
} else if (MjSize == &raw::MjModel::ntupledata) {
offset = m.tuple_adr[index];
shape.insert(shape.begin(), m.tuple_size[index]);
} else {
Expand All @@ -135,9 +131,7 @@ py::array_t<T> MakeArray(T* base_ptr, int index, std::vector<int>&& shape,
}
} // namespace

// M is either a raw::MjModel or MjDataMetadata.
template <typename M>
NameToIDMappings::NameToIDMappings(const M& m)
NameToIDMappings::NameToIDMappings(const raw::MjModel& m)
: body(MakeNameToID(m.nbody, m.name_bodyadr, m.names)),
jnt(MakeNameToID(m.njnt, m.name_jntadr, m.names)),
geom(MakeNameToID(m.ngeom, m.name_geomadr, m.names)),
Expand All @@ -160,9 +154,7 @@ NameToIDMappings::NameToIDMappings(const M& m)
tuple(MakeNameToID(m.ntuple, m.name_tupleadr, m.names)),
key(MakeNameToID(m.nkey, m.name_keyadr, m.names)) {}

// M is either a raw::MjModel or MjDataMetadata.
template <typename M>
IDToNameMappings::IDToNameMappings(const M& m)
IDToNameMappings::IDToNameMappings(const raw::MjModel& m)
: body(MakeIDToName(m.nbody, m.name_bodyadr, m.names)),
jnt(MakeIDToName(m.njnt, m.name_jntadr, m.names)),
geom(MakeIDToName(m.ngeom, m.name_geomadr, m.names)),
Expand Down Expand Up @@ -223,7 +215,7 @@ MJMODEL_VIEW_GROUPS
MJMODEL_VIEW_GROUPS
#undef XGROUP

MjDataIndexer::MjDataIndexer(raw::MjData* d, const MjDataMetadata* m,
MjDataIndexer::MjDataIndexer(raw::MjData* d, const raw::MjModel* m,
py::handle owner)
: d_(d),
m_(m),
Expand Down Expand Up @@ -369,7 +361,7 @@ MJMODEL_KEYFRAME
#define X(type, prefix, var, dim0, dim1) \
py::array_t<type> XGROUP::var() { \
if (!var##_.has_value()) { \
var##_.emplace(MakeArray<&MjDataMetadata::dim0>( \
var##_.emplace(MakeArray<&raw::MjModel::dim0>( \
d_->prefix##var, index_, MAKE_SHAPE(dim1), *m_, owner_)); \
} \
return *var##_; \
Expand Down
17 changes: 6 additions & 11 deletions python/mujoco/indexers.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
#include <absl/container/flat_hash_map.h>
#include <mujoco/mjxmacro.h>
#include "indexer_xmacro.h"
#include "mjdata_meta.h"
#include "raw.h"
#include <pybind11/numpy.h>
#include <pybind11/pybind11.h>
Expand All @@ -35,9 +34,7 @@ using NameToID = absl::flat_hash_map<std::string, int>;
using IDToName = std::vector<std::string>;

struct NameToIDMappings {
// M is either a raw::MjModel or MjDataMetadata.
template <typename M>
explicit NameToIDMappings(const M& m);
explicit NameToIDMappings(const raw::MjModel& m);

NameToID body;
NameToID jnt;
Expand All @@ -63,9 +60,7 @@ struct NameToIDMappings {
};

struct IDToNameMappings {
// M is either a raw::MjModel or MjDataMetadata.
template <typename M>
explicit IDToNameMappings(const M& m);
explicit IDToNameMappings(const raw::MjModel& m);

IDToName body;
IDToName jnt;
Expand Down Expand Up @@ -163,7 +158,7 @@ class MjModelIndexer {
class MjDataGroupedViewsBase {
public:
MjDataGroupedViewsBase(int index, std::string_view name, raw::MjData* d,
const MjDataMetadata* m,
const raw::MjModel* m,
pybind11::handle owner)
: index_(index), name_(name), d_(d), m_(m), owner_(owner) {}

Expand All @@ -175,7 +170,7 @@ class MjDataGroupedViewsBase {
int index_;
std::string name_;
raw::MjData* d_;
const MjDataMetadata* m_;
const raw::MjModel* m_;
pybind11::handle owner_;
};

Expand All @@ -201,7 +196,7 @@ MJDATA_VIEW_GROUPS
// (e.g. a particular geom or joint) either by name or by ID.
class MjDataIndexer {
public:
MjDataIndexer(raw::MjData* d, const MjDataMetadata* m,
MjDataIndexer(raw::MjData* d, const raw::MjModel* m,
pybind11::handle owner);

#define XGROUP(MjDataGroupedViews, field, nfield, FIELD_XMACROS) \
Expand All @@ -213,7 +208,7 @@ class MjDataIndexer {

private:
raw::MjData* d_;
const MjDataMetadata* m_;
const raw::MjModel* m_;
pybind11::handle owner_;
NameToIDMappings name_to_id_;
IDToNameMappings id_to_name_;
Expand Down
Loading

0 comments on commit 084facc

Please sign in to comment.