Skip to content

Commit

Permalink
refactoring the lifting operator, adding more toggles
Browse files Browse the repository at this point in the history
  • Loading branch information
scaomath committed Jun 20, 2024
1 parent a1b344a commit b40b738
Show file tree
Hide file tree
Showing 3 changed files with 117 additions and 84 deletions.
90 changes: 55 additions & 35 deletions sfno/sfno.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def forward(self, x: torch.Tensor):

class MLP(nn.Module):
def __init__(
self, in_channels, out_channels, mid_channels, activation: str = "GELU"
self, in_channels, out_channels, mid_channels, activation: str = "ReLU"
):
super().__init__()
self.mlp1 = nn.Conv3d(in_channels, mid_channels, 1)
Expand All @@ -70,46 +70,43 @@ class PositionalEncoding(nn.Module):
https://pytorch.org/tutorials/beginner/transformer_tutorial.html
a modified sinosoidal PE inspired from the Transformers
input is (batch, *, nx, ny, t)
output is (batch, C+*, nx, ny, t)
output is (batch, C, nx, ny, t)
1 comes from the input
time_exponential_scale comes from the a priori estimate of Navier-Stokes Eqs
the random feature basis are added with a pointwise conv3d
"""

def __init__(
self,
self,
modes_x: int = 16,
modes_y: int = 16,
modes_t: int = 5,
num_channels: int = 20,
input_shape=(64, 64, 10),
channel_expansion=False,
spatial_random_feats=False,
max_time_steps=100,
time_exponential_scale=1e-2,
**kwargs,
):
super().__init__()
if channel_expansion:
self.num_channels = modes_x * modes_y * modes_t + 3 # the Euclidean coords
self.proj = nn.Identity()
else:
self.num_channels = num_channels # the Euclidean coords
self.proj = nn.Conv3d(
modes_x * modes_y * modes_t + 3, num_channels, kernel_size=1
)

assert num_channels % 2 == 0 and num_channels > 3
self.num_channels = num_channels # the Euclidean coords
self.max_time_steps = max_time_steps
self.time_exponential_scale = time_exponential_scale
self.modes_x = modes_x
self.modes_y = modes_y
self.modes_t = modes_t
assert modes_x % 2 == 0
assert modes_y % 2 == 0
assert modes_t % 2 == 0

self.max_time_steps = max_time_steps
self.time_exponential_scale = time_exponential_scale
self._pe = self._pe_expanded if spatial_random_feats else self._pe
self._pe(*input_shape)
if spatial_random_feats:
in_chan = modes_x * modes_y * modes_t + 3
self.proj = nn.Conv3d(in_chan, num_channels, kernel_size=1)
else:
self.proj = nn.Identity()


def _pe(self, *shape):
def _pe_expanded(self, *shape):
nx, ny, nt = shape
gridx = torch.linspace(0, 1, nx)
gridy = torch.linspace(0, 1, ny)
Expand All @@ -131,18 +128,33 @@ def _pe(self, *shape):
* basis_y(torch.pi * j * gridy)
* basis_t(torch.pi * k * gridt)
)
if i == 0 and j == 0 and k == 0:
print(basis.shape)
pe.append(basis)
pe = torch.stack(pe).unsqueeze(0) # (1, num_channels+3, nx, ny, nt)
self.pe = pe

def forward(self, v):
def _pe(self, *shape):
nx, ny, nt = shape
gridx = torch.linspace(0, 1, nx)
gridy = torch.linspace(0, 1, ny)
gridt = torch.linspace(0, 1, self.max_time_steps+1)[1:nt+1]
gridx, gridy, _gridt = torch.meshgrid(gridx, gridy, gridt, indexing="ij")
pe = [gridx, gridy, _gridt]
for k in range(self.num_channels - 3):
basis = torch.sin if k % 2 == 0 else torch.cos
_gridt = torch.exp(self.time_exponential_scale * gridt) * basis(
torch.pi * (k + 1) * gridt
)
_gridt = _gridt.reshape(1, 1, nt).repeat(nx, ny, 1)
pe.append(_gridt)
pe = torch.stack(pe).unsqueeze(0) # (1, num_channels+3, nx, ny, nt)
self.pe = pe

def forward(self, v: torch.Tensor):
if self.pe is None or self.pe.shape[-3:] != v.shape[-3:]:
*_, nx, ny, nt = v.size() # (batch, 1, x, y, t)
self._pe(nx, ny, nt)
pe = self.pe.to(v.dtype).to(v.device)
return self.proj(v + pe)
return v + self.proj(pe)


class Helmholtz(nn.Module):
Expand Down Expand Up @@ -229,10 +241,11 @@ def __init__(
modes_t,
latent_steps=10,
norm="backward",
activation: str = "GELU",
activation:str="GELU",
beta=0.1,
pe_channel_expansion=False,
spatial_random_feats=False,
channel_expansion=128,
nonlinear=True,
**kwargs,
) -> None:
"""
Expand All @@ -243,13 +256,12 @@ def __init__(
pe_modes_t = modes_t - 1

self.pe = PositionalEncoding(
modes_x // 2,
modes_y // 2,
pe_modes_t,
modes_x//2,
modes_y//2,
pe_modes_t//2,
num_channels=width,
time_exponential_scale=beta,
channel_expansion=pe_channel_expansion,
)
spatial_random_feats=spatial_random_feats,)

in_channels = self.pe.num_channels
self.norm = LayerNorm3d(in_channels)
Expand All @@ -262,13 +274,17 @@ def __init__(
norm=norm,
bias=False,
)
self.nonlinear = getattr(nn, activation)()
self.mlp = MLP(width, width, channel_expansion)

if nonlinear:
self.activation = getattr(nn, activation)()
self.mlp = MLP(width, width, channel_expansion, activation)
else:
self.activation = nn.Identity()
self.mlp = nn.Conv3d(width, width, kernel_size=1)

def forward(self, v):
for b in [self.pe, self.norm, self.proj]:
v = b(v)
v = self.nonlinear(v + self.mlp(self.sconv(v)))
v = self.activation(v + self.mlp(self.sconv(v)))
return v


Expand Down Expand Up @@ -519,6 +535,8 @@ def __init__(
spatial_padding: int = 0,
temporal_padding: bool = True,
channel_expansion: int = 128,
spatial_random_feats: bool = False,
lift_activation: bool = True,
latent_steps: int = 10,
output_steps: int = None,
debug=False,
Expand All @@ -535,7 +553,7 @@ def __init__(
1. New lifting operator
- new PE: since the treatment of grid is different from FNO official code, which give my autograd trouble, new PE is similar to the one used in the Transformers, the time dimension's PE is according to the NSE. The PE occupies the extra channels.
- new LayerNorm3d: instead of normalizing the input/output pointwisely when preparing the data like the original FNO did. Tthe normalization prevents to predict arbitrary time steps.
- new LayerNorm3d: instead of normalizing the input/output pointwisely when preparing the data like the original FNO did, this makes an input-steps agnostic normalization. Note that the global normalization by mean/std of (n, n, n_t)-shaped tensor in the original FNO3d prevents to predict arbitrary time steps.
- the channel lifting now works pretty much like the depth-wise conv but uses the globally spectral as FNO does. Since there is no need to treat the time steps as channels now it can accept arbitrary time steps in the input.
2. new out projection: it maps the latent time steps to a given output time steps using FFT's natural super-resolution.
- output arbitrary steps.
Expand Down Expand Up @@ -576,6 +594,8 @@ def __init__(
beta=beta,
activation=activation,
channel_expansion=channel_expansion,
spatial_random_feats=spatial_random_feats,
nonlinear=lift_activation,
)

act_func = getattr(nn, activation)
Expand Down
54 changes: 27 additions & 27 deletions sfno/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,9 +50,11 @@ def main(args):

log_filename = os.path.join(LOG_PATH, f"{current_time}_{log_name}.log")
logger = get_logger(log_filename)
logger.info(f"Saving log at {log_filename}")


all_args = {k: v for k, v in vars(args).items() if not callable(v)}
logger.info(" | ".join(f"{k}={v}" for k, v in all_args.items()))
logger.info("Arguments: "+" | ".join(f"{k}={v}" for k, v in all_args.items()))

example = args.example
Ntrain = args.num_samples
Expand All @@ -72,19 +74,19 @@ def main(args):
modes = args.modes
modes_t = args.modes_t
width = args.width
num_layers = args.num_layers
beta = args.beta
activation = args.activation
spatial_padding = args.spatial_padding
pe_trainable = args.pe_trainable
pe_channel_expansion = args.pe_channel_expansion
pe_experimental = args.pe_experimental
lift_experimental = args.lift_experimental
spatial_random_feats = args.spatial_random_feats
lift_activation = not args.lift_linear

seed = args.seed
eval_only = args.eval_only
train_only = args.train_only

get_seed(seed, quiet=True)
get_seed(seed, quiet=False, logger=logger)

beta_str = f"{beta:.0e}".replace("e-0", "e-").replace("e+0", "e")
model_name = f"sfno_ex_{example}_ep{epochs}_m{modes}_w{width}_b{beta_str}.pt"
Expand Down Expand Up @@ -120,14 +122,14 @@ def main(args):
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)

torch.cuda.empty_cache()
model = SFNO(modes, modes, modes_t, width, beta,
model = SFNO(modes, modes, modes_t, width, beta,
num_spectral_layers=num_layers,
output_steps=out_steps,
spatial_padding=spatial_padding,
activation=activation,
pe_trainable=pe_trainable,
pe_channel_expansion=pe_channel_expansion,
pe_experimental=pe_experimental,
lift_experimental=lift_experimental)
spatial_random_feats=spatial_random_feats,
lift_activation=lift_activation)
logger.info(f"Number of parameters: {get_num_params(model)}")
model.to(device)

Expand All @@ -144,17 +146,16 @@ def main(args):
)

loss_func = SobolevLoss(n_grid=n, norm_order=norm_order, relative=True)
logger.info(f"Loss func: {loss_func}")
get_config(loss_func, logger=logger)

for ep in range(epochs):
model.train()
train_l2 = 0.0

with tqdm(train_loader) as pbar:
time_epoch = datetime.now().strftime("%d-%b-%Y %H:%M:%S")
pbar.set_description(
f"{time_epoch} - Epoch [{ep+1}/{epochs}] train rel L2: {train_l2:.4e}"
)
t_ep = datetime.now().strftime("%d-%b-%Y %H:%M:%S")
tr_loss_str = f"current train rel L2: 0.0"
pbar.set_description(f"{t_ep} - Epoch [{ep+1:3d}/{epochs}] {tr_loss_str:>35}")
for i, data in enumerate(train_loader):
l2 = train_batch_ns(
model,
Expand All @@ -171,9 +172,8 @@ def main(args):
scheduler.step()

if i % 4 == 0:
pbar.set_description(
f"{time_epoch} - Epoch [{ep+1}/{epochs}] train rel L2: {l2.item():.4e}"
)
tr_loss_str = f"current train rel L2: {l2.item():.4e}"
pbar.set_description(f"{t_ep} - Epoch [{ep+1:3d}/{epochs}] {tr_loss_str:>35}")
pbar.update(4)
val_l2_min = 1e4
val_l2 = eval_epoch_ns(
Expand All @@ -187,10 +187,10 @@ def main(args):
if val_l2 < val_l2_min:
torch.save(model.state_dict(), path_model)
val_l2_min = val_l2
logger.info(
f"Epoch [{ep+1}/{epochs}] avg train rel L2: {train_l2/len(train_loader):.4e}"
)
logger.info(f"Epoch [{ep+1}/{epochs}] avg val rel L2: {val_l2:.4e}")
tr_loss_str = f"avg train rel L2: {train_l2/len(train_loader):.4e}"
val_loss_str = f"avg val rel L2: {val_l2:.4e}"
logger.info(f"Epoch [{ep+1:3d}/{epochs}] {tr_loss_str:>35}")
logger.info(f"Epoch [{ep+1:3d}/{epochs}] {val_loss_str:>35}")

logger.info(f"{epochs} epochs training complete. Model saved to {path_model}")

Expand All @@ -215,12 +215,12 @@ def main(args):
test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False)
torch.cuda.empty_cache()
model = SFNO(modes, modes, modes_t, width, beta,
num_spectral_layers=num_layers,
spatial_padding=spatial_padding,
activation=activation,
pe_trainable=pe_trainable,
pe_channel_expansion=pe_channel_expansion,
pe_experimental=pe_experimental,
lift_experimental=lift_experimental).to(device)
spatial_random_feats=spatial_random_feats,
lift_activation=lift_activation).to(device)
model.load_state_dict(torch.load(path_model))
logger.info(f"Loaded model from {path_model}")
eval_metric = SobolevLoss(n_grid=n_test, norm_order=norm_order, relative=True)
Expand Down Expand Up @@ -276,6 +276,7 @@ def main(args):
parser.add_argument("--width", type=int, default=10)
parser.add_argument("--modes", type=int, default=32)
parser.add_argument("--modes-t", type=int, default=5)
parser.add_argument("--num-layers", type=int, default=4)
parser.add_argument("--spatial-padding", type=int, default=0)
parser.add_argument("--time-steps", type=int, default=10)
parser.add_argument("--out-time-steps", type=int, default=10)
Expand All @@ -284,9 +285,8 @@ def main(args):
parser.add_argument("--beta", type=float, default=0.0)
parser.add_argument("--activation", type=str, default="GELU")
parser.add_argument("--pe-trainable", default=False, action="store_true")
parser.add_argument("--pe-experimental", default=False, action="store_true")
parser.add_argument("--pe-channel-expansion", default=False, action="store_true")
parser.add_argument("--lift-experimental", default=False, action="store_true")
parser.add_argument("--spatial-random-feats", default=False, action="store_true")
parser.add_argument("--lift-linear", default=False, action="store_true")
parser.add_argument("--double", default=False, action="store_true")
parser.add_argument("--norm-order", type=float, default=0.0)
parser.add_argument("--eval-only", default=False, action="store_true")
Expand Down
Loading

0 comments on commit b40b738

Please sign in to comment.