diff --git a/doc/python.rst b/doc/python.rst index 5a479d98c5..e749984908 100644 --- a/doc/python.rst +++ b/doc/python.rst @@ -730,11 +730,11 @@ states and sensor values. The rollouts are run in parallel with an internally ma state, sensordata = rollout.rollout(model, data, initial_state, control) -``model`` is either a single instance of MjModel or a sequence of compatible MjModel of length ``nroll``. +``model`` is either a single instance of MjModel or a sequence of compatible MjModel of length ``nbatch``. ``data`` is either a single instance of MjData or a sequence of compatible MjData of length ``nthread``. -``initial_state`` is an ``nroll x nstate`` array, with ``nroll`` initial states of size ``nstate``, where +``initial_state`` is an ``nbatch x nstate`` array, with ``nbatch`` initial states of size ``nstate``, where ``nstate = mj_stateSize(model, mjtState.mjSTATE_FULLPHYSICS)`` is the size of the -:ref:`full physics state`. ``control`` is a ``nroll x nstep x ncontrol`` array of controls. Controls are +:ref:`full physics state`. ``control`` is a ``nbatch x nstep x ncontrol`` array of controls. Controls are by default the ``mjModel.nu`` standard actuators, but any combination of :ref:`user input` arrays can be specified by passing an optional ``control_spec`` bitflag. diff --git a/python/least_squares.ipynb b/python/least_squares.ipynb index 80819e9162..3e76d2060f 100644 --- a/python/least_squares.ipynb +++ b/python/least_squares.ipynb @@ -1967,8 +1967,8 @@ " data.mocap_pos[mocapid] = target\n", "\n", " # Append the mocap targets to the controls\n", - " nroll = ctrl0.shape[0]\n", - " mocap = np.tile(data.mocap_pos[mocapid], (nroll, 1))\n", + " nbatch = ctrl0.shape[0]\n", + " mocap = np.tile(data.mocap_pos[mocapid], (nbatch, 1))\n", " ctrl0 = np.hstack((ctrl0, mocap))\n", " ctrlT = np.hstack((ctrlT, mocap))\n", "\n", @@ -1988,7 +1988,7 @@ " state = np.empty(nstate)\n", " mujoco.mj_getState(model, data, state, spec)\n", "\n", - " # Perform rollouts (sensors.shape == nroll, nstep, nsensordata)\n", + " # Perform rollouts (sensors.shape == nbatch, nstep, nsensordata)\n", " states, sensors = rollout.rollout(model, data, state, control,\n", " control_spec=control_spec)\n", "\n", diff --git a/python/mujoco/rollout.cc b/python/mujoco/rollout.cc index 42519189c1..e591af3480 100644 --- a/python/mujoco/rollout.cc +++ b/python/mujoco/rollout.cc @@ -12,10 +12,12 @@ // See the License for the specific language governing permissions and // limitations under the License. +#include #include #include #include #include +#include #include #include "errors.h" @@ -46,31 +48,32 @@ Construct a rollout object containing a thread pool for parallel rollouts. )"; const auto rollout_doc = R"( -Roll out open-loop trajectories from initial states, get resulting states and sensor values. +Roll out batch of trajectories from initial states, get resulting states and sensor values. input arguments (required): - model list of MjModel instances of length nroll - data list of associated MjData instances of length nthread + model list of homogenous MjModel instances of length nbatch + data list of compatible MjData instances of length nthread nstep integer, number of steps to be taken for each trajectory control_spec specification of controls, ncontrol = mj_stateSize(m, control_spec) - state0 (nroll x nstate) nroll initial state vectors, - nstate = mj_stateSize(m, mjSTATE_FULLPHYSICS) + state0 (nbatch x nstate) nbatch initial state arrays, where + nstate = mj_stateSize(m, mjSTATE_FULLPHYSICS) input arguments (optional): - warmstart0 (nroll x nv) nroll qacc_warmstart vectors - control (nroll x nstep x ncontrol) nroll trajectories of nstep controls + warmstart0 (nbatch x nv) nbatch qacc_warmstart arrays + control (nbatch x nstep x ncontrol) nbatch trajectories of nstep controls output arguments (optional): - state (nroll x nstep x nstate) nroll nstep states - sensordata (nroll x nstep x nsendordata) nroll trajectories of nstep sensordata vectors - chunk_size integer, determines threadpool chunk size. If unspecified - chunk_size = max(1, nroll / (nthread * 10)) + state (nbatch x nstep x nstate) nbatch nstep states + sensordata (nbatch x nstep x nsendordata) nbatch trajectories of nstep sensordata arrays + chunk_size integer, determines threadpool chunk size. If unspecified, the default is + chunk_size = max(1, nbatch / (nthread * 10)) )"; // C-style rollout function, assumes all arguments are valid // all input fields of d are initialised, contents at call time do not matter // after returning, d will contain the last step of the last rollout -void _unsafe_rollout(std::vector& m, mjData* d, int start_roll, int end_roll, int nstep, unsigned int control_spec, - const mjtNum* state0, const mjtNum* warmstart0, const mjtNum* control, - mjtNum* state, mjtNum* sensordata) { +void _unsafe_rollout(std::vector& m, mjData* d, int start_roll, + int end_roll, int nstep, unsigned int control_spec, + const mjtNum* state0, const mjtNum* warmstart0, + const mjtNum* control, mjtNum* state, mjtNum* sensordata) { // sizes int nstate = mj_stateSize(m[0], mjSTATE_FULLPHYSICS); int ncontrol = mj_stateSize(m[0], control_spec); @@ -174,12 +177,12 @@ void _unsafe_rollout(std::vector& m, mjData* d, int start_roll, // C-style threaded version of _unsafe_rollout void _unsafe_rollout_threaded(std::vector& m, std::vector& d, - int nroll, int nstep, unsigned int control_spec, + int nbatch, int nstep, unsigned int control_spec, const mjtNum* state0, const mjtNum* warmstart0, const mjtNum* control, mjtNum* state, mjtNum* sensordata, ThreadPool* pool, int chunk_size) { - int nfulljobs = nroll / chunk_size; - int chunk_remainder = nroll % chunk_size; + int nfulljobs = nbatch / chunk_size; + int chunk_remainder = nbatch % chunk_size; int njobs = (chunk_remainder > 0) ? nfulljobs + 1 : nfulljobs; // Reset the pool counter @@ -213,7 +216,7 @@ void _unsafe_rollout_threaded(std::vector& m, std::vector> arg, - const char* name, int nroll, int nstep, int dim) { + const char* name, int nbatch, int nstep, int dim) { // if empty return nullptr if (!arg.has_value()) { return nullptr; @@ -223,7 +226,7 @@ mjtNum* get_array_ptr(std::optional> arg, py::buffer_info info = arg->request(); // check size - int expected_size = nroll * nstep * dim; + int expected_size = nbatch * nstep * dim; if (info.size != expected_size) { std::ostringstream msg; msg << name << ".size should be " << expected_size << ", got " << info.size; @@ -247,9 +250,9 @@ class Rollout { std::optional sensordata, std::optional chunk_size) { // get raw pointers - int nroll = state0.shape(0); - std::vector model_ptrs(nroll); - for (int r = 0; r < nroll; r++) { + int nbatch = state0.shape(0); + std::vector model_ptrs(nbatch); + for (int r = 0; r < nbatch; r++) { model_ptrs[r] = m[r].cast()->get(); } @@ -284,13 +287,13 @@ class Rollout { int nstate = mj_stateSize(model_ptrs[0], mjSTATE_FULLPHYSICS); int ncontrol = mj_stateSize(model_ptrs[0], control_spec); - mjtNum* state0_ptr = get_array_ptr(state0, "state0", nroll, 1, nstate); + mjtNum* state0_ptr = get_array_ptr(state0, "state0", nbatch, 1, nstate); mjtNum* warmstart0_ptr = - get_array_ptr(warmstart0, "warmstart0", nroll, 1, model_ptrs[0]->nv); + get_array_ptr(warmstart0, "warmstart0", nbatch, 1, model_ptrs[0]->nv); mjtNum* control_ptr = - get_array_ptr(control, "control", nroll, nstep, ncontrol); - mjtNum* state_ptr = get_array_ptr(state, "state", nroll, nstep, nstate); - mjtNum* sensordata_ptr = get_array_ptr(sensordata, "sensordata", nroll, + get_array_ptr(control, "control", nbatch, nstep, ncontrol); + mjtNum* state_ptr = get_array_ptr(state, "state", nbatch, nstep, nstate); + mjtNum* sensordata_ptr = get_array_ptr(sensordata, "sensordata", nbatch, nstep, model_ptrs[0]->nsensordata); // perform rollouts @@ -299,21 +302,21 @@ class Rollout { py::gil_scoped_release no_gil; // call unsafe rollout function, multi or single threaded - if (this->nthread_ > 0 && nroll > 1) { + if (this->nthread_ > 0 && nbatch > 1) { int chunk_size_final = 1; if (!chunk_size.has_value()) { - chunk_size_final = std::max(1, nroll / (10 * this->nthread_)); + chunk_size_final = std::max(1, nbatch / (10 * this->nthread_)); } else { chunk_size_final = *chunk_size; } InterceptMjErrors(_unsafe_rollout_threaded)( - model_ptrs, data_ptrs, nroll, nstep, control_spec, state0_ptr, + model_ptrs, data_ptrs, nbatch, nstep, control_spec, state0_ptr, warmstart0_ptr, control_ptr, state_ptr, sensordata_ptr, this->pool_.get(), chunk_size_final); } else { InterceptMjErrors(_unsafe_rollout)( - model_ptrs, data_ptrs[0], 0, nroll, nstep, control_spec, state0_ptr, - warmstart0_ptr, control_ptr, state_ptr, sensordata_ptr); + model_ptrs, data_ptrs[0], 0, nbatch, nstep, control_spec, + state0_ptr, warmstart0_ptr, control_ptr, state_ptr, sensordata_ptr); } } } diff --git a/python/mujoco/rollout.py b/python/mujoco/rollout.py index 68a9be23bd..24ed27e673 100644 --- a/python/mujoco/rollout.py +++ b/python/mujoco/rollout.py @@ -65,34 +65,34 @@ def rollout( """Rolls out open-loop trajectories from initial states, get subsequent state and sensor values. Python wrapper for rollout.cc, see documentation therein. - Infers nroll and nstep. + Infers nbatch and nstep. Tiles inputs with singleton dimensions. Allocates outputs if none are given. Args: - model: An instance or length nroll sequence of MjModel with the same size signature. + model: An instance or length nbatch sequence of MjModel with the same size signature. data: Associated mjData instance or sequence of instances with length nthread. initial_state: Array of initial states from which to roll out trajectories. - ([nroll or 1] x nstate) + ([nbatch or 1] x nstate) control: Open-loop controls array to apply during the rollouts. - ([nroll or 1] x [nstep or 1] x ncontrol) + ([nbatch or 1] x [nstep or 1] x ncontrol) control_spec: mjtState specification of control vectors. skip_checks: Whether to skip internal shape and type checks. nstep: Number of steps in rollouts (inferred if unspecified). initial_warmstart: Initial qfrc_warmstart array (optional). - ([nroll or 1] x nv) + ([nbatch or 1] x nv) state: State output array (optional). - (nroll x nstep x nstate) + (nbatch x nstep x nstate) sensordata: Sensor data output array (optional). - (nroll x nstep x nsensordata) + (nbatch x nstep x nsensordata) chunk_size: Determines threadpool chunk size. If unspecified, - chunk_size = max(1, nroll / (nthread * 10)) + chunk_size = max(1, nbatch / (nthread * 10)) Returns: state: - State output array, (nroll x nstep x nstate). + State output array, (nbatch x nstep x nstate). sensordata: - Sensor data output array, (nroll x nstep x nsensordata). + Sensor data output array, (nbatch x nstep x nsensordata). Raises: RuntimeError: rollout requested after thread pool shutdown. @@ -103,7 +103,7 @@ def rollout( raise RuntimeError('rollout requested after thread pool shutdown') # skip_checks shortcut: - # don't infer nroll/nstep + # don't infer nbatch or nstep # don't support singleton expansion # don't allocate output arrays # just call rollout and return @@ -159,8 +159,8 @@ def rollout( state = _ensure_3d(state) sensordata = _ensure_3d(sensordata) - # infer nroll, check for incompatibilities - nroll = _infer_dimension( + # infer nbatch, check for incompatibilities + nbatch = _infer_dimension( 0, 1, initial_state=initial_state, @@ -169,12 +169,12 @@ def rollout( state=state, sensordata=sensordata, ) - if isinstance(model, list) and nroll == 1: - nroll = len(model) + if isinstance(model, list) and nbatch == 1: + nbatch = len(model) - if isinstance(model, list) and len(model) > 1 and len(model) != nroll: + if isinstance(model, list) and len(model) > 1 and len(model) != nbatch: raise ValueError( - f'nroll inferred as {nroll} but model is length {len(model)}' + f'nbatch inferred as {nbatch} but model is length {len(model)}' ) elif not isinstance(model, list): model = [model] # Use a length 1 list to simplify code below @@ -212,16 +212,16 @@ def rollout( _check_trailing_dimension(nsensordata, sensordata=sensordata) # tile input arrays/lists if required (singleton expansion) - model = model * nroll if len(model) == 1 else model - initial_state = _tile_if_required(initial_state, nroll) - initial_warmstart = _tile_if_required(initial_warmstart, nroll) - control = _tile_if_required(control, nroll, nstep) + model = model * nbatch if len(model) == 1 else model + initial_state = _tile_if_required(initial_state, nbatch) + initial_warmstart = _tile_if_required(initial_warmstart, nbatch) + control = _tile_if_required(control, nbatch, nstep) # allocate output if not provided if state is None: - state = np.empty((nroll, nstep, nstate)) + state = np.empty((nbatch, nstep, nstate)) if sensordata is None: - sensordata = np.empty((nroll, nstep, nsensordata)) + sensordata = np.empty((nbatch, nstep, nsensordata)) # call rollout self.rollout_.rollout( @@ -276,35 +276,35 @@ def rollout( """Rolls out open-loop trajectories from initial states, get subsequent states and sensor values. Python wrapper for rollout.cc, see documentation therein. - Infers nroll and nstep. + Infers nbatch and nstep. Tiles inputs with singleton dimensions. Allocates outputs if none are given. Args: - model: An instance or length nroll sequence of MjModel with the same size signature. + model: An instance or length nbatch sequence of MjModel with the same size signature. data: Associated mjData instance or sequence of instances with length nthread. initial_state: Array of initial states from which to roll out trajectories. - ([nroll or 1] x nstate) + ([nbatch or 1] x nstate) control: Open-loop controls array to apply during the rollouts. - ([nroll or 1] x [nstep or 1] x ncontrol) + ([nbatch or 1] x [nstep or 1] x ncontrol) control_spec: mjtState specification of control vectors. skip_checks: Whether to skip internal shape and type checks. nstep: Number of steps in rollouts (inferred if unspecified). initial_warmstart: Initial qfrc_warmstart array (optional). - ([nroll or 1] x nv) + ([nbatch or 1] x nv) state: State output array (optional). - (nroll x nstep x nstate) + (nbatch x nstep x nstate) sensordata: Sensor data output array (optional). - (nroll x nstep x nsensordata) + (nbatch x nstep x nsensordata) chunk_size: Determines threadpool chunk size. If unspecified, - chunk_size = max(1, nroll / (nthread * 10)) + chunk_size = max(1, nbatch / (nthread * 10)) persistent_pool: Determines if a persistent thread pool is created or reused. Returns: state: - State output array, (nroll x nstep x nstate). + State output array, (nbatch x nstep x nstate). sensordata: - Sensor data output array, (nroll x nstep x nsensordata). + Sensor data output array, (nbatch x nstep x nsensordata). Raises: ValueError: bad shapes or sizes. diff --git a/python/mujoco/rollout_test.py b/python/mujoco/rollout_test.py index 5c8c5fbeb7..f27dc02b5a 100644 --- a/python/mujoco/rollout_test.py +++ b/python/mujoco/rollout_test.py @@ -185,11 +185,11 @@ def test_multi_step(self, model_name): nstate = mujoco.mj_stateSize(model, mujoco.mjtState.mjSTATE_FULLPHYSICS) data = mujoco.MjData(model) - nroll = 5 # number of rollouts + nbatch = 5 # number of rollouts nstep = 1 # number of steps - initial_state = np.random.randn(nroll, nstate) - control = np.random.randn(nroll, nstep, model.nu) + initial_state = np.random.randn(nbatch, nstate) + control = np.random.randn(nbatch, nstep, model.nu) state, sensordata = rollout.rollout(model, data, initial_state, control) mujoco.mj_resetData(model, data) @@ -198,108 +198,108 @@ def test_multi_step(self, model_name): np.testing.assert_array_equal(sensordata, py_sensordata) @parameterized.parameters(ALL_MODELS.keys()) - def test_infer_nroll_initial_state(self, model_name): + def test_infer_nbatch_initial_state(self, model_name): model = mujoco.MjModel.from_xml_string(ALL_MODELS[model_name]) nstate = mujoco.mj_stateSize(model, mujoco.mjtState.mjSTATE_FULLPHYSICS) data = mujoco.MjData(model) - nroll = 5 # number of rollouts + nbatch = 5 # number of rollouts nstep = 1 # number of steps - initial_state = np.random.randn(nroll, nstate) + initial_state = np.random.randn(nbatch, nstate) control = np.random.randn(nstep, model.nu) state, sensordata = rollout.rollout(model, data, initial_state, control) mujoco.mj_resetData(model, data) - control = np.tile(control, (nroll, 1, 1)) + control = np.tile(control, (nbatch, 1, 1)) py_state, py_sensordata = py_rollout(model, data, initial_state, control) np.testing.assert_array_equal(state, py_state) np.testing.assert_array_equal(sensordata, py_sensordata) @parameterized.parameters(ALL_MODELS.keys()) - def test_infer_nroll_control(self, model_name): + def test_infer_nbatch_control(self, model_name): model = mujoco.MjModel.from_xml_string(ALL_MODELS[model_name]) nstate = mujoco.mj_stateSize(model, mujoco.mjtState.mjSTATE_FULLPHYSICS) data = mujoco.MjData(model) - nroll = 5 # number of rollouts + nbatch = 5 # number of rollouts nstep = 1 # number of steps initial_state = np.random.randn(nstate) - control = np.random.randn(nroll, nstep, model.nu) + control = np.random.randn(nbatch, nstep, model.nu) state, sensordata = rollout.rollout(model, data, initial_state, control) mujoco.mj_resetData(model, data) - initial_state = np.tile(initial_state, (nroll, 1)) + initial_state = np.tile(initial_state, (nbatch, 1)) py_state, py_sensordata = py_rollout(model, data, initial_state, control) np.testing.assert_array_equal(state, py_state) np.testing.assert_array_equal(sensordata, py_sensordata) @parameterized.parameters(ALL_MODELS.keys()) - def test_infer_nroll_warmstart(self, model_name): + def test_infer_nbatch_warmstart(self, model_name): model = mujoco.MjModel.from_xml_string(ALL_MODELS[model_name]) nstate = mujoco.mj_stateSize(model, mujoco.mjtState.mjSTATE_FULLPHYSICS) data = mujoco.MjData(model) - nroll = 5 # number of rollouts + nbatch = 5 # number of rollouts nstep = 1 # number of steps initial_state = np.random.randn(nstate) control = np.random.randn(nstep, model.nu) - initial_warmstart = np.tile(data.qacc_warmstart.copy(), (nroll, 1)) + initial_warmstart = np.tile(data.qacc_warmstart.copy(), (nbatch, 1)) state, sensordata = rollout.rollout( model, data, initial_state, control, initial_warmstart=initial_warmstart ) mujoco.mj_resetData(model, data) - initial_state = np.tile(initial_state, (nroll, 1)) - control = np.tile(control, (nroll, 1, 1)) + initial_state = np.tile(initial_state, (nbatch, 1)) + control = np.tile(control, (nbatch, 1, 1)) py_state, py_sensordata = py_rollout(model, data, initial_state, control) np.testing.assert_array_equal(state, py_state) np.testing.assert_array_equal(sensordata, py_sensordata) @parameterized.parameters(ALL_MODELS.keys()) - def test_infer_nroll_state(self, model_name): + def test_infer_nbatch_state(self, model_name): model = mujoco.MjModel.from_xml_string(ALL_MODELS[model_name]) nstate = mujoco.mj_stateSize(model, mujoco.mjtState.mjSTATE_FULLPHYSICS) data = mujoco.MjData(model) - nroll = 5 # number of rollouts + nbatch = 5 # number of rollouts nstep = 1 # number of steps initial_state = np.random.randn(nstate) control = np.random.randn(nstep, model.nu) - state = np.empty((nroll, nstep, nstate)) + state = np.empty((nbatch, nstep, nstate)) state, sensordata = rollout.rollout( model, data, initial_state, control, state=state ) mujoco.mj_resetData(model, data) - initial_state = np.tile(initial_state, (nroll, 1)) - control = np.tile(control, (nroll, 1, 1)) + initial_state = np.tile(initial_state, (nbatch, 1)) + control = np.tile(control, (nbatch, 1, 1)) py_state, py_sensordata = py_rollout(model, data, initial_state, control) np.testing.assert_array_equal(state, py_state) np.testing.assert_array_equal(sensordata, py_sensordata) @parameterized.parameters(ALL_MODELS.keys()) - def test_infer_nroll_sensordata(self, model_name): + def test_infer_nbatch_sensordata(self, model_name): model = mujoco.MjModel.from_xml_string(ALL_MODELS[model_name]) nstate = mujoco.mj_stateSize(model, mujoco.mjtState.mjSTATE_FULLPHYSICS) data = mujoco.MjData(model) - nroll = 5 # number of rollouts + nbatch = 5 # number of rollouts nstep = 1 # number of steps initial_state = np.random.randn(nstate) control = np.random.randn(nstep, model.nu) - sensordata = np.empty((nroll, nstep, model.nsensordata)) + sensordata = np.empty((nbatch, nstep, model.nsensordata)) state, sensordata = rollout.rollout( model, data, initial_state, control, sensordata=sensordata ) mujoco.mj_resetData(model, data) - initial_state = np.tile(initial_state, (nroll, 1)) - control = np.tile(control, (nroll, 1, 1)) + initial_state = np.tile(initial_state, (nbatch, 1)) + control = np.tile(control, (nbatch, 1, 1)) py_state, py_sensordata = py_rollout(model, data, initial_state, control) np.testing.assert_array_equal(state, py_state) np.testing.assert_array_equal(sensordata, py_sensordata) @@ -310,13 +310,13 @@ def test_one_rollout_fixed_ctrl(self, model_name): nstate = mujoco.mj_stateSize(model, mujoco.mjtState.mjSTATE_FULLPHYSICS) data = mujoco.MjData(model) - nroll = 1 # number of rollouts + nbatch = 1 # number of rollouts nstep = 3 # number of steps initial_state = np.random.randn(nstate) control = np.random.randn(model.nu) - state = np.empty((nroll, nstep, nstate)) - sensordata = np.empty((nroll, nstep, model.nsensordata)) + state = np.empty((nbatch, nstep, nstate)) + sensordata = np.empty((nbatch, nstep, model.nsensordata)) rollout.rollout( model, data, initial_state, control, state=state, sensordata=sensordata ) @@ -332,11 +332,11 @@ def test_multi_rollout(self, model_name): nstate = mujoco.mj_stateSize(model, mujoco.mjtState.mjSTATE_FULLPHYSICS) data = mujoco.MjData(model) - nroll = 2 # number of initial states + nbatch = 2 # number of initial states nstep = 3 # number of timesteps - initial_state = np.random.randn(nroll, nstate) - control = np.random.randn(nroll, nstep, model.nu) + initial_state = np.random.randn(nbatch, nstate) + control = np.random.randn(nbatch, nstep, model.nu) state, sensordata = rollout.rollout(model, data, initial_state, control) py_state, py_sensordata = py_rollout(model, data, initial_state, control) @@ -345,26 +345,26 @@ def test_multi_rollout(self, model_name): @parameterized.parameters(ALL_MODELS.keys()) def test_multi_model(self, model_name): - nroll = 3 # number of initial states and models + nbatch = 3 # number of initial states and models nstep = 3 # number of timesteps spec = mujoco.MjSpec.from_string(ALL_MODELS[model_name]) if len(spec.bodies) > 1: model = [] - for i in range(nroll): + for i in range(nbatch): body = spec.bodies[1] assert body.name != 'world' body.pos = body.pos + i model.append(spec.compile()) else: - model = [spec.compile() for _ in range(nroll)] + model = [spec.compile() for _ in range(nbatch)] nstate = mujoco.mj_stateSize(model[0], mujoco.mjtState.mjSTATE_FULLPHYSICS) data = mujoco.MjData(model[0]) - initial_state = np.random.randn(nroll, nstate) - control = np.random.randn(nroll, nstep, model[0].nu) + initial_state = np.random.randn(nbatch, nstate) + control = np.random.randn(nbatch, nstep, model[0].nu) state, sensordata = rollout.rollout(model, data, initial_state, control) py_state, py_sensordata = py_rollout(model, data, initial_state, control) @@ -377,12 +377,12 @@ def test_multi_rollout_fixed_ctrl_infer_from_output(self, model_name): nstate = mujoco.mj_stateSize(model, mujoco.mjtState.mjSTATE_FULLPHYSICS) data = mujoco.MjData(model) - nroll = 2 # number of rollouts + nbatch = 2 # number of rollouts nstep = 3 # number of timesteps - initial_state = np.random.randn(nroll, nstate) - control = np.random.randn(nroll, 1, model.nu) - state = np.empty((nroll, nstep, nstate)) + initial_state = np.random.randn(nbatch, nstate) + control = np.random.randn(nbatch, 1, model.nu) + state = np.empty((nbatch, nstep, nstate)) state, sensordata = rollout.rollout( model, data, initial_state, control, state=state ) @@ -398,10 +398,10 @@ def test_py_rollout_generalized_control(self, model_name): nstate = mujoco.mj_stateSize(model, mujoco.mjtState.mjSTATE_FULLPHYSICS) data = mujoco.MjData(model) - nroll = 4 # number of rollouts + nbatch = 4 # number of rollouts nstep = 3 # number of timesteps - initial_state = np.random.randn(nroll, nstate) + initial_state = np.random.randn(nbatch, nstate) control_spec = ( mujoco.mjtState.mjSTATE_CTRL @@ -409,7 +409,7 @@ def test_py_rollout_generalized_control(self, model_name): | mujoco.mjtState.mjSTATE_XFRC_APPLIED ) ncontrol = mujoco.mj_stateSize(model, control_spec) - control = np.random.randn(nroll, nstep, ncontrol) + control = np.random.randn(nbatch, nstep, ncontrol) state, sensordata = rollout.rollout( model, data, initial_state, control, control_spec=control_spec @@ -426,8 +426,8 @@ def test_detect_divergence(self): nstate = mujoco.mj_stateSize(model, mujoco.mjtState.mjSTATE_FULLPHYSICS) data = mujoco.MjData(model) - nroll = 4 # number of rollouts - initial_state = np.empty((nroll, nstate)) + nbatch = 4 # number of rollouts + initial_state = np.empty((nbatch, nstate)) # get diverging (0, 2) and non-diverging (1, 3) states mujoco.mj_getState( @@ -446,7 +446,7 @@ def test_detect_divergence(self): nstep = 10000 # divergence after ~15s, timestep = 2e-3 - state = np.random.randn(nroll, nstep, nstate) + state = np.random.randn(nbatch, nstep, nstate) rollout.rollout(model, data, initial_state, state=state) @@ -464,19 +464,19 @@ def test_threading(self): model = mujoco.MjModel.from_xml_string(TEST_XML) nstate = mujoco.mj_stateSize(model, mujoco.mjtState.mjSTATE_FULLPHYSICS) num_workers = 32 - nroll = 100 + nbatch = 100 nstep = 5 - initial_state = np.random.randn(nroll, nstate) - state = np.empty((nroll, nstep, nstate)) - sensordata = np.empty((nroll, nstep, model.nsensordata)) - control = np.random.randn(nroll, nstep, model.nu) + initial_state = np.random.randn(nbatch, nstate) + state = np.empty((nbatch, nstep, nstate)) + sensordata = np.empty((nbatch, nstep, model.nsensordata)) + control = np.random.randn(nbatch, nstep, model.nu) thread_local = threading.local() def thread_initializer(): thread_local.data = mujoco.MjData(model) - model_list = [copy.copy(model) for _ in range(nroll)] + model_list = [copy.copy(model) for _ in range(nbatch)] def call_rollout(initial_state, control, state, sensordata): rollout.rollout( @@ -490,7 +490,7 @@ def call_rollout(initial_state, control, state, sensordata): sensordata=sensordata, ) - n = nroll // num_workers # integer division + n = nbatch // num_workers # integer division chunks = [] # a list of tuples, one per worker for i in range(num_workers - 1): chunks.append(( @@ -526,14 +526,14 @@ def test_threading_native(self): model = mujoco.MjModel.from_xml_string(TEST_XML) nstate = mujoco.mj_stateSize(model, mujoco.mjtState.mjSTATE_FULLPHYSICS) num_workers = 32 - nroll = 100 + nbatch = 100 nstep = 5 - initial_state = np.random.randn(nroll, nstate) - state = np.empty((nroll, nstep, nstate)) - sensordata = np.empty((nroll, nstep, model.nsensordata)) - control = np.random.randn(nroll, nstep, model.nu) + initial_state = np.random.randn(nbatch, nstate) + state = np.empty((nbatch, nstep, nstate)) + sensordata = np.empty((nbatch, nstep, model.nsensordata)) + control = np.random.randn(nbatch, nstep, model.nu) - model_list = [copy.copy(model) for _ in range(nroll)] + model_list = [copy.copy(model) for _ in range(nbatch)] data_list = [mujoco.MjData(model) for _ in range(num_workers)] rollout.rollout( @@ -555,14 +555,14 @@ def test_threading_native_persistent_object(self): model = mujoco.MjModel.from_xml_string(TEST_XML) nstate = mujoco.mj_stateSize(model, mujoco.mjtState.mjSTATE_FULLPHYSICS) num_workers = 32 - nroll = 100 + nbatch = 100 nstep = 5 - initial_state = np.random.randn(nroll, nstate) - state = np.empty((nroll, nstep, nstate)) - sensordata = np.empty((nroll, nstep, model.nsensordata)) - control = np.random.randn(nroll, nstep, model.nu) + initial_state = np.random.randn(nbatch, nstate) + state = np.empty((nbatch, nstep, nstate)) + sensordata = np.empty((nbatch, nstep, model.nsensordata)) + control = np.random.randn(nbatch, nstep, model.nu) - model_list = [copy.copy(model) for _ in range(nroll)] + model_list = [copy.copy(model) for _ in range(nbatch)] data_list = [mujoco.MjData(model) for _ in range(num_workers)] with rollout.Rollout(nthread=num_workers) as rollout_: @@ -604,14 +604,14 @@ def test_threading_native_persistent_function(self): model = mujoco.MjModel.from_xml_string(TEST_XML) nstate = mujoco.mj_stateSize(model, mujoco.mjtState.mjSTATE_FULLPHYSICS) num_workers = 32 - nroll = 100 + nbatch = 100 nstep = 5 - initial_state = np.random.randn(nroll, nstate) - state = np.empty((nroll, nstep, nstate)) - sensordata = np.empty((nroll, nstep, model.nsensordata)) - control = np.random.randn(nroll, nstep, model.nu) + initial_state = np.random.randn(nbatch, nstate) + state = np.empty((nbatch, nstep, nstate)) + sensordata = np.empty((nbatch, nstep, model.nsensordata)) + control = np.random.randn(nbatch, nstep, model.nu) - model_list = [copy.copy(model) for _ in range(nroll)] + model_list = [copy.copy(model) for _ in range(nbatch)] data_list = [mujoco.MjData(model) for _ in range(num_workers)] for _ in range(2): @@ -699,11 +699,11 @@ def test_intercept_mj_errors(self): nstate = mujoco.mj_stateSize(model, mujoco.mjtState.mjSTATE_FULLPHYSICS) data = mujoco.MjData(model) - nroll = 1 + nbatch = 1 nstep = 3 - initial_state = np.zeros((nroll, nstate)) - ctrl = np.zeros((nroll, nstep, model.nu)) + initial_state = np.zeros((nbatch, nstate)) + ctrl = np.zeros((nbatch, nstep, model.nu)) model.opt.solver = 10 # invalid solver type with self.assertRaisesWithLiteralMatch( @@ -716,9 +716,9 @@ def test_invalid(self): nstate = mujoco.mj_stateSize(model, mujoco.mjtState.mjSTATE_FULLPHYSICS) data = mujoco.MjData(model) - nroll = 1 + nbatch = 1 - initial_state = np.zeros((nroll, nstate)) + initial_state = np.zeros((nbatch, nstate)) control = 'string' with self.assertRaisesWithLiteralMatch( @@ -737,31 +737,31 @@ def test_bad_sizes(self): nstate = mujoco.mj_stateSize(model, mujoco.mjtState.mjSTATE_FULLPHYSICS) data = mujoco.MjData(model) - nroll = 1 + nbatch = 1 nstep = 3 - initial_state = np.random.randn(nroll, nstate + 1) + initial_state = np.random.randn(nbatch, nstate + 1) with self.assertRaisesWithLiteralMatch( ValueError, 'trailing dimension of initial_state must be 6, got 7' ): rollout.rollout(model, data, initial_state) - initial_state = np.random.randn(nroll, nstate) + initial_state = np.random.randn(nbatch, nstate) control = np.random.randn(1, nstep, model.nu + 1) with self.assertRaisesWithLiteralMatch( ValueError, 'trailing dimension of control must be 2, got 3' ): rollout.rollout(model, data, initial_state, control) - control = np.random.randn(nroll, nstep, model.nu) - state = np.random.randn(nroll, nstep + 1, nstate) # incompatible nstep + control = np.random.randn(nbatch, nstep, model.nu) + state = np.random.randn(nbatch, nstep + 1, nstate) # incompatible nstep with self.assertRaisesWithLiteralMatch( ValueError, 'dimension 1 inferred as 3 but state has 4' ): rollout.rollout(model, data, initial_state, control, state=state) - initial_state = np.random.randn(nroll, nstate) - control = np.random.randn(nroll, nstep, model.nu) + initial_state = np.random.randn(nbatch, nstate) + control = np.random.randn(nbatch, nstep, model.nu) bad_spec = mujoco.mjtState.mjSTATE_ACT with self.assertRaisesWithLiteralMatch( ValueError, 'control_spec can only contain bits in mjSTATE_USER' @@ -946,17 +946,17 @@ def py_rollout( ): initial_state = ensure_2d(initial_state) control = ensure_3d(control) - nroll = initial_state.shape[0] + nbatch = initial_state.shape[0] nstep = control.shape[1] if isinstance(model, mujoco.MjModel): - model = [copy.copy(model) for _ in range(nroll)] + model = [copy.copy(model) for _ in range(nbatch)] nstate = mujoco.mj_stateSize(model[0], mujoco.mjtState.mjSTATE_FULLPHYSICS) - state = np.empty((nroll, nstep, nstate)) - sensordata = np.empty((nroll, nstep, model[0].nsensordata)) - for r in range(nroll): + state = np.empty((nbatch, nstep, nstate)) + sensordata = np.empty((nbatch, nstep, model[0].nsensordata)) + for r in range(nbatch): state_r, sensordata_r = one_rollout( model[r], data, initial_state[r], control[r], control_spec )