Skip to content

Commit

Permalink
Rename rollout variable nroll to nbatch.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 718196100
Change-Id: Ic36b1a96ea1af2de351539115d4eee051d195aaf
  • Loading branch information
yuvaltassa authored and copybara-github committed Jan 22, 2025
1 parent ac76540 commit 0090e1e
Show file tree
Hide file tree
Showing 5 changed files with 163 additions and 160 deletions.
6 changes: 3 additions & 3 deletions doc/python.rst
Original file line number Diff line number Diff line change
Expand Up @@ -730,11 +730,11 @@ states and sensor values. The rollouts are run in parallel with an internally ma
state, sensordata = rollout.rollout(model, data, initial_state, control)
``model`` is either a single instance of MjModel or a sequence of compatible MjModel of length ``nroll``.
``model`` is either a single instance of MjModel or a sequence of compatible MjModel of length ``nbatch``.
``data`` is either a single instance of MjData or a sequence of compatible MjData of length ``nthread``.
``initial_state`` is an ``nroll x nstate`` array, with ``nroll`` initial states of size ``nstate``, where
``initial_state`` is an ``nbatch x nstate`` array, with ``nbatch`` initial states of size ``nstate``, where
``nstate = mj_stateSize(model, mjtState.mjSTATE_FULLPHYSICS)`` is the size of the
:ref:`full physics state<geFullPhysics>`. ``control`` is a ``nroll x nstep x ncontrol`` array of controls. Controls are
:ref:`full physics state<geFullPhysics>`. ``control`` is a ``nbatch x nstep x ncontrol`` array of controls. Controls are
by default the ``mjModel.nu`` standard actuators, but any combination of :ref:`user input<geInput>` arrays can be
specified by passing an optional ``control_spec`` bitflag.

Expand Down
6 changes: 3 additions & 3 deletions python/least_squares.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -1967,8 +1967,8 @@
" data.mocap_pos[mocapid] = target\n",
"\n",
" # Append the mocap targets to the controls\n",
" nroll = ctrl0.shape[0]\n",
" mocap = np.tile(data.mocap_pos[mocapid], (nroll, 1))\n",
" nbatch = ctrl0.shape[0]\n",
" mocap = np.tile(data.mocap_pos[mocapid], (nbatch, 1))\n",
" ctrl0 = np.hstack((ctrl0, mocap))\n",
" ctrlT = np.hstack((ctrlT, mocap))\n",
"\n",
Expand All @@ -1988,7 +1988,7 @@
" state = np.empty(nstate)\n",
" mujoco.mj_getState(model, data, state, spec)\n",
"\n",
" # Perform rollouts (sensors.shape == nroll, nstep, nsensordata)\n",
" # Perform rollouts (sensors.shape == nbatch, nstep, nsensordata)\n",
" states, sensors = rollout.rollout(model, data, state, control,\n",
" control_spec=control_spec)\n",
"\n",
Expand Down
67 changes: 35 additions & 32 deletions python/mujoco/rollout.cc
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,12 @@
// See the License for the specific language governing permissions and
// limitations under the License.

#include <algorithm>
#include <iostream>
#include <memory>
#include <optional>
#include <sstream>
#include <vector>

#include <mujoco/mujoco.h>
#include "errors.h"
Expand Down Expand Up @@ -46,31 +48,32 @@ Construct a rollout object containing a thread pool for parallel rollouts.
)";

const auto rollout_doc = R"(
Roll out open-loop trajectories from initial states, get resulting states and sensor values.
Roll out batch of trajectories from initial states, get resulting states and sensor values.
input arguments (required):
model list of MjModel instances of length nroll
data list of associated MjData instances of length nthread
model list of homogenous MjModel instances of length nbatch
data list of compatible MjData instances of length nthread
nstep integer, number of steps to be taken for each trajectory
control_spec specification of controls, ncontrol = mj_stateSize(m, control_spec)
state0 (nroll x nstate) nroll initial state vectors,
nstate = mj_stateSize(m, mjSTATE_FULLPHYSICS)
state0 (nbatch x nstate) nbatch initial state arrays, where
nstate = mj_stateSize(m, mjSTATE_FULLPHYSICS)
input arguments (optional):
warmstart0 (nroll x nv) nroll qacc_warmstart vectors
control (nroll x nstep x ncontrol) nroll trajectories of nstep controls
warmstart0 (nbatch x nv) nbatch qacc_warmstart arrays
control (nbatch x nstep x ncontrol) nbatch trajectories of nstep controls
output arguments (optional):
state (nroll x nstep x nstate) nroll nstep states
sensordata (nroll x nstep x nsendordata) nroll trajectories of nstep sensordata vectors
chunk_size integer, determines threadpool chunk size. If unspecified
chunk_size = max(1, nroll / (nthread * 10))
state (nbatch x nstep x nstate) nbatch nstep states
sensordata (nbatch x nstep x nsendordata) nbatch trajectories of nstep sensordata arrays
chunk_size integer, determines threadpool chunk size. If unspecified, the default is
chunk_size = max(1, nbatch / (nthread * 10))
)";

// C-style rollout function, assumes all arguments are valid
// all input fields of d are initialised, contents at call time do not matter
// after returning, d will contain the last step of the last rollout
void _unsafe_rollout(std::vector<const mjModel*>& m, mjData* d, int start_roll, int end_roll, int nstep, unsigned int control_spec,
const mjtNum* state0, const mjtNum* warmstart0, const mjtNum* control,
mjtNum* state, mjtNum* sensordata) {
void _unsafe_rollout(std::vector<const mjModel*>& m, mjData* d, int start_roll,
int end_roll, int nstep, unsigned int control_spec,
const mjtNum* state0, const mjtNum* warmstart0,
const mjtNum* control, mjtNum* state, mjtNum* sensordata) {
// sizes
int nstate = mj_stateSize(m[0], mjSTATE_FULLPHYSICS);
int ncontrol = mj_stateSize(m[0], control_spec);
Expand Down Expand Up @@ -174,12 +177,12 @@ void _unsafe_rollout(std::vector<const mjModel*>& m, mjData* d, int start_roll,

// C-style threaded version of _unsafe_rollout
void _unsafe_rollout_threaded(std::vector<const mjModel*>& m, std::vector<mjData*>& d,
int nroll, int nstep, unsigned int control_spec,
int nbatch, int nstep, unsigned int control_spec,
const mjtNum* state0, const mjtNum* warmstart0,
const mjtNum* control, mjtNum* state, mjtNum* sensordata,
ThreadPool* pool, int chunk_size) {
int nfulljobs = nroll / chunk_size;
int chunk_remainder = nroll % chunk_size;
int nfulljobs = nbatch / chunk_size;
int chunk_remainder = nbatch % chunk_size;
int njobs = (chunk_remainder > 0) ? nfulljobs + 1 : nfulljobs;

// Reset the pool counter
Expand Down Expand Up @@ -213,7 +216,7 @@ void _unsafe_rollout_threaded(std::vector<const mjModel*>& m, std::vector<mjData

// check size of optional argument to rollout(), return raw pointer
mjtNum* get_array_ptr(std::optional<const py::array_t<mjtNum>> arg,
const char* name, int nroll, int nstep, int dim) {
const char* name, int nbatch, int nstep, int dim) {
// if empty return nullptr
if (!arg.has_value()) {
return nullptr;
Expand All @@ -223,7 +226,7 @@ mjtNum* get_array_ptr(std::optional<const py::array_t<mjtNum>> arg,
py::buffer_info info = arg->request();

// check size
int expected_size = nroll * nstep * dim;
int expected_size = nbatch * nstep * dim;
if (info.size != expected_size) {
std::ostringstream msg;
msg << name << ".size should be " << expected_size << ", got " << info.size;
Expand All @@ -247,9 +250,9 @@ class Rollout {
std::optional<const PyCArray> sensordata,
std::optional<int> chunk_size) {
// get raw pointers
int nroll = state0.shape(0);
std::vector<const raw::MjModel*> model_ptrs(nroll);
for (int r = 0; r < nroll; r++) {
int nbatch = state0.shape(0);
std::vector<const raw::MjModel*> model_ptrs(nbatch);
for (int r = 0; r < nbatch; r++) {
model_ptrs[r] = m[r].cast<const MjModelWrapper*>()->get();
}

Expand Down Expand Up @@ -284,13 +287,13 @@ class Rollout {
int nstate = mj_stateSize(model_ptrs[0], mjSTATE_FULLPHYSICS);
int ncontrol = mj_stateSize(model_ptrs[0], control_spec);

mjtNum* state0_ptr = get_array_ptr(state0, "state0", nroll, 1, nstate);
mjtNum* state0_ptr = get_array_ptr(state0, "state0", nbatch, 1, nstate);
mjtNum* warmstart0_ptr =
get_array_ptr(warmstart0, "warmstart0", nroll, 1, model_ptrs[0]->nv);
get_array_ptr(warmstart0, "warmstart0", nbatch, 1, model_ptrs[0]->nv);
mjtNum* control_ptr =
get_array_ptr(control, "control", nroll, nstep, ncontrol);
mjtNum* state_ptr = get_array_ptr(state, "state", nroll, nstep, nstate);
mjtNum* sensordata_ptr = get_array_ptr(sensordata, "sensordata", nroll,
get_array_ptr(control, "control", nbatch, nstep, ncontrol);
mjtNum* state_ptr = get_array_ptr(state, "state", nbatch, nstep, nstate);
mjtNum* sensordata_ptr = get_array_ptr(sensordata, "sensordata", nbatch,
nstep, model_ptrs[0]->nsensordata);

// perform rollouts
Expand All @@ -299,21 +302,21 @@ class Rollout {
py::gil_scoped_release no_gil;

// call unsafe rollout function, multi or single threaded
if (this->nthread_ > 0 && nroll > 1) {
if (this->nthread_ > 0 && nbatch > 1) {
int chunk_size_final = 1;
if (!chunk_size.has_value()) {
chunk_size_final = std::max(1, nroll / (10 * this->nthread_));
chunk_size_final = std::max(1, nbatch / (10 * this->nthread_));
} else {
chunk_size_final = *chunk_size;
}
InterceptMjErrors(_unsafe_rollout_threaded)(
model_ptrs, data_ptrs, nroll, nstep, control_spec, state0_ptr,
model_ptrs, data_ptrs, nbatch, nstep, control_spec, state0_ptr,
warmstart0_ptr, control_ptr, state_ptr, sensordata_ptr,
this->pool_.get(), chunk_size_final);
} else {
InterceptMjErrors(_unsafe_rollout)(
model_ptrs, data_ptrs[0], 0, nroll, nstep, control_spec, state0_ptr,
warmstart0_ptr, control_ptr, state_ptr, sensordata_ptr);
model_ptrs, data_ptrs[0], 0, nbatch, nstep, control_spec,
state0_ptr, warmstart0_ptr, control_ptr, state_ptr, sensordata_ptr);
}
}
}
Expand Down
66 changes: 33 additions & 33 deletions python/mujoco/rollout.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,34 +65,34 @@ def rollout(
"""Rolls out open-loop trajectories from initial states, get subsequent state and sensor values.
Python wrapper for rollout.cc, see documentation therein.
Infers nroll and nstep.
Infers nbatch and nstep.
Tiles inputs with singleton dimensions.
Allocates outputs if none are given.
Args:
model: An instance or length nroll sequence of MjModel with the same size signature.
model: An instance or length nbatch sequence of MjModel with the same size signature.
data: Associated mjData instance or sequence of instances with length nthread.
initial_state: Array of initial states from which to roll out trajectories.
([nroll or 1] x nstate)
([nbatch or 1] x nstate)
control: Open-loop controls array to apply during the rollouts.
([nroll or 1] x [nstep or 1] x ncontrol)
([nbatch or 1] x [nstep or 1] x ncontrol)
control_spec: mjtState specification of control vectors.
skip_checks: Whether to skip internal shape and type checks.
nstep: Number of steps in rollouts (inferred if unspecified).
initial_warmstart: Initial qfrc_warmstart array (optional).
([nroll or 1] x nv)
([nbatch or 1] x nv)
state: State output array (optional).
(nroll x nstep x nstate)
(nbatch x nstep x nstate)
sensordata: Sensor data output array (optional).
(nroll x nstep x nsensordata)
(nbatch x nstep x nsensordata)
chunk_size: Determines threadpool chunk size. If unspecified,
chunk_size = max(1, nroll / (nthread * 10))
chunk_size = max(1, nbatch / (nthread * 10))
Returns:
state:
State output array, (nroll x nstep x nstate).
State output array, (nbatch x nstep x nstate).
sensordata:
Sensor data output array, (nroll x nstep x nsensordata).
Sensor data output array, (nbatch x nstep x nsensordata).
Raises:
RuntimeError: rollout requested after thread pool shutdown.
Expand All @@ -103,7 +103,7 @@ def rollout(
raise RuntimeError('rollout requested after thread pool shutdown')

# skip_checks shortcut:
# don't infer nroll/nstep
# don't infer nbatch or nstep
# don't support singleton expansion
# don't allocate output arrays
# just call rollout and return
Expand Down Expand Up @@ -159,8 +159,8 @@ def rollout(
state = _ensure_3d(state)
sensordata = _ensure_3d(sensordata)

# infer nroll, check for incompatibilities
nroll = _infer_dimension(
# infer nbatch, check for incompatibilities
nbatch = _infer_dimension(
0,
1,
initial_state=initial_state,
Expand All @@ -169,12 +169,12 @@ def rollout(
state=state,
sensordata=sensordata,
)
if isinstance(model, list) and nroll == 1:
nroll = len(model)
if isinstance(model, list) and nbatch == 1:
nbatch = len(model)

if isinstance(model, list) and len(model) > 1 and len(model) != nroll:
if isinstance(model, list) and len(model) > 1 and len(model) != nbatch:
raise ValueError(
f'nroll inferred as {nroll} but model is length {len(model)}'
f'nbatch inferred as {nbatch} but model is length {len(model)}'
)
elif not isinstance(model, list):
model = [model] # Use a length 1 list to simplify code below
Expand Down Expand Up @@ -212,16 +212,16 @@ def rollout(
_check_trailing_dimension(nsensordata, sensordata=sensordata)

# tile input arrays/lists if required (singleton expansion)
model = model * nroll if len(model) == 1 else model
initial_state = _tile_if_required(initial_state, nroll)
initial_warmstart = _tile_if_required(initial_warmstart, nroll)
control = _tile_if_required(control, nroll, nstep)
model = model * nbatch if len(model) == 1 else model
initial_state = _tile_if_required(initial_state, nbatch)
initial_warmstart = _tile_if_required(initial_warmstart, nbatch)
control = _tile_if_required(control, nbatch, nstep)

# allocate output if not provided
if state is None:
state = np.empty((nroll, nstep, nstate))
state = np.empty((nbatch, nstep, nstate))
if sensordata is None:
sensordata = np.empty((nroll, nstep, nsensordata))
sensordata = np.empty((nbatch, nstep, nsensordata))

# call rollout
self.rollout_.rollout(
Expand Down Expand Up @@ -276,35 +276,35 @@ def rollout(
"""Rolls out open-loop trajectories from initial states, get subsequent states and sensor values.
Python wrapper for rollout.cc, see documentation therein.
Infers nroll and nstep.
Infers nbatch and nstep.
Tiles inputs with singleton dimensions.
Allocates outputs if none are given.
Args:
model: An instance or length nroll sequence of MjModel with the same size signature.
model: An instance or length nbatch sequence of MjModel with the same size signature.
data: Associated mjData instance or sequence of instances with length nthread.
initial_state: Array of initial states from which to roll out trajectories.
([nroll or 1] x nstate)
([nbatch or 1] x nstate)
control: Open-loop controls array to apply during the rollouts.
([nroll or 1] x [nstep or 1] x ncontrol)
([nbatch or 1] x [nstep or 1] x ncontrol)
control_spec: mjtState specification of control vectors.
skip_checks: Whether to skip internal shape and type checks.
nstep: Number of steps in rollouts (inferred if unspecified).
initial_warmstart: Initial qfrc_warmstart array (optional).
([nroll or 1] x nv)
([nbatch or 1] x nv)
state: State output array (optional).
(nroll x nstep x nstate)
(nbatch x nstep x nstate)
sensordata: Sensor data output array (optional).
(nroll x nstep x nsensordata)
(nbatch x nstep x nsensordata)
chunk_size: Determines threadpool chunk size. If unspecified,
chunk_size = max(1, nroll / (nthread * 10))
chunk_size = max(1, nbatch / (nthread * 10))
persistent_pool: Determines if a persistent thread pool is created or reused.
Returns:
state:
State output array, (nroll x nstep x nstate).
State output array, (nbatch x nstep x nstate).
sensordata:
Sensor data output array, (nroll x nstep x nsensordata).
Sensor data output array, (nbatch x nstep x nsensordata).
Raises:
ValueError: bad shapes or sizes.
Expand Down
Loading

0 comments on commit 0090e1e

Please sign in to comment.