Skip to content

Commit

Permalink
feat: write result directory to s3 (#278)
Browse files Browse the repository at this point in the history
Co-authored-by: anna-grim <[email protected]>
  • Loading branch information
anna-grim and anna-grim authored Nov 1, 2024
1 parent 86aa063 commit 704ca15
Show file tree
Hide file tree
Showing 5 changed files with 64 additions and 5 deletions.
2 changes: 1 addition & 1 deletion src/deep_neurographs/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ class GraphConfig:
min_size: float = 30.0
node_spacing: int = 1
proposals_per_leaf: int = 2
prune_depth: float = 20.0
prune_depth: float = 16.0
remove_doubles_bool: bool = False
search_radius: float = 20.0
smooth_bool: bool = True
Expand Down
40 changes: 40 additions & 0 deletions src/deep_neurographs/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,8 @@ def __init__(
is_multimodal=False,
label_path=None,
log_runtimes=True,
save_to_s3_bool=False,
s3_dict=None,
):
"""
Initializes an object that executes the full GraphTrace inference
Expand Down Expand Up @@ -101,6 +103,10 @@ def __init__(
default is None.
log_runtimes : bool, optional
Indication of whether to log runtimes. The default is True.
save_to_s3_bool : bool, optional
Indication of whether to save result to s3. The default is False.
s3_dict : dict, optional
...
Returns
-------
Expand All @@ -113,6 +119,8 @@ def __init__(
self.model_path = model_path
self.sample_id = sample_id
self.segmentation_id = segmentation_id
self.save_to_s3_bool = save_to_s3_bool
self.s3_dict = s3_dict

# Extract config settings
self.graph_config = config.graph_config
Expand Down Expand Up @@ -330,13 +338,45 @@ def save_results(self, round_id=None):
None
"""
# Save result locally
suffix = f"-{round_id}" if round_id else ""
filename = f"corrected-processed-swcs{suffix}.zip"
path = os.path.join(self.output_dir, filename)
self.graph.to_zipped_swcs(path)
self.save_connections(round_id=round_id)
self.write_metadata()

# Save result on s3
if self.save_to_s3_bool:
self.save_to_s3()

def save_to_s3(self):
"""
Saves a corrected swc files to s3 along with metadata and runtimes.
Parameters
----------
None
Returns
-------
None
"""
# Initializations
bucket_name = self.s3_dict["bucket_name"]
date = datetime.today().strftime("%Y%m%d")
subdir_name = f"/corrected_{self.sample_id}_{self.segmentation_id}_{date}"
prefix = self.s3_dict["prefix"] + subdir_name

# Move result files
for filename in os.listdir(self.output_dir):
if filename != "processed-swcs.zip":
local_path = os.path.join(self.output_dir, filename)
s3_path = os.path.join(prefix, filename)
util.write_to_s3(local_path, bucket_name, s3_path)
print("Results written to S3 prefix -->", prefix)

# --- io ---
def save_connections(self, round_id=None):
"""
Expand Down
3 changes: 2 additions & 1 deletion src/deep_neurographs/neurograph.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@

from deep_neurographs import generate_proposals, geometry
from deep_neurographs.groundtruth_generation import init_targets
from deep_neurographs.utils import graph_util as gutil, img_util, util
from deep_neurographs.utils import graph_util as gutil
from deep_neurographs.utils import img_util, util


class NeuroGraph(nx.Graph):
Expand Down
2 changes: 1 addition & 1 deletion src/deep_neurographs/utils/graph_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
MIN_SIZE = 30
NODE_SPACING = 1
SMOOTH_BOOL = True
PRUNE_DEPTH = 20
PRUNE_DEPTH = 16


class GraphLoader:
Expand Down
22 changes: 20 additions & 2 deletions src/deep_neurographs/utils/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -560,6 +560,7 @@ def count_fragments(fragments_pointer, min_size=0):
print("# Connected Components:", nx.number_connected_components(graph))
print("# Nodes:", graph.number_of_nodes())
print("# Edges:", graph.number_of_edges())
del graph


def time_writer(t, unit="seconds"):
Expand Down Expand Up @@ -649,7 +650,7 @@ def reformat_number(number):

def get_memory_usage():
"""
Gets the current memory usage.
Gets the current memory usage in gigabytes.
Parameters
----------
Expand All @@ -658,12 +659,29 @@ def get_memory_usage():
Returns
-------
float
Current memory usage.
Current memory usage in gigabytes.
"""
return psutil.virtual_memory().used / 1e9


def get_memory_available():
"""
Gets the available memory in gigabytes.
Parameters
----------
None
Returns
-------
float
Available memory usage in gigabytes.
"""
return psutil.virtual_memory().available / 1e9


def spaced_idxs(container, k):
"""
Generates an array of indices based on a specified step size and ensures
Expand Down

0 comments on commit 704ca15

Please sign in to comment.