forked from NVlabs/ocropus3-ocrobin
-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathocrobin-pred
executable file
·116 lines (97 loc) · 3.13 KB
/
ocrobin-pred
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
#!/usr/bin/python
import os
import re
import glob
import random as pyr
import os.path
import argparse
import torch
import scipy.ndimage as ndi
import torch.nn.functional as F
from pylab import *
from torch import nn, optim, autograd
import matplotlib as mpl
import ocrobin
from dlinputs import filters
from dlinputs import gopen
from dlinputs import utils
from dlinputs import paths
model_path = os.environ.get("MODELS", ".:/usr/local/share/ocrobin:/usr/share/ocrobin")
default_model = "bin-000000046-005393.pt"
rc("image", cmap="gray")
ion()
parser = argparse.ArgumentParser("train a page segmenter")
parser.add_argument("-m", "--model", default=default_model, help="load model")
parser.add_argument("-b", "--batchsize", type=int, default=1)
parser.add_argument("-D", "--makesource", default=None)
parser.add_argument("-P", "--makepipeline", default=None)
parser.add_argument("-i", "--invert", action="store_true")
parser.add_argument("--display", type=int, default=0)
parser.add_argument("input")
parser.add_argument("output", nargs="?")
args = parser.parse_args()
ARGS = {k: v for k, v in args.__dict__.items()}
def make_source():
return gopen.sharditerator_once(args.input)
def make_pipeline():
def fixdepth(image):
assert image.ndim in [2, 3]
if image.ndim==3:
image = np.mean(image, 2)
image -= amin(image)
image /= amax(image)
if args.invert:
image = 1-image
return image
return filters.compose(
filters.rename(input="gray.png png gray.jpg jpeg jpg"),
filters.map(input=fixdepth),
filters.batched(args.batchsize, expand=True))
if args.makesource: execfile(args.makesource)
if args.makepipeline: execfile(args.makepipeline)
def pixels_to_batch(x):
b, d, h, w = x.size()
return x.permute(0, 2, 3, 1).contiguous().view(b*h*w, d)
class PixelsToBatch(nn.Module):
def forward(self, x):
return pixels_to_batch(x)
source = make_source()
pipeline = make_pipeline()
source = pipeline(source)
if args.output:
sink = gopen.open_sink(args.output)
mname = paths.find_file(model_path, args.model)
assert mname is not None, "model not found"
print "loading", mname
bm = ocrobin.Binarizer(mname)
print bm.model
def display_batch(image, output):
clf()
if image is not None:
subplot(121); imshow(image[0,:,:,0], vmin=0, vmax=1)
if output is not None:
subplot(122); imshow(output[0,:,:,0], vmin=0, vmax=1)
draw()
ginput(1, 1e-3)
for i, sample in enumerate(source):
fname = sample["__key__"]
print i, fname
image = sample["input"]
output = bm.binarize_batch(image)
#if nbatches % 10 == 0:
if args.display > 0:
if i % args.display == 0:
clf()
subplot(121); imshow(image[0], vmin=0, vmax=1)
subplot(122); imshow(output[0], vmin=0, vmax=1)
draw(); ginput(1, 1e-3)
waitforbuttonpress(0.0001)
for i in xrange(len(sample["__key__"])):
result = utils.metadict(sample, {
"__key__": fname[i],
"bin.png": output[i]
})
if args.output:
sink.write(result)
if args.output:
sink.close()