Skip to content

Commit

Permalink
Handle NaN columns in the marks
Browse files Browse the repository at this point in the history
  • Loading branch information
edeno committed Jun 30, 2021
1 parent 51d9a1a commit d2bfac4
Showing 1 changed file with 10 additions and 3 deletions.
13 changes: 10 additions & 3 deletions replay_trajectory_classification/multiunit_likelihood_integer.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,8 +194,11 @@ def fit_multiunit_likelihood_integer(position,
estimate_intensity(marginal_density, occupancy, mean_rates[-1])
+ np.spacing(1))

encoding_marks.append(
multiunit[is_spike & not_nan_position].astype(np.int64))
multiunit = multiunit[is_spike & not_nan_position]
not_nan_marks = np.all(~np.isnan(multiunit), axis=0)
multiunit = multiunit[:, not_nan_marks]

encoding_marks.append(multiunit.astype(np.int64))
encoding_positions.append(position[is_spike & not_nan_position])

summed_ground_process_intensity = np.sum(
Expand Down Expand Up @@ -257,8 +260,12 @@ def estimate_multiunit_likelihood_integer(multiunits,
for multiunit, enc_marks, enc_pos, mean_rate in zip(
multiunits, encoding_marks, encoding_positions, mean_rates):
is_spike = np.any(~np.isnan(multiunit), axis=1)
multiunit = multiunit[is_spike]
not_nan_marks = np.all(~np.isnan(multiunit), axis=0)
multiunit = multiunit[:, not_nan_marks]

decoding_marks = da.from_array(
multiunit[is_spike].astype(np.int64), chunks=chunks)
multiunit.astype(np.int64), chunks=chunks)
log_joint_mark_intensities.append(
decoding_marks.map_blocks(
estimate_log_joint_mark_intensity,
Expand Down

0 comments on commit d2bfac4

Please sign in to comment.