diff --git a/replay_trajectory_classification/environments.py b/replay_trajectory_classification/environments.py index ecbdbdc..9cdeb76 100644 --- a/replay_trajectory_classification/environments.py +++ b/replay_trajectory_classification/environments.py @@ -6,12 +6,13 @@ import numpy as np import pandas as pd from numba import njit -from replay_trajectory_classification.core import atleast_2d, get_centers from scipy import ndimage from scipy.interpolate import interp1d from sklearn.neighbors import NearestNeighbors from track_linearization import plot_graph_as_1D +from replay_trajectory_classification.core import atleast_2d, get_centers + @dataclass class Environment: @@ -42,6 +43,9 @@ class Environment: Fill holes when inferring the track dilate : bool, optional Inflate the available track area with binary dilation + bin_count_threshold : int, optional + Greater than this number of samples should be in the bin for it to + be considered on the track. """ @@ -55,6 +59,7 @@ class Environment: infer_track_interior: bool = True fill_holes: bool = False dilate: bool = False + bin_count_threshold: int = 0 def __eq__(self, other): return self.environment_name == other @@ -70,14 +75,17 @@ def fit_place_grid(self, position=None, infer_track_interior=True): position, self.place_bin_size, self.position_range, - self.infer_track_interior, ) self.infer_track_interior = infer_track_interior if self.is_track_interior is None and self.infer_track_interior: self.is_track_interior_ = get_track_interior( - position, self.edges_, self.fill_holes, self.dilate + position, + self.edges_, + self.fill_holes, + self.dilate, + self.bin_count_threshold, ) elif self.is_track_interior is None and not self.infer_track_interior: self.is_track_interior_ = np.ones(self.centers_shape_, dtype=np.bool) @@ -161,7 +169,7 @@ def get_n_bins(position, bin_size=2.5, position_range=None): return np.ceil(extent / bin_size).astype(np.int32) -def get_grid(position, bin_size=2.5, position_range=None, infer_track_interior=True): +def get_grid(position, bin_size=2.5, position_range=None): """Gets the spatial grid of bins. Parameters @@ -207,7 +215,9 @@ def get_grid(position, bin_size=2.5, position_range=None, infer_track_interior=T return edges, place_bin_edges, place_bin_centers, centers_shape -def get_track_interior(position, bins, fill_holes=False, dilate=False): +def get_track_interior( + position, bins, fill_holes=False, dilate=False, bin_count_threshold=0 +): """Infers the interior bins of the track given positions. Parameters @@ -223,6 +233,9 @@ def get_track_interior(position, bins, fill_holes=False, dilate=False): Fill any holes in the extracted track interior bins dialate : bool, optional Inflate the extracted track interior bins + bin_count_threshold : int, optional + Greater than this number of samples should be in the bin for it to + be considered on the track. Returns ------- @@ -231,7 +244,7 @@ def get_track_interior(position, bins, fill_holes=False, dilate=False): """ bin_counts, _ = np.histogramdd(position, bins=bins) - is_track_interior = (bin_counts > 0).astype(int) + is_track_interior = (bin_counts > bin_count_threshold).astype(int) n_position_dims = position.shape[1] if n_position_dims > 1: structure = ndimage.generate_binary_structure(n_position_dims, 1)