diff --git a/examples/pytorch_surrogate_model/run_ml_surrogate_15_stage.py b/examples/pytorch_surrogate_model/run_ml_surrogate_15_stage.py index a49a024e3..6bd150422 100644 --- a/examples/pytorch_surrogate_model/run_ml_surrogate_15_stage.py +++ b/examples/pytorch_surrogate_model/run_ml_surrogate_15_stage.py @@ -100,18 +100,19 @@ def download_and_unzip(url, data_dir): data_url = "https://zenodo.org/records/10810754/files/models.zip?download=1" download_and_unzip(data_url, "models.zip") -# It was found that the PyTorch multithreaded defaults interfere with MPI-enabled AMReX -# when initializing the models: https://github.com/AMReX-Codes/pyamrex/issues/322 +# It was found that the PyTorch multithreaded defaults interfere with AMReX OpenMP +# when initializing the models or iterating elements: +# https://github.com/AMReX-Codes/pyamrex/issues/322 +# https://github.com/ECP-WarpX/impactx/issues/773#issuecomment-2585043099 # So we manually set the number of threads to serial (1). -if Config.have_mpi: - n_threads = torch.get_num_threads() - torch.set_num_threads(1) +# Torch threading is not a problem with GPUs and might work when MPI is disabled. +# Could also just be a mixing of OpenMP libraries (gomp and llvm omp) when using the +# pre-build PyTorch pip packages. +torch.set_num_threads(1) model_list = [ surrogate_model(f"models/beam_stage_{stage_i}_model.pt", device=device) for stage_i in range(N_stage) ] -if Config.have_mpi: - torch.set_num_threads(n_threads) pp_amrex = amr.ParmParse("amrex") pp_amrex.add("the_arena_init_size", 0) @@ -328,6 +329,7 @@ def set_lens(self, pc, step, period): lpa = LPASurrogateStage(i, model_list[i], L_surrogate, L_stage_period * i) lpa.nslice = n_slice lpa.ds = L_surrogate + lpa.threadsafe = False lpa_stages.append(lpa) monitor = elements.BeamMonitor("monitor")