Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added definition for a 10 hidden layer ANN for Ablation tests #30

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
129 changes: 129 additions & 0 deletions ann_cnn_training/model_definition.py
Original file line number Diff line number Diff line change
Expand Up @@ -361,3 +361,132 @@ def totalsize(self):
# print('model size: {:.3f}MB'.format(size_all_mb))

return size_all_mb


class ANN_CNN10(nn.Module):
def __init__(self, idim, odim, hdim, stencil, dropout=0):
super().__init__()

self.idim = idim
self.odim = odim
self.hdim = hdim
self.dropout_prob = dropout
self.stencil = stencil
self.fac = np.floor(0.5 * self.stencil)

# assume normalized data as input
# same activation for all layers
# same dropout probabilities for all layers

# Applying multiple 3x3 conv layers than just one stencilxstencil layer performs better
if self.fac == 1:
self.conv1 = nn.Conv2d(
in_channels=idim, out_channels=idim, kernel_size=3, stride=1, padding=0
)
self.act_cnn = nn.ReLU()
self.dropout0 = nn.Dropout(p=0.5 * self.dropout_prob)

elif self.fac == 2:
self.conv1 = nn.Conv2d(
in_channels=idim, out_channels=idim, kernel_size=3, stride=1, padding=0
)
self.act_cnn = nn.ReLU()
self.dropout0 = nn.Dropout(p=0.5 * self.dropout_prob)
# print('-CNN 1')
self.conv2 = nn.Conv2d(
in_channels=idim, out_channels=idim, kernel_size=3, stride=1, padding=0
)
self.act_cnn2 = nn.ReLU()
self.dropout0_2 = nn.Dropout(p=0.5 * self.dropout_prob)

elif self.fac == 3: # this is not ready yet, and might not be needed for my study
self.conv1 = nn.Conv2d(
in_channels=idim, out_channels=idim, kernel_size=self.stencil, stride=1, padding=0
)
self.act_cnn = nn.ReLU()
self.dropout0 = nn.Dropout(p=0.5 * self.dropout_prob)

# can define a block and divide it into blocks as well
self.layer1 = nn.Linear(idim, hdim) # ,dtype=torch.float16)
self.act1 = nn.LeakyReLU()
# self.bnorm1 = nn.BatchNorm1d(hdim)

self.dropout = nn.Dropout(p=self.dropout_prob)

self.layer2 = nn.Linear(hdim, hdim)
self.act2 = nn.LeakyReLU()
# self.bnorm2 = nn.BatchNorm1d(hdim)
# -------------------------------------------------------
self.layer3 = nn.Linear(hdim, hdim)
self.act3 = nn.LeakyReLU()
# self.bnorm3 = nn.BatchNorm1d(hdim)
# -------------------------------------------------------
self.layer4 = nn.Linear(hdim, hdim)
self.act4 = nn.LeakyReLU()
# self.bnorm4 = nn.BatchNorm1d(2 * hdim)
# --------------------------------------------------------
self.layer5 = nn.Linear(hdim, hdim)
self.act5 = nn.LeakyReLU()
# self.bnorm5 = nn.BatchNorm1d(hdim)
# -------------------------------------------------------

self.layer6 = nn.Linear(hdim, hdim)
self.act6 = nn.LeakyReLU()
# -------------------------------------------------------
self.layer7 = nn.Linear(hdim, hdim)
self.act7 = nn.LeakyReLU()
# -------------------------------------------------------
self.layer8 = nn.Linear(hdim, hdim)
self.act8 = nn.LeakyReLU()
# -------------------------------------------------------
self.layer9 = nn.Linear(hdim, hdim)
self.act9 = nn.LeakyReLU()
# -------------------------------------------------------
self.layer10 = nn.Linear(hdim, 2 * odim)
self.act10 = nn.LeakyReLU()
Comment on lines +416 to +446
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should gather this into a list and loop over the layers and activation functions

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

see here for example:

https://github.com/DataWaveProject/CAM_GW_pytorch_emulator/blob/fd780f5b23fbdbe83d0bf7b33f3ff5c3e216fede/newCAM_emulation/Model.py#L61-L68

        layers = []
        input_size = in_ver * ilev + in_nover  
        for _ in range(hidden_layers):
            layers.append(nn.Linear(input_size, hidden_size, dtype=torch.float64))
            layers.append(nn.SiLU())
            input_size = hidden_size
        layers.append(nn.Linear(hidden_size, out_ver * ilev, dtype=torch.float64))
        self.linear_stack = nn.Sequential(*layers)

# -------------------------------------------------------

self.output = nn.Linear(2 * odim, odim)

def forward(self, x):
if self.fac == 1:
x = torch.squeeze(self.dropout0(self.act_cnn(self.conv1(x))))
elif self.fac == 2:
x = torch.squeeze(self.dropout0(self.act_cnn(self.conv1(x))))
x = torch.squeeze(self.dropout0_2(self.act_cnn2(self.conv2(x))))

x = self.dropout(self.act1(self.layer1(x)))
x = self.dropout(self.act2(self.layer2(x)))
x = self.dropout(self.act3(self.layer3(x)))
x = self.dropout(self.act4(self.layer4(x)))
x = self.dropout(self.act5(self.layer5(x)))
x = self.dropout(self.act6(self.layer6(x)))
x = self.dropout(self.act7(self.layer7(x)))
x = self.dropout(self.act8(self.layer8(x)))
x = self.dropout(self.act9(self.layer9(x)))
x = self.dropout(self.act10(self.layer10(x)))
Comment on lines +458 to +467
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

similarly here:

We should gather this into a list and loop over the layers and activation functions

x = self.output(x)

return x

# calculates total number of learnable parameters
def totalparams(self):
param_size = 0
for param in self.parameters():
param_size += param.nelement()

return param_size

# computes total model size in MBs
def totalsize(self):
param_size = 0
for param in self.parameters():
param_size += param.nelement() * param.element_size()
buffer_size = 0
for buffer in self.buffers():
buffer_size += buffer.nelement() * buffer.element_size()

size_all_mb = (param_size + buffer_size) / 1024**2
# print('model size: {:.3f}MB'.format(size_all_mb))

return size_all_mb
26 changes: 21 additions & 5 deletions ann_cnn_training/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
import pandas as pd

from dataloader_definition import Dataset_ANN_CNN
from model_definition import ANN_CNN
from model_definition import ANN_CNN, ANN_CNN10
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As part of the refactor we should perhaps make number of layers a parameter rather than a new class?

from function_training import Training_ANN_CNN

torch.set_printoptions(edgeitems=2)
Expand All @@ -49,6 +49,7 @@
restart = False
init_epoch = 1 # where to resume. Should have checkpoint saved for init_epoch-1. 1 for fresh runs.
nepochs = 100
ablation = True
# ----------------------
domain = sys.argv[1] # global' # 'regional'
vertical = sys.argv[2] #'global' # or 'stratosphere_only' or 'stratosphere_update'
Expand All @@ -67,7 +68,12 @@
bs_test = bs_train
dropout = 0.1

log_filename = f"./ann_cnns_{stencil}x{stencil}_{domain}_{vertical}_{features}_epoch_{init_epoch}_to_{init_epoch+nepochs-1}.txt"
# log_filename = f"./ann_cnns_{stencil}x{stencil}_{domain}_{vertical}_{features}_epoch_{init_epoch}_to_{init_epoch+nepochs-1}.txt"

if not ablation:
log_filename = f"./ann_cnns_{stencil}x{stencil}_{domain}_{vertical}_{features}_epoch_{init_epoch}_to_{init_epoch+nepochs-1}.txt"
else:
log_filename = f"./ann_cnns_{stencil}x{stencil}_{domain}_{vertical}_{features}_epoch_{init_epoch}_to_{init_epoch+nepochs-1}_ABLATION_10hiddenlayers.txt"


def write_log(*args):
Expand Down Expand Up @@ -163,7 +169,10 @@ def write_log(*args):
hdim = 4 * idim
write_log(f"Input dim: {idim}, hidden dim: {hdim}, output dim: {odim}")

model = ANN_CNN(idim=idim, odim=odim, hdim=hdim, dropout=dropout, stencil=trainset.stencil)
if not ablation:
model = ANN_CNN(idim=idim, odim=odim, hdim=hdim, dropout=dropout, stencil=trainset.stencil)
else:
model = ANN_CNN10(idim=idim, odim=odim, hdim=hdim, dropout=dropout, stencil=trainset.stencil)
model = model.to(device)
optimizer = optim.Adam(model.parameters(), lr=1e-4)
scheduler = torch.optim.lr_scheduler.CyclicLR(
Expand All @@ -186,8 +195,15 @@ def write_log(*args):
# write_log('fac_created')


file_prefix = odir + f"{vertical}/ann_cnn_{stencil}x{stencil}_{domain}_{vertical}_era5_{features}_"
# write_log(f'file prefix: {file_prefix}')
if not ablation:
file_prefix = (
odir + f"{vertical}/ann_cnn_{stencil}x{stencil}_{domain}_{vertical}_era5_{features}_"
)
else:
file_prefix = (
odir
+ f"{vertical}/ann_cnn_{stencil}x{stencil}_{domain}_{vertical}_era5_{features}_ABLATION_10hiddenlayers_"
)
if restart:
# load checkpoint before resuming training
PATH = f"{file_prefix}_train_epoch{init_epoch-1}.pt"
Expand Down
Loading