diff --git a/bnpm/ca2p_preprocessing.py b/bnpm/ca2p_preprocessing.py index 45815dd..6823d90 100644 --- a/bnpm/ca2p_preprocessing.py +++ b/bnpm/ca2p_preprocessing.py @@ -2,7 +2,7 @@ import scipy.signal import matplotlib.pyplot as plt import torch -from tqdm import tqdm +from tqdm.auto import tqdm import time import gc diff --git a/bnpm/decomposition.py b/bnpm/decomposition.py index b381208..8d5f591 100644 --- a/bnpm/decomposition.py +++ b/bnpm/decomposition.py @@ -762,11 +762,12 @@ def fit(self, X): X.moveaxis(self.batch_dimension, 0), torch.arange(X.shape[self.batch_dimension]), ) - self.kwargs_dataloader.pop('batch_size', None) + kwargs_dataloader_tmp = copy.deepcopy(self.kwargs_dataloader) + kwargs_dataloader_tmp.pop('batch_size', None) dataloader = torch.utils.data.DataLoader( dataset, batch_size=self.batch_size, - **self.kwargs_dataloader, + **kwargs_dataloader_tmp, ) kwargs_tmp = copy.deepcopy(self.kwargs_CP_NN_HALS) diff --git a/bnpm/file_helpers.py b/bnpm/file_helpers.py index ffcaef7..cab2391 100644 --- a/bnpm/file_helpers.py +++ b/bnpm/file_helpers.py @@ -4,7 +4,7 @@ from pathlib import Path import zipfile -from tqdm import tqdm +from tqdm.auto import tqdm from . import path_helpers @@ -92,7 +92,7 @@ def prepare_path(path, mkdir=False, exist_ok=True): return str(path_obj) ### Custom functions for preparing paths for saving and loading files and directories -def prepare_filepath_for_saving(filepath, mkdir=False, allow_overwrite=True): +def prepare_filepath_for_saving(filepath, mkdir=False, allow_overwrite=False): return prepare_path(filepath, mkdir=mkdir, exist_ok=allow_overwrite) def prepare_filepath_for_loading(filepath, must_exist=True): path = prepare_path(filepath, mkdir=False, exist_ok=must_exist) @@ -116,7 +116,7 @@ def pickle_save( mode='wb', zipCompress=False, mkdir=False, - allow_overwrite=True, + allow_overwrite=False, library='pickle', **kwargs_zipfile, ): @@ -221,7 +221,7 @@ def pickle_load( return pickle.load(f) -def json_save(obj, filepath, indent=4, mode='w', mkdir=False, allow_overwrite=True): +def json_save(obj, filepath, indent=4, mode='w', mkdir=False, allow_overwrite=False): """ Saves an object to a json file. Uses json.dump. @@ -267,7 +267,7 @@ def json_load(filepath, mode='r'): return json.load(f) -def yaml_save(obj, filepath, indent=4, mode='w', mkdir=False, allow_overwrite=True): +def yaml_save(obj, filepath, indent=4, mode='w', mkdir=False, allow_overwrite=False): """ Saves an object to a yaml file. Uses yaml.dump. @@ -363,7 +363,7 @@ def matlab_save( obj, filepath, mkdir=False, - allow_overwrite=True, + allow_overwrite=False, clean_string=True, list_to_objArray=True, none_to_nan=True, @@ -470,7 +470,7 @@ def download_file( hash_type='MD5', hash_hex=None, mkdir=False, - allow_overwrite=True, + allow_overwrite=False, write_mode='wb', verbose=True, chunk_size=1024, diff --git a/bnpm/neural_networks.py b/bnpm/neural_networks.py index d80be38..807cc0a 100644 --- a/bnpm/neural_networks.py +++ b/bnpm/neural_networks.py @@ -2,7 +2,7 @@ import torch import numpy as np -from tqdm import tqdm +from tqdm.auto import tqdm class RegressionRNN(torch.nn.Module): """ diff --git a/bnpm/parallel_helpers.py b/bnpm/parallel_helpers.py index b089475..297436f 100644 --- a/bnpm/parallel_helpers.py +++ b/bnpm/parallel_helpers.py @@ -3,7 +3,7 @@ import multiprocessing as mp from functools import partial import numpy as np -from tqdm import tqdm +from tqdm.auto import tqdm class ParallelExecutionError(Exception): """ diff --git a/bnpm/server.py b/bnpm/server.py index 30d32be..7638140 100644 --- a/bnpm/server.py +++ b/bnpm/server.py @@ -1261,7 +1261,7 @@ def make_rsync_command( Implemented by casperdcl here: https://github.com/tqdm/tqdm/issues/311#issuecomment-387066847 """ try: - from tqdm import tqdm + from tqdm.auto import tqdm except ImportError: class _TqdmWrap(object): # tqdm not installed - construct and return dummy/basic versions diff --git a/bnpm/similarity.py b/bnpm/similarity.py index b5cebb8..f2348b5 100644 --- a/bnpm/similarity.py +++ b/bnpm/similarity.py @@ -7,7 +7,7 @@ import scipy.optimize from numba import njit, prange, jit import torch -from tqdm import tqdm +from tqdm.auto import tqdm from . import indexing, torch_helpers diff --git a/bnpm/torch_helpers.py b/bnpm/torch_helpers.py index b7ae02d..d6da368 100644 --- a/bnpm/torch_helpers.py +++ b/bnpm/torch_helpers.py @@ -10,7 +10,7 @@ import torch from torch.utils.data import Dataset import numpy as np -from tqdm import tqdm +from tqdm.auto import tqdm from . import indexing from . import misc @@ -523,7 +523,89 @@ def __getitem__( idx (int): The index of the requested sample. """ - return self.X[idx], idx + return self.X[idx] + + +class Dataset_numpy(Dataset): + """ + Creates a PyTorch dataset from a numpy array. + RH 2024 + + Args: + X (np.ndarray): + The data from which to create the dataset. + axis (int): + The dimension along which to sample the data. + device (str): + The device where the tensors will be stored. + dtype (torch.dtype): + The data type to use for the tensor. + + Attributes: + X (np.ndarray or np.memmap): + The data from the numpy file. + n_samples (int): + The number of samples in the dataset. + + Returns: + (torch.utils.data.Dataset): + A PyTorch dataset. + """ + def __init__( + self, + X: Union[np.ndarray, np.memmap], + axis: int = 0, + device: str = 'cpu', + dtype: torch.dtype = torch.float32, + ): + """ + Initializes the Dataset_NumpyFile with the provided parameters. + """ + assert isinstance(X, (np.ndarray, np.memmap)), 'X must be a numpy array or memmap.' + self.X = X + self.n_samples = self.X.shape[axis] + self.is_memmap = isinstance(self.X, np.memmap) + self.axis = axis + self.device = device + self.dtype = dtype + + def __len__(self) -> int: + """ + Returns the number of samples in the dataset. + + Returns: + (int): + n_samples (int): + The number of samples in the dataset. + """ + return self.n_samples + + def __getitem__( + self, + idx: int, + ) -> Tuple[torch.Tensor, int]: + """ + Returns a single sample and its index from the dataset. + + Args: + idx (int): + The index of the sample to return. + + Returns: + sample (torch.Tensor): + The requested sample from the dataset. + """ + arr = np.take(self.X, idx, axis=self.axis) + if self.is_memmap: + arr = np.array(arr) + return torch.as_tensor(arr, dtype=self.dtype, device=self.device) + + def close(self): + """ + Closes the numpy file. + """ + if self.is_memmap: + self.X.close() class BatchRandomSampler(torch.utils.data.Sampler): diff --git a/bnpm/video.py b/bnpm/video.py index 6c55903..52e6a9c 100644 --- a/bnpm/video.py +++ b/bnpm/video.py @@ -7,7 +7,7 @@ import torchvision import numpy as np import cv2 -from tqdm import tqdm +from tqdm.auto import tqdm ###############################################################################