-
Notifications
You must be signed in to change notification settings - Fork 16
/
Copy pathsimplenet.py
583 lines (505 loc) · 23.5 KB
/
simplenet.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
# In the name of God the most compassionate the most merciful
""" SimpleNet
Paper: `Lets Keep it simple, Using simple architectures to outperform deeper and more complex architectures`
- https://arxiv.org/abs/1608.06037
@article{hasanpour2016lets,
title={Lets keep it simple, Using simple architectures to outperform deeper and more complex architectures},
author={Hasanpour, Seyyed Hossein and Rouhani, Mohammad and Fayyaz, Mohsen and Sabokrou, Mohammad},
journal={arXiv preprint arXiv:1608.06037},
year={2016}
}
Official Caffe impl at https://github.com/Coderx7/SimpleNet
Official Pythorch impl at https://github.com/Coderx7/SimpleNet_Pytorch
Seyyed Hossein Hasanpour
"""
import os
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.hub import download_url_to_file
from typing import Union, Tuple, List, Dict, Any, cast, Optional
# from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
DEFAULT_CROP_PCT = 0.875
DEFAULT_CROP_MODE = "center"
IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406)
IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225)
__all__ = [
"simplenet",
"simplenet_cifar_310k",
"simplenet_cifar_460k",
"simplenet_cifar_5m",
"simplenet_cifar_5m_extra_pool", # refers to the early pytorch impl of cifar10/100 that mistakenly used an extra pooling(see issue #5)
"simplenetv1_small_m1_05", # 1.5m
"simplenetv1_small_m2_05", # 1.5m
"simplenetv1_small_m1_075", # 3m
"simplenetv1_small_m2_075", # 3m
"simplenetv1_5m_m1", # 5m
"simplenetv1_5m_m2", # 5m
"simplenetv1_9m_m1", # 9m
"simplenetv1_9m_m2", # 9m
] # model_registry will add each entrypoint fn to this
def _cfg(url="", **kwargs):
return {
"url": url,
"num_classes": 1000,
"input_size": (3, 224, 224),
"crop_pct": 0.875,
"interpolation": "bicubic",
"mean": IMAGENET_DEFAULT_MEAN,
"std": IMAGENET_DEFAULT_STD,
**kwargs,
}
default_cfgs: Dict[str, Dict[str, Any]] = {
"simplenetv1_small_m1_05": _cfg(
url="https://github.com/Coderx7/SimpleNet_Pytorch/releases/download/v1.0.0/simplenetv1_small_m1_05-be804903.pth"
),
"simplenetv1_small_m2_05": _cfg(
url="https://github.com/Coderx7/SimpleNet_Pytorch/releases/download/v1.0.0/simplenetv1_small_m2_05-ca4b3e2b.pth"
),
"simplenetv1_small_m1_075": _cfg(
url="https://github.com/Coderx7/SimpleNet_Pytorch/releases/download/v1.0.0/simplenetv1_small_m1_075-098acbff.pth"
),
"simplenetv1_small_m2_075": _cfg(
url="https://github.com/Coderx7/SimpleNet_Pytorch/releases/download/v1.0.0/simplenetv1_small_m2_075-609ff4da.pth"
),
"simplenetv1_5m_m1": _cfg(
url="https://github.com/Coderx7/SimpleNet_Pytorch/releases/download/v1.0.0/simplenetv1_5m_m1-36c4ca4d.pth"
),
"simplenetv1_5m_m2": _cfg(
url="https://github.com/Coderx7/SimpleNet_Pytorch/releases/download/v1.0.0/simplenetv1_5m_m2-9bd6bb36.pth"
),
"simplenetv1_9m_m1": _cfg(
url="https://github.com/Coderx7/SimpleNet_Pytorch/releases/download/v1.0.0/simplenetv1_9m_m1-524f9972.pth"
),
"simplenetv1_9m_m2": _cfg(
url="https://github.com/Coderx7/SimpleNet_Pytorch/releases/download/v1.0.0/simplenetv1_9m_m2-59e8733b.pth"
),
}
class View(nn.Module):
def forward(self, x):
print(f"{x.shape}")
return x
class SimpleNet(nn.Module):
def __init__(
self,
num_classes: int = 1000,
in_chans: int = 3,
scale: float = 1,
network_idx: int = 0,
mode: int = 2,
drop_rates: Dict[int, float] = {},
):
"""Instantiates a SimpleNet model. SimpleNet is comprised of the most basic building blocks of a CNN architecture.
It uses basic principles to maximize the network performance both in terms of feature representation and speed without
resorting to complex design or operators.
Args:
num_classes (int, optional): number of classes. Defaults to 1000.
in_chans (int, optional): number of input channels. Defaults to 3.
scale (float, optional): scale of the architecture width. Defaults to 1.0.
network_idx (int, optional): the network index indicating the 5 million or 8 million version(0 and 1 respectively). Defaults to 0.
mode (int, optional): stride mode of the architecture. specifies how fast the input shrinks.
This is used for larger input sizes such as the 224x224 in imagenet training where the
input size incurs a lot of overhead if not downsampled properly.
you can choose between 0 meaning no change and 4. where each number denotes a specific
downsampling strategy. For imagenet use 1-4.
the larger the stride mode, the higher accuracy and the slower
the network gets. stride mode 1 is the fastest and achives very good accuracy.
Defaults to 2.
drop_rates (Dict[int,float], optional): custom drop out rates specified per layer.
each rate should be paired with the corrosponding layer index(pooling and cnn layers are counted only). Defaults to {}.
"""
super(SimpleNet, self).__init__()
# (channels or layer-type, stride=1, drp=0.)
self.cfg: Dict[str, List[Tuple[Union(int, str), int, Union(float, None), Optional[str]]]] = {
"simplenet_cifar_310k": [
(64, 1, 0.0),
(32, 1, 0.0),
(32, 1, 0.0),
(32, 1, None),
("p", 2, 0.0),
(32, 1, 0.0),
(32, 1, 0.0),
(64, 1, None),
("p", 2, 0.0),
(64, 1, 0.0),
(64, 1, None),
("p", 2, 0.0),
(128, 1, 0.0),
(256, 1, 0.0, "k1"),
(64, 1, None, "k1"),
("p", 2, 0.0),
(64, 1, None),
],
"simplenet_cifar_460k": [
(32, 1, 0.0),
(32, 1, 0.0),
(32, 1, 0.0),
(64, 1, None),
("p", 2, 0.0),
(64, 1, 0.0),
(64, 1, 0.0),
(64, 1, None),
("p", 2, 0.0),
(64, 1, 0.0),
(64, 1, None),
("p", 2, 0.0),
(96, 1, 0.0),
(96, 1, 0.0, "k1"),
(96, 1, None, "k1"),
("p", 2, 0.0),
(100, 1, None),
],
"simplenet_cifar_5m": [
(64, 1, 0.0),
(128, 1, 0.0),
(128, 1, 0.0),
(128, 1, None),
("p", 2, 0.0),
(128, 1, 0.0),
(128, 1, 0.0),
(256, 1, None),
("p", 2, 0.0),
(256, 1, 0.0),
(256, 1, None),
("p", 2, 0.0),
(512, 1, 0.0),
(2048, 1, 0.0, "k1"),
(256, 1, None, "k1"),
("p", 2, 0.0),
(256, 1, None),
],
"simplenet_cifar_5m_extra_pool": [
(64, 1, 0.0),
(128, 1, 0.0),
(128, 1, 0.0),
(128, 1, None),
("p", 2, 0.0),
(128, 1, 0.0),
(128, 1, 0.0),
(256, 1, None),
("p", 2, 0.0),
(256, 1, 0.0),
(256, 1, None),
("p", 2, 0.0),
(512, 1, 0.0),
("p", 2, 0.0), # extra pooling!
(2048, 1, 0.0, "k1"),
(256, 1, None, "k1"),
("p", 2, 0.0),
(256, 1, None),
],
"simplenetv1_imagenet": [
(64, 1, 0.0),
(128, 1, 0.0),
(128, 1, 0.0),
(128, 1, 0.0),
(128, 1, 0.0),
(128, 1, 0.0),
("p", 2, 0.0),
(256, 1, 0.0),
(256, 1, 0.0),
(256, 1, 0.0),
(512, 1, 0.0),
("p", 2, 0.0),
(2048, 1, 0.0, "k1"),
(256, 1, 0.0, "k1"),
(256, 1, 0.0),
],
"simplenetv1_imagenet_9m": [
(128, 1, 0.0),
(192, 1, 0.0),
(192, 1, 0.0),
(192, 1, 0.0),
(192, 1, 0.0),
(192, 1, 0.0),
("p", 2, 0.0),
(320, 1, 0.0),
(320, 1, 0.0),
(320, 1, 0.0),
(640, 1, 0.0),
("p", 2, 0.0),
(2560, 1, 0.0, "k1"),
(320, 1, 0.0, "k1"),
(320, 1, 0.0),
],
}
# make sure values are in proper form!
self.dropout_rates = {int(key): float(value) for key, value in drop_rates.items()}
# 15 is the last layer of the network(including two previous pooling layers)
# basically specifying the dropout rate for the very last layer to be used after the pooling
self.last_dropout_rate = self.dropout_rates.get(15, 0.0)
self.strides = {
0: {},
1: {0: 2, 1: 2, 2: 2},
2: {0: 2, 1: 2, 2: 1, 3: 2},
3: {0: 2, 1: 2, 2: 1, 3: 1, 4: 2},
4: {0: 2, 1: 1, 2: 2, 3: 1, 4: 2, 5: 1},
}
self.num_classes = num_classes
self.in_chans = in_chans
self.scale = scale
self.networks = [
"simplenetv1_imagenet", # 0
"simplenetv1_imagenet_9m", # 1
# other archs
"simplenet_cifar_310k", # 2
"simplenet_cifar_460k", # 3
"simplenet_cifar_5m", # 4
"simplenet_cifar_5m_extra_pool", # 5
]
self.network_idx = network_idx
self.mode = mode
self.features = self._make_layers(scale)
self.classifier = nn.Linear(round(self.cfg[self.networks[network_idx]][-1][0] * scale), num_classes)
def forward(self, x):
out = self.features(x)
out = F.max_pool2d(out, kernel_size=out.size()[2:])
out = F.dropout2d(out, self.last_dropout_rate, training=self.training)
out = out.view(out.size(0), -1)
out = self.classifier(out)
return out
def _make_layers(self, scale: float):
layers: List[nn.Module] = []
input_channel = self.in_chans
stride_list = self.strides[self.mode]
for idx, (layer, stride, defaul_dropout_rate, *layer_type) in enumerate(
self.cfg[self.networks[self.network_idx]]
):
stride = stride_list[idx] if len(stride_list) > idx else stride
# check if any custom dropout rate is specified
# for this layer, note that pooling also counts as 1 layer
custom_dropout = self.dropout_rates.get(idx, None)
custom_dropout = defaul_dropout_rate if custom_dropout is None else custom_dropout
# dropout values must be strictly decimal. while 0 doesnt introduce any issues here
# i.e. during training and inference, if you try to jit trace your model it will crash
# due to using 0 as dropout value(this applies up to 1.13.1) so here is an explicit
# check to convert any possible integer value to its decimal counterpart.
custom_dropout = None if custom_dropout is None else float(custom_dropout)
kernel_size = 3
padding = 1
if layer_type == ['k1']:
kernel_size = 1
padding = 0
if layer == "p":
layers += [
nn.MaxPool2d(kernel_size=(2, 2), stride=(stride, stride)),
nn.Dropout2d(p=custom_dropout, inplace=True),
]
else:
filters = round(layer * scale)
if custom_dropout is None:
layers += [
nn.Conv2d(input_channel, filters, kernel_size=kernel_size, stride=stride, padding=padding),
nn.BatchNorm2d(filters, eps=1e-05, momentum=0.05, affine=True),
nn.ReLU(inplace=True),
]
else:
layers += [
nn.Conv2d(input_channel, filters, kernel_size=kernel_size, stride=stride, padding=padding),
nn.BatchNorm2d(filters, eps=1e-05, momentum=0.05, affine=True),
nn.ReLU(inplace=True),
nn.Dropout2d(p=custom_dropout, inplace=False),
]
input_channel = filters
model = nn.Sequential(*layers)
for m in model.modules():
if isinstance(m, nn.Conv2d):
nn.init.xavier_uniform_(m.weight.data, gain=nn.init.calculate_gain("relu"))
return model
def _gen_simplenet(
model_variant: str = "simplenetv1_m2",
num_classes: int = 1000,
in_chans: int = 3,
scale: float = 1.0,
network_idx: int = 0,
mode: int = 2,
pretrained: bool = False,
drop_rates: Dict[int, float] = {},
) -> SimpleNet:
model = SimpleNet(num_classes, in_chans, scale=scale, network_idx=network_idx, mode=mode, drop_rates=drop_rates)
if pretrained:
cfg = default_cfgs.get(model_variant, None)
if cfg is None:
raise Exception(f"Unknown model variant ('{model_variant}') specified!")
url = cfg["url"]
checkpoint_filename = url.split("/")[-1]
checkpoint_path = f"tmp/{checkpoint_filename}"
print(f"saving in checkpoint_path:{checkpoint_path}")
if not os.path.exists(checkpoint_path):
os.makedirs("tmp", exist_ok=True)
download_url_to_file(url, checkpoint_path)
checkpoint = torch.load(checkpoint_path, map_location="cpu",)
model.load_state_dict(checkpoint)
return model
def simplenet(pretrained: bool = False, **kwargs: Any) -> SimpleNet:
"""Generic simplenet model builder. by default it returns `simplenetv1_5m_m2` model
but specifying different arguments such as `netidx`, `scale` or `mode` will result in
the corrosponding network variant.
when pretrained is specified, if the combination of settings resemble any known variants
specified in the `default_cfg`, their respective pretrained weights will be loaded, otherwise
an exception will be thrown denoting Unknown model variant being specified.
Args:
pretrained (bool, optional): loads the model with pretrained weights only if the model is a known variant specified in default_cfg. Defaults to False.
Raises:
Exception: if pretrained is used with an unknown/custom model variant and exception is raised.
Returns:
SimpleNet: a SimpleNet model instance is returned upon successful instantiation.
"""
num_classes = kwargs.get("num_classes", 1000)
in_chans = kwargs.get("in_chans", 3)
scale = kwargs.get("scale", 1.0)
network_idx = kwargs.get("network_idx", 0)
mode = kwargs.get("mode", 2)
drop_rates = kwargs.get("drop_rates", {})
model_variant = "simplenetv1_5m_m2"
if pretrained:
# check if the model specified is a known variant
model_base = None
if network_idx == 0:
model_base = 5
elif network_idx == 1:
model_base = 9
config = ""
if math.isclose(scale, 1.0):
config = f"{model_base}m_m{mode}"
elif math.isclose(scale, 0.75):
config = f"small_m{mode}_075"
elif math.isclose(scale, 0.5):
config = f"small_m{mode}_05"
else:
config = f"m{mode}_{scale:.2f}".replace(".", "")
model_variant = f"simplenetv1_{config}"
return _gen_simplenet(model_variant, num_classes, in_chans, scale, network_idx, mode, pretrained, drop_rates)
def _remove_network_settings(kwargs: Dict[str, Any]) -> Dict[str, Any]:
"""Removes network related settings passed in kwargs for predefined network configruations below
Returns:
Dict[str,Any]: cleaned kwargs
"""
model_args = {k: v for k, v in kwargs.items() if k not in ["scale", "network_idx", "mode", "drop_rate"]}
return model_args
# cifar10/100 models
def simplenet_cifar_310k(pretrained: bool = False, **kwargs: Any) -> SimpleNet:
"""original implementation of smaller variants of simplenet for cifar10/100
that were used in the paper
"""
model_variant = "simplenet_cifar_310k"
model_args = _remove_network_settings(kwargs)
return _gen_simplenet(model_variant, network_idx=2, mode=0, pretrained=pretrained, **model_args)
def simplenet_cifar_460k(pretrained: bool = False, **kwargs: Any) -> SimpleNet:
"""original implementation of smaller variants of simplenet for cifar10/100
that were used in the paper
"""
model_variant = "simplenet_cifar_460k"
model_args = _remove_network_settings(kwargs)
return _gen_simplenet(model_variant, network_idx=3, mode=0, pretrained=pretrained, **model_args)
def simplenet_cifar_5m(pretrained: bool = False, **kwargs: Any) -> SimpleNet:
"""The original implementation of simplenet trained on cifar10/100 in caffe.
"""
model_variant = "simplenet_cifar_5m"
model_args = _remove_network_settings(kwargs)
return _gen_simplenet(model_variant, network_idx=4, mode=0, pretrained=pretrained, **model_args)
def simplenet_cifar_5m_extra_pool(pretrained: bool = False, **kwargs: Any) -> SimpleNet:
"""An early pytorch implementation of simplenet that mistakenly used an extra pooling layer
.it was not know until 2021 which was reported in https://github.com/Coderx7/SimpleNet_Pytorch/issues/5
this is just here to be able to load the weights that were trained using this variation still available on the repository.
"""
model_variant = "simplenet_cifar_5m_extra_pool"
model_args = _remove_network_settings(kwargs)
return _gen_simplenet(model_variant, network_idx=5, mode=0, pretrained=pretrained, **model_args)
# imagenet models
def simplenetv1_small_m1_05(pretrained: bool = False, **kwargs: Any) -> SimpleNet:
"""Creates a small variant of simplenetv1_5m, with 1.5m parameters. This uses m1 stride mode
which makes it the fastest variant available.
Args:
pretrained (bool, optional): loads the model with pretrained weights. Defaults to False.
Returns:
SimpleNet: a SimpleNet model instance is returned upon successful instantiation.
"""
model_variant = "simplenetv1_small_m1_05"
model_args = _remove_network_settings(kwargs)
return _gen_simplenet(model_variant, scale=0.5, network_idx=0, mode=1, pretrained=pretrained, **model_args)
def simplenetv1_small_m2_05(pretrained: bool = False, **kwargs: Any) -> SimpleNet:
"""Creates a second small variant of simplenetv1_5m, with 1.5m parameters. This uses m2 stride mode
which makes it the second fastest variant available.
Args:
pretrained (bool, optional): loads the model with pretrained weights. Defaults to False.
Returns:
SimpleNet: a SimpleNet model instance is returned upon successful instantiation.
"""
model_variant = "simplenetv1_small_m2_05"
model_args = _remove_network_settings(kwargs)
return _gen_simplenet(model_variant, scale=0.5, network_idx=0, mode=2, pretrained=pretrained, **model_args)
def simplenetv1_small_m1_075(pretrained: bool = False, **kwargs: Any) -> SimpleNet:
"""Creates a third small variant of simplenetv1_5m, with 3m parameters. This uses m1 stride mode
which makes it the third fastest variant available.
Args:
pretrained (bool, optional): loads the model with pretrained weights. Defaults to False.
Returns:
SimpleNet: a SimpleNet model instance is returned upon successful instantiation.
"""
model_variant = "simplenetv1_small_m1_075"
model_args = _remove_network_settings(kwargs)
return _gen_simplenet(model_variant, scale=0.75, network_idx=0, mode=1, pretrained=pretrained, **model_args)
def simplenetv1_small_m2_075(pretrained: bool = False, **kwargs: Any) -> SimpleNet:
"""Creates a forth small variant of simplenetv1_5m, with 3m parameters. This uses m2 stride mode
which makes it the forth fastest variant available.
Args:
pretrained (bool, optional): loads the model with pretrained weights. Defaults to False.
Returns:
SimpleNet: a SimpleNet model instance is returned upon successful instantiation.
"""
model_variant = "simplenetv1_small_m2_075"
model_args = _remove_network_settings(kwargs)
return _gen_simplenet(model_variant, scale=0.75, network_idx=0, mode=2, pretrained=pretrained, **model_args)
def simplenetv1_5m_m1(pretrained: bool = False, **kwargs: Any) -> SimpleNet:
"""Creates the base simplement model known as simplenetv1_5m, with 5m parameters. This variant uses m1 stride mode
which makes it a fast and performant model.
Args:
pretrained (bool, optional): loads the model with pretrained weights. Defaults to False.
Returns:
SimpleNet: a SimpleNet model instance is returned upon successful instantiation.
"""
model_variant = "simplenetv1_5m_m1"
model_args = _remove_network_settings(kwargs)
return _gen_simplenet(model_variant, scale=1.0, network_idx=0, mode=1, pretrained=pretrained, **model_args)
def simplenetv1_5m_m2(pretrained: bool = False, **kwargs: Any) -> SimpleNet:
"""Creates the base simplement model known as simplenetv1_5m, with 5m parameters. This variant uses m2 stride mode
which makes it a bit more performant model compared to the m1 variant of the same variant at the expense of a bit slower inference.
Args:
pretrained (bool, optional): loads the model with pretrained weights. Defaults to False.
Returns:
SimpleNet: a SimpleNet model instance is returned upon successful instantiation.
"""
model_variant = "simplenetv1_5m_m2"
model_args = _remove_network_settings(kwargs)
return _gen_simplenet(model_variant, scale=1.0, network_idx=0, mode=2, pretrained=pretrained, **model_args)
def simplenetv1_9m_m1(pretrained: bool = False, **kwargs: Any) -> SimpleNet:
"""Creates a variant of the simplenetv1_5m, with 9m parameters. This variant uses m1 stride mode
which makes it run faster.
Args:
pretrained (bool, optional): loads the model with pretrained weights. Defaults to False.
Returns:
SimpleNet: a SimpleNet model instance is returned upon successful instantiation.
"""
model_variant = "simplenetv1_9m_m1"
model_args = _remove_network_settings(kwargs)
return _gen_simplenet(model_variant, scale=1.0, network_idx=1, mode=1, pretrained=pretrained, **model_args)
def simplenetv1_9m_m2(pretrained: bool = False, **kwargs: Any) -> SimpleNet:
"""Creates a variant of the simplenetv1_5m, with 9m parameters. This variant uses m2 stride mode
which makes it a bit more performant model compared to the m1 variant of the same variant at the expense of a bit slower inference.
Args:
pretrained (bool, optional): loads the model with pretrained weights. Defaults to False.
Returns:
SimpleNet: a SimpleNet model instance is returned upon successful instantiation.
"""
model_variant = "simplenetv1_9m_m2"
model_args = _remove_network_settings(kwargs)
return _gen_simplenet(model_variant, scale=1.0, network_idx=1, mode=2, pretrained=pretrained, **model_args)
if __name__ == "__main__":
model = simplenet(num_classes=1000)
input_dummy = torch.randn(size=(1, 3, 224, 224))
out = model(input_dummy)
# out.mean().backward()
print(model)
print(f"output: {out.size()}")