Skip to content

Commit

Permalink
Add AMSMonitor interface and unify both RMQ API (#32)
Browse files Browse the repository at this point in the history
Signed-off-by: Loic Pottier <[email protected]>
  • Loading branch information
lpottier committed Mar 23, 2024
1 parent 3e4b410 commit aadc4a7
Show file tree
Hide file tree
Showing 7 changed files with 1,002 additions and 621 deletions.
11 changes: 8 additions & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -77,13 +77,18 @@ exclude = [
# E226: Missing white space around arithmetic operator

[tool.ruff]
ignore = ["E501", "W503", "E226", "BLK100", "E203"]
lint.ignore = ["E501", "E226", "E203"]
show-fixes = true

exclude = [
".git",
"__pycache__",
"*.egg-info",
"build"
]
# change the default line length number or characters.
line-length = 120
lint.select = ['E', 'F', 'W', 'A', 'PLC', 'PLE', 'PLW', 'I', 'N', 'Q']

[tool.yapf]
ignore = ["E501", "W503", "E226", "BLK100", "E203"]
column_limit = 120

308 changes: 308 additions & 0 deletions src/AMSWorkflow/ams/monitor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,308 @@
#!/usr/bin/env python3
# Copyright 2021-2023 Lawrence Livermore National Security, LLC and other
# AMSLib Project Developers
#
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

import datetime
import json
import logging
import multiprocessing
import threading
import time
from typing import Callable, List, Union


class AMSMonitor:
"""
AMSMonitor can be used to decorate class methods and will
record automatically the duration of the tasks in a hashmap
with timestamp. The decorator will also automatically
record the values of all attributes of the class.
class ExampleTask1(Task):
def __init__(self):
self.total_bytes = 0
self.total_bytes2 = 0
# Example: we do not want to record total_bytes
# but just total_bytes2
#
# @AMSMonitor() would record all attributes
# (total_bytes and total_bytes2)
#
# @AMSMonitor(accumulate=True) would record all
# attributes and accumulate their values
# (sum up total_bytes instead of recording a
# different total_bytes per invocation)
@AMSMonitor(record=["total_bytes2"])
def __call__(self):
i = 0
with AMSMonitor(object=self, tag="while_loop"):
while (i<=3):
self.total_bytes += 10
self.total_bytes2 = 1
i += 1
Each time `ExampleTask1()` is being called, AMSMonitor will
populate `_stats` as follows (showed with two calls here):
{'ExampleTask1':
{'while_loop':
{'02/29/2024-19:27:53':
{
'total_bytes2': 1,
'amsmonitor_duration': 4.004607439041138
}
}
},
{'__call__':
{'02/29/2024-19:29:24':
{
'total_bytes2': 1,
'amsmonitor_duration': 4.10461138
}
}
}
}
Attributes:
record: attributes to record, if empty ([]) all attributes
will be recorded.
accumulate: If True, AMSMonitor will accumulate recorded
data instead of recording a new timestamp for
any subsequent call of AMSMonitor on the same method.
For example, if we record a method from the class `A`
with some attributes, `attr1`, `attr2`. The first
time `AMSMonitor` is being called on `A`, a new entry
for `A` is created and values for `attr1 and `attr2`
are populated, if `ams_accumulate=True`, for each call
to `AMSMonitor()` the value of `attr1` and `attr2` will
be added to their previous values.
object: Mandatory if using `with` statement, `object` is
the main object should be provided (i.e., self).
tag: Mandatory if using `with` statement, `tag` is the
name that will appear in the record for that
context manager statement.
"""

_manager = multiprocessing.Manager()
_stats = _manager.dict()
_ts_format = "%m/%d/%Y-%H:%M:%S"
_reserved_keys = ["amsmonitor_duration"]
_lock = threading.Lock()
_count = 0

def __init__(self, record=None, accumulate=False, obj=None, tag=None, logger: logging.Logger = None, **kwargs):
self.accumulate = accumulate
self.kwargs = kwargs
self.record = record
if not isinstance(record, list):
self.record = None
# We make sure we do not overwrite protected attributes managed by AMSMonitor
if self.record:
self.record = self._remove_reserved_keys(self.record)
self.object = obj
self.start_time = 0
self.internal_ts = 0
self.tag = tag
AMSMonitor._count += 1
self.logger = logger if logger else logging.getLogger(__name__)

def __str__(self) -> str:
return AMSMonitor.info() if AMSMonitor._stats != {} else "{}"

def __repr__(self) -> str:
return self.__str__()

def lock(self):
AMSMonitor._lock.acquire()

def unlock(self):
AMSMonitor._lock.release()

def __enter__(self):
if not self.object or not self.tag:
self.logger.error('missing parameter "object" or "tag" when using context manager syntax')
return
self.start_monitor()
return self

def __exit__(self, exc_type, exc_val, exc_tb):
self.stop_monitor()

@classmethod
def info(cls) -> str:
s = ""
if cls._stats == {}:
return "{}"
for k, v in cls._stats.items():
s += f"{k}\n"
for i, j in v.items():
s += f" {i}\n"
for p, z in j.items():
s += f" {p:<10}\n"
for r, q in z.items():
s += f" {r:<30} => {q}\n"
return s.rstrip()

@classmethod
@property
def stats(cls):
return AMSMonitor._stats

@classmethod
@property
def format_ts(cls):
return AMSMonitor._ts_format

@classmethod
def convert_ts(cls, ts: str) -> datetime.datetime:
return datetime.strptime(ts, cls.format_ts)

@classmethod
def json(cls, json_output: str):
"""
Write the collected metrics to a JSON file.
"""
with open(json_output, "w") as fp:
# we have to use .copy() as DictProxy is not serializable
json.dump(cls._stats.copy(), fp, indent=4)
# To avoid partial line at the end of the file
fp.write("\n")

def start_monitor(self, *args, **kwargs):
self.start_time = time.time()
self.internal_ts = datetime.datetime.now().strftime(self._ts_format)

def stop_monitor(self):
end = time.time()
class_name = self.object.__class__.__name__
func_name = self.tag

new_data = vars(self.object)
# Filter out multiprocessing which cannot be stored without causing RuntimeError
new_data = self._filter_out_object(new_data)
# We remove stuff we do not want (attribute of the calling class captured by vars())
if self.record != []:
new_data = self._filter(new_data, self.record)
# We inject some data we want to record
new_data["amsmonitor_duration"] = end - self.start_time
self._update_db(new_data, class_name, func_name, self.internal_ts)

# We reinitialize some variables
self.start_time = 0
self.internal_ts = 0

def __call__(self, func: Callable):
"""
The main decorator.
"""

def wrapper(*args, **kwargs):
ts = datetime.datetime.now().strftime(self._ts_format)
start = time.time()
value = func(*args, **kwargs)
end = time.time()
if not hasattr(args[0], "__dict__"):
return value
class_name = args[0].__class__.__name__
func_name = self.tag if self.tag else func.__name__
new_data = vars(args[0])

# Filter out multiprocessing which cannot be stored without causing RuntimeError
new_data = self._filter_out_object(new_data)

# We remove stuff we do not want (attribute of the calling class captured by vars())
new_data = self._filter(new_data, self.record)
new_data["amsmonitor_duration"] = end - start
self._update_db(new_data, class_name, func_name, ts)
return value

return wrapper

def _update_db(self, new_data: dict, class_name: str, func_name: str, ts: str):
"""
This function update the hashmap containing all the records.
"""
self.lock()
if class_name not in AMSMonitor._stats:
AMSMonitor._stats[class_name] = {}

if func_name not in AMSMonitor._stats[class_name]:
temp = AMSMonitor._stats[class_name]
temp.update({func_name: {}})
AMSMonitor._stats[class_name] = temp
temp = AMSMonitor._stats[class_name]

# We accumulate for each class with a different name
if self.accumulate and temp[func_name] != {}:
ts = self._get_ts(class_name, func_name)
temp[func_name][ts] = self._acc(temp[func_name][ts], new_data)
else:
temp[func_name][ts] = {}
for k, v in new_data.items():
temp[func_name][ts][k] = v
# This trick is needed because AMSMonitor._stats is a manager.dict (not shared memory)
AMSMonitor._stats[class_name] = temp
self.unlock()

def _remove_reserved_keys(self, d: Union[dict, List]) -> dict:
for key in self._reserved_keys:
if key in d:
self.logger.warning(f"attribute {key} is protected and will be ignored ({d})")
if isinstance(d, list):
idx = d.index(key)
d.pop(idx)
elif isinstance(d, dict):
del d[key]
return d

def _acc(self, original: dict, new_data: dict) -> dict:
"""
Sum up element-wise two hashmaps (ignore fields that are not common)
"""
for k, v in new_data.items():
# We accumalate variable internally managed by AMSMonitor (duration etc)
if k in AMSMonitor._reserved_keys:
original[k] = float(original[k]) + float(v)
else:
original[k] = v
return original

def _filter_out_object(self, data: dict) -> dict:
"""
Filter out a hashmap to remove objects which can cause errors
"""

def is_serializable(x):
try:
json.dumps(x)
return True
except (TypeError, OverflowError):
return False

new_dict = {k: v for k, v in data.items() if is_serializable(v)}

return new_dict

def _filter(self, data: dict, keys: List[str]) -> dict:
"""
Filter out a hashmap to contains only keys listed by list of keys
"""
if not self.record:
return data
return {k: v for k, v in data.items() if k in keys}

def _get_ts(self, class_name: str, tag: str) -> str:
"""
Return initial timestamp for a given monitored function.
"""
ts = datetime.datetime.now().strftime(self._ts_format)
if class_name not in AMSMonitor._stats or tag not in AMSMonitor._stats[class_name]:
return ts

init_ts = list(AMSMonitor._stats[class_name][tag].keys())
if len(init_ts) > 1:
self.logger.warning(f"more than 1 timestamp detected for {class_name} / {tag}")
return ts if init_ts == [] else init_ts[0]
Loading

0 comments on commit aadc4a7

Please sign in to comment.