-
Notifications
You must be signed in to change notification settings - Fork 5
/
Copy pathexample.py
101 lines (84 loc) · 2.87 KB
/
example.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
import streamlit as st
import pandas as pd
import numpy as np
import cv2
import onnxruntime as rt
from PIL import Image
import json
def softmax(vector):
e = np.exp(vector)
return e / e.sum()
def apply(im,mi):
im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB)
iW = im.shape[1]
iH = im.shape[0]
x = cv2.resize(im, (224,224))
x = x.astype(np.float32)/255.0
x = np.transpose(x, (2, 0, 1) ) # BCHW
x = x[np.newaxis,:,:,:]
res = mi.predict(x)
return res
class modelInference:
def __init__(self,path2model):
self.sess = rt.InferenceSession(path2model)
self.input_name = self.sess.get_inputs()[0].name
self.input_shape = self.sess.get_inputs()[0].shape
self.input_type = self.sess.get_inputs()[0].type
self.output_name = self.sess.get_outputs()[0].name
self.output_shape = self.sess.get_outputs()[0].shape
self.output_type = self.sess.get_outputs()[0].type
with open('models/imagenet_labels.json') as f:
self.labels = json.load(f)
def print_model_info(self):
# Input informations
print("input name", self.input_name)
print("input shape", self.input_shape)
print("input type", self.input_type)
# Outpout informations
print("output name", self.output_name)
print("output shape", self.output_shape)
print("output type", self.output_type)
def predict(self, x):
x = x.astype(np.float32)
res = self.sess.run([self.output_name], {self.input_name: x})[0]
probs = softmax(res)[0]
idx = np.argsort(-probs,axis=-1)
idxTop = idx[0:5]
s = ""
for i in idxTop:
p = np.round(probs[i]*100,2)
s+=f"{self.labels [i]}: {p}%, "
#res = self.labels[res]
return s
@st.cache(suppress_st_warning=True)
def get_model():
mi = modelInference(path2model="models/classification.onnx")
return mi
mi = get_model()
st.title("Resnet18 Classification")
# ------------
st.sidebar.title("Parametri HSV filtraggio colori")
Hmin,Hmax = st.sidebar.slider("Hmin - Hmax", 0, 255, (0, 255), 1)
Smin,Smax = st.sidebar.slider("Smin - Smax", 0, 255, (0, 255), 1)
Vmin,Vmax = st.sidebar.slider("Vmin - Vmax", 0, 255, (0, 255), 1)
st.set_option('deprecation.showfileUploaderEncoding', False)
uploaded_file = st.file_uploader("Upload Image", type=["png","jpeg","jpg","bmp"])
if uploaded_file is not None:
im = cv2.imdecode(np.frombuffer(uploaded_file.read(), np.uint8),cv2.IMREAD_COLOR)
# ----
im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB)
hsv = cv2.cvtColor(im, cv2.COLOR_RGB2HSV)
# Threshold of blue in HSV space
lower_blue = np.array([Hmin,Smin,Vmin])
upper_blue = np.array([Hmax,Smax,Vmax])
# preparing the mask to overlay
mask = cv2.inRange(hsv, lower_blue, upper_blue)
result = cv2.bitwise_and(im, im, mask = mask)
# ---
st.image(result, use_column_width=True)
#------------
btn = st.button("Predict")
if btn:
# Classificazione
res = apply(im,mi)
st.write(res)