diff --git a/mjx/mujoco/mjx/benchmark/benchmark.py b/mjx/mujoco/mjx/benchmark/benchmark.py index c3145f577c..ca29a7e1e8 100644 --- a/mjx/mujoco/mjx/benchmark/benchmark.py +++ b/mjx/mujoco/mjx/benchmark/benchmark.py @@ -27,76 +27,51 @@ FLAGS = flags.FLAGS -_PATHS = { - 'humanoid': 'benchmark/model/humanoid/humanoid.xml', - 'barkour': 'benchmark/model/barkour_v0/assets/barkour_v0_mjx.xml', - 'shadow_hand': 'benchmark/model/shadow_hand/scene_right.xml', -} - -_BATCH_SIZE = { - ('barkour', 'tpu_v5e'): 1024, - ('humanoid', 'tpu_v5e'): 1024, - ('shadow_hand', 'tpu_v5e'): 1024, - ('barkour', 'gpu_a100'): 8192, - ('humanoid', 'gpu_a100'): 8192, - ('shadow_hand', 'gpu_a100'): 4096, - ('barkour', 'cpu'): 64, - ('humanoid', 'cpu'): 64, - ('shadow_hand', 'cpu'): 64, -} - -_SOLVER_CONFIG = { - ('barkour', 'tpu_v5e'): (mujoco.mjtSolver.mjSOL_CG, 4, 6), - ('humanoid', 'tpu_v5e'): (mujoco.mjtSolver.mjSOL_CG, 6, 6), - ('shadow_hand', 'tpu_v5e'): (mujoco.mjtSolver.mjSOL_CG, 8, 6), - ('humanoid', 'gpu_a100'): (mujoco.mjtSolver.mjSOL_NEWTON, 1, 4), - ('barkour', 'gpu_a100'): (mujoco.mjtSolver.mjSOL_NEWTON, 1, 4), - ('shadow_hand', 'gpu_a100'): (mujoco.mjtSolver.mjSOL_NEWTON, 1, 4), - ('barkour', 'cpu'): (mujoco.mjtSolver.mjSOL_NEWTON, 1, 4), - ('humanoid', 'cpu'): (mujoco.mjtSolver.mjSOL_NEWTON, 1, 4), - ('shadow_hand', 'cpu'): (mujoco.mjtSolver.mjSOL_NEWTON, 1, 4), -} - - -flags.DEFINE_string('model', 'humanoid', 'Model to benchmark') -flags.DEFINE_enum('device', 'cpu', ('cpu', 'tpu_v5e', 'gpu_a100'), - 'Device benchmark is running on') - - -def _measure_fn(state, init_fn, step_fn, batch_size: int = 1024) -> float: - """Reports jit time and op time for a function.""" +flags.DEFINE_string('mjcf', None, 'path to model', required=True) +flags.DEFINE_integer('step_count', 1000, 'number of steps per rollout') +flags.DEFINE_integer('batch_size', 1024, 'number of parallel rollouts') +flags.DEFINE_integer('unroll', 1, 'loop unroll length') +flags.DEFINE_enum('solver', 'cg', ['cg', 'newton'], 'constraint solver') +flags.DEFINE_integer('iterations', 1, 'number of solver iterations') +flags.DEFINE_integer('ls_iterations', 4, 'number of linesearch iterations') + - step_count = 100 if FLAGS.device == 'cpu' else 1000 +def _measure(state, init_fn, step_fn) -> float: + """Reports jit time and op time for a function.""" - @jax.jit + @jax.pmap def run_batch(seed: jp.ndarray): + batch_size = FLAGS.batch_size // jax.device_count() rngs = jax.random.split(jax.random.PRNGKey(seed), batch_size) - init_state = jax.vmap(init_fn)(rngs) + state = jax.vmap(init_fn)(rngs) @jax.vmap - def run(state): - def step(state, _): - state = step_fn(state) - return state, () - - return jax.lax.scan(step, state, (), length=step_count) + def step(state, _): + state = step_fn(state) + return state, None - return run(init_state) + state, _ = jax.lax.scan( + step, state, None, length=FLAGS.step_count, unroll=FLAGS.unroll + ) + return state # run once to jit beg = time.perf_counter() - jax.tree_util.tree_map(lambda x: x.block_until_ready(), run_batch(0)) + seed = 0 + seeds = jp.arange(seed, seed + jax.device_count(), dtype=int) + jax.tree_util.tree_map(lambda x: x.block_until_ready(), run_batch(seeds)) first_t = time.perf_counter() - beg times = [] while state: + seed += jax.device_count() + seeds = jp.arange(seed, seed + jax.device_count(), dtype=int) beg = time.perf_counter() - batch = run_batch(jp.array(len(times))) - jax.tree_util.tree_map(lambda x: x.block_until_ready(), batch) + jax.tree_util.tree_map(lambda x: x.block_until_ready(), run_batch(seeds)) times.append(time.perf_counter() - beg) op_time = jp.mean(jp.array(times)) - batch_sps = batch_size * step_count / op_time + batch_sps = FLAGS.batch_size * FLAGS.step_count / op_time state.counters['jit_time'] = first_t - op_time state.counters['batch_sps'] = batch_sps @@ -106,11 +81,14 @@ def step(state, _): def _run(state: benchmark.State): """Benchmark a model.""" - f = epath.resource_path('mujoco.mjx') / _PATHS[FLAGS.model] + f = epath.resource_path('mujoco.mjx') / 'benchmark/model' / FLAGS.mjcf m = mujoco.MjModel.from_xml_path(f.as_posix()) - m.opt.solver, m.opt.iterations, m.opt.ls_iterations = _SOLVER_CONFIG[ - (FLAGS.model, FLAGS.device) - ] + m.opt.solver = { + 'cg': mujoco.mjtSolver.mjSOL_CG, + 'newton': mujoco.mjtSolver.mjSOL_NEWTON, + }[FLAGS.solver.lower()] + m.opt.iterations = FLAGS.iterations + m.opt.ls_iterations = FLAGS.ls_iterations m = mjx.device_put(m) def init(rng): @@ -122,11 +100,10 @@ def init(rng): def step(d): return mjx.step(m, d) - batch_size = _BATCH_SIZE[(FLAGS.model, FLAGS.device)] - _measure_fn(state, init, step, batch_size=batch_size) + _measure(state, init, step) if __name__ == '__main__': FLAGS(sys.argv) - benchmark.register(_run, name=FLAGS.model + '_' + FLAGS.device) + benchmark.register(_run, name=sys.argv[0].split('/')[-1]) benchmark.main()