Skip to content

Commit

Permalink
PyTorch Threading Mixed with AMReX is icky
Browse files Browse the repository at this point in the history
Issues as soon as we use MPI+OMP and add our `Drift` element.
  • Loading branch information
ax3l committed Jan 13, 2025
1 parent 0fa34a3 commit 0b250f3
Showing 1 changed file with 9 additions and 7 deletions.
16 changes: 9 additions & 7 deletions examples/pytorch_surrogate_model/run_ml_surrogate_15_stage.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,18 +100,19 @@ def download_and_unzip(url, data_dir):
data_url = "https://zenodo.org/records/10810754/files/models.zip?download=1"
download_and_unzip(data_url, "models.zip")

# It was found that the PyTorch multithreaded defaults interfere with MPI-enabled AMReX
# when initializing the models: https://github.com/AMReX-Codes/pyamrex/issues/322
# It was found that the PyTorch multithreaded defaults interfere with AMReX OpenMP
# when initializing the models or iterating elements:
# https://github.com/AMReX-Codes/pyamrex/issues/322
# https://github.com/ECP-WarpX/impactx/issues/773#issuecomment-2585043099
# So we manually set the number of threads to serial (1).
if Config.have_mpi:
n_threads = torch.get_num_threads()
torch.set_num_threads(1)
# Torch threading is not a problem with GPUs and might work when MPI is disabled.
# Could also just be a mixing of OpenMP libraries (gomp and llvm omp) when using the
# pre-build PyTorch pip packages.
torch.set_num_threads(1)
model_list = [
surrogate_model(f"models/beam_stage_{stage_i}_model.pt", device=device)
for stage_i in range(N_stage)
]
if Config.have_mpi:
torch.set_num_threads(n_threads)

pp_amrex = amr.ParmParse("amrex")
pp_amrex.add("the_arena_init_size", 0)
Expand Down Expand Up @@ -328,6 +329,7 @@ def set_lens(self, pc, step, period):
lpa = LPASurrogateStage(i, model_list[i], L_surrogate, L_stage_period * i)
lpa.nslice = n_slice
lpa.ds = L_surrogate
lpa.threadsafe = False
lpa_stages.append(lpa)

monitor = elements.BeamMonitor("monitor")
Expand Down

0 comments on commit 0b250f3

Please sign in to comment.