Skip to content

Commit

Permalink
WIP: add typing
Browse files Browse the repository at this point in the history
  • Loading branch information
adfaure committed Dec 6, 2023
1 parent 59e209e commit f489fd6
Show file tree
Hide file tree
Showing 18 changed files with 457 additions and 353 deletions.
2 changes: 1 addition & 1 deletion oar/api/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from oar.lib.basequery import BaseQuery, BaseQueryCollection

# from oar.lib.models import (db, Job, Resource)
from oar.lib.utils import cached_property, row2dict
from oar.lib.utils import row2dict

# from flask import abort, current_app
# TODO: This whole file is to review since it has been adapted from flask and now use in fastapi.
Expand Down
2 changes: 1 addition & 1 deletion oar/cli/db/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def load_configuration_file(ctx, param, value):
return config


def default_database_url():
def default_database_url(): # pragma: no cover
try:
return config.get_sqlalchemy_uri()
except Exception:
Expand Down
22 changes: 18 additions & 4 deletions oar/kao/kamelot.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
import sys
from ast import List

from sqlalchemy.orm import scoped_session, sessionmaker
from sqlalchemy.orm import Session, scoped_session, sessionmaker

from oar.kao.karma import karma_jobs_sorting
from oar.kao.multifactor_priority import multifactor_jobs_sorting
from oar.kao.platform import Platform
from oar.kao.quotas import Quotas
from oar.kao.scheduling import schedule_id_jobs_ct, set_slots_with_prev_scheduled_jobs
from oar.kao.slot import MAX_TIME, SlotSet
from oar.lib.configuration import Configuration
from oar.lib.globals import get_logger, init_oar
from oar.lib.job_handling import NO_PLACEHOLDER, JobPseudo
from oar.lib.plugins import find_plugin_function
Expand Down Expand Up @@ -56,7 +58,13 @@ def jobs_sorting(session, config, queues, now, waiting_jids, waiting_jobs, plt):


def internal_schedule_cycle(
session, config, plt, now, all_slot_sets, job_security_time, queues
session: Session,
config: Configuration,
plt: Platform,
now: int,
all_slot_sets,
job_security_time: int,
queues,
):
resource_set = plt.resource_set(session, config)

Expand Down Expand Up @@ -104,7 +112,13 @@ def internal_schedule_cycle(
logger.info("no waiting jobs")


def schedule_cycle(session, config, plt, now, queues=["default"]):
def schedule_cycle(
session: Session,
config: Configuration,
plt: Platform,
now: int,
queues: List[str] = ["default"],
):
logger.info(
"Begin scheduling....now: {}, queue(s): {}".format(
now, " ".join([q for q in queues])
Expand Down Expand Up @@ -202,7 +216,7 @@ def schedule_cycle(session, config, plt, now, queues=["default"]):
#
# Main function
#
def main(session=None, config=None):
def main(session: Session = None, config: Configuration = None):
if not session:
config, engine, log = init_oar(config)

Expand Down
8 changes: 6 additions & 2 deletions oar/kao/scheduling.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
Scheduling functions used by :py:mod:`oar.kao.kamelot`.
"""
import copy
from typing import Any, Tuple

from procset import ProcSet

Expand All @@ -11,6 +12,7 @@
from oar.lib.globals import get_logger, init_oar
from oar.lib.hierarchy import find_resource_hierarchies_scattered
from oar.lib.job_handling import ALLOW, JobPseudo
from oar.lib.models import Job

# for quotas
from oar.lib.resource import ResourceSet
Expand Down Expand Up @@ -115,7 +117,7 @@ def find_resource_hierarchies_job(itvs_slots, hy_res_rqts, hy):


def find_first_suitable_contiguous_slots_quotas(
slots_set: SlotSet, job, res_rqt, hy, min_start_time: int
slots_set: SlotSet, job, res_rqt: Tuple[int, int, Any], hy, min_start_time: int
):
"""
Loop through time slices from a :py:class:`oar.kao.slot.SlotSet` that are long enough for the job's walltime.
Expand Down Expand Up @@ -327,7 +329,9 @@ def find_first_suitable_contiguous_slots(
)


def assign_resources_mld_job_split_slots(slots_set: SlotSet, job, hy, min_start_time):
def assign_resources_mld_job_split_slots(
slots_set: SlotSet, job: Job, hy, min_start_time
):
"""
According to a resources a :class:`SlotSet` find the time and the resources to launch a job.
This function supports the moldable jobs. In case of multiple moldable job corresponding to the request
Expand Down
55 changes: 38 additions & 17 deletions oar/lib/accounting.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
# -*- coding: utf-8 -*-
from typing import Any

from sqlalchemy import func, or_, text
from sqlalchemy.orm import Session

from oar.lib.models import (
Accounting,
Expand All @@ -14,7 +17,13 @@
# get_sum_accounting_window() -> see Karma.py


def get_accounting_summary(session, start_time, stop_time, user="", sql_property=""):
def get_accounting_summary(
session: Session,
start_time: int,
stop_time: int,
user: str = "",
sql_property: str = "",
) -> dict[Any, Any]:
"""Get an array of consumptions by users
p.arams: start date, ending date, optional user"""

Expand Down Expand Up @@ -52,8 +61,13 @@ def get_accounting_summary(session, start_time, stop_time, user="", sql_property


def get_accounting_summary_byproject(
session, start_time, stop_time, user="", limit="", offset=""
):
session: Session,
start_time: int,
stop_time: int,
user: str = "",
limit: str = "",
offset: str = "",
) -> dict[Any, Any]:
""" "Get an array of consumptions by project for a given user
params: start date, ending date, user"""

Expand Down Expand Up @@ -93,15 +107,15 @@ def get_accounting_summary_byproject(


def update_accounting(
session,
start_time,
stop_time,
window_size,
user,
project,
queue_name,
c_type,
nb_resources,
session: Session,
start_time: int,
stop_time: int,
window_size: int,
user: str,
project: str,
queue_name: str,
c_type: str,
nb_resources: int,
):
"""Insert accounting data in table accounting
# params : start date in second, stop date in second, window size, user, queue, type(ASKED or USED)
Expand Down Expand Up @@ -135,7 +149,14 @@ def update_accounting(


def add_accounting_row(
session, window_start, window_stop, user, project, queue_name, c_type, consumption
session: Session,
window_start: int,
window_stop: int,
user: str,
project: str,
queue_name: str,
c_type: str,
consumption: str,
):
# Insert or update one row according to consumption

Expand Down Expand Up @@ -213,7 +234,7 @@ def add_accounting_row(
)


def check_accounting_update(session, window_size):
def check_accounting_update(session: Session, window_size: int):
"""Check jobs that are not treated in accounting table
params : base, window size"""

Expand Down Expand Up @@ -290,23 +311,23 @@ def check_accounting_update(session, window_size):


def delete_all_from_accounting(
session,
session: Session,
):
"""Empty the table accounting and update the jobs table."""
session.query(Accounting).delete(synchronize_session=False)
session.query(Job).update({Job.accounted: "NO"}, synchronize_session=False)
session.commit()


def delete_accounting_windows_before(session, window_stop):
def delete_accounting_windows_before(session: Session, window_stop: int):
"""Remove windows from accounting."""
session.query(Accounting).filter(Accounting.window_stop <= window_stop).delete(
synchronize_session=False
)
session.commit()


def get_last_project_karma(session, user, project, date):
def get_last_project_karma(session: Session, user: str, project: str, date: int):
"""Get the last project Karma of user at a given date
params: user, project, date"""

Expand Down
18 changes: 13 additions & 5 deletions oar/lib/configuration.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# -*- coding: utf-8 -*-
import sys
from io import open
from typing import Any

from .exceptions import InvalidConfiguration
from .utils import reraise, try_convert_decimal
Expand Down Expand Up @@ -133,7 +134,12 @@ def load_default_config(self, silent=True):
self.load_file(self.DEFAULT_CONFIG_FILE, silent=silent)

def load_file(
self, filename, comment_char="#", strip_quotes=True, silent=False, clear=False
self,
filename: str,
comment_char: str = "#",
strip_quotes: bool = True,
silent: bool = False,
clear: bool = False,
):
"""Updates the values in the config from a config file.
:param filename: the filename of the config. This can either be an
Expand Down Expand Up @@ -175,7 +181,7 @@ def load_file(

return True

def get_sqlalchemy_uri(self, read_only=False):
def get_sqlalchemy_uri(self, read_only: bool = False): # pragma: no cover
if read_only:
login = "base_login_ro"
passwd = "base_passwd_ro"
Expand All @@ -199,12 +205,14 @@ def get_sqlalchemy_uri(self, read_only=False):
keys = tuple(("DB_%s" % i.upper() for i in e.args))
raise InvalidConfiguration("Cannot find %s" % keys)

def setdefault_config(self, default_config):
def setdefault_config(self, default_config: dict[str, Any]):
# import pdb; pdb.set_trace()
for k, v in default_config.items():
self.setdefault(k, v)

def get_namespace(self, namespace, lowercase=True, trim_namespace=True):
def get_namespace(
self, namespace: str, lowercase: bool = True, trim_namespace: bool = True
):
"""Returns a dictionary containing a subset of configuration options
that match the specified namespace/prefix. Example usage::
Expand Down Expand Up @@ -243,5 +251,5 @@ def get_namespace(self, namespace, lowercase=True, trim_namespace=True):
rv[key] = v
return rv

def __str__(self):
def __str__(self) -> str:
return f"{dict(self)}"
27 changes: 18 additions & 9 deletions oar/lib/event.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,17 @@
# coding: utf-8

from typing import List, Optional

from sqlalchemy import desc, func
from sqlalchemy.orm import Session

from oar.lib import tools
from oar.lib.models import EventLog, EventLogHostname


def add_new_event(session, ev_type, job_id, description, to_check="YES"):
def add_new_event(
session: Session, ev_type: str, job_id: int, description: str, to_check: str = "YES"
):
"""Add a new entry in event_log table"""
event_data = EventLog(
type=ev_type,
Expand All @@ -19,7 +24,9 @@ def add_new_event(session, ev_type, job_id, description, to_check="YES"):
session.commit()


def add_new_event_with_host(session, ev_type, job_id, description, hostnames):
def add_new_event_with_host(
session: Session, ev_type: str, job_id: int, description: str, hostnames: List[str]
):
ins = EventLog.__table__.insert().values(
{
"type": ev_type,
Expand All @@ -40,7 +47,7 @@ def add_new_event_with_host(session, ev_type, job_id, description, hostnames):
session.commit()


def is_an_event_exists(session, job_id, event):
def is_an_event_exists(session: Session, job_id: int, event: str):
res = (
session.query(func.count(EventLog.id))
.filter(EventLog.job_id == job_id)
Expand All @@ -50,7 +57,7 @@ def is_an_event_exists(session, job_id, event):
return res


def get_job_events(session, job_id):
def get_job_events(session: Session, job_id: int):
"""Get events for the specified job"""
result = (
session.query(EventLog)
Expand All @@ -61,7 +68,7 @@ def get_job_events(session, job_id):
return result


def get_jobs_events(session, job_ids):
def get_jobs_events(session: Session, job_ids: List[int]):
"""Get events for the specified jobs"""
result = (
session.query(EventLog)
Expand All @@ -72,7 +79,7 @@ def get_jobs_events(session, job_ids):
return result


def get_to_check_events(session):
def get_to_check_events(session: Session):
""" "Get all events with toCheck field on YES"""
result = (
session.query(EventLog)
Expand All @@ -83,7 +90,7 @@ def get_to_check_events(session):
return result


def check_event(session, event_type, job_id):
def check_event(session: Session, event_type: str, job_id: int):
"""Turn the field toCheck into NO"""
session.query(EventLog).filter(EventLog.job_id == job_id).filter(
EventLog.type == event_type
Expand All @@ -93,7 +100,7 @@ def check_event(session, event_type, job_id):
session.commit()


def get_hostname_event(session, event_id):
def get_hostname_event(session: Session, event_id: int):
"""Get hostnames corresponding to an event Id"""
res = (
session.query(EventLogHostname.hostname)
Expand All @@ -103,7 +110,9 @@ def get_hostname_event(session, event_id):
return [h[0] for h in res]


def get_events_for_hostname_from(session, host, date=None):
def get_events_for_hostname_from(
session: Session, host: str, date: Optional[int] = None
) -> List[EventLog]:
"""Get events for the hostname given as parameter
If date is given, returns events since that date, else return the 30 last events.
"""
Expand Down
Loading

0 comments on commit f489fd6

Please sign in to comment.