Skip to content

Commit

Permalink
Document constructor of Aurora
Browse files Browse the repository at this point in the history
  • Loading branch information
wesselb committed Aug 15, 2024
1 parent 6251363 commit 767b313
Show file tree
Hide file tree
Showing 3 changed files with 45 additions and 7 deletions.
49 changes: 43 additions & 6 deletions aurora/model/aurora.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,6 @@

__all__ = ["Aurora", "AuroraSmall"]

VariableList = tuple[str, ...]
"""type: Tuple of variable names."""


class Aurora(torch.nn.Module):
"""The Aurora model.
Expand All @@ -26,9 +23,9 @@ class Aurora(torch.nn.Module):

def __init__(
self,
surf_vars: VariableList = ("2t", "10u", "10v", "msl"),
static_vars: VariableList = ("lsm", "z", "slt"),
atmos_vars: VariableList = ("z", "u", "v", "t", "q"),
surf_vars: tuple[str, ...] = ("2t", "10u", "10v", "msl"),
static_vars: tuple[str, ...] = ("lsm", "z", "slt"),
atmos_vars: tuple[str, ...] = ("z", "u", "v", "t", "q"),
window_size: Int3Tuple = (2, 6, 12),
encoder_depths: tuple[int, ...] = (6, 10, 8),
encoder_num_heads: tuple[int, ...] = (8, 16, 32),
Expand All @@ -49,6 +46,46 @@ def __init__(
use_lora: bool = True,
lora_steps: int = 40,
) -> None:
"""Construct an instance of the model.
Args:
surf_vars (tuple[str, ...], optional): All surface-level variables supported by the
model. The model is sensitive to the order of `surf_vars`!
static_vars (tuple[str, ...], optional): All static variables supported by the
model. The model is sensitive to the order of `static_vars`!
atmos_vars (tuple[str, ...], optional): All atmospheric variables supported by the
model. The model is sensitive to the order of `atmos-vars`!
window_size (tuple[int, int, int], optional): Vertical height, height, and width of the
window of the underlying Swin transformer.
encoder_depths (tuple[int, ...], optional): Number of blocks in each encoder layer.
encoder_num_header (tuple[int, ...], optional) Number of attention heads in each encoder
layer.
decoder_depths (tuple[int, ...], optional): Number of blocks in each decoder layer.
decoder_num_header (tuple[int, ...], optional) Number of attention heads in each decoder
layer.
latent_levels (int, optional): Number of latent pressure levels.
patch_size (int, optional): Patch size.
embed_dim (int, optional): Patch embedding dimension.
num_heads (int, optional): Number of attention heads in the aggregation and
deaggregation blocks.
mlp_ratio (float, optional): Hidden dim. to embedding dim. ratio for MLPs.
drop_rate (float, optional): Drop-out rate.
drop_path (float, optional): Drop-path rate.
enc_depth (int, optional): Number of Perceiver blocks in the encoder.
dec_depth (int, optioanl): Number of Perceiver blocks in the decoder.
dec_mlp_ratio (float, optional): Hidden dim. to embedding dim. ratio for MLPs in the
decoder. The embedding dimensionality here is different, which is why this is a
separate parameter.
perceiver_ln_eps (float, optional): Epsilon in the perceiver layer norm. layers. Used
to stabilise the model.
max_history_size (int, optional): Maximum number of history steps.
use_lora (bool, optional): Use LoRA adaptation.
lora_steps (int, optional): Use different LoRA adaptation for the first so-many roll-out
steps.
Returns:
:class:`aurora.model.aurora.Aurora`: Instance.
"""
super().__init__()
self.surf_vars = surf_vars
self.atmos_vars = atmos_vars
Expand Down
2 changes: 1 addition & 1 deletion aurora/rollout.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ def rollout(model: Aurora, batch: Batch, steps: int) -> Generator[Batch, None, N
"""Perform a roll-out to make long-term predictions.
Args:
model (:class:`aurora.model.aurora.Aurora`): The model to roll-out.
model (:class:`aurora.model.aurora.Aurora`): The model to roll out.
batch (:class:`aurora.batch.Batch`): The batch to start the roll-out from.
steps (int): The number of roll-out steps.
Expand Down
1 change: 1 addition & 0 deletions docs/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ Roll-Outs
Models
------
.. autoclass:: aurora.Aurora
:special-members: __init__
:members:

.. autoclass:: aurora.AuroraSmall
Expand Down

0 comments on commit 767b313

Please sign in to comment.