forked from openvinotoolkit/model_server
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy patheast_ocr_client.py
92 lines (82 loc) · 4.32 KB
/
east_ocr_client.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
#
# Copyright (c) 2021 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import grpc
import cv2
import numpy as np
from tensorflow import make_tensor_proto, make_ndarray
import argparse
from tensorflow_serving.apis import predict_pb2
from tensorflow_serving.apis import prediction_service_pb2_grpc
parser = argparse.ArgumentParser(description='Client for OCR pipeline')
parser.add_argument('--grpc_address', required=False, default='localhost', help='Specify url to grpc service. default:localhost')
parser.add_argument('--grpc_port', required=False, default=9178, help='Specify port to grpc service. default: 9178')
parser.add_argument('--pipeline_name', required=False, default='ocr', help='Pipeline name to request. default: ocr')
parser.add_argument('--image_input_name', required=False, default='image', help='Pipeline input name for input with image. default: image')
parser.add_argument('--image_input_path', required=True, help='Input image path.')
parser.add_argument('--texts_output_name', required=False, default='texts', help='Pipeline output name for output with recognized texts. default: texts')
parser.add_argument('--text_images_output_name', required=False, default='text_images', help='Pipeline output name for cropped images with text. default: text_images')
parser.add_argument('--text_images_save_path', required=False, default='', help='If specified, images will be saved to disk.')
parser.add_argument('--image_width', required=False, default=1920, help='Original image width. default: 1920')
parser.add_argument('--image_height', required=False, default=1024, help='Original image height. default: 1024')
args = vars(parser.parse_args())
def prepare_img_input_in_nchw_format(request, name, path, resize_to_shape):
img = cv2.imread(path).astype(np.float32) # BGR color format, shape HWC
img = cv2.resize(img, (resize_to_shape[1], resize_to_shape[0]))
target_shape = (img.shape[0], img.shape[1])
img = img.transpose(2,0,1).reshape(1,3,target_shape[0],target_shape[1])
request.inputs[name].CopyFrom(make_tensor_proto(img, shape=img.shape))
def nchw_to_image(output_nd, name, location):
for i in range(output_nd.shape[0]):
out = output_nd[i][0]
out = out.transpose(1,2,0)
cv2.imwrite(location + name + '_' + str(i) + '.jpg', out)
def crnn_output_to_text(output_nd):
for i in range(output_nd.shape[0]):
data = output_nd[i]
alphabet = 'abcdefghijklmnopqrstuvwxyz0123456789_'
preds = data.argmax(2)
word = ''
for i in range(preds.shape[0]):
word += alphabet[preds[i,0]]
print(word)
address = "{}:{}".format(args['grpc_address'],args['grpc_port'])
MAX_MESSAGE_LENGTH = 1024 * 1024 * 1024
channel = grpc.insecure_channel(address,
options=[
('grpc.max_send_message_length', MAX_MESSAGE_LENGTH),
('grpc.max_receive_message_length', MAX_MESSAGE_LENGTH),
])
stub = prediction_service_pb2_grpc.PredictionServiceStub(channel)
request = predict_pb2.PredictRequest()
request.model_spec.name = args['pipeline_name']
prepare_img_input_in_nchw_format(request, args['image_input_name'], args['image_input_path'], (int(args['image_height']), int(args['image_width'])))
try:
response = stub.Predict(request, 30.0)
except grpc.RpcError as err:
if err.code() == grpc.StatusCode.ABORTED:
print('No text has been found in the image')
exit(1)
else:
raise err
for name in response.outputs:
print(f"Output: name[{name}]")
tensor_proto = response.outputs[name]
output_nd = make_ndarray(tensor_proto)
print(f" numpy => shape[{output_nd.shape}] data[{output_nd.dtype}]")
if name == args['text_images_output_name'] and len(args['text_images_save_path']) > 0:
nchw_to_image(output_nd, name, args['text_images_save_path'])
if name == args['texts_output_name']:
crnn_output_to_text(output_nd)