From 355ea0830ea625009e78237658321663ab6bfefd Mon Sep 17 00:00:00 2001 From: Wessel Bruinsma Date: Tue, 30 Jul 2024 15:23:02 +0200 Subject: [PATCH] Test model forward --- aurora/__init__.py | 8 +++++++- aurora/model/perceiver_decoder.py | 10 ++++++++-- aurora/model/perceiver_encoder.py | 6 +++--- aurora/model/swin_3d_block.py | 7 +++++-- 4 files changed, 23 insertions(+), 8 deletions(-) diff --git a/aurora/__init__.py b/aurora/__init__.py index 410d493..1f15436 100644 --- a/aurora/__init__.py +++ b/aurora/__init__.py @@ -1,5 +1,11 @@ """Copyright (c) Microsoft Corporation. Licensed under the MIT license.""" +from aurora.batch import Batch, Metadata from aurora.model.aurora import Aurora, AuroraSmall -__all__ = ["Aurora", "AuroraSmall"] +__all__ = [ + "Aurora", + "AuroraSmall", + "Batch", + "Metadata", +] diff --git a/aurora/model/perceiver_decoder.py b/aurora/model/perceiver_decoder.py index 86ca11a..1df09ea 100644 --- a/aurora/model/perceiver_decoder.py +++ b/aurora/model/perceiver_decoder.py @@ -108,7 +108,7 @@ def forward( self, x: torch.Tensor, batch: Batch, - pres: Int3Tuple, + patch_res: Int3Tuple, lead_time: timedelta, ) -> Batch: """Forward pass of MultiScaleEncoder. @@ -135,7 +135,13 @@ def forward( H, W = lat.shape[0], lon.shape[-1] # Unwrap the latent level dimension. - x = rearrange(x, "B (C H W) D -> B (H W) C D", C=pres[0], H=pres[1], W=pres[2]) + x = rearrange( + x, + "B (C H W) D -> B (H W) C D", + C=patch_res[0], + H=patch_res[1], + W=patch_res[2], + ) # Decode surface vars. x_surf = self.surf_head(x[..., :1, :]) # (B, L, 1, V_S*p*p) diff --git a/aurora/model/perceiver_encoder.py b/aurora/model/perceiver_encoder.py index 3ee2393..cb7c9f3 100644 --- a/aurora/model/perceiver_encoder.py +++ b/aurora/model/perceiver_encoder.py @@ -165,9 +165,9 @@ def forward(self, batch: Batch, lead_time: timedelta) -> torch.Tensor: atmos_vars = tuple(batch.atmos_vars.keys()) atmos_levels = batch.metadata.atmos_levels - x_surf = torch.stack(batch.surf_vars.values(), dim=2) - x_static = torch.stack(batch.static_vars.values(), dim=2) - x_atmos = torch.stack(batch.atmos_vars.values(), dim=2) + x_surf = torch.stack(tuple(batch.surf_vars.values()), dim=2) + x_static = torch.stack(tuple(batch.static_vars.values()), dim=2) + x_atmos = torch.stack(tuple(batch.atmos_vars.values()), dim=2) B, T, _, C, H, W = x_atmos.size() assert x_surf.shape[:2] == ( diff --git a/aurora/model/swin_3d_block.py b/aurora/model/swin_3d_block.py index 82e9f2b..a941ebf 100644 --- a/aurora/model/swin_3d_block.py +++ b/aurora/model/swin_3d_block.py @@ -7,6 +7,7 @@ """ import itertools +from datetime import timedelta from functools import lru_cache import torch @@ -650,7 +651,7 @@ def get_encoder_specs(self, patch_res: Int3Tuple) -> tuple[list[Int3Tuple], list return all_res, padded_outs def forward( - self, x: torch.Tensor, t: torch.Tensor, rollout_step: int, patch_res: Int3Tuple + self, x: torch.Tensor, lead_time: timedelta, rollout_step: int, patch_res: Int3Tuple ) -> torch.Tensor: assert ( x.shape[1] == patch_res[0] * patch_res[1] * patch_res[2] @@ -664,7 +665,9 @@ def forward( all_enc_res, padded_outs = self.get_encoder_specs(patch_res) - c = self.time_mlp(lead_time_expansion(t, self.embed_dim).to(dtype=x.dtype)) + lead_hours = lead_time / timedelta(hours=1) + lead_times = lead_hours * torch.ones(x.shape[0], dtype=torch.float32, device=x.device) + c = self.time_mlp(lead_time_expansion(lead_times, self.embed_dim).to(dtype=x.dtype)) skips = [] for i, layer in enumerate(self.encoder_layers):