Skip to content

Commit

Permalink
add xinet train config
Browse files Browse the repository at this point in the history
  • Loading branch information
fpaissan committed Nov 27, 2023
1 parent 4b2b958 commit f8c1a6b
Showing 1 changed file with 26 additions and 14 deletions.
40 changes: 26 additions & 14 deletions recipes/image_classification/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,20 +36,32 @@ class ImageClassification(mm.MicroMind):
def __init__(self, hparams, *args, **kwargs):
super().__init__(hparams, *args, **kwargs)

self.modules["classifier"] = PhiNet(
input_shape=hparams.input_shape,
alpha=hparams.alpha,
num_layers=hparams.num_layers,
beta=hparams.beta,
t_zero=hparams.t_zero,
compatibility=False,
divisor=hparams.divisor,
downsampling_layers=hparams.downsampling_layers,
return_layers=hparams.return_layers,
# classification-specific
include_top=True,
num_classes=hparams.num_classes,
)
if hparams.model == "phinet":
self.modules["classifier"] = PhiNet(
input_shape=hparams.input_shape,
alpha=hparams.alpha,
num_layers=hparams.num_layers,
beta=hparams.beta,
t_zero=hparams.t_zero,
compatibility=False,
divisor=hparams.divisor,
downsampling_layers=hparams.downsampling_layers,
return_layers=hparams.return_layers,
# classification-specific
include_top=True,
num_classes=hparams.num_classes,
)
elif hparams.model == "xinet":
self.modules["classifier"] = XiNet(
input_shape=hparams.input_shape,
alpha=hparams.alpha,
compression=hparams.compression,
num_layers=hparams.num_layers,
return_layers=hparams.return_layers,
# classification-specific
include_top=True,
num_classes=hparams.num_classes,
)

tot_params = 0
for m in self.modules.values():
Expand Down

0 comments on commit f8c1a6b

Please sign in to comment.