-
Notifications
You must be signed in to change notification settings - Fork 1
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Logging argparsing #34
base: ablation_10layers
Are you sure you want to change the base?
Conversation
…all folders, ann, attention, and TL
… Load one input checkpoints, infer, return output.
…n UNet. Not implemented for TL so far.
… Attention UNets, and Transfer Learning
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @amangupta2 , these changes will greatly improve the scripts and make them easier to run and more robust.
The argparse now adds additional sanity checks to your inputs which will help prevent user error. Great job 👍
I have made a couple of suggestions. We can discuss them in the next meeting if you prefer.
@@ -119,7 +116,7 @@ def Training_ANN_CNN( | |||
|
|||
|
|||
def Inference_and_Save_ANN_CNN( | |||
model, testset, testloader, bs_test, device, stencil, log_filename, outfile | |||
model, testset, testloader, bs_test, device, stencil, logger, outfile |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
FYI, you shouldn't pass the logger as an argument. It should be configured at the file level. This will give better logging information.
model, testset, testloader, bs_test, device, stencil, logger, outfile | |
model, testset, testloader, bs_test, device, stencil, outfile |
Use something like this at the start of each file, e.g., for this specific file it would look like:
from netCDF4 import Dataset
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import logging
logger = logging.getLogger(__name__)
def Inference_and_Save_AttentionUNet( | ||
model, testset, testloader, bs_test, device, log_filename, outfile | ||
): | ||
def Inference_and_Save_AttentionUNet(model, testset, testloader, bs_test, device, logger, outfile): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same here
def Inference_and_Save_AttentionUNet(model, testset, testloader, bs_test, device, logger, outfile): | |
def Inference_and_Save_AttentionUNet(model, testset, testloader, bs_test, device, outfile): |
parser.add_argument("-c", "--ckpt_dir", default=".", help="Checkpoint directory") | ||
parser.add_argument("-o", "--output_dir", default=".", help="Output directory to save outputs") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
for path arguments I would set their type to be path
e.g., something like:
parser.add_argument("-c", "--ckpt_dir", default=".", help="Checkpoint directory") | |
parser.add_argument("-o", "--output_dir", default=".", help="Output directory to save outputs") | |
from pathlib import Path | |
parser.add_argument("-c", "--ckpt_dir", default=Path.cwd(), help="Checkpoint directory", type=Path) | |
parser.add_argument("-o", "--output_dir", default=Path.cwd(), help="Output directory to save outputs", type=Path) |
Unfortunately one downside of argparse
compared to click
is it does not verify paths. So you can still pass in anything.
The benefit to this approach is however, it is platform independent as Path.cwd() should work on Windows too.
Path.cwd()
is equivalent to Path(".")
. It is the path for the Current Working Directory (CWD).
This gives you some extra protection.
"-m", | ||
"--month", | ||
type=int, | ||
choices=[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
you could simplify this a little:
choices=[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12], | |
choices=range(1,13), | |
metavar="{1,2,...,12}", |
metavar
isn't required but it makes the help output look nicer
@@ -21,15 +25,65 @@ | |||
|
|||
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") | |||
|
|||
parser = argparse.ArgumentParser() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
similar changes here to argparse as before for ann_cnn_training/inference.py
"-m", | ||
"--month", | ||
type=int, | ||
choices=[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
similar changes here to argparse as before for ann_cnn_training/inference.py
type=int, | ||
help="checkpoint (epoch)of the model to be used for transfer learning", | ||
) | ||
parser.add_argument( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
similar changes here to argparse as before for ann_cnn_training/inference.py
# if model == "attention": | ||
# test_file = [ | ||
# f"test_files/test_1x1_inputfeatures_u_v_theta_w_uw_vw_era5_training_data_hourly_2015_constant_mu_sigma_scaling08.nc" | ||
# ] | ||
# elif model == "ann": | ||
# if stencil == 1: | ||
# test_file = [ | ||
# f"test_files/test_1x1_inputfeatures_u_v_theta_w_uw_vw_era5_training_data_hourly_2015_constant_mu_sigma_scaling08.nc" | ||
# ] | ||
# elif stencil == 3: | ||
# test_file = [ | ||
# f"test_files/test_nonlocal_3x3_inputfeatures_u_v_theta_w_uw_vw_era5_training_data_hourly_2015_constant_mu_sigma_scaling08.nc" | ||
# ] | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
# if model == "attention": | |
# test_file = [ | |
# f"test_files/test_1x1_inputfeatures_u_v_theta_w_uw_vw_era5_training_data_hourly_2015_constant_mu_sigma_scaling08.nc" | |
# ] | |
# elif model == "ann": | |
# if stencil == 1: | |
# test_file = [ | |
# f"test_files/test_1x1_inputfeatures_u_v_theta_w_uw_vw_era5_training_data_hourly_2015_constant_mu_sigma_scaling08.nc" | |
# ] | |
# elif stencil == 3: | |
# test_file = [ | |
# f"test_files/test_nonlocal_3x3_inputfeatures_u_v_theta_w_uw_vw_era5_training_data_hourly_2015_constant_mu_sigma_scaling08.nc" | |
# ] |
best practice to remove unused code rather than "commenting" it out. We can always go back to an earlier version (through git/GitHib) if we need it again.
# print( | ||
# f"Model created. \n --- model size: {model.totalsize():.2f} MBs,\n --- Num params: {model.totalparams()/10**6:.3f} mil. " | ||
# ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same here. We can either delete it, or move it to the logger
# print( | ||
# f"Model created. \n --- model size: {model.totalsize():.2f} MBs,\n --- Num params: {model.totalparams()/10**6:.3f} mil. " | ||
# ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
and here
#28 and #29
-- Built on branch ablation_10layers
-- Replaced all print statements with logger.info
-- And now parsing command line arguments using argparse. Agreed, it is much neater and convenient.
If all looks good, we can merge this with ablation 10_layers and merge that with main.