Skip to content

Commit

Permalink
tidy code
Browse files Browse the repository at this point in the history
  • Loading branch information
b8raoult committed Sep 2, 2024
1 parent ef96388 commit e7ab127
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 21 deletions.
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -165,3 +165,6 @@ slurm-*.out
?.*
?
_version.py
*.pickle
hub/
.vscode/settings.json
50 changes: 29 additions & 21 deletions src/ai_models_aurora/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,13 +43,13 @@ class AuroraModel(Model):

expver = "auro"

use_an = False

def run(self):

# TODO: control location of cache

model = self.klass(use_lora=False)
LOG.info(f"Model is {self.__class__.__name__}, use_lora={self.use_lora}")

model = self.klass(use_lora=self.use_lora)
model = model.to(self.device)
LOG.info("Downloading Aurora model %s", self.checkpoint)
model.load_checkpoint("microsoft/aurora", self.checkpoint, strict=False)
Expand Down Expand Up @@ -158,10 +158,16 @@ class Aurora2p5(AuroraModel):

# https://microsoft.github.io/aurora/models.html#aurora-0-25-pretrained
class Aurora2p5Pretrained(Aurora2p5):
use_lora = False
checkpoint = "aurora-0.25-pretrained.ckpt"


class UseIFSMixin:
# https://microsoft.github.io/aurora/models.html#aurora-0-25-fine-tuned
class Aurora2p5FineTuned(Aurora2p5):
use_lora = True
checkpoint = "aurora-0.25-finetuned.ckpt"

# We want FC, step=0
def patch_retrieve_request(self, r):
if r.get("class", "od") != "od":
return
Expand All @@ -172,10 +178,7 @@ def patch_retrieve_request(self, r):
if r.get("stream", "oper") not in ("oper", "scda"):
return

if self.use_an:
r["type"] = "an"
else:
r["type"] = "fc"
r["type"] = "fc"

time = r.get("time", 12)

Expand All @@ -187,25 +190,30 @@ def patch_retrieve_request(self, r):
}[time]


# https://microsoft.github.io/aurora/models.html#aurora-0-25-fine-tuned
class Aurora2p5FineTuned(UseIFSMixin, Aurora2p5):
checkpoint = "aurora-0.25-finetuned.ckpt"


class Aurora0p1(AuroraModel):
klass = AuroraHighRes

# https://microsoft.github.io/aurora/models.html#aurora-0-1-fine-tuned
class Aurora0p1FineTuned(AuroraModel):
download_files = ("aurora-0.1-static.pickle",)
# Input
area = [90, 0, -90, 360 - 0.1]
grid = [0.1, 0.1]


# https://microsoft.github.io/aurora/models.html#aurora-0-1-fine-tuned
class Aurora0p1FineTuned(Aurora0p1):
klass = AuroraHighRes

use_lora = True
checkpoint = "aurora-0.1-finetuned.ckpt"


model = Aurora0p1FineTuned
# model = Aurora0p1FineTuned


def model(model_version, **kwargs):

# select with --model-version

models = {
"0.25-pretrained": Aurora2p5Pretrained,
"0.25-finetuned": Aurora2p5FineTuned,
"0.1-finetuned": Aurora0p1FineTuned,
"default": Aurora0p1FineTuned,
"latest": Aurora0p1FineTuned, # Backward compatibility
}
return models[model_version](**kwargs)

0 comments on commit e7ab127

Please sign in to comment.