forked from NVlabs/ocropus3-ocroline
-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathdeeper-model.py
30 lines (25 loc) · 912 Bytes
/
deeper-model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
def make_model(ninput=48, noutput=97):
B, W, H, D = (0, 900), (0, 9000), ninput, (0, 5000)
return nn.Sequential(
# reorder to Torch conventions
layers.Reorder("BHWD", "BDHW"),
layers.CheckSizes(B, 1, H, W, name="input"),
# convolutional layers
flex.Conv2d(100, 3, padding=(1, 1)), # BDWH
nn.ReLU(),
nn.MaxPool2d(3),
flex.Conv2d(100, 3, padding=(1, 1)), # BDWH
flex.BatchNorm2d(),
nn.ReLU(),
# turn image into sequence
#layers.Reorder("BDHW", "BWDH"),
layers.Fun(lambda x: x.view(x.size(0), -1, x.size(3)), "BDHW->BDW"),
#layers.Reorder("BWD", "BDW"),
layers.CheckSizes(B, D, W),
# run 1D LSTM
flex.Lstm1(100),
flex.Conv1d(noutput, 1),
# reorder
layers.Reorder("BDW", "BWD"),
layers.CheckSizes(B, W, noutput, name="output")
)