Skip to content

Commit

Permalink
Add option to threshold place bins by count
Browse files Browse the repository at this point in the history
  • Loading branch information
edeno committed Sep 29, 2022
1 parent 3bb6ef3 commit ef09bb5
Showing 1 changed file with 19 additions and 6 deletions.
25 changes: 19 additions & 6 deletions replay_trajectory_classification/environments.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,13 @@
import numpy as np
import pandas as pd
from numba import njit
from replay_trajectory_classification.core import atleast_2d, get_centers
from scipy import ndimage
from scipy.interpolate import interp1d
from sklearn.neighbors import NearestNeighbors
from track_linearization import plot_graph_as_1D

from replay_trajectory_classification.core import atleast_2d, get_centers


@dataclass
class Environment:
Expand Down Expand Up @@ -42,6 +43,9 @@ class Environment:
Fill holes when inferring the track
dilate : bool, optional
Inflate the available track area with binary dilation
bin_count_threshold : int, optional
Greater than this number of samples should be in the bin for it to
be considered on the track.
"""

Expand All @@ -55,6 +59,7 @@ class Environment:
infer_track_interior: bool = True
fill_holes: bool = False
dilate: bool = False
bin_count_threshold: int = 0

def __eq__(self, other):
return self.environment_name == other
Expand All @@ -70,14 +75,17 @@ def fit_place_grid(self, position=None, infer_track_interior=True):
position,
self.place_bin_size,
self.position_range,
self.infer_track_interior,
)

self.infer_track_interior = infer_track_interior

if self.is_track_interior is None and self.infer_track_interior:
self.is_track_interior_ = get_track_interior(
position, self.edges_, self.fill_holes, self.dilate
position,
self.edges_,
self.fill_holes,
self.dilate,
self.bin_count_threshold,
)
elif self.is_track_interior is None and not self.infer_track_interior:
self.is_track_interior_ = np.ones(self.centers_shape_, dtype=np.bool)
Expand Down Expand Up @@ -161,7 +169,7 @@ def get_n_bins(position, bin_size=2.5, position_range=None):
return np.ceil(extent / bin_size).astype(np.int32)


def get_grid(position, bin_size=2.5, position_range=None, infer_track_interior=True):
def get_grid(position, bin_size=2.5, position_range=None):
"""Gets the spatial grid of bins.
Parameters
Expand Down Expand Up @@ -207,7 +215,9 @@ def get_grid(position, bin_size=2.5, position_range=None, infer_track_interior=T
return edges, place_bin_edges, place_bin_centers, centers_shape


def get_track_interior(position, bins, fill_holes=False, dilate=False):
def get_track_interior(
position, bins, fill_holes=False, dilate=False, bin_count_threshold=0
):
"""Infers the interior bins of the track given positions.
Parameters
Expand All @@ -223,6 +233,9 @@ def get_track_interior(position, bins, fill_holes=False, dilate=False):
Fill any holes in the extracted track interior bins
dialate : bool, optional
Inflate the extracted track interior bins
bin_count_threshold : int, optional
Greater than this number of samples should be in the bin for it to
be considered on the track.
Returns
-------
Expand All @@ -231,7 +244,7 @@ def get_track_interior(position, bins, fill_holes=False, dilate=False):
"""
bin_counts, _ = np.histogramdd(position, bins=bins)
is_track_interior = (bin_counts > 0).astype(int)
is_track_interior = (bin_counts > bin_count_threshold).astype(int)
n_position_dims = position.shape[1]
if n_position_dims > 1:
structure = ndimage.generate_binary_structure(n_position_dims, 1)
Expand Down

0 comments on commit ef09bb5

Please sign in to comment.