diff --git a/annolid/behavior/models/classifier.py b/annolid/behavior/models/classifier.py index 63cf47d..81eda6f 100644 --- a/annolid/behavior/models/classifier.py +++ b/annolid/behavior/models/classifier.py @@ -50,7 +50,7 @@ class BehaviorClassifier(nn.Module): Args: feature_extractor (nn.Module): The feature extraction module. - d_model (int, optional): The embedding dimension. Defaults to 512. + d_model (int, optional): The embedding dimension. Defaults to 768. nhead (int, optional): The number of attention heads. Defaults to 8. num_layers (int, optional): The number of transformer encoder layers. Defaults to 6. dim_feedforward (int, optional): The dimension of the feedforward network. Defaults to 2048. @@ -58,7 +58,7 @@ class BehaviorClassifier(nn.Module): num_classes (int, optional): The number of behavior classes. Defaults to 5. """ - def __init__(self, feature_extractor: nn.Module, d_model: int = 512, nhead: int = 8, + def __init__(self, feature_extractor: nn.Module, d_model: int = 768, nhead: int = 8, num_layers: int = 6, dim_feedforward: int = 2048, dropout: float = 0.1, num_classes: int = 5): super().__init__()