Skip to content

Commit

Permalink
Benchmark improvements.
Browse files Browse the repository at this point in the history
- Use multiple devices if present
- FLAGS more agnostic to diverse platforms

PiperOrigin-RevId: 598990299
Change-Id: Ifbd454ca521a788324f2b62f1c79e7d4b8305a30
  • Loading branch information
erikfrey authored and copybara-github committed Jan 17, 2024
1 parent 45208a5 commit ee922be
Showing 1 changed file with 36 additions and 59 deletions.
95 changes: 36 additions & 59 deletions mjx/mujoco/mjx/benchmark/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,76 +27,51 @@

FLAGS = flags.FLAGS

_PATHS = {
'humanoid': 'benchmark/model/humanoid/humanoid.xml',
'barkour': 'benchmark/model/barkour_v0/assets/barkour_v0_mjx.xml',
'shadow_hand': 'benchmark/model/shadow_hand/scene_right.xml',
}

_BATCH_SIZE = {
('barkour', 'tpu_v5e'): 1024,
('humanoid', 'tpu_v5e'): 1024,
('shadow_hand', 'tpu_v5e'): 1024,
('barkour', 'gpu_a100'): 8192,
('humanoid', 'gpu_a100'): 8192,
('shadow_hand', 'gpu_a100'): 4096,
('barkour', 'cpu'): 64,
('humanoid', 'cpu'): 64,
('shadow_hand', 'cpu'): 64,
}

_SOLVER_CONFIG = {
('barkour', 'tpu_v5e'): (mujoco.mjtSolver.mjSOL_CG, 4, 6),
('humanoid', 'tpu_v5e'): (mujoco.mjtSolver.mjSOL_CG, 6, 6),
('shadow_hand', 'tpu_v5e'): (mujoco.mjtSolver.mjSOL_CG, 8, 6),
('humanoid', 'gpu_a100'): (mujoco.mjtSolver.mjSOL_NEWTON, 1, 4),
('barkour', 'gpu_a100'): (mujoco.mjtSolver.mjSOL_NEWTON, 1, 4),
('shadow_hand', 'gpu_a100'): (mujoco.mjtSolver.mjSOL_NEWTON, 1, 4),
('barkour', 'cpu'): (mujoco.mjtSolver.mjSOL_NEWTON, 1, 4),
('humanoid', 'cpu'): (mujoco.mjtSolver.mjSOL_NEWTON, 1, 4),
('shadow_hand', 'cpu'): (mujoco.mjtSolver.mjSOL_NEWTON, 1, 4),
}


flags.DEFINE_string('model', 'humanoid', 'Model to benchmark')
flags.DEFINE_enum('device', 'cpu', ('cpu', 'tpu_v5e', 'gpu_a100'),
'Device benchmark is running on')


def _measure_fn(state, init_fn, step_fn, batch_size: int = 1024) -> float:
"""Reports jit time and op time for a function."""
flags.DEFINE_string('mjcf', None, 'path to model', required=True)
flags.DEFINE_integer('step_count', 1000, 'number of steps per rollout')
flags.DEFINE_integer('batch_size', 1024, 'number of parallel rollouts')
flags.DEFINE_integer('unroll', 1, 'loop unroll length')
flags.DEFINE_enum('solver', 'cg', ['cg', 'newton'], 'constraint solver')
flags.DEFINE_integer('iterations', 1, 'number of solver iterations')
flags.DEFINE_integer('ls_iterations', 4, 'number of linesearch iterations')


step_count = 100 if FLAGS.device == 'cpu' else 1000
def _measure(state, init_fn, step_fn) -> float:
"""Reports jit time and op time for a function."""

@jax.jit
@jax.pmap
def run_batch(seed: jp.ndarray):
batch_size = FLAGS.batch_size // jax.device_count()
rngs = jax.random.split(jax.random.PRNGKey(seed), batch_size)
init_state = jax.vmap(init_fn)(rngs)
state = jax.vmap(init_fn)(rngs)

@jax.vmap
def run(state):
def step(state, _):
state = step_fn(state)
return state, ()

return jax.lax.scan(step, state, (), length=step_count)
def step(state, _):
state = step_fn(state)
return state, None

return run(init_state)
state, _ = jax.lax.scan(
step, state, None, length=FLAGS.step_count, unroll=FLAGS.unroll
)
return state

# run once to jit
beg = time.perf_counter()
jax.tree_util.tree_map(lambda x: x.block_until_ready(), run_batch(0))
seed = 0
seeds = jp.arange(seed, seed + jax.device_count(), dtype=int)
jax.tree_util.tree_map(lambda x: x.block_until_ready(), run_batch(seeds))
first_t = time.perf_counter() - beg

times = []
while state:
seed += jax.device_count()
seeds = jp.arange(seed, seed + jax.device_count(), dtype=int)
beg = time.perf_counter()
batch = run_batch(jp.array(len(times)))
jax.tree_util.tree_map(lambda x: x.block_until_ready(), batch)
jax.tree_util.tree_map(lambda x: x.block_until_ready(), run_batch(seeds))
times.append(time.perf_counter() - beg)

op_time = jp.mean(jp.array(times))
batch_sps = batch_size * step_count / op_time
batch_sps = FLAGS.batch_size * FLAGS.step_count / op_time

state.counters['jit_time'] = first_t - op_time
state.counters['batch_sps'] = batch_sps
Expand All @@ -106,11 +81,14 @@ def step(state, _):
def _run(state: benchmark.State):
"""Benchmark a model."""

f = epath.resource_path('mujoco.mjx') / _PATHS[FLAGS.model]
f = epath.resource_path('mujoco.mjx') / 'benchmark/model' / FLAGS.mjcf
m = mujoco.MjModel.from_xml_path(f.as_posix())
m.opt.solver, m.opt.iterations, m.opt.ls_iterations = _SOLVER_CONFIG[
(FLAGS.model, FLAGS.device)
]
m.opt.solver = {
'cg': mujoco.mjtSolver.mjSOL_CG,
'newton': mujoco.mjtSolver.mjSOL_NEWTON,
}[FLAGS.solver.lower()]
m.opt.iterations = FLAGS.iterations
m.opt.ls_iterations = FLAGS.ls_iterations
m = mjx.device_put(m)

def init(rng):
Expand All @@ -122,11 +100,10 @@ def init(rng):
def step(d):
return mjx.step(m, d)

batch_size = _BATCH_SIZE[(FLAGS.model, FLAGS.device)]
_measure_fn(state, init, step, batch_size=batch_size)
_measure(state, init, step)


if __name__ == '__main__':
FLAGS(sys.argv)
benchmark.register(_run, name=FLAGS.model + '_' + FLAGS.device)
benchmark.register(_run, name=sys.argv[0].split('/')[-1])
benchmark.main()

0 comments on commit ee922be

Please sign in to comment.