From 86888a6a4a9e3b5f21ce343872ff953b47b74c8b Mon Sep 17 00:00:00 2001 From: Baudouin Raoult Date: Wed, 28 Aug 2024 15:28:17 +0000 Subject: [PATCH] work in progress --- README.md | 2 - src/ai_models_aurora/model.py | 112 +++++++++++++++++++++++++++++++--- 2 files changed, 103 insertions(+), 11 deletions(-) diff --git a/README.md b/README.md index 774c7ac..62aef71 100644 --- a/README.md +++ b/README.md @@ -4,5 +4,3 @@ See https://microsoft.github.io/aurora/intro.html See https://github.com/microsoft/aurora - - diff --git a/src/ai_models_aurora/model.py b/src/ai_models_aurora/model.py index 52f90b6..74ce6c9 100644 --- a/src/ai_models_aurora/model.py +++ b/src/ai_models_aurora/model.py @@ -7,25 +7,119 @@ import logging -from anemoi.inference.plugin import AIModelPlugin -from aurora import Aurora, AuroraSmall, rollout +import numpy as np +import torch +from ai_models.model import Model +from aurora import Aurora +from aurora import Batch +from aurora import Metadata +from aurora import rollout LOG = logging.getLogger(__name__) -class AuroraModel(AIModelPlugin): - expver = "auro" +class AuroraModel(Model): + + # Input + area = [90, 0, -90, 360 - 0.25] + grid = [0.25, 0.25] + + surf_vars = ("2t", "10u", "10v", "msl") + static_vars = ("lsm", "z", "slt") + atmos_vars = ("z", "u", "v", "t", "q") + levels = (1000, 925, 850, 700, 600, 500, 400, 300, 250, 200, 150, 100, 50) + lagged = (-6, 0) - # Download - download_files = ["checkpoint.ckpt"] + # For the MARS requets + param_sfc = surf_vars + static_vars + param_level_pl = (atmos_vars, levels) + + # Output + + expver = "auro" def __init__(self, **kwargs): super().__init__(**kwargs) - self.model =Aurora() + self.model = Aurora() def run(self): - model = AuroraSmall() - model.load_checkpoint("microsoft/aurora", "aurora-0.25-small-pretrained.ckpt") + LOG.info("Running Aurora model") + model = Aurora(use_lora=False) # Model is not fine-tuned. + model = model.to(self.device) + LOG.info("Downloading Aurora model") + # TODO: control location of cache + model.load_checkpoint("microsoft/aurora", "aurora-0.25-pretrained.ckpt") + LOG.info("Loading Aurora model to device %s", self.device) + + model = model.to(self.device) + model.eval() + + fields_pl = self.fields_pl + fields_sfc = self.fields_sfc + + N, W, S, E = self.area + WE, NS = self.grid + Nj = round((N - S) / NS) + 1 + Ni = round((E - W) / WE) + 1 + + to_numpy_kwargs = dict(dtype=np.float32) + + + + + # Shape (Batch, Time, Lat, Lon) + surf_vars = {} + + for k in self.surf_vars: + f = fields_sfc.sel(param=k).order_by(valid_datetime="ascending") + f = f.to_numpy(**to_numpy_kwargs) + f = torch.from_numpy(f) + f = f.unsqueeze(0) # Add batch dimension + print(f.shape) + surf_vars[k] = f + + # Shape (Lat, Lon) + static_vars = {} + for k in self.static_vars: + f = fields_sfc.sel(param=k).order_by(valid_datetime="ascending") + f =f.to_numpy(**to_numpy_kwargs)[-1] + f = torch.from_numpy(f) + print(f.shape) + static_vars[k] = f + + # Shape (Batch, Time, Level, Lat, Lon) + atmos_vars = {} + for k in self.atmos_vars: + f = fields_pl.sel(param=k).order_by(valid_datetime="ascending", level=self.levels) + f = f.to_numpy(**to_numpy_kwargs).reshape(len(self.lagged), len(self.levels), Nj, Ni) + f = torch.from_numpy(f) + f = f.unsqueeze(0) # Add batch dimension + print(f.shape) + atmos_vars[k] = f + + + # https://microsoft.github.io/aurora/batch.html + + batch = Batch( + surf_vars=surf_vars, + static_vars=static_vars, + atmos_vars=atmos_vars, + metadata=Metadata( + lat=torch.linspace(N, S, Nj), + lon=torch.linspace(W, E, Ni), + time=self.start_datetime, + atmos_levels=self.levels, + ), + ) + + print(batch.metadata.lat.shape) + print(batch.metadata.lon.shape) + + with torch.inference_mode(): + + for pred in rollout(model, batch, steps=10): + print(pred.metadata.time) + model = AuroraModel