From 3e1349f5699b427934b4b3348351bf96514939a0 Mon Sep 17 00:00:00 2001 From: dekken Date: Sun, 18 Dec 2022 22:48:40 +0100 Subject: [PATCH] run multiple simulations in // from python --- pyphare/pyphare/cpp/__init__.py | 25 +- pyphare/pyphare/pharesee/hierarchy.py | 59 +++ pyphare/pyphare/pharesee/particles.py | 8 +- pyphare/pyphare/pharesee/run.py | 22 +- pyphare/pyphare/simulator/simulator.py | 4 + pyphare/pyphare/simulator/simulators.py | 344 ++++++++++++++++++ tests/functional/alfven_wave/alfven_wave1d.py | 13 +- tools/build_compare_2_branches.py | 46 +++ tools/build_compare_2_branches.sh | 6 + tools/build_compare_top_2_commits.py | 38 ++ tools/build_compare_top_2_commits.sh | 6 + tools/python3/cmake.py | 19 +- tools/python3/git.py | 15 +- 13 files changed, 570 insertions(+), 35 deletions(-) create mode 100644 pyphare/pyphare/simulator/simulators.py create mode 100644 tools/build_compare_2_branches.py create mode 100644 tools/build_compare_2_branches.sh create mode 100644 tools/build_compare_top_2_commits.py create mode 100644 tools/build_compare_top_2_commits.sh diff --git a/pyphare/pyphare/cpp/__init__.py b/pyphare/pyphare/cpp/__init__.py index f8ab5a923..1a115cacd 100644 --- a/pyphare/pyphare/cpp/__init__.py +++ b/pyphare/pyphare/cpp/__init__.py @@ -1,6 +1,7 @@ # continue to use override if set _cpp_lib_override = None +_globals = dict(module=None) def cpp_lib(override=None): import importlib @@ -8,15 +9,21 @@ def cpp_lib(override=None): global _cpp_lib_override if override is not None: _cpp_lib_override = override - if _cpp_lib_override is not None: - return importlib.import_module(_cpp_lib_override) - - if not __debug__: - return importlib.import_module("pybindlibs.cpp") - try: - return importlib.import_module("pybindlibs.cpp_dbg") - except ImportError as err: - return importlib.import_module("pybindlibs.cpp") + + def get_module(): + if _cpp_lib_override is not None: + return importlib.import_module(_cpp_lib_override) + if not __debug__: + return importlib.import_module("pybindlibs.cpp") + try: + return importlib.import_module("pybindlibs.cpp_dbg") + except ImportError as err: + return importlib.import_module("pybindlibs.cpp") + + if _globals["module"] is None: + _globals["module"] = get_module() + print("Loaded C++ module", _globals["module"].__file__) + return _globals["module"] def cpp_etc_lib(): diff --git a/pyphare/pyphare/pharesee/hierarchy.py b/pyphare/pyphare/pharesee/hierarchy.py index 4046e6fe3..2e67ef3ef 100644 --- a/pyphare/pyphare/pharesee/hierarchy.py +++ b/pyphare/pyphare/pharesee/hierarchy.py @@ -168,6 +168,22 @@ def __init__(self, layout, field_name, data, **kwargs): self.dataset = data + def compare(self, that, atol=1e-16, rtol=0): + if isinstance(that, FieldData): + assert self.name == that.name + # drop nans + a = self.dataset + b = that.dataset + mask = ~(np.isnan(a) | np.isnan(b)) + if __debug__: + try: + np.testing.assert_allclose(a[mask], b[mask], atol=atol, rtol=rtol) + except AssertionError as e: + print(f"FieldData comparison failure in {self.name}") + raise e + return True + return np.allclose(a[mask], b[mask], atol=atol, rtol=rtol) + return False def meshgrid(self, select=None): def grid(): @@ -182,6 +198,10 @@ def grid(): return mesh + def __eq__(self, that): + return self.compare(that) + + class ParticleData(PatchData): """ Concrete type of PatchData representing particles in a region @@ -1690,3 +1710,42 @@ def get_times_from_h5(filepath): times = np.array(sorted([float(s) for s in list(f["t"].keys())])) f.close() return times + + +def hierarchy_compare(this, that): + if not isinstance(this, PatchHierarchy) or not isinstance(that, PatchHierarchy): + return False + + if this.ndim != that.ndim or this.domain_box != that.domain_box: + return False + + if this.time_hier.keys() != that.time_hier.keys(): + return False + + for tidx in this.times(): + patch_levels_ref = this.time_hier[tidx] + patch_levels_cmp = that.time_hier[tidx] + + if patch_levels_ref.keys() != patch_levels_cmp.keys(): + return False + + for level_idx in patch_levels_cmp.keys(): + patch_level_ref = patch_levels_ref[level_idx] + patch_level_cmp = patch_levels_cmp[level_idx] + + for patch_idx in range(len(patch_level_cmp.patches)): + patch_ref = patch_level_ref.patches[patch_idx] + patch_cmp = patch_level_cmp.patches[patch_idx] + + if patch_ref.patch_datas.keys() != patch_cmp.patch_datas.keys(): + return False + + for patch_data_key in patch_ref.patch_datas.keys(): + patch_data_ref = patch_ref.patch_datas[patch_data_key] + patch_data_cmp = patch_cmp.patch_datas[patch_data_key] + + if patch_data_cmp != patch_data_ref: + return False + + return True + diff --git a/pyphare/pyphare/pharesee/particles.py b/pyphare/pyphare/pharesee/particles.py index 5d7413aed..0df62f028 100644 --- a/pyphare/pyphare/pharesee/particles.py +++ b/pyphare/pyphare/pharesee/particles.py @@ -194,11 +194,11 @@ def all_assert_sorted(part1, part2): deltol = 1e-6 if any([part.deltas.dtype == np.float32 for part in [part1, part2]] ) else 1e-12 np.testing.assert_array_equal(part1.iCells[idx1], part2.iCells[idx2]) - np.testing.assert_allclose(part1.deltas[idx1], part2.deltas[idx2], atol=deltol) + np.testing.assert_allclose(part1.deltas[idx1], part2.deltas[idx2], rtol=0, atol=deltol) - np.testing.assert_allclose(part1.v[idx1,0], part2.v[idx2,0], atol=1e-12) - np.testing.assert_allclose(part1.v[idx1,1], part2.v[idx2,1], atol=1e-12) - np.testing.assert_allclose(part1.v[idx1,2], part2.v[idx2,2], atol=1e-12) + np.testing.assert_allclose(part1.v[idx1,0], part2.v[idx2,0], rtol=0, atol=1e-12) + np.testing.assert_allclose(part1.v[idx1,1], part2.v[idx2,1], rtol=0, atol=1e-12) + np.testing.assert_allclose(part1.v[idx1,2], part2.v[idx2,2], rtol=0, atol=1e-12) def any_assert(part1, part2): diff --git a/pyphare/pyphare/pharesee/run.py b/pyphare/pyphare/pharesee/run.py index 2de12b188..ba23087c6 100644 --- a/pyphare/pyphare/pharesee/run.py +++ b/pyphare/pyphare/pharesee/run.py @@ -336,14 +336,18 @@ def GetDl(self, level='finest', time=None): def GetAllAvailableQties(self, time, pops): assert self.single_hier_for_all_quantities == True # can't work otherwise - self.GetParticles(time, pops) - self.GetB(time) - self.GetE(time) - self.GetNi(time) - self.GetVi(time) - - for pop in pops: - self.GetFlux(time, pop) - self.GetN(time, pop) + for fn in [self.GetB, self.GetE, self.GetNi, self.GetVi]: + try: + fn(time) + except: + pass + + try: + self.GetParticles(time, pops) + for pop in pops: + self.GetFlux(time, pop) + self.GetN(time, pop) + except: + pass return self.hier diff --git a/pyphare/pyphare/simulator/simulator.py b/pyphare/pyphare/simulator/simulator.py index 949cf1643..3edfb67f3 100644 --- a/pyphare/pyphare/simulator/simulator.py +++ b/pyphare/pyphare/simulator/simulator.py @@ -144,6 +144,10 @@ def _auto_dump(self): return self.auto_dump and self.dump() + def finished(self): + return self.cpp_sim.currentTime() >= self.cpp_sim.endTime() + + def dump(self, *args): assert len(args) == 0 or len(args) == 2 diff --git a/pyphare/pyphare/simulator/simulators.py b/pyphare/pyphare/simulator/simulators.py new file mode 100644 index 000000000..ebdb2d023 --- /dev/null +++ b/pyphare/pyphare/simulator/simulators.py @@ -0,0 +1,344 @@ +""" + + + +""" +import os + +import sys +import json +import time +import inspect +import importlib +import concurrent + +from enum import IntEnum +from pathlib import Path +from datetime import datetime +from multiprocessing import Process +from multiprocessing import shared_memory, Lock + +import numpy as np +import pyphare.pharein as ph +from pyphare.pharesee.run import Run +from pyphare.pharesee.hierarchy import hierarchy_compare +from pyphare.simulator.simulator import Simulator + +from fastapi import FastAPI, Request + + +_globals = dict(servers=dict(), busy=0) +phare_runs_dir = Path(os.getcwd()) / "phare_runs" # can't contain period "." + +lock = Lock() + +shared_size = (200, 200) # kinda arbitrary but over allocated + +class SharedSimulationStateEnum(IntEnum): + CURRENT_TIME = 0 + IS_BUSY = 7 + CAN_ADVANCE = 8 + IS_FINISHED = 9 + + + +def create_shared_block(n_sims): + a = np.zeros(shared_size, dtype=np.int64) + shm = shared_memory.SharedMemory(create=True, size=a.nbytes) + np_array = np.ndarray(a.shape, dtype=np.int64, buffer=shm.buf) + return shm, np_array + + +def atomic_set(sim_id, shr_name, pos, val): + existing_shm = shared_memory.SharedMemory(name=shr_name) + np_array = np.ndarray(shared_size, dtype=np.int64, buffer=existing_shm.buf) + lock.acquire() + + np_array[sim_id][pos] = val + + lock.release() + existing_shm.close() + + +def poll(sim_id, shr_name): + while True: + # print("polling", sim_id) + wait_time = state_machine(sim_id, shr_name) + time.sleep(wait_time) + + +def state_machine(sim_id, shr_name): + result = 2 + finished = is_finished() + + existing_shm = shared_memory.SharedMemory(name=shr_name) + np_array = np.ndarray(shared_size, dtype=np.int64, buffer=existing_shm.buf) + + lock.acquire() # atomic operations below + should_advance = np_array[sim_id][SharedSimulationStateEnum.CAN_ADVANCE] == 1 + if should_advance: + np_array[sim_id][SharedSimulationStateEnum.IS_BUSY] = 1 + np_array[sim_id][SharedSimulationStateEnum.CAN_ADVANCE] = 0 + lock.release() # atomic operations above + + if not finished and should_advance: + advance_sim() + + lock.acquire() # atomic operations below + if should_advance: + np_array[sim_id][SharedSimulationStateEnum.IS_BUSY] = 0 + np_array[sim_id][SharedSimulationStateEnum.IS_FINISHED] = is_finished() + np_array[sim_id][SharedSimulationStateEnum.CURRENT_TIME] = 1e9 * current_time() + lock.release() # atomic operations above + + existing_shm.close() + + if should_advance and finished: + print("FINISHED", sim_id) + exit(0) + + return result + + +def set_env(dic): + for k, v in dic.items(): + os.environ[k] = v + + +def prepend_python_path(val): + sys.path = [val] + sys.path + + +def build_diag_dir(sim_id): + return str(phare_runs_dir / f"diags_{os.environ['PHARE_LOG_TIME']}-ID={sim_id}") + + +def init_sim(sim): + ph.global_vars.sim = sim + + sim.diag_options["options"]["dir"] = build_diag_dir(os.environ["SIM_ID"]) + Path(sim.diag_options["options"]["dir"]).mkdir(parents=True, exist_ok=True) + _globals["simulator"] = Simulator(sim).initialize() + return str(_globals["simulator"].currentTime()) + + +def advance_sim(): + _globals["busy"] = 1 + _globals["simulator"].advance() + _globals["busy"] = 0 + + +def is_finished(): + return _globals["simulator"].finished() + + +def current_time(): + return _globals["simulator"].currentTime() + + +def init_simulation(sim_id, sim, shr_name, dic): + set_env(dic) + + init_sim(sim) + + poll(sim_id, shr_name) + + +def start_server_process(sim_id, sim, shr_name, dic): + servers = _globals["servers"] + assert sim_id not in servers + + if "build_dir" in dic: + prepend_python_path(dic["build_dir"]) + + _globals["servers"][sim_id] = Process( + target=init_simulation, args=(sim_id, sim, shr_name, dic) + ) + _globals["servers"][sim_id].start() + + try_count = 5 + for i in range(1, try_count + 1): + time.sleep(0.5) + try: + assert servers[sim_id].exitcode is None # or it crashed/exited early + return + except Exception as e: + if i == try_count: + raise e + + +def stop_servers(): + for k, v in _globals["servers"].items(): + v.terminate() + + +def build_dir_path(path): + p = Path(os.path.realpath(path)) + if not p.exists(): + p = Path(os.getcwd()) / path + assert p.exists() + return str(p) + + +class Simulators: + def __init__(self, starting_sim_id=10): + self.simulations = [] + self.simulation_configs = [] + self.states = dict(init=False) + self.starting_sim_id = starting_sim_id + self.log_time = datetime.now().strftime("%m_%d_%Y_%H_%M_%S") + os.environ["PHARE_LOG_TIME"] = self.log_time + + # loaded during init + self.thread_pool = None + self.shr = None + self.shared_np_array = None + + def register(self, simulation, build_dir=None, diag_dir=None): + self.simulations += [simulation] + self.simulation_configs += [dict(build_dir=build_dir, diag_dir=diag_dir)] + ph.global_vars.sim = None + + def init(self): + shr, np_array = create_shared_block(len(self.simulations)) + self.shr = shr + self.shared_np_array = np_array + self._state_machine_set_per_simulation(SharedSimulationStateEnum.IS_BUSY, 1) + + self.thread_pool = concurrent.futures.ThreadPoolExecutor( + max_workers=len(self.simulations) + ) + + for i, simulation in enumerate(self.simulations): + sim_id = self.starting_sim_id + i + simulation_config = self.simulation_configs[i] + + init_dict = dict(SIM_ID=str(sim_id), PHARE_LOG_TIME=self.log_time) + if simulation_config["build_dir"]: + init_dict["build_dir"] = build_dir_path(simulation_config["build_dir"]) + + start_server_process( + sim_id, + simulation, + shr.name, + init_dict, + ) + self.states["init"] = True + + def _state_machine_list_for_value(self, offset): + existing_shm = shared_memory.SharedMemory(name=self.shr.name) + np_array = np.ndarray(shared_size, dtype=np.int64, buffer=existing_shm.buf) + lock.acquire() + + results = np.ndarray(len(self.simulations)) + for i, simulation in enumerate(self.simulations): + sim_id = self.starting_sim_id + i + results[i] = np_array[sim_id][offset] + + lock.release() + existing_shm.close() + return results + + def _state_machine_set_per_simulation(self, offset, val): + existing_shm = shared_memory.SharedMemory(name=self.shr.name) + np_array = np.ndarray(shared_size, dtype=np.int64, buffer=existing_shm.buf) + lock.acquire() + + for i, simulation in enumerate(self.simulations): + sim_id = self.starting_sim_id + i + np_array[sim_id][offset] = val + + lock.release() + existing_shm.close() + + def wait_for_simulations(self): + # we have to wait for all simulations to finish the current timestep + while True: + if np.sum(self._state_machine_list_for_value(SharedSimulationStateEnum.IS_BUSY)) == 0: + return + time.sleep(1) + + def advance(self, compare=False): + if not self.states["init"]: + self.init() + + atomic_set(0, self.shr.name, 0, int(compare)) + self._state_machine_set_per_simulation(SharedSimulationStateEnum.CAN_ADVANCE, 1) + + if compare: + self._state_machine_set_per_simulation(SharedSimulationStateEnum.IS_BUSY, 1) + # we have to wait for all simulations to finish the current timestep + self.wait_for_simulations() + self.compare() + + def compare(self): + for i, simulation in enumerate(self.simulations): + assert ( + _globals["servers"][self.starting_sim_id + i].exitcode is None + ), "or it crashed early" + + sim_times = self._state_machine_list_for_value(SharedSimulationStateEnum.CURRENT_TIME) + + times = { + i: (0.0 + sim_times[i]) / 1e9 # :( it's an int so after decimal is dropped + for i, simulation in enumerate(self.simulations) + } + + if len(times) < 2: + return + + for i, simulation in enumerate(self.simulations): + for j in range(i + 1, len(times)): + if times[i] == times[j]: + run0 = Run( + build_diag_dir(self.starting_sim_id + i), + single_hier_for_all_quantities=True, + ) + run1 = Run( + build_diag_dir(self.starting_sim_id + j), + single_hier_for_all_quantities=True, + ) + + print( + f"comparing {self.starting_sim_id + i} and {self.starting_sim_id + j} at time {times[i]}" + ) + assert hierarchy_compare( + run0.GetAllAvailableQties(times[i], []), + run1.GetAllAvailableQties(times[i], []), + ) + print("OK!") + + def run(self, compare=False): + if not self.states["init"]: + self.init() + + while self.simulations: + self.advance(compare=compare) + + is_busy = self._state_machine_list_for_value(SharedSimulationStateEnum.IS_BUSY) + is_finished = self._state_machine_list_for_value(SharedSimulationStateEnum.IS_FINISHED) + + # trim finished simulations + self.simulations = [ + sim + for i, sim in enumerate(self.simulations) + if is_busy[i] or not is_finished[i] + ] + + print("running simulations", len(self.simulations)) + time.sleep(1) + + self.kill() + + def __del__(self): + self.kill() + + def kill(self): + time.sleep(2) + + if self.shr: + self.shr.close() + self.shr.unlink() + self.shr = None + + stop_servers() diff --git a/tests/functional/alfven_wave/alfven_wave1d.py b/tests/functional/alfven_wave/alfven_wave1d.py index 1e874f3f3..b5dd8306b 100644 --- a/tests/functional/alfven_wave/alfven_wave1d.py +++ b/tests/functional/alfven_wave/alfven_wave1d.py @@ -18,6 +18,11 @@ import numpy as np mpl.use('Agg') +# Override if you do want it seeded +MODEL_INIT={} +TIME_STEP_NBR=100000 +TIMESTEP=.01 + #################################################################### @@ -32,8 +37,8 @@ def config(): Simulation( smallest_patch_size=50, largest_patch_size=50, - time_step_nbr=100000, # number of time steps (not specified if time_step and final_time provided) - final_time=1000, # simulation final time (not specified if time_step and time_step_nbr provided) + time_step_nbr=TIME_STEP_NBR, # number of time steps (not specified if time_step and final_time provided) + time_step=TIMESTEP, # simulation time_step (not specified if final_time and time_step_nbr provided) boundary_types="periodic", # boundary condition, string or tuple, length == len(cell) == len(dl) cells=1000, # integer or tuple length == dimension dl=1, # mesh size of the root level, float or tuple @@ -98,7 +103,7 @@ def vthz(x): MaxwellianFluidModel( bx=bx, by=by, bz=bz, - protons={"charge": 1, "density": density, **vvv} + protons={"charge": 1, "density": density, **vvv, "init": MODEL_INIT}, ) ElectronModel(closure="isothermal", Te=0.0) @@ -120,6 +125,8 @@ def vthz(x): compute_timestamps=timestamps, ) + return gv.sim + diff --git a/tools/build_compare_2_branches.py b/tools/build_compare_2_branches.py new file mode 100644 index 000000000..5633209dd --- /dev/null +++ b/tools/build_compare_2_branches.py @@ -0,0 +1,46 @@ +import sys +import shutil +from pathlib import Path + +from pyphare.simulator.simulators import Simulators +from tools.python3 import pushd, cmake, git +from tests.functional.alfven_wave import alfven_wave1d + +# we want it seeded +alfven_wave1d.MODEL_INIT={"seed": 1337} +alfven_wave1d.TIME_STEP_NBR = 10 + +if len(sys.argv) != 3: + print("Incorrect input arguments, expects two branch names", sys.argv) + exit(0) + +# register exit handler +git.git_branch_reset_at_exit() + +branches = sys.argv[1:] + +# check branches exist +for branch in branches: + git.checkout(branch) + +build_dir = Path("build") + +for branch in branches: + b = (build_dir / branch) + b.mkdir(parents=True, exist_ok=True) + + git.checkout(branch) + with pushd(b): + cmake.config("../..") + +for branch in branches: + git.checkout(branch) + with pushd(build_dir / branch): + cmake.build() + +simulators = Simulators() +for branch in branches: + # alfven_wave1d.config() will be identical here even if different on branches + # as it is already parsed before we change branch + simulators.register(alfven_wave1d.config(), build_dir=str(build_dir / branch)) +simulators.run(compare=True) diff --git a/tools/build_compare_2_branches.sh b/tools/build_compare_2_branches.sh new file mode 100644 index 000000000..c513b2572 --- /dev/null +++ b/tools/build_compare_2_branches.sh @@ -0,0 +1,6 @@ + +set -ex +SCRIPT_DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" && pwd )" +cd $SCRIPT_DIR && cd .. && CWD=$PWD # move to project root + +python3 tools/build_compare_2_branches.py $@ diff --git a/tools/build_compare_top_2_commits.py b/tools/build_compare_top_2_commits.py new file mode 100644 index 000000000..6775c4d80 --- /dev/null +++ b/tools/build_compare_top_2_commits.py @@ -0,0 +1,38 @@ +import shutil +from pathlib import Path + +from pyphare.simulator.simulators import Simulators +from tools.python3 import pushd, cmake, git +from tests.functional.alfven_wave import alfven_wave1d + +# we want it seeded +alfven_wave1d.MODEL_INIT={"seed": 1337} +alfven_wave1d.TIME_STEP_NBR = 10 + +# register exit handler +git.git_branch_reset_at_exit() + +top_2_hashes = git.hashes(2) + +build_dir = Path("build") + +for hsh in top_2_hashes: + b = (build_dir / hsh) + b.mkdir(parents=True, exist_ok=True) + + git.checkout(hsh) + with pushd(b): + cmake.config("../..") + +for hsh in top_2_hashes: + b = (build_dir / hsh) + + git.checkout(hsh) + with pushd(b): + cmake.build() + +simulators = Simulators() +for hsh in top_2_hashes: + simulators.register(alfven_wave1d.config(), build_dir=str(build_dir / hsh)) + +simulators.run(compare=True) diff --git a/tools/build_compare_top_2_commits.sh b/tools/build_compare_top_2_commits.sh new file mode 100644 index 000000000..69daac104 --- /dev/null +++ b/tools/build_compare_top_2_commits.sh @@ -0,0 +1,6 @@ + +set -ex +SCRIPT_DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" && pwd )" +cd $SCRIPT_DIR && cd .. && CWD=$PWD # move to project root + +python3 tools/build_compare_top_2_commits.py diff --git a/tools/python3/cmake.py b/tools/python3/cmake.py index 6c0350a8b..58cdca073 100644 --- a/tools/python3/cmake.py +++ b/tools/python3/cmake.py @@ -4,31 +4,36 @@ def version(): pass +# set to false if you do not want this +_USE_NINJA=True + +# None == subprojects, needs building per commit if unset, so not great +_SAMRAI_DIR="/mkn/r/llnl/samrai/master" def make_config_str( - path, samrai_dir=None, cxx_flags=None, use_ninja=False, use_ccache=False, extra="" + path, samrai_dir=None, cxx_flags=None, use_ccache=False, extra="" ): """ FULL UNPORTABLE OPTIMIZATIONS = cxx_flags="-O3 -march=native -mtune=native" """ - + samrai_dir = _SAMRAI_DIR if samrai_dir is None else samrai_dir samrai_dir = "" if samrai_dir is None else f"-DSAMRAI_ROOT={samrai_dir}" cxx_flags = "" if cxx_flags is None else f'-DCMAKE_CXX_FLAGS="{cxx_flags}"' ccache = "" if use_ccache is False else "-DCMAKE_CXX_COMPILER_LAUNCHER=ccache" - ninja = "" if not use_ninja else "-G Ninja" + ninja = "" if not _USE_NINJA else "-G Ninja" return f"cmake {path} {samrai_dir} {cxx_flags} {ninja} {ccache} {extra}" def config( - path, samrai_dir=None, cxx_flags=None, use_ninja=False, use_ccache=False, extra="" + path, samrai_dir=None, cxx_flags=None, use_ccache=False, extra="" ): - cmd = make_config_str(path, samrai_dir, cxx_flags, use_ninja, extra) + cmd = make_config_str(path, samrai_dir, cxx_flags, extra) run(cmd, capture_output=False) -def build(use_ninja=False, threads=1): - run("ninja" if use_ninja else f"make -j{threads}", capture_output=False) +def build(threads=1): + run("ninja" if _USE_NINJA else f"make -j{threads}", capture_output=False) def list_tests(): diff --git a/tools/python3/git.py b/tools/python3/git.py index 3dce25685..03fdb4430 100644 --- a/tools/python3/git.py +++ b/tools/python3/git.py @@ -1,5 +1,7 @@ -from tools.python3 import decode_bytes, run + +import atexit import subprocess +from tools.python3 import decode_bytes, run def current_branch(): @@ -30,6 +32,13 @@ def checkout(branch, create=False, recreate=False): create = True if create and not branch_exists(branch): - run(f"git checkout -b {branch}") + run(f"git checkout -b {branch}", check=True) else: - run(f"git checkout {branch}") + run(f"git checkout {branch}", check=True) + +def git_branch_reset_at_exit(): + current_git_branch = current_branch() + + @atexit.register + def _reset_(): + checkout(current_git_branch)