diff --git a/README.md b/README.md index e1e06bb..cab0759 100644 --- a/README.md +++ b/README.md @@ -48,13 +48,23 @@ The model is trained using the script `train.py` using the demo data. The optimi The `Demodata` folder contains the demo data used to train and test the model The `newCAM_emulation` folder contains the code that is required to load data, train the model and make predictions which is structured as following: -> `train.py` - train the model -> `NN-pred.py` - predict the GWD using the trained model - -> `loaddata.py` - load the data and reshape it to the NN input +> `loaddata.py` - load the data from source .nc files and normalises before feeding it to the neural network. -> `model.py` - define the NN model +> `model.py` - defines the NN class and the early stopping mechanism. + +> `train.py` - trains the model for given number of epochs using the training and validation loops. + +> `main.py` - uses the above three modules to sequentially +1. Read the features list (would vary depending on the GW source, currently is convection) +2. Take information on data like ilev, number of variables varying acrross vertical levels etc. +3. Use `loaddata.py` to load data for the variables in the feature list defined earlier, normalise it, build an `xtrain` `ytrain` for model using a data loader and finally create a custom dataset for easy iteration over the xtrain and ytrain. +4. Take model hyperparameters such as learning rate, epochs, hidden layers and passes to `model.py` +5. Also take Loss function, optimiser and early stopping parameters ans pass it to `train.py` along with the defined model and the custom dataset. +6. Train the model and save the weights in the +`trained_models` folder. +7. The saved model can be loaded and tested on any dataset here. + ## Usage Instructions To use the repository, following steps are required: diff --git a/newCAM_emulation/Model.py b/newCAM_emulation/Model.py index ccdd690..9d2c44d 100644 --- a/newCAM_emulation/Model.py +++ b/newCAM_emulation/Model.py @@ -1,158 +1,58 @@ """Neural Network model for the CAM-EM.""" -import netCDF4 as nc import numpy as np -import scipy.stats as st import torch -import xarray as xr from torch import nn -from torch.nn.utils import prune -from torch.utils.data import DataLoader, Dataset +# ruff: noqa: PLR0913 -# Required for feeding the data iinto NN. -class myDataset(Dataset): - """ - Dataset class for loading features and labels. - - Args: - X (numpy.ndarray): Input features. - Y (numpy.ndarray): Corresponding labels. - """ - - def __init__(self, X, Y): - """Create an instance of myDataset class.""" - self.features = torch.tensor(X, dtype=torch.float64) - self.labels = torch.tensor(Y, dtype=torch.float64) - - def __len__(self): - """Return the number of samples in the dataset.""" - return len(self.features.T) - - def __getitem__(self, idx): - """Return a sample from the dataset.""" - feature = self.features[:, idx] - label = self.labels[:, idx] - - return feature, label - -# The NN model. class FullyConnected(nn.Module): """ Fully connected neural network model. - The model consists of multiple fully connected layers with SiLU activation function. - Attributes ---------- - linear_stack (torch.nn.Sequential): Sequential container for layers. + linear_stack : nn.Sequential + Sequential container of linear layers and activation functions. """ - def __init__(self): - """Create an instance of FullyConnected NN model.""" + def __init__( + self, ilev=93, in_ver=8, in_nover=4, out_ver=2, hidden_layers=8, hidden_size=500 + ): super(FullyConnected, self).__init__() - ilev = 93 - - self.linear_stack = nn.Sequential( - nn.Linear(8 * ilev + 4, 500, dtype=torch.float64), - nn.SiLU(), - nn.Linear(500, 500, dtype=torch.float64), - nn.SiLU(), - nn.Linear(500, 500, dtype=torch.float64), - nn.SiLU(), - nn.Linear(500, 500, dtype=torch.float64), - nn.SiLU(), - nn.Linear(500, 500, dtype=torch.float64), - nn.SiLU(), - nn.Linear(500, 500, dtype=torch.float64), - nn.SiLU(), - nn.Linear(500, 500, dtype=torch.float64), - nn.SiLU(), - nn.Linear(500, 500, dtype=torch.float64), - nn.SiLU(), - nn.Linear(500, 500, dtype=torch.float64), - nn.SiLU(), - nn.Linear(500, 500, dtype=torch.float64), - nn.SiLU(), - nn.Linear(500, 500, dtype=torch.float64), - nn.SiLU(), - nn.Linear(500, 500, dtype=torch.float64), - nn.SiLU(), - nn.Linear(500, 2 * ilev, dtype=torch.float64), - ) + self.ilev = ilev + self.in_ver = in_ver + self.in_nover = in_nover + self.out_ver = out_ver + self.hidden_layers = hidden_layers + self.hidden_size = hidden_size + + layers = [] + + input_size = in_ver * ilev + in_nover + + # The following for loop provides the sequential layer by layer flow + # of data in the model as the layers used in our model are identical. + for _ in range(hidden_layers): + layers.append(nn.Linear(input_size, hidden_size, dtype=torch.float64)) + layers.append(nn.SiLU()) + input_size = hidden_size + layers.append(nn.Linear(hidden_size, out_ver * ilev, dtype=torch.float64)) + self.linear_stack = nn.Sequential(*layers) def forward(self, X): """ Forward pass through the network. - Args: - X (torch.Tensor): Input tensor. + Parameters + ---------- + X : torch.Tensor + Input tensor. Returns ------- - torch.Tensor: Output tensor. + torch.Tensor + Output tensor. """ return self.linear_stack(X) - - -# training loop -def train_loop(dataloader, model, loss_fn, optimizer): - """ - Training loop. - - Args: - dataloader (DataLoader): DataLoader for training data. - model (nn.Module): Neural network model. - loss_fn (torch.nn.Module): Loss function. - optimizer (torch.optim.Optimizer): Optimizer. - - Returns - ------- - float: Average training loss. - """ - size = len(dataloader.dataset) - avg_loss = 0 - for batch, (X, Y) in enumerate(dataloader): - # Compute prediction and loss - pred = model(X) - loss = loss_fn(pred, Y) - - # Backpropagation - optimizer.zero_grad(set_to_none=True) - loss.backward() - optimizer.step() - - with torch.no_grad(): - avg_loss += loss.item() - - avg_loss /= len(dataloader) - - return avg_loss - - -# validating loop -def val_loop(dataloader, model, loss_fn): - """ - Validation loop. - - Args: - dataloader (DataLoader): DataLoader for validation data. - model (nn.Module): Neural network model. - loss_fn (torch.nn.Module): Loss function. - - Returns - ------- - float: Average validation loss. - """ - avg_loss = 0 - with torch.no_grad(): - for batch, (X, Y) in enumerate(dataloader): - # Compute prediction and loss - pred = model(X) - loss = loss_fn(pred, Y) - avg_loss += loss.item() - - avg_loss /= len(dataloader) - - return avg_loss diff --git a/newCAM_emulation/NN_pred.py b/newCAM_emulation/NN_pred.py deleted file mode 100644 index 8b20304..0000000 --- a/newCAM_emulation/NN_pred.py +++ /dev/null @@ -1,170 +0,0 @@ -"""Prediction module for the neural network.""" - -import matplotlib.pyplot as plt -import Model -import netCDF4 as nc -import numpy as np -import torch -import torch.nn.functional as nnF -import torchvision -from loaddata import data_loader, newnorm -from torch import nn -from torch.utils.data import DataLoader -from torchvision import datasets, transforms -from torchvision.utils import save_image - -""" -Determine if any GPUs are available -""" -device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') -print(device) - - -""" -Initialize Hyperparameters -""" -ilev = 93 -dim_NN = 8*ilev + 4 -dim_NNout = 2*ilev - -batch_size = 8 -learning_rate = 1e-4 -num_epochs = 1 - - - - - -## load mean and std for normalization -fm = np.load('Demodata/mean_demo.npz') -fs = np.load('Demodata/std_demo.npz') - -Um = fm['U'] -Vm = fm['V'] -Tm = fm['T'] -DSEm = fm['DSE'] -NMm = fm['NM'] -NETDTm = fm['NETDT'] -Z3m = fm['Z3'] -RHOIm = fm['RHOI'] -PSm = fm['PS'] -latm = fm['lat'] -lonm = fm['lon'] -UTGWSPECm = fm['UTGWSPEC'] -VTGWSPECm = fm['VTGWSPEC'] - -Us = fs['U'] -Vs = fs['V'] -Ts = fs['T'] -DSEs = fs['DSE'] -NMs = fs['NM'] -NETDTs = fs['NETDT'] -Z3s = fs['Z3'] -RHOIs = fs['RHOI'] -PSs = fs['PS'] -lats = fs['lat'] -lons = fs['lon'] -UTGWSPECs = fs['UTGWSPEC'] -VTGWSPECs = fs['VTGWSPEC'] - - - -""" -Initialize the network and the Adam optimizer -""" -GWnet = Model.FullyConnected() - -optimizer = torch.optim.Adam(GWnet.parameters(), lr=learning_rate) - - -s_list = list(range(5,6)) - -for iter in s_list: - if (iter > 0): - GWnet.load_state_dict(torch.load('./conv_torch.pth')) - GWnet.eval() - print ('data loader iteration',iter) - filename = './Demodata/Demo_timestep_' + str(iter).zfill(3) + '.nc' - - F = nc.Dataset(filename) - PS = np.asarray(F['PS'][0,:]) - PS = newnorm(PS, PSm, PSs) - - Z3 = np.asarray(F['Z3'][0,:,:]) - Z3 = newnorm(Z3, Z3m, Z3s) - - U = np.asarray(F['U'][0,:,:]) - U = newnorm(U, Um, Us) - - V = np.asarray(F['V'][0,:,:]) - V = newnorm(V, Vm, Vs) - - T = np.asarray(F['T'][0,:,:]) - T = newnorm(T, Tm, Ts) - - lat = F['lat'] - lat = newnorm(lat, np.mean(lat), np.std(lat)) - - lon = F['lon'] - lon = newnorm(lon, np.mean(lon), np.std(lon)) - - DSE = np.asarray(F['DSE'][0,:,:]) - DSE = newnorm(DSE, DSEm, DSEs) - - RHOI = np.asarray(F['RHOI'][0,:,:]) - RHOI = newnorm(RHOI, RHOIm, RHOIs) - - NETDT = np.asarray(F['NETDT'][0,:,:]) - NETDT = newnorm(NETDT, NETDTm, NETDTs) - - NM = np.asarray(F['NMBV'][0,:,:]) - NM = newnorm(NM, NMm, NMs) - - UTGWSPEC = np.asarray(F['BUTGWSPEC'][0,:,:]) - UTGWSPEC = newnorm(UTGWSPEC, UTGWSPECm, UTGWSPECs) - - VTGWSPEC = np.asarray(F['BVTGWSPEC'][0,:,:]) - VTGWSPEC = newnorm(VTGWSPEC, VTGWSPECm, VTGWSPECs) - - - - print('shape of PS',np.shape(PS)) - print('shape of Z3',np.shape(Z3)) - print('shape of U',np.shape(U)) - print('shape of V',np.shape(V)) - print('shape of T',np.shape(T)) - print('shape of DSE',np.shape(DSE)) - print('shape of RHOI',np.shape(RHOI)) - print('shape of NETDT',np.shape(NETDT)) - print('shape of NM',np.shape(NM)) - print('shape of UTGWSPEC',np.shape(UTGWSPEC)) - print('shape of VTGWSPEC',np.shape(VTGWSPEC)) - - x_test,y_test = data_loader (U,V,T, DSE, NM, NETDT, Z3, - RHOI, PS,lat,lon,UTGWSPEC, VTGWSPEC) - - print('shape of x_test', np.shape(x_test)) - print('shape of y_test', np.shape(y_test)) - - - data = Model.myDataset(X=x_test, Y=y_test) - test_loader = DataLoader(data, batch_size=len(data), shuffle=False) - print(test_loader) - - - for batch, (X, Y) in enumerate(test_loader): - print(np.shape(Y)) - pred = GWnet(X) - truth = Y.cpu().detach().numpy() - predict = pred.cpu().detach().numpy() - - print(np.corrcoef(truth.flatten(), predict.flatten())[0, 1]) - print('shape of truth ',np.shape(truth)) - print('shape of prediction',np.shape(predict)) - - np.save('./pred_data_' + str(iter) + '.npy', predict) - - - - - diff --git a/newCAM_emulation/__init__.py b/newCAM_emulation/__init__.py index e69de29..2cf3fc8 100644 --- a/newCAM_emulation/__init__.py +++ b/newCAM_emulation/__init__.py @@ -0,0 +1 @@ +"""Intantiate the Emulation.""" diff --git a/newCAM_emulation/loaddata.py b/newCAM_emulation/loaddata.py index 859bf26..17b203a 100644 --- a/newCAM_emulation/loaddata.py +++ b/newCAM_emulation/loaddata.py @@ -1,79 +1,201 @@ """Implementing data loader for training neural network.""" +import os +import re + +import netCDF4 as nc import numpy as np +import torch + +# ruff: noqa: PLR0913 +# ruff: noqa: PLR2004 + + +def load_variables(directory_path, variable_names, startfile, endfile): + """ + Load specified variables from NetCDF files in the given directory. + + Parameters + ---------- + directory_path : str + Path to the directory containing NetCDF files. + variable_names : list of str + List of variable names to load. + startfile : int + Starting file number. + endfile : int + Ending file number. + + Returns + ------- + dict + Dictionary containing loaded variables data. + """ + variable_mapping = {"NM": "NMBV"} + variable_data = {} + pattern = re.compile(r"^newCAM_demo_sub_\d{startfile,endfile}$") + + for file_name in os.listdir(directory_path): + if file_name.startswith("newCAM_demo_sub_"): + file_path = os.path.join(directory_path, file_name) + with nc.Dataset(file_path) as dataset: + for var_name in variable_names: + mapped_name = variable_mapping.get(var_name, var_name) + if mapped_name in dataset.variables: + var_data = dataset[mapped_name][:] + variable_data[var_name] = var_data + + return variable_data + + +def load_mean_std(file_path_mean, file_path_std, variable_names): + """ + Load mean and standard deviation values for specified variables from files. + + Parameters + ---------- + file_path_mean : str + Path to the file containing mean values. + file_path_std : str + Path to the file containing standard deviation values. + variable_names : list of str + List of variable names. + + Returns + ------- + tuple of dict + Dictionaries containing mean and standard deviation values. + """ + mean_data = np.load(file_path_mean) + std_data = np.load(file_path_std) + mean_dict = {var_name: mean_data[var_name] for var_name in variable_names} + std_dict = {var_name: std_data[var_name] for var_name in variable_names} + return mean_dict, std_dict + + +def normalize_data(variable_data, mean_values, std_values): + """ + Normalize the data using mean and standard deviation values. + + Parameters + ---------- + variable_data : dict + Dictionary containing the variable data. + mean_values : dict + Dictionary containing mean values. + std_values : dict + Dictionary containing standard deviation values. + + Returns + ------- + dict + Dictionary containing normalized data. + """ + normalized_data = {} + for var_name, var_data in variable_data.items(): + if var_name in mean_values and var_name in std_values: + mean = mean_values[var_name] + std = std_values[var_name] + normalized_var_data = (var_data - mean) / std + normalized_data[var_name] = normalized_var_data + return normalized_data + + +def data_loader(variable_names, normalized_data, ilev, in_ver, in_nover, out_ver): + """ + Prepare the data for training by organizing it into input and output arrays. + + Parameters + ---------- + variable_names : list of str + List of variable names. + normalized_data : dict + Dictionary containing normalized data. + ilev : int + Number of vertical levels. + in_ver : int + Number of input variables that vary across vertical levels. + in_nover : int + Number of input variables that do not vary across vertical levels. + out_ver : int + Number of output variables that vary across vertical levels. + + + Returns + ------- + tuple of np.ndarray + Input and output arrays for training. + """ + Ncol = normalized_data[variable_names[1]].shape[2] + dim_NN = int(in_ver * ilev + in_nover) + dim_NNout = int(out_ver * ilev) + x_train = np.zeros([dim_NN, Ncol]) + y_train = np.zeros([dim_NNout, Ncol]) + target_var = variable_names[-2:] + y_index = 0 + x_index = 0 + for var_name, var_data in normalized_data.items(): + var_shape = var_data.shape + if var_name in target_var: + y_train[y_index * ilev : (y_index + 1) * ilev, :] = var_data + y_index += 1 + elif len(var_shape) == 2: + x_train[x_index, :] = var_data + elif len(var_shape) == 3: + new_ilev = var_shape[1] + x_train[x_index : x_index + new_ilev, :] = var_data + x_index += 1 + return x_train, y_train + + +class MyDataset(torch.utils.data.Dataset): + """ + Custom Dataset for loading features and labels. + + Parameters + ---------- + X : np.ndarray + Feature data. + Y : np.ndarray + Label data. + + Attributes + ---------- + features : torch.Tensor + Tensor containing the feature data. + labels : torch.Tensor + Tensor containing the label data. + """ + + def __init__(self, X, Y): + self.features = torch.tensor(X, dtype=torch.float64) + self.labels = torch.tensor(Y, dtype=torch.float64) + + def __len__(self): + """ + Return the length of the dataset. + + Returns + ------- + int + Length of the dataset. + """ + return len(self.features.T) + + def __getitem__(self, idx): + """ + Return a single sample from the dataset. -ilev = 93 -dim_NN =int(8*ilev+4) -dim_NNout =int(2*ilev) - -def newnorm(var, varm, varstd): - """Normalizes the input variable(s) using mean and standard deviation. - - Args: - var (numpy.ndarray): Input variable(s) to be normalized. - varm (numpy.ndarray): Mean of the variable(s). - varstd (numpy.ndarray): Standard deviation of the variable(s). - - Returns - ------- - numpy.ndarray: Normalized variable(s). - """ - dim=varm.size - if dim > 1 : - vara = var - varm[:, :] - varstdmax = varstd - varstdmax[varstd==0.0] = 1.0 - tmp = vara / varstdmax[:, :] - else: - tmp = ( var - varm ) / varstd - return tmp - - -def data_loader (U,V,T, DSE, NM, NETDT, Z3, RHOI, PS, lat, lon, UTGWSPEC, VTGWSPEC): - """ - Loads and preprocesses input data for neural network training. - - Args: - U (numpy.ndarray): Zonal wind component. - V (numpy.ndarray): Meridional wind component. - T (numpy.ndarray): Temperature. - DSE (numpy.ndarray): Dry static energy. - NM (numpy.ndarray): Northward mass flux. - NETDT (numpy.ndarray): Net downward total radiation flux. - Z3 (numpy.ndarray): Geopotential height. - RHOI (numpy.ndarray): Air density. - PS (numpy.ndarray): Surface pressure. - lat (numpy.ndarray): Latitude. - lon (numpy.ndarray): Longitude. - UTGWSPEC (numpy.ndarray): Target zonal wind spectral component. - VTGWSPEC (numpy.ndarray): Target meridional wind spectral component. - - Returns - ------- - tuple: A tuple containing the input data and target data arrays. - """ - Ncol = U.shape[1] - #Nlon = U.shape[2] - #Ncol = Nlat*Nlon - - x_train = np.zeros([dim_NN,Ncol]) - y_train = np.zeros([dim_NNout,Ncol]) - - - x_train [0:ilev, : ] = U.reshape(ilev, Ncol) - x_train [ilev:2*ilev, :] = V.reshape(ilev, Ncol) - x_train [2*ilev:3*ilev,:] = T.reshape(ilev, Ncol) - x_train [3*ilev:4*ilev, :] = DSE.reshape(ilev, Ncol) - x_train [4*ilev:5*ilev, :] = NM.reshape(ilev, Ncol) - x_train [5*ilev:6*ilev, :] = NETDT.reshape(ilev, Ncol) - x_train [6*ilev:7*ilev, :] = Z3.reshape(ilev, Ncol) - x_train [7*ilev:8*ilev+1, :] = RHOI.reshape(ilev+1, Ncol) - x_train [8*ilev+1:8*ilev+2, :] = PS.reshape(1, Ncol) - x_train [8*ilev+2:8*ilev+3, :] = lat.reshape(1, Ncol) - x_train [8*ilev+3:ilev*ilev+4, :] = lon.reshape(1, Ncol) - - y_train [0:ilev, :] = UTGWSPEC.reshape(ilev, Ncol) - y_train [ilev:2*ilev, :] = VTGWSPEC.reshape(ilev, Ncol) - - return x_train,y_train + Parameters + ---------- + idx : int + Index of the sample. + Returns + ------- + tuple of torch.Tensor + Feature and label tensors for the sample. + """ + feature = self.features[:, idx] + label = self.labels[:, idx] + return feature, label diff --git a/newCAM_emulation/main.py b/newCAM_emulation/main.py new file mode 100644 index 0000000..93ff886 --- /dev/null +++ b/newCAM_emulation/main.py @@ -0,0 +1,126 @@ +"""Script to load data and train the neural network.""" + +import os + +import numpy as np +import torch +from loaddata import ( + MyDataset, + data_loader, + load_mean_std, + load_variables, + normalize_data, +) +from Model import FullyConnected +from torch import nn +from torch.utils.data import DataLoader +from train import early_stopping, train_with_early_stopping + +# File paths and parameters +directory_path = "Demodata" +file_path_mean = "Demodata/mean_demo_sub.npz" +file_path_std = "Demodata/std_demo_sub.npz" +trained_model_path = "trained_models/weights_conv" + +# variable information +features = [ + "U", + "V", + "T", + "DSE", + "NM", + "NETDT", + "Z3", + "RHOI", + "PS", + "lat", + "lon", + "UTGWSPEC", + "VTGWSPEC", +] +ilev = 93 +in_ver = 8 +in_nover = 4 +out_ver = 2 + +# Load and preprocess data +variable_data = load_variables(directory_path, features, 1, 5) +mean_dict, std_dict = load_mean_std(file_path_mean, file_path_std, features) +normalized_data = normalize_data(variable_data, mean_dict, std_dict) +xtrain, ytrain = data_loader( + features, + normalized_data, + ilev=ilev, + in_ver=in_ver, + in_nover=in_nover, + out_ver=out_ver, +) + +# Print the shapes of xtrain and ytrain +print(f"xtrain shape: {xtrain.shape}") +print(f"ytrain shape: {ytrain.shape}") + + +# Prepare dataset and dataloaders +data = MyDataset(X=xtrain, Y=ytrain) +split_data = torch.utils.data.random_split( + data, [0.75, 0.25], generator=torch.Generator().manual_seed(42) +) +train_dataloader = DataLoader(split_data[0], batch_size=128, shuffle=True) +val_dataloader = DataLoader(split_data[1], batch_size=len(split_data[1]), shuffle=True) + +# Model training parameters +learning_rate = 1e-5 +epochs = 100 +hidden_layers = 8 +hidden_size = 500 + +model = FullyConnected(ilev, in_ver, in_nover, out_ver, hidden_layers, hidden_size) +optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate) +criterion = nn.MSELoss() +early_stopper = early_stopping(patience=5, min_delta=0) + +# Train the model with early stopping +train_losses, val_losses = train_with_early_stopping( + train_dataloader, + val_dataloader, + model, + optimizer, + criterion, + early_stopper, + epochs=epochs, +) +print(f"Train Loss: {train_losses}") +print(f"Valid Loss: {val_losses}") + +# Save the trained model +torch.save(model.state_dict(), trained_model_path) + +# Load the trained model for prediction +model.load_state_dict(torch.load(trained_model_path)) +model.eval() +print() + +# Prepare input data for prediction +# For prediction, we need new input data. Here, we use different files for simplicity. +test_data = load_variables(directory_path, features, 4, 5) +normalized_test_data = normalize_data(test_data, mean_dict, std_dict) +x_test, y_test = data_loader( + features, + normalized_test_data, + ilev=ilev, + in_ver=in_ver, + in_nover=in_nover, + out_ver=out_ver, +) + +# Convert test data to tensors +x_test_tensor = torch.tensor(x_test, dtype=torch.float64).T + +# Make predictions +with torch.no_grad(): + predictions = model(x_test_tensor).numpy() + +# Print predictions +print("Predictions Shape:\n", predictions.shape) +# print("Predictions:\n", predictions) diff --git a/newCAM_emulation/train.py b/newCAM_emulation/train.py index 198097e..8202b0e 100644 --- a/newCAM_emulation/train.py +++ b/newCAM_emulation/train.py @@ -1,52 +1,89 @@ """Training script for the neural network.""" -import Model -import netCDF4 as nc -import numpy as np import torch -from loaddata import data_loader, newnorm from torch import nn -from torch.backends import mps -from torch.cuda import is_available -from torch.utils.data import DataLoader - -if is_available(): - DEVICE = "cuda" -elif mps.is_available(): - DEVICE = "mps" -else: - DEVICE = "cpu" -print(f"Using device: {DEVICE}") - - -class EarlyStopper: - """Class for implementing early stopping during training.""" - - def __init__(self, patience=1, min_delta=0): - """Create an instance of EarlyStopper class.""" - self.patience = patience - self.min_delta = min_delta - self.counter = 0 - self.min_validation_loss = np.inf - - def early_stop(self, validation_loss): + +# ruff: noqa: PLR0913 + +# class EarlyStopper: +# """ +# Early stopping utility to stop training when validation loss doesn't improve. + +# Parameters +# ---------- +# patience : int, optional +# Number of epochs to wait before stopping (default is 1). +# min_delta : float, optional +# Minimum change in the loss to qualify as an improvement (default is 0). + +# Attributes +# ---------- +# patience : int +# Number of epochs to wait before stopping. +# min_delta : float +# Minimum change in the monitored quantity to qualify as an improvement. +# counter : int +# Counter for the number of epochs without improvement. +# min_validation_loss : float +# Minimum validation loss recorded. +# """ + +# def __init__(self, patience=1, min_delta=0): +# self.patience = patience +# self.min_delta = min_delta +# self.counter = 0 +# self.min_validation_loss = np.inf + +# def early_stop(self, validation_loss, model=None): +# """ +# Check if training should be stopped early. + +# Parameters +# ---------- +# validation_loss : float +# Current validation loss. +# model : nn.Module, optional +# Model to save if validation loss improves (default is None). + +# Returns +# ------- +# bool +# True if training should be stopped, False otherwise. +# """ +# if validation_loss < self.min_validation_loss: +# self.min_validation_loss = validation_loss +# self.counter = 0 +# # if model is not None: +# # # torch.save(model.state_dict(), 'conv_torch.pth') +# # torch.save(model.state_dict(), 'trained_models/weights_conv') +# elif validation_loss > (self.min_validation_loss + self.min_delta): +# self.counter += 1 +# if self.counter >= self.patience: +# return True +# return False + +def early_stopping(self, validation_loss, patience=1, min_delta=0, model=None): """ - Check if early stopping condition is met. + Check if training should be stopped early. - Args: - validation_loss (float): Loss value on the validation set. + Parameters + ---------- + validation_loss : float + Current validation loss. + model : nn.Module, optional + Model to save if validation loss improves (default is None). Returns ------- - bool: True if early stopping condition is met, False otherwise. + bool + True if training should be stopped, False otherwise. """ if validation_loss < self.min_validation_loss: self.min_validation_loss = validation_loss self.counter = 0 - - #save model - torch.save(model.state_dict(), 'conv_torch.pth') - + # if model is not None: + # # torch.save(model.state_dict(), 'conv_torch.pth') + # torch.save(model.state_dict(), 'trained_models/weights_conv') elif validation_loss > (self.min_validation_loss + self.min_delta): self.counter += 1 if self.counter >= self.patience: @@ -54,133 +91,106 @@ def early_stop(self, validation_loss): return False - - -## load mean and std for normalization -fm = np.load('../Demodata/mean_demo_sub.npz') -fs = np.load('../Demodata/std_demo_sub.npz') - -Um = fm['U'] -Vm = fm['V'] -Tm = fm['T'] -DSEm = fm['DSE'] -NMm = fm['NM'] -NETDTm = fm['NETDT'] -Z3m = fm['Z3'] -RHOIm = fm['RHOI'] -PSm = fm['PS'] -latm = fm['lat'] -lonm = fm['lon'] -UTGWSPECm = fm['UTGWSPEC'] -VTGWSPECm = fm['VTGWSPEC'] - -Us = fs['U'] -Vs = fs['V'] -Ts = fs['T'] -DSEs = fs['DSE'] -NMs = fs['NM'] -NETDTs = fs['NETDT'] -Z3s = fs['Z3'] -RHOIs = fs['RHOI'] -PSs = fs['PS'] -lats = fs['lat'] -lons = fs['lon'] -UTGWSPECs = fs['UTGWSPEC'] -VTGWSPECs = fs['VTGWSPEC'] - -ilev = 93 - -dim_NN =int(8*ilev+4) -dim_NNout =int(2*ilev) - -model = Model.FullyConnected() - -train_losses = [] -val_losses = [0] - -learning_rate = 1e-5 -epochs = 100 -optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate) # weight_decay=1e-5 - - -s_list = list(range(1, 6)) - -for iter in s_list: - if (iter > 1): - model.load_state_dict(torch.load('conv_torch.pth')) - print ('data loader iteration',iter) - filename = '../Demodata/newCAM_demo_sub_' + str(iter).zfill(1) + '.nc' - print('working on: ', filename) - - F = nc.Dataset(filename) - PS = np.asarray(F['PS'][0,:]) - PS = newnorm(PS, PSm, PSs) - - Z3 = np.asarray(F['Z3'][0,:,:]) - Z3 = newnorm(Z3, Z3m, Z3s) - - U = np.asarray(F['U'][0,:,:]) - U = newnorm(U, Um, Us) - - V = np.asarray(F['V'][0,:,:]) - V = newnorm(V, Vm, Vs) - - T = np.asarray(F['T'][0,:,:]) - T = newnorm(T, Tm, Ts) - - lat = F['lat'] - lat = newnorm(lat, np.mean(lat), np.std(lat)) - - lon = F['lon'] - lon = newnorm(lon, np.mean(lon), np.std(lon)) - - DSE = np.asarray(F['DSE'][0,:,:]) - DSE = newnorm(DSE, DSEm, DSEs) - - RHOI = np.asarray(F['RHOI'][0,:,:]) - RHOI = newnorm(RHOI, RHOIm, RHOIs) - - NETDT = np.asarray(F['NETDT'][0,:,:]) - NETDT = newnorm(NETDT, NETDTm, NETDTs) - - NM = np.asarray(F['NMBV'][0,:,:]) - NM = newnorm(NM, NMm, NMs) - - UTGWSPEC = np.asarray(F['UTGWSPEC'][0,:,:]) - UTGWSPEC = newnorm(UTGWSPEC, UTGWSPECm, UTGWSPECs) - - VTGWSPEC = np.asarray(F['VTGWSPEC'][0,:,:]) - VTGWSPEC = newnorm(VTGWSPEC, VTGWSPECm, VTGWSPECs) - - x_train,y_train = data_loader(U,V,T, DSE, NM, NETDT, Z3, - RHOI, PS,lat,lon,UTGWSPEC, VTGWSPEC) - - data = Model.myDataset(X=x_train, Y=y_train) - - batch_size = 128 - - split_data = torch.utils.data.random_split(data, [0.75, 0.25], - generator=torch.Generator().manual_seed(42)) - train_dataloader = DataLoader(split_data[0], - batch_size=batch_size, - shuffle=True) - val_dataloader = DataLoader(split_data[1], - batch_size=len(split_data[1]), - shuffle=True) - - # training - early_stopper = EarlyStopper(patience=5, min_delta=0) # Note the hyper parameters. - for t in range(epochs): - if t % 2 ==0: - print(f"Epoch {t+1}\n-------------------------------") +def train_loop(dataloader, model, loss_fn, optimizer): + """ + Training loop for a single epoch. + + Parameters + ---------- + dataloader : torch.utils.data.DataLoader + DataLoader for the training data. + model : nn.Module + Neural network model. + loss_fn : callable + Loss function. + optimizer : torch.optim.Optimizer + Optimizer for training. + + Returns + ------- + float + Average training loss. + """ + avg_loss = 0 + for batch, (X, Y) in enumerate(dataloader): + pred = model(X) + loss = loss_fn(pred, Y) + optimizer.zero_grad(set_to_none=True) + loss.backward() + optimizer.step() + avg_loss += loss.item() + avg_loss /= len(dataloader) + return avg_loss + + +def val_loop(dataloader, model, loss_fn): + """ + Validate loop for a single epoch. + + Parameters + ---------- + dataloader : torch.utils.data.DataLoader + DataLoader for the validation data. + model : nn.Module + Neural network model. + loss_fn : callable + Loss function. + + Returns + ------- + float + Average validation loss. + """ + avg_loss = sum(loss_fn(model(X), Y).item() for X, Y in dataloader) / len(dataloader) + return avg_loss + + +def train_with_early_stopping( + train_dataloader, + val_dataloader, + model, + optimizer, + criterion, + early_stopper, + epochs=100, +): + """ + Train the model with early stopping. + + Parameters + ---------- + train_dataloader : torch.utils.data.DataLoader + DataLoader for the training data. + val_dataloader : torch.utils.data.DataLoader + DataLoader for the validation data. + model : nn.Module + Neural network model. + optimizer : torch.optim.Optimizer + Optimizer for training. + criterion : callable + Loss function. + early_stopper : EarlyStopper + Early stopping utility. + epochs : int, optional + Number of epochs to train (default is 100). + + Returns + ------- + tuple of list of float + Training losses and validation losses for each epoch. + """ + train_losses = [] + val_losses = [0] + for epoch in range(epochs): + if epoch % 2 == 0: + print(f"Epoch {epoch + 1}\n-------------------------------") print(val_losses[-1]) - print('counter=' + str(early_stopper.counter)) - train_loss = Model.train_loop(train_dataloader, model, nn.MSELoss(), optimizer) - + print("counter=" + str(early_stopper.counter)) + train_loss = train_loop(train_dataloader, model, criterion, optimizer) train_losses.append(train_loss) - val_loss = Model.val_loop(val_dataloader, model, nn.MSELoss()) + val_loss = val_loop(val_dataloader, model, criterion) val_losses.append(val_loss) - if early_stopper.early_stop(val_loss): - print("BREAK!") + if early_stopper.early_stop(val_loss, model): + # print("BREAK!") break - + return train_losses, val_losses diff --git a/pyproject.toml b/pyproject.toml index 93d8d9e..ae7c3d1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -62,7 +62,8 @@ extend-include = ["*.ipynb"] [tool.ruff.lint] # Enable: D: `pydocstyle`, PL: `pylint`, I: `isort`, W: `pycodestyle whitespace` # NPY: `numpy`, -select = ["D", "PL", "I", "E", "W", "NPY" ] + +select = ["D", "PL", "I", "E", "W", "NPY"] # Enable D417 (Missing argument description) on top of the NumPy convention. extend-select = ["D417"]