Skip to content

Commit

Permalink
Test model forward
Browse files Browse the repository at this point in the history
  • Loading branch information
wesselb committed Jul 30, 2024
1 parent 23407ef commit 355ea08
Show file tree
Hide file tree
Showing 4 changed files with 23 additions and 8 deletions.
8 changes: 7 additions & 1 deletion aurora/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,11 @@
"""Copyright (c) Microsoft Corporation. Licensed under the MIT license."""

from aurora.batch import Batch, Metadata
from aurora.model.aurora import Aurora, AuroraSmall

__all__ = ["Aurora", "AuroraSmall"]
__all__ = [
"Aurora",
"AuroraSmall",
"Batch",
"Metadata",
]
10 changes: 8 additions & 2 deletions aurora/model/perceiver_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ def forward(
self,
x: torch.Tensor,
batch: Batch,
pres: Int3Tuple,
patch_res: Int3Tuple,
lead_time: timedelta,
) -> Batch:
"""Forward pass of MultiScaleEncoder.
Expand All @@ -135,7 +135,13 @@ def forward(
H, W = lat.shape[0], lon.shape[-1]

# Unwrap the latent level dimension.
x = rearrange(x, "B (C H W) D -> B (H W) C D", C=pres[0], H=pres[1], W=pres[2])
x = rearrange(
x,
"B (C H W) D -> B (H W) C D",
C=patch_res[0],
H=patch_res[1],
W=patch_res[2],
)

# Decode surface vars.
x_surf = self.surf_head(x[..., :1, :]) # (B, L, 1, V_S*p*p)
Expand Down
6 changes: 3 additions & 3 deletions aurora/model/perceiver_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,9 +165,9 @@ def forward(self, batch: Batch, lead_time: timedelta) -> torch.Tensor:
atmos_vars = tuple(batch.atmos_vars.keys())
atmos_levels = batch.metadata.atmos_levels

x_surf = torch.stack(batch.surf_vars.values(), dim=2)
x_static = torch.stack(batch.static_vars.values(), dim=2)
x_atmos = torch.stack(batch.atmos_vars.values(), dim=2)
x_surf = torch.stack(tuple(batch.surf_vars.values()), dim=2)
x_static = torch.stack(tuple(batch.static_vars.values()), dim=2)
x_atmos = torch.stack(tuple(batch.atmos_vars.values()), dim=2)

B, T, _, C, H, W = x_atmos.size()
assert x_surf.shape[:2] == (
Expand Down
7 changes: 5 additions & 2 deletions aurora/model/swin_3d_block.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
"""

import itertools
from datetime import timedelta
from functools import lru_cache

import torch
Expand Down Expand Up @@ -650,7 +651,7 @@ def get_encoder_specs(self, patch_res: Int3Tuple) -> tuple[list[Int3Tuple], list
return all_res, padded_outs

def forward(
self, x: torch.Tensor, t: torch.Tensor, rollout_step: int, patch_res: Int3Tuple
self, x: torch.Tensor, lead_time: timedelta, rollout_step: int, patch_res: Int3Tuple
) -> torch.Tensor:
assert (
x.shape[1] == patch_res[0] * patch_res[1] * patch_res[2]
Expand All @@ -664,7 +665,9 @@ def forward(

all_enc_res, padded_outs = self.get_encoder_specs(patch_res)

c = self.time_mlp(lead_time_expansion(t, self.embed_dim).to(dtype=x.dtype))
lead_hours = lead_time / timedelta(hours=1)
lead_times = lead_hours * torch.ones(x.shape[0], dtype=torch.float32, device=x.device)
c = self.time_mlp(lead_time_expansion(lead_times, self.embed_dim).to(dtype=x.dtype))

skips = []
for i, layer in enumerate(self.encoder_layers):
Expand Down

0 comments on commit 355ea08

Please sign in to comment.