From 1e2e0b3053e8c1866441e154569b6f56a46fc42f Mon Sep 17 00:00:00 2001 From: Alessio Quaglino Date: Mon, 15 Jan 2024 06:38:11 -0800 Subject: [PATCH] Remove all visual-only assets if `discardvisual` is selected, not only the geoms. PiperOrigin-RevId: 598598250 Change-Id: I12e60ce41ad436d037b4877d197decf6bdce0fc8 --- doc/XMLreference.rst | 22 ++- doc/changelog.rst | 16 +- src/user/user_mesh.cc | 17 +++ src/user/user_model.cc | 163 ++++++++++++++++++++- src/user/user_model.h | 28 ++-- src/user/user_objects.cc | 95 ++++++++---- src/user/user_objects.h | 17 ++- src/xml/xml_native_reader.cc | 6 - test/user/testdata/discardvisual.xml | 73 +++++++++ test/user/testdata/discardvisual_false.xml | 73 +++++++++ test/user/user_model_test.cc | 112 ++++++++++++++ 11 files changed, 557 insertions(+), 65 deletions(-) create mode 100644 test/user/testdata/discardvisual.xml create mode 100644 test/user/testdata/discardvisual_false.xml diff --git a/doc/XMLreference.rst b/doc/XMLreference.rst index 8fb1edd3e1..6100d9c165 100644 --- a/doc/XMLreference.rst +++ b/doc/XMLreference.rst @@ -306,14 +306,20 @@ any effect. The settings here are global and apply to the entire model. .. _compiler-discardvisual: :at:`discardvisual`: :at-val:`[false, true], "false" for MJCF, "true" for URDF` - This attribute instructs the parser to discard "visual geoms", defined as geoms whose contype and conaffinity - attributes are both set to 0. This functionality is useful for models that contain two sets of geoms, one for - collisions and the other for visualization. Note that URDF models are usually constructed in this way. It rarely - makes sense to have two sets of geoms in the model, especially since MuJoCo uses convex hulls for collisions, so we - recommend using this feature to discard redundant geoms. Keep in mind however that geoms considered visual per the - above definition can still participate in collisions, if they appear in the explicit list of contact - :ref:`pairs `. The parser does not check this list before discarding geoms; it relies solely on the geom - attributes to make the determination. + This attribute instructs the compiler to discard all model elements which are purely visual and have no effect on the + physics (with one exception, see below). This often enables smaller :ref:`mjModel` structs and faster simulation. + + - All materials are discarded. + - All textures are discarded. + - All geoms with :ref:`contype`=:ref:`conaffinity`=0 are discarded, if they + are not referenced in another MJCF element. If a discarded geom was used for inferring body inertia, an explicit + :ref:`inertial` element is added to the body. + - All meshes which are not referenced by any geom (in particular those discarded above) are discarded. + + The resulting compiled model will have exactly the same dynamics as the original model, with the exception of + raycasting, as used for example by :ref:`rangefinder`, since raycasting reports distances to + visual geoms. When visualizing models compiled with this flag, it is important to remember that colliding geoms are + often placed in a :ref:`group` which is invisible by default. .. _compiler-convexhull: diff --git a/doc/changelog.rst b/doc/changelog.rst index 9665e7a54a..fb886273a8 100644 --- a/doc/changelog.rst +++ b/doc/changelog.rst @@ -5,16 +5,22 @@ Changelog Upcoming version (not yet released) ----------------------------------- +General +^^^^^^^ +1. Improved the :ref:discardvisual compiler flag, which now discards all visual-only assets. See + :ref:discardvisual for details. + MJX ^^^ -1. Added :ref:`dyntype` ``filterexact``. -2. Added :at:`site` transmission. -3. Updated MJX colab tutorial with more stable quadruped environment. -4. Added ``mjx.ray`` which mirrors :ref:`mj_ray` for planes, spheres, capsules, and boxes. +2. Added :ref:`dyntype` ``filterexact``. +3. Added :at:`site` transmission. +4. Updated MJX colab tutorial with more stable quadruped environment. +5. Added ``mjx.ray`` which mirrors :ref:`mj_ray` for planes, spheres, capsules, and boxes. Bug fixes ^^^^^^^^^ -5. Fixed a bug that prevented the use of pins with plugins if flexes are not in the worldbody. Fixes :github:issue:`1270`. +6. Fixed a bug that prevented the use of pins with plugins if flexes are not in the worldbody. Fixes + :github:issue:`1270`. Version 3.1.1 (December 18, 2023) diff --git a/src/user/user_mesh.cc b/src/user/user_mesh.cc index cc494576f5..d3a15e2d9b 100644 --- a/src/user/user_mesh.cc +++ b/src/user/user_mesh.cc @@ -166,6 +166,7 @@ mjCMesh::mjCMesh(mjCModel* _model, mjCDef* _def) { valideigenvalue_ = true; validinequality_ = true; processed_ = false; + visual_ = true; // reset to default if given if (_def) { @@ -708,6 +709,13 @@ void mjCMesh::CopyGraph(int* arr) const { +void mjCMesh::DelTexcoord() { + if (texcoord_) mju_free(texcoord_); + ntexcoord_ = 0; +} + + + // set geom size to match mesh void mjCMesh::FitGeom(mjCGeom* geom, double* meshpos) { // copy mesh pos into meshpos @@ -2229,6 +2237,15 @@ mjCFlex::mjCFlex(mjCModel* _model) { } +bool mjCFlex::HasTexcoord() const { + return !texcoord.empty(); +} + + +void mjCFlex::DelTexcoord() { + texcoord.clear(); +} + // compiler void mjCFlex::Compile(const mjVFS* vfs) { diff --git a/src/user/user_model.cc b/src/user/user_model.cc index 7ceb163a30..a3407b8ee7 100644 --- a/src/user/user_model.cc +++ b/src/user/user_model.cc @@ -649,9 +649,126 @@ void mjCModel::MakeLists(mjCBody* body) { } +// delete material with given name or all materials if the name is omitted +template +static void DeleteMaterial(std::vector& list, std::string_view name = "") { + for (T* plist : list) { + if (name.empty() || plist->material == name) { + plist->material.clear(); + } + } +} + + +// delete texture with given name or all textures if the name is omitted +template +static void DeleteTexture(std::vector& list, std::string_view name = "") { + for (T* plist : list) { + if (name.empty() || plist->texture == name) { + plist->texture.clear(); + } + } +} + + +// delete all texture coordinates +template +static void DeleteTexcoord(std::vector& list) { + for (T* plist : list) { + if (plist->HasTexcoord()) { + plist->DelTexcoord(); + } + } +} + + +// returns a vector that stores the reference correction for each entry +template +static void DeleteElements(std::vector& elements, + const std::vector& discard) { + if (elements.empty()) { + return; + } + + std::vector ndiscard(elements.size(), 0); + + int i = 0; + for (int j=0; jid > 0) { + element->id -= ndiscard[element->id]; + } + } +} + + +template <> +void mjCModel::Delete(std::vector& elements, + const std::vector& discard) { + // update bodies + for (mjCBody* body : bodies) { + body->geoms.erase( + std::remove_if(body->geoms.begin(), body->geoms.end(), + [&discard](mjCGeom* geom) { return discard[geom->id]; }), + body->geoms.end()); + } + + // remove geoms from the main vector + DeleteElements(elements, discard); +} + + +template <> +void mjCModel::Delete(std::vector& elements, + const std::vector& discard) { + DeleteElements(elements, discard); +} + + +template <> +void mjCModel::DeleteAll(std::vector& elements) { + DeleteMaterial(geoms); + DeleteMaterial(skins); + DeleteMaterial(sites); + DeleteMaterial(tendons); + for (mjCMaterial* element : elements) { + delete element; + } + elements.clear(); +} + + +template <> +void mjCModel::DeleteAll(std::vector& elements) { + DeleteTexture(materials); + for (mjCTexture* element : elements) { + delete element; + } + elements.clear(); +} + // index assets -void mjCModel::IndexAssets(void) { +void mjCModel::IndexAssets(bool discard) { // assets referenced in geoms for (int i=0; imeshname.empty()) { mjCBase* m = FindObject(mjOBJ_MESH, pgeom->meshname); if (m) { - pgeom->mesh = (mjCMesh*)m; + if (discard && geoms[i]->visual_) { + // do not associate with a mesh + pgeom->mesh = nullptr; + } else { + // associate mesh with geom + pgeom->mesh = (mjCMesh*)m; + + // mark mesh as not visual + // this is irreversible so only performed when IndexAssets is called with discard + if (discard) { + pgeom->mesh->SetNotVisual(); + } + } } else { throw mjCError(pgeom, "mesh '%s' not found in geom %d", pgeom->meshname.c_str(), i); } @@ -746,6 +875,20 @@ void mjCModel::IndexAssets(void) { } } } + + // discard visual meshes and geoms + if (discard) { + std::vector discard_mesh(meshes.size(), false); + std::vector discard_geom(geoms.size(), false); + + std::transform(meshes.begin(), meshes.end(), discard_mesh.begin(), + [](const mjCMesh* mesh) { return mesh->IsVisual(); }); + std::transform(geoms.begin(), geoms.end(), discard_geom.begin(), + [](const mjCGeom* geom) { return geom->IsVisual(); }); + + Delete(meshes, discard_mesh); + Delete(geoms, discard_geom); + } } @@ -2049,8 +2192,8 @@ void mjCModel::CopyObjects(mjModel* m) { // geom pairs to include for (int i=0; ipair_dim[i] = pairs[i]->condim; - m->pair_geom1[i] = pairs[i]->geom1; - m->pair_geom2[i] = pairs[i]->geom2; + m->pair_geom1[i] = pairs[i]->geom1->id; + m->pair_geom2[i] = pairs[i]->geom2->id; m->pair_signature[i] = pairs[i]->signature; copyvec(m->pair_solref+mjNREF*i, pairs[i]->solref, mjNREF); copyvec(m->pair_solreffriction+mjNREF*i, pairs[i]->solreffriction, mjNREF); @@ -2711,8 +2854,16 @@ void mjCModel::TryCompile(mjModel*& m, mjData*& d, const mjVFS* vfs) { } } + // delete visual assets + if (discardvisual) { + DeleteAll(materials); + DeleteTexcoord(flexes); + DeleteTexcoord(meshes); + DeleteAll(textures); + } + // convert names into indices - IndexAssets(); + IndexAssets(false); // mark meshes that need convex hull for (int i=0; i + void Delete(std::vector& elements, + const std::vector& discard); // delete elements marked as discard=true + + template + void DeleteAll(std::vector& elements); // delete all elements + //------------------------ API for access to model elements (outside tree) int NumObjects(mjtObj type); // number of objects in specified list mjCBase* GetObject(mjtObj type, int id); // pointer to specified object @@ -185,16 +193,16 @@ class mjCModel { void SetDefaultNames(std::vector& assets); //------------------------ compile phases - void MakeLists(mjCBody* body); // make lists of bodies, geoms, joints, sites - void IndexAssets(void); // convert asset names into indices - void CheckEmptyNames(void); // check empty names - void SetSizes(void); // compute sizes - void AutoSpringDamper(mjModel*);// automatic stiffness and damping computation - void LengthRange(mjModel*, mjData*); // compute actuator lengthrange - void CopyNames(mjModel*); // copy names, compute name addresses - void CopyPaths(mjModel*); // copy paths, compute path addresses - void CopyObjects(mjModel*); // copy objects outside kinematic tree - void CopyTree(mjModel*); // copy objects inside kinematic tree + void MakeLists(mjCBody* body); // make lists of bodies, geoms, joints, sites + void IndexAssets(bool discard); // convert asset names into indices + void CheckEmptyNames(void); // check empty names + void SetSizes(void); // compute sizes + void AutoSpringDamper(mjModel*); // automatic stiffness and damping computation + void LengthRange(mjModel*, mjData*); // compute actuator lengthrange + void CopyNames(mjModel*); // copy names, compute name addresses + void CopyPaths(mjModel*); // copy paths, compute path addresses + void CopyObjects(mjModel*); // copy objects outside kinematic tree + void CopyTree(mjModel*); // copy objects inside kinematic tree //------------------------ sizes // sizes set from object list lengths diff --git a/src/user/user_objects.cc b/src/user/user_objects.cc index bffee237e8..4797363dea 100644 --- a/src/user/user_objects.cc +++ b/src/user/user_objects.cc @@ -314,6 +314,7 @@ int mjCBoundingVolumeHierarchy::MakeBVH(std::vector& e return -1; } + bool is_visual = true; int nelements = elements.size(); mjtNum AAMM[6] = {mjMAXVAL, mjMAXVAL, mjMAXVAL, -mjMAXVAL, -mjMAXVAL, -mjMAXVAL}; @@ -325,6 +326,8 @@ int mjCBoundingVolumeHierarchy::MakeBVH(std::vector& e // skip visual objects if (elements[i]->conaffinity==0 && elements[i]->contype==0) { continue; + } else { + is_visual = false; } // transform element aabb to aamm format @@ -360,6 +363,11 @@ int mjCBoundingVolumeHierarchy::MakeBVH(std::vector& e } } + // a body with only visual geoms does not have a bvh + if (is_visual) { + return nbvh; + } + // inflate flat AABBs for (int i=0; i<3; i++) { if (mju_abs(AAMM[i]-AAMM[i+3])name); } } + + if (!model->discardvisual) { + return; + } + + // set inertial to explicit for bodies containing visual geoms + for (int j=0; jIsVisual()) { + explicitinertial = true; + break; + } + } } @@ -1350,6 +1370,7 @@ mjCGeom::mjCGeom(mjCModel* _model, mjCDef* _def) { matid = -1; mesh = nullptr; hfield = nullptr; + visual_ = false; // reset to default if given if (_def) { @@ -1737,6 +1758,9 @@ void mjCGeom::Compile(void) { name.c_str(), id); } + // check if can collide + visual_ = !contype && !conaffinity; + // normalize quaternion mjuu_normvec(quat, 4); @@ -3145,7 +3169,9 @@ mjCPair::mjCPair(mjCModel* _model, mjCDef* _def) { friction[4] = 0.0001; // clear internal variables - geom1 = geom2 = signature = -1; + geom1 = nullptr; + geom2 = nullptr; + signature = -1; // reset to default if given if (_def) { @@ -3167,46 +3193,48 @@ void mjCPair::Compile(void) { } // find geom 1 - mjCGeom* pg1 = (mjCGeom*)model->FindObject(mjOBJ_GEOM, geomname1); - if (!pg1) { + geom1 = (mjCGeom*)model->FindObject(mjOBJ_GEOM, geomname1); + if (!geom1) { throw mjCError(this, "geom '%s' not found in collision %d", geomname1.c_str(), id); } // find geom 2 - mjCGeom* pg2 = (mjCGeom*)model->FindObject(mjOBJ_GEOM, geomname2); - if (!pg2) { + geom2 = (mjCGeom*)model->FindObject(mjOBJ_GEOM, geomname2); + if (!geom2) { throw mjCError(this, "geom '%s' not found in collision %d", geomname2.c_str(), id); } + // mark geoms as not visual + geom1->SetNotVisual(); + geom2->SetNotVisual(); + // swap if body1 > body2 - if (pg1->body->id > pg2->body->id) { + if (geom1->body->id > geom2->body->id) { string nametmp = geomname1; geomname1 = geomname2; geomname2 = nametmp; - mjCGeom* geomtmp = pg1; - pg1 = pg2; - pg2 = geomtmp; + mjCGeom* geomtmp = geom1; + geom1 = geom2; + geom2 = geomtmp; } // get geom ids and body signature - geom1 = pg1->id; - geom2 = pg2->id; - signature = ((pg1->body->id)<<16) + pg2->body->id; + signature = ((geom1->body->id)<<16) + geom2->body->id; // set undefined margin: max if (!mjuu_defined(margin)) { - margin = mjMAX(pg1->margin, pg2->margin); + margin = mjMAX(geom1->margin, geom2->margin); } // set undefined gap: max if (!mjuu_defined(gap)) { - gap = mjMAX(pg1->gap, pg2->gap); + gap = mjMAX(geom1->gap, geom2->gap); } // set undefined condim, friction, solref, solimp: different priority - if (pg1->priority!=pg2->priority) { - mjCGeom* pgh = (pg1->priority>pg2->priority ? pg1 : pg2); + if (geom1->priority!=geom2->priority) { + mjCGeom* pgh = (geom1->priority>geom2->priority ? geom1 : geom2); // condim if (condim<0) { @@ -3239,23 +3267,23 @@ void mjCPair::Compile(void) { else { // condim: max if (condim<0) { - condim = mjMAX(pg1->condim, pg2->condim); + condim = mjMAX(geom1->condim, geom2->condim); } // friction: max if (!mjuu_defined(friction[0])) { - friction[0] = friction[1] = mju_max(pg1->friction[0], pg2->friction[0]); - friction[2] = mju_max(pg1->friction[1], pg2->friction[1]); - friction[3] = friction[4] = mju_max(pg1->friction[2], pg2->friction[2]); + friction[0] = friction[1] = mju_max(geom1->friction[0], geom2->friction[0]); + friction[2] = mju_max(geom1->friction[1], geom2->friction[1]); + friction[3] = friction[4] = mju_max(geom1->friction[2], geom2->friction[2]); } // solver mix factor double mix; - if (pg1->solmix>=mjMINVAL && pg2->solmix>=mjMINVAL) { - mix = pg1->solmix / (pg1->solmix + pg2->solmix); - } else if (pg1->solmixsolmixsolmix>=mjMINVAL && geom2->solmix>=mjMINVAL) { + mix = geom1->solmix / (geom1->solmix + geom2->solmix); + } else if (geom1->solmixsolmixsolmixsolmix0) { for (int i=0; isolref[i] + (1-mix)*pg2->solref[i]; + solref[i] = mix*geom1->solref[i] + (1-mix)*geom2->solref[i]; } } // direct: min else { for (int i=0; isolref[i], pg2->solref[i]); + solref[i] = mju_min(geom1->solref[i], geom2->solref[i]); } } } @@ -3281,7 +3309,7 @@ void mjCPair::Compile(void) { // impedance if (!mjuu_defined(solimp[0])) { for (int i=0; isolimp[i] + (1-mix)*pg2->solimp[i]; + solimp[i] = mix*geom1->solimp[i] + (1-mix)*geom2->solimp[i]; } } } @@ -3666,6 +3694,9 @@ void mjCTendon::Compile(void) { "tendon '%s' (id = %d): geom at pos %d not bracketed by sites", name.c_str(), id, i); } + + // mark geoms as non visual + model->geoms[path[i]->obj->id]->SetNotVisual(); break; case mjWRAP_JOINT: @@ -4107,6 +4138,11 @@ void mjCSensor::Compile(void) { name.c_str(), id); } + // if geom mark it as non visual + if (objtype == mjOBJ_GEOM) { + ((mjCGeom*)obj)->SetNotVisual(); + } + // get sensorized object id } else if (type != mjSENS_CLOCK && type != mjSENS_PLUGIN && type != mjSENS_USER) { throw mjCError(this, "invalid type in sensor '%s' (id = %d)", name.c_str(), id); @@ -4579,6 +4615,11 @@ void mjCTuple::Compile(void) { throw mjCError(this, "unrecognized object '%s' in tuple %d", objname[i].c_str(), id); } + // if geom mark it as non visual + if (objtype[i] == mjOBJ_GEOM) { + ((mjCGeom*)res)->SetNotVisual(); + } + // assign id obj[i] = res; } diff --git a/src/user/user_objects.h b/src/user/user_objects.h index 4b6df1bdaf..6168786815 100644 --- a/src/user/user_objects.h +++ b/src/user/user_objects.h @@ -381,6 +381,8 @@ class mjCGeom : public mjCBase { public: double GetVolume(void); // compute geom volume void SetInertia(void); // compute and set geom inertia + bool IsVisual(void) const { return visual_; } + void SetNotVisual(void) { visual_ = false; } // Compute all coefs modeling the interaction with the surrounding fluid. void SetFluidCoefs(void); @@ -432,6 +434,7 @@ class mjCGeom : public mjCBase { double GetRBound(void); // compute bounding sphere radius void ComputeAABB(void); // compute axis-aligned bounding box + bool visual_; // true: geom does not collide and is unreferenced int matid; // id of geom's material mjCMesh* mesh; // geom's mesh mjCHField* hfield; // geom's hfield @@ -591,6 +594,9 @@ class mjCFlex: public mjCBase { std::vector elem; // element vertex ids std::vector texcoord; // vertex texture coordinates + bool HasTexcoord() const; // texcoord not null + void DelTexcoord(); // delete texcoord + private: mjCFlex(mjCModel* = 0); // constructor void Compile(const mjVFS* vfs); // compiler @@ -680,6 +686,9 @@ class mjCMesh: public mjCBase { double& GetVolumeRef(mjtMeshType type); // get volume void FitGeom(mjCGeom* geom, double* meshpos); // approximate mesh with simple geom bool HasTexcoord() const; // texcoord not null + void DelTexcoord(); // delete texcoord + bool IsVisual(void) const { return visual_; } // is geom visual + void SetNotVisual(void) { visual_ = false; } // mark mesh as not visual void CopyVert(float* arr) const; // copy vert data into array void CopyNormal(float* arr) const; // copy normal data into array @@ -693,6 +702,7 @@ class mjCMesh: public mjCBase { void SetBoundingVolume(int faceid); private: + bool visual_; // true: the mesh is only visual std::string content_type_; // content type of file std::string file_; // mesh file double refpos_[3]; // reference position (translate) @@ -847,6 +857,8 @@ class mjCTexture : public mjCBase { friend class mjXWriter; public: + ~mjCTexture(); // destructor + std::string get_file() const { return file; } mjtTexture type; // texture type @@ -876,7 +888,6 @@ class mjCTexture : public mjCBase { private: mjCTexture(mjCModel*); // constructor - ~mjCTexture(); // destructior void Compile(const mjVFS* vfs); // compiler void Builtin2D(void); // make builtin 2D @@ -959,8 +970,8 @@ class mjCPair : public mjCBase { mjCPair(mjCModel* = 0, mjCDef* = 0); // constructor void Compile(void); // compiler - int geom1; // id of geom1 - int geom2; // id of geom2 + mjCGeom* geom1; // geom1 + mjCGeom* geom2; // geom2 int signature; // body1<<16 + body2 }; diff --git a/src/xml/xml_native_reader.cc b/src/xml/xml_native_reader.cc index bd438c5e51..48de27c2b6 100644 --- a/src/xml/xml_native_reader.cc +++ b/src/xml/xml_native_reader.cc @@ -3019,12 +3019,6 @@ void mjXReader::Body(XMLElement* section, mjCBody* pbody, mjCFrame* frame) { mjCGeom* pgeom = pbody->AddGeom(def); OneGeom(elem, pgeom); pgeom->SetFrame(frame); - - // discard visual - if (!pgeom->contype && !pgeom->conaffinity && model->discardvisual) { - delete pbody->geoms.back(); - pbody->geoms.pop_back(); - } } // site sub-element diff --git a/test/user/testdata/discardvisual.xml b/test/user/testdata/discardvisual.xml new file mode 100644 index 0000000000..2ca12e42ea --- /dev/null +++ b/test/user/testdata/discardvisual.xml @@ -0,0 +1,73 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/test/user/testdata/discardvisual_false.xml b/test/user/testdata/discardvisual_false.xml new file mode 100644 index 0000000000..582ffadc4a --- /dev/null +++ b/test/user/testdata/discardvisual_false.xml @@ -0,0 +1,73 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/test/user/user_model_test.cc b/test/user/user_model_test.cc index 2079b1d199..bc0da348e5 100644 --- a/test/user/user_model_test.cc +++ b/test/user/user_model_test.cc @@ -15,6 +15,7 @@ // Tests for user/user_model.cc. #include +#include #include #include @@ -283,5 +284,116 @@ TEST_F(FuseStaticTest, FuseStaticEquivalent) { mj_deleteModel(m_no_fuse); } +// ------------- test discardvisual -------------------------------------------- + +using DiscardVisualTest = MujocoTest; +TEST_F(DiscardVisualTest, DiscardVisualKeepsInertia) { + static constexpr char xml[] = R"( + + + + + + + + + + + + + + + + + + + )"; + + std::array error; + mjModel* model = LoadModelFromString(xml, error.data(), error.size()); + EXPECT_THAT(model, NotNull()) << error.data(); + EXPECT_THAT(model->nmesh, 1); + EXPECT_THAT(model->body_inertia[3], model->body_inertia[6]); + EXPECT_THAT(model->body_inertia[4], model->body_inertia[7]); + EXPECT_THAT(model->body_inertia[5], model->body_inertia[8]); + mj_deleteModel(model); +} + +TEST_F(DiscardVisualTest, DiscardVisualEquivalent) { + char error[1024]; + size_t error_sz = 1024; + + static const char* const kDiscardvisualPath = + "user/testdata/discardvisual.xml"; + static const char* const kDiscardvisualFalsePath = + "user/testdata/discardvisual_false.xml"; + + const std::string xml_path1 = GetTestDataFilePath(kDiscardvisualPath); + mjModel* model1 = mj_loadXML(xml_path1.c_str(), 0, error, error_sz); + EXPECT_THAT(model1, NotNull()) << error; + + const std::string xml_path2 = GetTestDataFilePath(kDiscardvisualFalsePath); + mjModel* model2 = mj_loadXML(xml_path2.c_str(), 0, error, error_sz); + EXPECT_THAT(model2, NotNull()) << error; + + EXPECT_THAT(model1->nq, model2->nq); + EXPECT_THAT(model1->nmat, 0); + EXPECT_THAT(model1->ntex, 0); + EXPECT_THAT(model2->ngeom-model1->ngeom, 3); + EXPECT_THAT(model2->nmesh-model1->nmesh, 2); + EXPECT_THAT(model1->npair, model2->npair); + EXPECT_THAT(model1->nsensor, model2->nsensor); + EXPECT_THAT(model1->nwrap, model2->nwrap); + + for (int i = 0; i < model1->ngeom; i++) { + std::string name = std::string(model1->names + model1->name_geomadr[i]); + EXPECT_NE(name.find("kept"), std::string::npos); + EXPECT_EQ(name.find("discard"), std::string::npos); + } + + for (int i = 0; i < model1->npair; i++) { + int adr1 = model1->name_geomadr[model1->pair_geom1[i]]; + int adr2 = model2->name_geomadr[model2->pair_geom1[i]]; + EXPECT_STREQ(model1->names + adr1, model2->names + adr2); + adr1 = model1->name_geomadr[model1->pair_geom2[i]]; + adr2 = model2->name_geomadr[model2->pair_geom2[i]]; + EXPECT_STREQ(model1->names + adr1, model2->names + adr2); + } + + for (int i = 0; i < model1->nsensor; i++) { + int adr1 = model1->name_geomadr[model1->sensor_objid[i]]; + int adr2 = model2->name_geomadr[model2->sensor_objid[i]]; + EXPECT_STREQ(model1->names + adr1, model2->names + adr2); + } + + for (int i = 0; i < model1->nwrap; i++) { + int adr1 = model1->name_geomadr[model1->wrap_objid[i]]; + int adr2 = model2->name_geomadr[model2->wrap_objid[i]]; + EXPECT_STREQ(model1->names + adr1, model2->names + adr2); + } + + mjData *d1 = mj_makeData(model1); + mjData *d2 = mj_makeData(model2); + for (int i = 0; i < 100; i++) { + mj_step(model1, d1); + mj_step(model2, d2); + } + + for (int i = 0; i < model1->nq; i++) { + EXPECT_THAT(d1->qpos[i], d2->qpos[i]); + } + + mj_deleteModel(model1); + mj_deleteModel(model2); + mj_deleteData(d1); + mj_deleteData(d2); +} + } // namespace } // namespace mujoco