Skip to content

Commit

Permalink
add helpful typings
Browse files Browse the repository at this point in the history
  • Loading branch information
LuiggiTenorioK committed Dec 27, 2024
1 parent 0ad5a34 commit ed9d53a
Show file tree
Hide file tree
Showing 5 changed files with 76 additions and 18 deletions.
40 changes: 34 additions & 6 deletions autosubmit/autosubmit.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,10 @@
from configparser import ConfigParser
from distutils.util import strtobool
from pathlib import Path
from autosubmit.job.job import Job
from autosubmit.platforms.submitter import Submitter
from ruamel.yaml import YAML
from typing import Dict, Set, Tuple, Union
from typing import Dict, Optional, Set, Tuple, Union

from autosubmit.database.db_common import update_experiment_descrip_version
from autosubmit.helpers.parameters import PARAMETERS
Expand Down Expand Up @@ -1800,8 +1802,14 @@ def job_notify(as_conf,expid,job,job_prev_status,job_changes_tracker):
Status.VALUE_TO_KEY[job.status],
as_conf.experiment_data["MAIL"]["TO"])
return job_changes_tracker

@staticmethod
def check_wrappers(as_conf, job_list, platforms_to_test, expid):
def check_wrappers(
as_conf: AutosubmitConfig,
job_list: JobList,
platforms_to_test: Set[Platform],
expid: str,
) -> Tuple[Dict[str, List[List[Job]]], Dict[str, Tuple[Status, Status]]]:
"""
Check wrappers and inner jobs status also order the non-wrapped jobs to be submitted by active platforms
:param as_conf: a AutosubmitConfig object
Expand All @@ -1810,8 +1818,8 @@ def check_wrappers(as_conf, job_list, platforms_to_test, expid):
:param expid: a string with the experiment id
:return: non-wrapped jobs to check and a dictionary with the changes in the jobs status
"""
jobs_to_check = dict()
job_changes_tracker = dict()
jobs_to_check: Dict[str, List[List[Job]]] = dict()
job_changes_tracker: Dict[str, Tuple[Status, Status]] = dict()
for platform in platforms_to_test:
queuing_jobs = job_list.get_in_queue_grouped_id(platform)
Log.debug('Checking jobs for platform={0}'.format(platform.name))
Expand Down Expand Up @@ -1891,6 +1899,7 @@ def check_wrapper_stored_status(as_conf,job_list):
None, jobs[0].platform, as_conf, jobs[0].hold)
job_list.job_package_map[jobs[0].id] = wrapper_job
return job_list

@staticmethod
def get_historical_database(expid, job_list, as_conf):
"""
Expand Down Expand Up @@ -1941,9 +1950,27 @@ def process_historical_data_iteration(job_list,job_changes_tracker, expid):
exp_history.process_job_list_changes_to_experiment_totals(job_list.get_job_list())
Autosubmit.database_backup(expid)
return exp_history

@staticmethod
def prepare_run(expid, notransitive=False, start_time=None, start_after=None,
run_only_members=None, recover = False, check_scripts= False, submitter=None):
def prepare_run(
expid: str,
notransitive: bool = False,
start_time: str = None,
start_after: str = None,
run_only_members: str = None,
recover: bool = False,
check_scripts: bool = False,
submitter=None
) -> Tuple[
JobList,
Submitter,
Optional[ExperimentHistory],
Optional[str],
AutosubmitConfig,
Set[Platform],
JobPackagePersistence,
bool,
]:
"""
Prepare the run of the experiment.
:param expid: a string with the experiment id.
Expand Down Expand Up @@ -2103,6 +2130,7 @@ def prepare_run(expid, notransitive=False, start_time=None, start_after=None,
return job_list, submitter , exp_history, host , as_conf, platforms_to_test, packages_persistence, False
else:
return job_list, submitter, None, None, as_conf, platforms_to_test, packages_persistence, True

@staticmethod
def get_iteration_info(as_conf,job_list):
"""
Expand Down
27 changes: 21 additions & 6 deletions autosubmit/job/job.py
Original file line number Diff line number Diff line change
Expand Up @@ -1240,7 +1240,7 @@ def write_stats(self, last_retrial: int) -> None:
except BaseException as e:
Log.printlog("Trace {0} \n Failed to write the {1} e=6001".format(str(e), self.name))

def retrieve_logfiles(self, platform: Any, raise_error: bool = False) -> Dict[str, int]:
def retrieve_logfiles(self, platform: Platform, raise_error: bool = False) -> Dict[str, int]:
"""
Retrieves log files from remote host.
Expand Down Expand Up @@ -1316,7 +1316,7 @@ def is_over_wallclock(self, start_time, wallclock):
return True
return False

def update_status(self, as_conf, failed_file=False):
def update_status(self, as_conf: AutosubmitConfig, failed_file: bool = False) -> Status:
"""
Updates job status, checking COMPLETED file if needed
Expand Down Expand Up @@ -1374,6 +1374,9 @@ def update_status(self, as_conf, failed_file=False):
self.retrieve_logfiles(self.platform)
else:
self.platform.add_job_to_log_recover(self)

# TODO Read and store metrics here

return self.status

@staticmethod
Expand Down Expand Up @@ -2604,7 +2607,19 @@ class WrapperJob(Job):
:type as_config: AutosubmitConfig object \n
"""

def __init__(self, name, job_id, status, priority, job_list, total_wallclock, num_processors, platform, as_config, hold):
def __init__(
self,
name: str,
job_id: int,
status: str,
priority: int,
job_list: List[Job],
total_wallclock: str,
num_processors: int,
platform: Platform,
as_config: AutosubmitConfig,
hold: bool,
):
super(WrapperJob, self).__init__(name, job_id, status, priority)
self.failed = False
self.job_list = job_list
Expand Down Expand Up @@ -2686,7 +2701,7 @@ def check_status(self, status):
if not still_running:
self.cancel_failed_wrapper_job()

def check_inner_jobs_completed(self, jobs):
def check_inner_jobs_completed(self, jobs: List[Job]):
not_completed_jobs = [
job for job in jobs if job.status != Status.COMPLETED]
not_completed_job_names = [job.name for job in not_completed_jobs]
Expand Down Expand Up @@ -2772,7 +2787,7 @@ def _check_inner_job_wallclock(self, job):
return False

def _check_running_jobs(self):
not_finished_jobs_dict = OrderedDict()
not_finished_jobs_dict: OrderedDict[str, Job] = OrderedDict()
self.inner_jobs_running = list()
not_finished_jobs = [job for job in self.job_list if job.status not in [
Status.COMPLETED, Status.FAILED]]
Expand Down Expand Up @@ -2857,7 +2872,7 @@ def _check_running_jobs(self):
if retries == 0 or over_wallclock:
self.status = Status.FAILED

def _check_finished_job(self, job, failed_file=False):
def _check_finished_job(self, job: Job, failed_file: bool = False):
job.new_status = Status.FAILED
if not failed_file:
wait = 2
Expand Down
10 changes: 8 additions & 2 deletions autosubmit/platforms/paramiko_platform.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import sys
import socket
import os
from typing import List, TYPE_CHECKING
import paramiko
import datetime
import select
Expand All @@ -23,6 +24,11 @@
from paramiko.agent import Agent
import time

if TYPE_CHECKING:
# Avoid circular imports
from autosubmit.job.job import Job


def threaded(fn):
def wrapper(*args, **kwargs):
thread = Thread(target=fn, args=args, kwargs=kwargs, name=f"{args[0].name}_X11")
Expand Down Expand Up @@ -693,7 +699,7 @@ def _check_jobid_in_queue(self, ssh_output, job_list_cmd):
if job not in ssh_output:
return False
return True
def parse_joblist(self, job_list):
def parse_joblist(self, job_list: List[List['Job']]):
"""
Convert a list of job_list to job_list_cmd
:param job_list: list of jobs
Expand All @@ -714,7 +720,7 @@ def parse_joblist(self, job_list):
job_list_cmd=job_list_cmd[:-1]

return job_list_cmd
def check_Alljobs(self, job_list, as_conf, retries=5):
def check_Alljobs(self, job_list: List[List['Job']], as_conf, retries=5):
"""
Checks jobs running status
:param job_list: list of jobs
Expand Down
9 changes: 7 additions & 2 deletions autosubmit/platforms/pjmplatform.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
import os
from contextlib import suppress
from time import sleep
from typing import List, Union
from typing import List, Union, TYPE_CHECKING

from autosubmit.job.job_common import Status
from autosubmit.platforms.paramiko_platform import ParamikoPlatform
Expand All @@ -29,6 +29,11 @@
from log.log import AutosubmitCritical, AutosubmitError, Log

import textwrap

if TYPE_CHECKING:
# Avoid circular imports
from autosubmit.job.job import Job

class PJMPlatform(ParamikoPlatform):
"""
Class to manage jobs to host using PJM scheduler
Expand Down Expand Up @@ -306,7 +311,7 @@ def queuing_reason_cancel(self, reason):
return False
except Exception as e:
return False
def get_queue_status(self, in_queue_jobs, list_queue_jobid, as_conf):
def get_queue_status(self, in_queue_jobs: List['Job'], list_queue_jobid, as_conf):
if not in_queue_jobs:
return
cmd = self.get_queue_status_cmd(list_queue_jobid)
Expand Down
8 changes: 6 additions & 2 deletions autosubmit/platforms/slurmplatform.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from time import mktime
from time import sleep
from time import time
from typing import List, Union
from typing import List, Union, TYPE_CHECKING
from xml.dom.minidom import parseString

from autosubmit.job.job_common import Status, parse_output_number
Expand All @@ -32,6 +32,10 @@
from autosubmit.platforms.wrappers.wrapper_factory import SlurmWrapperFactory
from log.log import AutosubmitCritical, AutosubmitError, Log

if TYPE_CHECKING:
# Avoid circular imports
from autosubmit.job.job import Job

class SlurmPlatform(ParamikoPlatform):
"""
Class to manage jobs to host using SLURM scheduler
Expand Down Expand Up @@ -622,7 +626,7 @@ def parse_queue_reason(self, output, job_id):
return ''.join(reason)
return reason

def get_queue_status(self, in_queue_jobs, list_queue_jobid, as_conf):
def get_queue_status(self, in_queue_jobs: List['Job'], list_queue_jobid, as_conf):
if not in_queue_jobs:
return
cmd = self.get_queue_status_cmd(list_queue_jobid)
Expand Down

0 comments on commit ed9d53a

Please sign in to comment.