From f596048dd4480e958f9bf92ad25101a077a6dd99 Mon Sep 17 00:00:00 2001 From: Axel Huebl Date: Mon, 13 Jan 2025 11:26:34 -0800 Subject: [PATCH] PyTorch Threading Mixed with AMReX is icky Issues as soon as we use MPI+OMP and add our `Drift` element. --- .../run_ml_surrogate_15_stage.py | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/examples/pytorch_surrogate_model/run_ml_surrogate_15_stage.py b/examples/pytorch_surrogate_model/run_ml_surrogate_15_stage.py index a49a024e3..6bd150422 100644 --- a/examples/pytorch_surrogate_model/run_ml_surrogate_15_stage.py +++ b/examples/pytorch_surrogate_model/run_ml_surrogate_15_stage.py @@ -100,18 +100,19 @@ def download_and_unzip(url, data_dir): data_url = "https://zenodo.org/records/10810754/files/models.zip?download=1" download_and_unzip(data_url, "models.zip") -# It was found that the PyTorch multithreaded defaults interfere with MPI-enabled AMReX -# when initializing the models: https://github.com/AMReX-Codes/pyamrex/issues/322 +# It was found that the PyTorch multithreaded defaults interfere with AMReX OpenMP +# when initializing the models or iterating elements: +# https://github.com/AMReX-Codes/pyamrex/issues/322 +# https://github.com/ECP-WarpX/impactx/issues/773#issuecomment-2585043099 # So we manually set the number of threads to serial (1). -if Config.have_mpi: - n_threads = torch.get_num_threads() - torch.set_num_threads(1) +# Torch threading is not a problem with GPUs and might work when MPI is disabled. +# Could also just be a mixing of OpenMP libraries (gomp and llvm omp) when using the +# pre-build PyTorch pip packages. +torch.set_num_threads(1) model_list = [ surrogate_model(f"models/beam_stage_{stage_i}_model.pt", device=device) for stage_i in range(N_stage) ] -if Config.have_mpi: - torch.set_num_threads(n_threads) pp_amrex = amr.ParmParse("amrex") pp_amrex.add("the_arena_init_size", 0) @@ -328,6 +329,7 @@ def set_lens(self, pc, step, period): lpa = LPASurrogateStage(i, model_list[i], L_surrogate, L_stage_period * i) lpa.nslice = n_slice lpa.ds = L_surrogate + lpa.threadsafe = False lpa_stages.append(lpa) monitor = elements.BeamMonitor("monitor")