forked from painfulloop/Watermark-DnCNN
-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathWatermark_train.py
156 lines (119 loc) · 6.81 KB
/
Watermark_train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
# -*- coding: utf-8 -*-
import os, cv2
import numpy as np
import DnCNNModel
import tensorflow as tf
np.random.seed(0)
spec_size = [1, 40, 40, 1]
def transition(w):
return w
def watermark_loss(gen, gt):
# gt = tf.reshape(gt, [special_num, np.prod(gt.get_shape().as_list())])
# loss = tf.nn.sigmoid_cross_entropy_with_logits(labels=gt, logits=gen)
# loss = tf.reduce_mean(loss)
loss = tf.reduce_sum(tf.square(gen - gt), axis=[1, 2, 3])
loss = tf.reduce_mean(loss)
return loss
def ft_DnCNN_optimizer(dncnn_loss, line_loss, lr, lambda_):
loss = dncnn_loss + lambda_ * line_loss
optimizer = tf.train.AdamOptimizer(lr, name='AdamOptimizer')
ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
with tf.control_dependencies(ops):
var_list = [t for t in tf.trainable_variables()]
gradient = optimizer.compute_gradients(loss, var_list=var_list)
train_op = optimizer.apply_gradients(gradient)
return train_op
def sobel():
a = np.array([1, 4, 5, 0, -5, -4, -1])
b = np.array([1, 6, 15, 20, 15, 6, 1])
a = np.reshape(a, [1, a.shape[0]])
b = np.reshape(b, [1, b.shape[0]])
c = np.matmul(b.transpose(), a)
sobel_base = tf.constant(c, tf.float32)
sobel_x_filter = tf.reshape(sobel_base, [7, 7, 1, 1])
sobel_y_filter = tf.transpose(sobel_x_filter, [1, 0, 2, 3])
return sobel_x_filter, sobel_y_filter
def train(train_data='./data/img_clean_pats.npy', org_model_path='./DnCNN_weight/', overwriting_path='./overwriting/',
epochs=8, batch_size=128, learn_rate=0.0001, sigma=25, trigger_img='key_imgs/trigger_image.png',
verification_img='key_imgs/verification_image.png', lambda_=0.001):
# './DnCNN_weight/' folder containing weights of original DnCNN
# './overwriting/' folder containing new weights created in this script ( model trained with trigger key)
special_num = 5
DnCNN_model_name = 'Black_DnCNN_cman_weight_'
with tf.Graph().as_default():
lr = tf.placeholder(tf.float32, shape=[], name='learning_rate')
training = tf.placeholder(tf.bool, name='is_training')
img_clean = tf.placeholder(tf.float32, [batch_size, spec_size[1], spec_size[2], spec_size[3]],
name='clean_image')
img_spec = tf.placeholder(tf.float32, [special_num, spec_size[1], spec_size[2], spec_size[3]],
name='spec_image')
special_gt = tf.placeholder(tf.float32, [special_num, spec_size[1], spec_size[2], spec_size[3]])
# DnCNN model
img_noise = img_clean + tf.random_normal(shape=tf.shape(img_clean),
stddev=sigma / 255.0) # dati con aggiunta di rumore
img_total = tf.concat([img_noise, img_spec], 0) # concatenazione img_noise e img trigger
Y, N = DnCNNModel.dncnn(img_total, is_training=training)
# slide
Y_img = tf.slice(Y, [0, 0, 0, 0], [batch_size, spec_size[1], spec_size[2], spec_size[3]])
N_spe = tf.slice(N, [batch_size, 0, 0, 0], [special_num, spec_size[1], spec_size[2], spec_size[3]])
# host loss
dncnn_loss = DnCNNModel.lossing(Y_img, img_clean, batch_size)
# sobel_x, sobel_y = sobel()
# extract weight
dncnn_s_out = transition(N_spe)
# mark loss
mark_loss = watermark_loss(dncnn_s_out, special_gt) # special_gt = verification img
# Update model
dncnn_opt = ft_DnCNN_optimizer(dncnn_loss, mark_loss, lr, lambda_)
init = tf.global_variables_initializer()
dncnn_var_list = [v for v in tf.global_variables() if v.name.startswith('block')]
DnCNN_saver = tf.train.Saver(dncnn_var_list, max_to_keep=50)
with tf.Session() as sess:
data_total = np.load(train_data)
data_total = data_total.astype(np.float32) / 255.0
num_example, row, col, chanel = data_total.shape
numBatch = num_example // batch_size
# special_input = cv2.imread('./input_data/spec_input.png', 0) #trigger img
special_input = cv2.imread(trigger_img, 0)
special_input = special_input.astype(np.float32) / 255.0
special_input = np.expand_dims(special_input, 0)
special_input = np.expand_dims(special_input, 3)
special_input = np.repeat(special_input, special_num, axis=0)
# daub_Images = cv2.imread('./input_data/spec_gt.png', 0) #verification img
daub_Images = cv2.imread(verification_img, 0)
daub_Images = daub_Images.astype(np.float32) / 255.0
daub_Images = np.expand_dims(daub_Images, 0)
daub_Images = np.expand_dims(daub_Images, 3)
daub_Images = np.repeat(daub_Images, special_num, axis=0)
sess.run(init)
ckpt = tf.train.get_checkpoint_state(org_model_path)
if ckpt and ckpt.model_checkpoint_path:
full_path = tf.train.latest_checkpoint(org_model_path)
print(full_path)
DnCNN_saver.restore(sess, full_path)
print("Loading " + os.path.basename(full_path) + " to the model")
else:
print("DnCNN weight must be exist")
assert ckpt != None, 'weights not exist'
step = 0
for epoch in range(0, epochs):
np.random.shuffle(data_total)
for batch_id in range(0, numBatch):
# tag = np.random.randint(0, 26)
special_input = special_input + 0 * np.random.normal(size=special_input.shape) / 255
batch_images = data_total[batch_id * batch_size:(batch_id + 1) * batch_size, :, :, :]
if batch_id % 100 == 0:
dncnn_lost = sess.run(dncnn_loss, feed_dict={img_clean: batch_images, lr: learn_rate,
img_spec: special_input, training: False})
mark_lost = sess.run(mark_loss, feed_dict={img_clean: batch_images, img_spec: special_input,
lr: learn_rate,
special_gt: daub_Images, training: False})
print("step = %d, dncnn_loss = %f,mark_loss = %f" % (step, dncnn_lost, mark_lost))
_ = sess.run(dncnn_opt, feed_dict={img_clean: batch_images, lr: learn_rate,
img_spec: special_input, special_gt: daub_Images,
training: True})
step += 1
DnCNN_saver.save(sess, overwriting_path + DnCNN_model_name + str(epoch + 1) + ".ckpt")
print("+++++ epoch " + str(epoch + 1) + " is saved successfully +++++")
if __name__ == '__main__':
train()