Skip to content

Commit

Permalink
bug fix in test
Browse files Browse the repository at this point in the history
  • Loading branch information
Michael Arbel committed Dec 7, 2023
1 parent 791dd81 commit ed4578e
Show file tree
Hide file tree
Showing 7 changed files with 26 additions and 25 deletions.
8 changes: 4 additions & 4 deletions mlxp/_internal/configure.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,8 +153,8 @@ def _build_config(config_path, config_name, co_filename, overrides, interactive_
cfg = _update_config(default_cfg, overrides_config, overrides_mlxp)

im_handler = InteractiveModeHandler(cfg["mlxp"]["interactive_mode"], interactive_mode_file)

update_default_config = _set_scheduler(default_cfg, overrides_mlxp["mlxp"], im_handler)
scheduler_settings = _get_scheduler_settings(default_cfg, overrides_mlxp)
update_default_config = _set_scheduler(default_cfg, scheduler_settings, im_handler)

mlxp_file = os.path.join(config_path, "mlxp.yaml")
if not os.path.exists(mlxp_file) or update_default_config:
Expand Down Expand Up @@ -200,8 +200,7 @@ def _get_mlxp_configs(mlxp_file, default_config_mlxp):
return mlxp_config


def _set_scheduler(default_config, overrides, im_handler):
scheduler_settings = _get_scheduler_settings(default_config, overrides)
def _set_scheduler(default_config, scheduler_settings, im_handler):
scheduler_name, scheduler_name_default, using_scheduler, interactive_mode = scheduler_settings
update_default_config = False

Expand Down Expand Up @@ -230,6 +229,7 @@ def _get_scheduler_settings(default_config, overrides):
scheduler_name = scheduler_name_default
interactive_mode = default_config.mlxp.interactive_mode
if overrides:
overrides = overrides["mlxp"]
if "use_scheduler" in overrides:
using_scheduler = overrides["use_scheduler"]
if "scheduler" in overrides:
Expand Down
2 changes: 2 additions & 0 deletions mlxp/launcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,8 @@ def hydra_decorator(task_function: TaskFunction) -> Callable[[], None]:
@functools.wraps(task_function)
def decorated_main(cfg_passthrough: Optional[DictConfig] = None) -> Any:
processed_config_path = _process_config_path(config_path, task_function.__code__.co_filename)
os.makedirs(processed_config_path, exist_ok=True)

if cfg_passthrough is not None:
return task_function(cfg_passthrough)
else:
Expand Down
11 changes: 10 additions & 1 deletion tests/test_examples/script.sh → tests/script.sh
Original file line number Diff line number Diff line change
@@ -1,9 +1,18 @@
#!/bin/bash


python main.py \
#HYDRA_FULL_ERROR=1 OC_CAUSE=1 python -m ipdb launch.py

cd test_examples

python launch.py \
optimizer.lr=10.,1.,0.1\
seed=1,2,3,4\
+mlxp.use_scheduler=False\
+mlxp.use_version_manager=False\
+mlxp.use_logger=True\


python read.py

rm -r logs
2 changes: 1 addition & 1 deletion tests/test_examples/configs/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,4 @@ data:
d_int: 10
device: 'cpu'
optimizer:
lr: 10.
lr: 10.
4 changes: 2 additions & 2 deletions tests/test_examples/configs/mlxp.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,9 @@ scheduler:
option_cmd: []
version_manager:
name: GitVM
parent_work_dir: ./.workdir
parent_work_dir: ./.work_dir
compute_requirements: false
use_version_manager: false
use_scheduler: false
use_logger: true
interactive_mode: false
interactive_mode: true
2 changes: 1 addition & 1 deletion tests/test_examples/read.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
def read_outputs():

# Create a reader object to access the results stored by the logger.
parent_log_dir = './test_examples/logs'
parent_log_dir = './logs'
reader = mlxp.Reader(parent_log_dir)

# Displaying all fields accessible in the database.
Expand Down
22 changes: 6 additions & 16 deletions tests/test_mlxp.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,29 +7,19 @@
import os
import sys
import pytest
import subprocess


scripts = pathlib.Path('test_examples').resolve().glob('launch.py')
scripts = pathlib.Path('.').resolve().glob('script.sh')

@pytest.mark.parametrize('script', scripts)
def test_launching(script):

parent_path = str(script.parent)
sys.path.insert(0,parent_path)
runpy.run_path(str(script),run_name='__main__')


scripts = pathlib.Path('test_examples').resolve().glob('read.py')


@pytest.mark.parametrize('script', scripts)
def test_reading(script):

parent_path = str(script.parent)
sys.path.insert(0,parent_path)
#runpy.run_path(str(script),run_name='__main__')

def test_path_windows():
file_name = os.path.join('.','test_examples','launch.py')
assert os.path.exists(file_name)
with open(script, 'rb') as file:
script_code = file.read()
rc = subprocess.call(script_code, shell=True)
assert rc==0

0 comments on commit ed4578e

Please sign in to comment.