Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Logging argparsing #34

Open
wants to merge 4 commits into
base: ablation_10layers
Choose a base branch
from
Open

Conversation

amangupta2
Copy link
Contributor

@amangupta2 amangupta2 commented Dec 18, 2024

#28 and #29

-- Built on branch ablation_10layers
-- Replaced all print statements with logger.info
-- And now parsing command line arguments using argparse. Agreed, it is much neater and convenient.

If all looks good, we can merge this with ablation 10_layers and merge that with main.

Copy link
Collaborator

@TomMelt TomMelt left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @amangupta2 , these changes will greatly improve the scripts and make them easier to run and more robust.

The argparse now adds additional sanity checks to your inputs which will help prevent user error. Great job 👍

I have made a couple of suggestions. We can discuss them in the next meeting if you prefer.

@@ -119,7 +116,7 @@ def Training_ANN_CNN(


def Inference_and_Save_ANN_CNN(
model, testset, testloader, bs_test, device, stencil, log_filename, outfile
model, testset, testloader, bs_test, device, stencil, logger, outfile
Copy link
Collaborator

@TomMelt TomMelt Jan 16, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FYI, you shouldn't pass the logger as an argument. It should be configured at the file level. This will give better logging information.

Suggested change
model, testset, testloader, bs_test, device, stencil, logger, outfile
model, testset, testloader, bs_test, device, stencil, outfile

Use something like this at the start of each file, e.g., for this specific file it would look like:

from netCDF4 import Dataset
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import logging

logger = logging.getLogger(__name__)

def Inference_and_Save_AttentionUNet(
model, testset, testloader, bs_test, device, log_filename, outfile
):
def Inference_and_Save_AttentionUNet(model, testset, testloader, bs_test, device, logger, outfile):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same here

Suggested change
def Inference_and_Save_AttentionUNet(model, testset, testloader, bs_test, device, logger, outfile):
def Inference_and_Save_AttentionUNet(model, testset, testloader, bs_test, device, outfile):

Comment on lines +65 to +66
parser.add_argument("-c", "--ckpt_dir", default=".", help="Checkpoint directory")
parser.add_argument("-o", "--output_dir", default=".", help="Output directory to save outputs")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

for path arguments I would set their type to be path e.g., something like:

Suggested change
parser.add_argument("-c", "--ckpt_dir", default=".", help="Checkpoint directory")
parser.add_argument("-o", "--output_dir", default=".", help="Output directory to save outputs")
from pathlib import Path
parser.add_argument("-c", "--ckpt_dir", default=Path.cwd(), help="Checkpoint directory", type=Path)
parser.add_argument("-o", "--output_dir", default=Path.cwd(), help="Output directory to save outputs", type=Path)

Unfortunately one downside of argparse compared to click is it does not verify paths. So you can still pass in anything.

The benefit to this approach is however, it is platform independent as Path.cwd() should work on Windows too.

Path.cwd() is equivalent to Path("."). It is the path for the Current Working Directory (CWD).

This gives you some extra protection.

"-m",
"--month",
type=int,
choices=[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12],
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you could simplify this a little:

Suggested change
choices=[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12],
choices=range(1,13),
metavar="{1,2,...,12}",

metavar isn't required but it makes the help output look nicer

@@ -21,15 +25,65 @@

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

parser = argparse.ArgumentParser()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

similar changes here to argparse as before for ann_cnn_training/inference.py

"-m",
"--month",
type=int,
choices=[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12],
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

similar changes here to argparse as before for ann_cnn_training/inference.py

type=int,
help="checkpoint (epoch)of the model to be used for transfer learning",
)
parser.add_argument(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

similar changes here to argparse as before for ann_cnn_training/inference.py

Comment on lines +52 to +66
# if model == "attention":
# test_file = [
# f"test_files/test_1x1_inputfeatures_u_v_theta_w_uw_vw_era5_training_data_hourly_2015_constant_mu_sigma_scaling08.nc"
# ]
# elif model == "ann":
# if stencil == 1:
# test_file = [
# f"test_files/test_1x1_inputfeatures_u_v_theta_w_uw_vw_era5_training_data_hourly_2015_constant_mu_sigma_scaling08.nc"
# ]
# elif stencil == 3:
# test_file = [
# f"test_files/test_nonlocal_3x3_inputfeatures_u_v_theta_w_uw_vw_era5_training_data_hourly_2015_constant_mu_sigma_scaling08.nc"
# ]


Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
# if model == "attention":
# test_file = [
# f"test_files/test_1x1_inputfeatures_u_v_theta_w_uw_vw_era5_training_data_hourly_2015_constant_mu_sigma_scaling08.nc"
# ]
# elif model == "ann":
# if stencil == 1:
# test_file = [
# f"test_files/test_1x1_inputfeatures_u_v_theta_w_uw_vw_era5_training_data_hourly_2015_constant_mu_sigma_scaling08.nc"
# ]
# elif stencil == 3:
# test_file = [
# f"test_files/test_nonlocal_3x3_inputfeatures_u_v_theta_w_uw_vw_era5_training_data_hourly_2015_constant_mu_sigma_scaling08.nc"
# ]

best practice to remove unused code rather than "commenting" it out. We can always go back to an earlier version (through git/GitHib) if we need it again.

Comment on lines +89 to +91
# print(
# f"Model created. \n --- model size: {model.totalsize():.2f} MBs,\n --- Num params: {model.totalparams()/10**6:.3f} mil. "
# )
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same here. We can either delete it, or move it to the logger

Comment on lines +103 to +105
# print(
# f"Model created. \n --- model size: {model.totalsize():.2f} MBs,\n --- Num params: {model.totalparams()/10**6:.3f} mil. "
# )
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

and here

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants