diff --git a/aurora/model/aurora.py b/aurora/model/aurora.py index bbeb5d5..70b7e9f 100644 --- a/aurora/model/aurora.py +++ b/aurora/model/aurora.py @@ -14,9 +14,6 @@ __all__ = ["Aurora", "AuroraSmall"] -VariableList = tuple[str, ...] -"""type: Tuple of variable names.""" - class Aurora(torch.nn.Module): """The Aurora model. @@ -26,9 +23,9 @@ class Aurora(torch.nn.Module): def __init__( self, - surf_vars: VariableList = ("2t", "10u", "10v", "msl"), - static_vars: VariableList = ("lsm", "z", "slt"), - atmos_vars: VariableList = ("z", "u", "v", "t", "q"), + surf_vars: tuple[str, ...] = ("2t", "10u", "10v", "msl"), + static_vars: tuple[str, ...] = ("lsm", "z", "slt"), + atmos_vars: tuple[str, ...] = ("z", "u", "v", "t", "q"), window_size: Int3Tuple = (2, 6, 12), encoder_depths: tuple[int, ...] = (6, 10, 8), encoder_num_heads: tuple[int, ...] = (8, 16, 32), @@ -49,6 +46,46 @@ def __init__( use_lora: bool = True, lora_steps: int = 40, ) -> None: + """Construct an instance of the model. + + Args: + surf_vars (tuple[str, ...], optional): All surface-level variables supported by the + model. The model is sensitive to the order of `surf_vars`! + static_vars (tuple[str, ...], optional): All static variables supported by the + model. The model is sensitive to the order of `static_vars`! + atmos_vars (tuple[str, ...], optional): All atmospheric variables supported by the + model. The model is sensitive to the order of `atmos-vars`! + window_size (tuple[int, int, int], optional): Vertical height, height, and width of the + window of the underlying Swin transformer. + encoder_depths (tuple[int, ...], optional): Number of blocks in each encoder layer. + encoder_num_header (tuple[int, ...], optional) Number of attention heads in each encoder + layer. + decoder_depths (tuple[int, ...], optional): Number of blocks in each decoder layer. + decoder_num_header (tuple[int, ...], optional) Number of attention heads in each decoder + layer. + latent_levels (int, optional): Number of latent pressure levels. + patch_size (int, optional): Patch size. + embed_dim (int, optional): Patch embedding dimension. + num_heads (int, optional): Number of attention heads in the aggregation and + deaggregation blocks. + mlp_ratio (float, optional): Hidden dim. to embedding dim. ratio for MLPs. + drop_rate (float, optional): Drop-out rate. + drop_path (float, optional): Drop-path rate. + enc_depth (int, optional): Number of Perceiver blocks in the encoder. + dec_depth (int, optioanl): Number of Perceiver blocks in the decoder. + dec_mlp_ratio (float, optional): Hidden dim. to embedding dim. ratio for MLPs in the + decoder. The embedding dimensionality here is different, which is why this is a + separate parameter. + perceiver_ln_eps (float, optional): Epsilon in the perceiver layer norm. layers. Used + to stabilise the model. + max_history_size (int, optional): Maximum number of history steps. + use_lora (bool, optional): Use LoRA adaptation. + lora_steps (int, optional): Use different LoRA adaptation for the first so-many roll-out + steps. + + Returns: + :class:`aurora.model.aurora.Aurora`: Instance. + """ super().__init__() self.surf_vars = surf_vars self.atmos_vars = atmos_vars diff --git a/aurora/rollout.py b/aurora/rollout.py index b646f06..1363111 100644 --- a/aurora/rollout.py +++ b/aurora/rollout.py @@ -15,7 +15,7 @@ def rollout(model: Aurora, batch: Batch, steps: int) -> Generator[Batch, None, N """Perform a roll-out to make long-term predictions. Args: - model (:class:`aurora.model.aurora.Aurora`): The model to roll-out. + model (:class:`aurora.model.aurora.Aurora`): The model to roll out. batch (:class:`aurora.batch.Batch`): The batch to start the roll-out from. steps (int): The number of roll-out steps. diff --git a/docs/api.rst b/docs/api.rst index 4f2d36b..7ab0e60 100644 --- a/docs/api.rst +++ b/docs/api.rst @@ -17,6 +17,7 @@ Roll-Outs Models ------ .. autoclass:: aurora.Aurora + :special-members: __init__ :members: .. autoclass:: aurora.AuroraSmall