diff --git a/docs/generate_templates.rst b/docs/generate_templates.rst index 84c1e94..eb5e3c5 100644 --- a/docs/generate_templates.rst +++ b/docs/generate_templates.rst @@ -128,9 +128,9 @@ Drifting parameters summary max_drift: 100 # max distance from the initial and final cell position min_drift: 30 # min distance from the initial and final cell position drift_steps: 50 # number of drift steps - drift_x_lim: [-10, 10] # drift limits in the x-direction - drift_y_lim: [-10, 10] # drift limits in the y-direction - drift_z_lim: [20, 80] # drift limits in the z-direction + drift_xlim: [-10, 10] # drift limits in the x-direction + drift_ylim: [-10, 10] # drift limits in the y-direction + drift_zlim: [20, 80] # drift limits in the z-direction Running template generation using Python diff --git a/src/MEArec/drift_tools.py b/src/MEArec/drift_tools.py index 25e689b..f2ffae9 100644 --- a/src/MEArec/drift_tools.py +++ b/src/MEArec/drift_tools.py @@ -137,17 +137,17 @@ def generate_drift_dict_from_params( # triangle / sine frequency depends on the velocity freq = 1.0 / (2 * half_period) - times = np.arange(end_drift_index - start_drift_index) / drift_fs + drift_times = np.arange(end_drift_index - start_drift_index) / drift_fs if slow_drift_waveform == "triangluar": - triangle = np.abs(scipy.signal.sawtooth(2 * np.pi * freq * times + np.pi / 2)) + triangle = np.abs(scipy.signal.sawtooth(2 * np.pi * freq * drift_times + np.pi / 2)) triangle *= slow_drift_amplitude triangle -= slow_drift_amplitude / 2.0 drift_vector_um[start_drift_index:end_drift_index] = triangle drift_vector_um[end_drift_index:] = triangle[-1] elif slow_drift_waveform == "sine": - sine = np.cos(2 * np.pi * freq * times + np.pi / 2) + sine = np.cos(2 * np.pi * freq * drift_times + np.pi / 2) sine *= slow_drift_amplitude / 2.0 drift_vector_um[start_drift_index:end_drift_index] = sine diff --git a/src/MEArec/generators/recordinggenerator.py b/src/MEArec/generators/recordinggenerator.py index 2a3880d..aaedc8b 100644 --- a/src/MEArec/generators/recordinggenerator.py +++ b/src/MEArec/generators/recordinggenerator.py @@ -825,7 +825,9 @@ def generate_recordings( template_locs = np.array(locs)[reordered_idx_cells] template_rots = np.array(rots)[reordered_idx_cells] template_bin = np.array(bin_cat)[reordered_idx_cells] - templates = np.array(eaps)[reordered_idx_cells] + templates = np.empty((len(reordered_idx_cells), *eaps.shape[1:]), dtype=eaps.dtype) + for i, reordered_idx in enumerate(reordered_idx_cells): + templates[i] = eaps[reordered_idx] self.template_ids = reordered_idx_cells else: print(f"Using provided template ids: {self.template_ids}") @@ -991,7 +993,7 @@ def generate_recordings( if verbose_1: print("Smoothing templates") - templates = templates * window + templates *= window # delete temporary preprocessed templates del templates_rs, templates_pad diff --git a/src/MEArec/simulate_cells.py b/src/MEArec/simulate_cells.py index 03f8d71..f671ca8 100644 --- a/src/MEArec/simulate_cells.py +++ b/src/MEArec/simulate_cells.py @@ -875,7 +875,7 @@ def calc_extracellular( if verbose >= 1: print(f"Done generating EAPs for {cell_name}") - saved_eaps = np.array(saved_eaps) + saved_eaps = np.array(saved_eaps, dtype=np.float32) saved_positions = np.array(saved_positions) saved_rotations = np.array(saved_rotations) @@ -1268,14 +1268,14 @@ def check_solidangle(matrix, pre, post, polarlim): cell.set_rotation(x=x_rot, y=y_rot, z=z_rot) rot = [x_rot, y_rot, z_rot] - lfp = electrodes.get_transformation_matrix() @ cell.imem + lfp = np.array(electrodes.get_transformation_matrix() @ cell.imem, dtype=np.float32) # Reverse rotation to bring cell back into initial rotation state if rotation is not None: rev_rot = [-r for r in rot] cell.set_rotation(rev_rot[0], rev_rot[1], rev_rot[2], rotation_order="zyx") - return 1000 * lfp, pos, rot, found_position + return 1e3 * lfp, pos, rot, found_position def str2bool(v): diff --git a/src/MEArec/tools.py b/src/MEArec/tools.py index b53a69c..42176a5 100755 --- a/src/MEArec/tools.py +++ b/src/MEArec/tools.py @@ -213,7 +213,7 @@ def load_tmp_eap(templates_folder, celltypes=None, samples_per_cat=None, verbose print("loading cell type: ", f) if celltypes is not None: if celltype in celltypes: - eaps = np.load(str(eaplist[idx])) + eaps = np.load(str(eaplist[idx]), mmap_mode="r") locs = np.load(str(loclist[idx])) rots = np.load(str(rotlist[idx])) @@ -230,7 +230,7 @@ def load_tmp_eap(templates_folder, celltypes=None, samples_per_cat=None, verbose else: ignored_categories.add(celltype) else: - eaps = np.load(str(eaplist[idx])) + eaps = np.load(str(eaplist[idx]), mmap_mode="r") locs = np.load(str(loclist[idx])) rots = np.load(str(rotlist[idx])) @@ -245,10 +245,17 @@ def load_tmp_eap(templates_folder, celltypes=None, samples_per_cat=None, verbose cat_list.extend([celltype] * samples_to_read) loaded_categories.add(celltype) + if len(eap_list) > 0: + all_eaps = np.lib.format.open_memmap(templates_folder / "all_eaps.npy", mode="w+", dtype=eaps[0].dtype, shape=(len(eap_list), *eap_list[0].shape)) + for i in range(len(eap_list)): + all_eaps[i, ...] = eap_list[i] + else: + all_eaps = np.array([]) + if verbose: print("Done loading spike data ...") - return np.array(eap_list), np.array(loc_list), np.array(rot_list), np.array(cat_list, dtype=str) + return all_eaps, np.array(loc_list), np.array(rot_list), np.array(cat_list, dtype=str) def load_templates(templates, return_h5_objects=True, verbose=False, check_suffix=True): @@ -553,7 +560,7 @@ def save_template_generator(tempgen, filename=None, verbose=True): print("\nSaved templates in", filename, "\n") -def save_recording_generator(recgen, filename=None, verbose=False): +def save_recording_generator(recgen, filename=None, verbose=False, include_spike_traces: bool = True): """ Save recordings to disk. @@ -565,6 +572,8 @@ def save_recording_generator(recgen, filename=None, verbose=False): Path to .h5 file verbose : bool If True output is verbose + include_spike_traces: bool, default=True + If True, will include the spike traces (which can be large for many units) """ filename = Path(filename) if not filename.parent.is_dir(): @@ -573,12 +582,12 @@ def save_recording_generator(recgen, filename=None, verbose=False): with h5py.File(filename, "w") as f: f.attrs["mearec_version"] = mearec_version f.attrs["date"] = datetime.now().strftime("%y-%m-%d %H:%M:%S") - save_recording_to_file(recgen, f) + save_recording_to_file(recgen, f, include_spike_traces=include_spike_traces) if verbose: print("\nSaved recordings in", filename, "\n") -def save_recording_to_file(recgen, f, path=""): +def save_recording_to_file(recgen, f, path="", include_spike_traces: bool = True): """ Save recordings to file handler. @@ -588,6 +597,8 @@ def save_recording_to_file(recgen, f, path=""): RecordingGenerator object to be saved filename : _io.TextIOWrapper File handler + include_spike_traces: bool, default=True + If True, will include the spike traces (can be heavy) """ save_dict_to_hdf5(recgen.info, f, path + "info/") if len(recgen.voltage_peaks) > 0: @@ -598,7 +609,7 @@ def save_recording_to_file(recgen, f, path=""): f.create_dataset(path + "recordings", data=recgen.recordings) if recgen.gain_to_uV is not None: f["recordings"].attrs["gain_to_uV"] = recgen.gain_to_uV - if len(recgen.spike_traces) > 0: + if len(recgen.spike_traces) > 0 and include_spike_traces: f.create_dataset(path + "spike_traces", data=recgen.spike_traces) if len(recgen.spiketrains) > 0: for ii in range(len(recgen.spiketrains)):