-
Notifications
You must be signed in to change notification settings - Fork 16
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #19 from arabian9ts/bgr-transfer
BGR transfer
- Loading branch information
Showing
5 changed files
with
183 additions
and
72 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,48 @@ | ||
""" | ||
inference script | ||
date: 3/17 | ||
author: arabian9ts | ||
""" | ||
|
||
import cv2 | ||
import sys | ||
from util.util import * | ||
from model.ssd300 import * | ||
|
||
def inference(image_name): | ||
if image_name is None: | ||
return Exception('not specified image name to be drawed') | ||
|
||
fontType = cv2.FONT_HERSHEY_SIMPLEX | ||
img, w, h, _, = preprocess('./voc2007/'+image_name) | ||
pred_confs, pred_locs = ssd.infer(images=[img]) | ||
locs, labels = ssd.ssd.detect_objects(pred_confs, pred_locs) | ||
img = deprocess(img, w, h) | ||
if len(labels) and len(locs): | ||
for label, loc in zip(labels, locs): | ||
loc = center2corner(loc) | ||
loc = convert2diagonal_points(loc) | ||
cv2.rectangle(img, (int(loc[0]*w), int(loc[1]*h)), (int(loc[2]*w), int(loc[3]*h)), (0, 0, 255), 1) | ||
cv2.putText(img, str(int(label)), (int(loc[0]*w), int(loc[1]*h)), fontType, 0.7, (0, 0, 255), 1) | ||
|
||
return img | ||
|
||
|
||
# detect objects on a specified image. | ||
if 2 == len(sys.argv): | ||
sess = tf.Session() | ||
# tensorflow session | ||
ssd = SSD300(sess) | ||
sess.run(tf.global_variables_initializer()) | ||
|
||
# parameter saver | ||
saver = tf.train.Saver() | ||
saver.restore(sess, './checkpoints/params.ckpt') | ||
img = inference(sys.argv[1]) | ||
cv2.imwrite('./evaluated/'+sys.argv[1], img) | ||
cv2.namedWindow("img", cv2.WINDOW_NORMAL) | ||
cv2.imshow("img", img) | ||
cv2.waitKey(0) | ||
cv2.destroyAllWindows() | ||
sys.exit() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,100 @@ | ||
""" | ||
SSD300 is SSD wrapper class. | ||
date: 10/18 | ||
author: arabian9ts | ||
""" | ||
|
||
import tensorflow as tf | ||
import numpy as np | ||
|
||
from model.ssd import * | ||
from matcher import Matcher | ||
from model.computation import * | ||
from model.default_box import * | ||
|
||
|
||
class SSD300: | ||
def __init__(self, sess): | ||
""" | ||
initialize SSD model as SSD300 whose input size is 300x300 | ||
""" | ||
self.sess = sess | ||
|
||
# define input placeholder and initialize ssd instance | ||
self.input = tf.placeholder(shape=[None, 300, 300, 3], dtype=tf.float32) | ||
self.ssd = SSD() | ||
|
||
# build ssd network => feature-maps and confs and locs tensor is returned | ||
fmaps, confs, locs = self.ssd.build(self.input, is_training=True) | ||
|
||
# zip running set of tensor | ||
self.pred_set = [fmaps, confs, locs] | ||
|
||
# required param from default-box and loss function | ||
fmap_shapes = [map.get_shape().as_list() for map in fmaps] | ||
# print('fmap shapes is '+str(fmap_shapes)) | ||
self.dboxes = generate_boxes(fmap_shapes) | ||
print(len(self.dboxes)) | ||
|
||
# required placeholder for loss | ||
loss, loss_conf, loss_loc, self.pos, self.neg, self.gt_labels, self.gt_boxes = self.ssd.loss(len(self.dboxes)) | ||
self.train_set = [loss, loss_conf, loss_loc] | ||
# optimizer = tf.train.AdamOptimizer(0.05) | ||
optimizer = tf.train.AdamOptimizer(learning_rate=1e-3, beta1=0.9, beta2=0.999, epsilon=1e-08, use_locking=False, name='Adam') | ||
self.train_step = optimizer.minimize(loss) | ||
|
||
# provides matching method | ||
self.matcher = Matcher(fmap_shapes, self.dboxes) | ||
|
||
# inference process | ||
def infer(self, images): | ||
feature_maps, pred_confs, pred_locs = self.sess.run(self.pred_set, feed_dict={self.input: images}) | ||
return pred_confs, pred_locs | ||
|
||
# training process | ||
def train(self, images, actual_data): | ||
# ================ RESET / EVAL ================ # | ||
positives = [] | ||
negatives = [] | ||
ex_gt_labels = [] | ||
ex_gt_boxes = [] | ||
# ===================== END ===================== # | ||
|
||
# call prepare_loss per image | ||
# because matching method works with only one image | ||
def prepare_loss(pred_confs, pred_locs, actual_labels, actual_locs): | ||
pos_list, neg_list, t_gtl, t_gtb = self.matcher.matching(pred_confs, pred_locs, actual_labels, actual_locs) | ||
positives.append(pos_list) | ||
negatives.append(neg_list) | ||
ex_gt_labels.append(t_gtl) | ||
ex_gt_boxes.append(t_gtb) | ||
|
||
|
||
feature_maps, pred_confs, pred_locs = self.sess.run(self.pred_set, feed_dict={self.input: images}) | ||
|
||
for i in range(len(images)): | ||
actual_labels = [] | ||
actual_locs = [] | ||
# extract ground truth info | ||
for obj in actual_data[i]: | ||
loc = obj[:4] | ||
label = np.argmax(obj[4:]) | ||
|
||
# transform location for voc2007 | ||
loc = convert2wh(loc) | ||
loc = corner2center(loc) | ||
|
||
actual_locs.append(loc) | ||
actual_labels.append(label) | ||
|
||
prepare_loss(pred_confs[i], pred_locs[i], actual_labels, actual_locs) | ||
|
||
batch_loss, batch_conf, batch_loc = \ | ||
self.sess.run(self.train_set, \ | ||
feed_dict={self.input: images, self.pos: positives, self.neg: negatives, self.gt_labels: ex_gt_labels, self.gt_boxes: ex_gt_boxes}) | ||
|
||
self.sess.run(self.train_step, \ | ||
feed_dict={self.input: images, self.pos: positives, self.neg: negatives, self.gt_labels: ex_gt_labels, self.gt_boxes: ex_gt_boxes}) | ||
|
||
return pred_confs, pred_locs, batch_loc, batch_conf, batch_loss |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
995a014
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
prediction result is ok? thanks