Skip to content

Commit

Permalink
clean up rollout.cc
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 599849967
Change-Id: I899b28bf78bf0414b07314c24cfe6e24441eba30
  • Loading branch information
yuvaltassa authored and copybara-github committed Jan 19, 2024
1 parent e0864ab commit a23a368
Showing 1 changed file with 15 additions and 17 deletions.
32 changes: 15 additions & 17 deletions python/mujoco/rollout.cc
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,14 @@
// See the License for the specific language governing permissions and
// limitations under the License.

#include <array>
#include <cstdio>
#include <iostream>
#include <optional>
#include <sstream>
#include <string>

#include "functions.h"
#include <mujoco/mujoco.h>
#include "errors.h"
#include "raw.h"
#include "structs.h"
#include <pybind11/buffer_info.h>
#include <pybind11/numpy.h>
#include <pybind11/pybind11.h>
Expand Down Expand Up @@ -74,7 +73,6 @@ void _unsafe_rollout(const mjModel* m, mjData* d, int nstate, int nstep,

// loop over initial states
for (int s=0; s < nstate; s++) {

// set initial state
if (state0) {
mju_copy(d->qpos, state0 + s*nqva, nq);
Expand Down Expand Up @@ -108,9 +106,9 @@ void _unsafe_rollout(const mjModel* m, mjData* d, int nstate, int nstep,
mju_zero(d->xfrc_applied, 6*nbody);
}
if (!mocap) {
for (int j=0; j<nbody; j++) {
for (int j=0; j < nbody; j++) {
int id = m->body_mocapid[j];
if (id>=0) {
if (id >= 0) {
mju_copy3(d->mocap_pos+3*id, m->body_pos+3*j);
mju_copy4(d->mocap_quat+4*id, m->body_quat+4*j);
}
Expand Down Expand Up @@ -174,7 +172,8 @@ mjtNum* get_array_ptr(std::optional<const py::array_t<mjtNum>> arg,
int expected_size = nstate * nstep * dim;
if (info.size != expected_size) {
std::ostringstream msg;
msg << name << ".size should be " << expected_size << ", got " << info.size;
msg << name << ".size should be " << expected_size <<
", got " << info.size;
throw py::value_error(msg.str());
}
return static_cast<mjtNum*>(info.ptr);
Expand All @@ -200,7 +199,6 @@ PYBIND11_MODULE(_rollout, pymodule) {
std::optional<const PyCArray> state,
std::optional<const PyCArray> sensordata
) {

const raw::MjModel* model = m.get();
raw::MjData* data = d.get();

Expand All @@ -213,7 +211,8 @@ PYBIND11_MODULE(_rollout, pymodule) {
int nqva = model->nq + model->nv + model->na;
mjtNum* init_state_ptr =
get_array_ptr(init_state, "initial_state", nstate, 1, nqva);
mjtNum* ctrl_ptr = get_array_ptr(ctrl, "ctrl", nstate, nstep, model->nu);
mjtNum* ctrl_ptr =
get_array_ptr(ctrl, "ctrl", nstate, nstep, model->nu);
mjtNum* qfrc_ptr =
get_array_ptr(qfrc, "qfrc_applied", nstate, nstep, model->nv);
mjtNum* xfrc_ptr =
Expand All @@ -222,11 +221,11 @@ PYBIND11_MODULE(_rollout, pymodule) {
get_array_ptr(mocap, "mocap", nstate, nstep, 7*model->nmocap);
mjtNum* init_time_ptr =
get_array_ptr(init_time, "init_time", nstate, 1, 1);
mjtNum* init_warmstart_ptr =
get_array_ptr(init_warmstart, "init_warmstart", nstate, 1, model->nv);
mjtNum* init_warmstart_ptr = get_array_ptr(
init_warmstart, "init_warmstart", nstate, 1, model->nv);
mjtNum* state_ptr = get_array_ptr(state, "state", nstate, nstep, nqva);
mjtNum* sensordata_ptr =
get_array_ptr(sensordata, "sensordata", nstate, nstep, model->nsensordata);
mjtNum* sensordata_ptr = get_array_ptr(sensordata, "sensordata", nstate,
nstep, model->nsensordata);

// perform rollouts
{
Expand Down Expand Up @@ -255,9 +254,8 @@ PYBIND11_MODULE(_rollout, pymodule) {
py::arg("sensordata") = py::none(),
py::doc(rollout_doc)
);
}

} // namespace

}

}
} // namespace mujoco::python

0 comments on commit a23a368

Please sign in to comment.