forked from chrisruk/scf
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtensor.py
54 lines (45 loc) · 2.06 KB
/
tensor.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
#!/usr/bin/python
import os
import tensorflow as tf
import pmt
import freezegraph
def save_graph(sess,output_path,checkpoint,checkpoint_state_name,input_graph_name,output_graph_name):
checkpoint_prefix = os.path.join(output_path,checkpoint)
saver = tf.train.Saver(tf.all_variables())
saver.save(sess, checkpoint_prefix, global_step=0,latest_filename=checkpoint_state_name)
tf.train.write_graph(sess.graph.as_graph_def(),output_path,
input_graph_name)
# We save out the graph to disk, and then call the const conversion
# routine.
input_graph_path = os.path.join(output_path, input_graph_name)
input_saver_def_path = ""
input_binary = False
input_checkpoint_path = checkpoint_prefix + "-0"
output_node_names = "out"
restore_op_name = "save/restore_all"
filename_tensor_name = "save/Const:0"
output_graph_path = os.path.join(output_path, output_graph_name)
clear_devices = False
freezegraph.freeze_graph(input_graph_path, input_saver_def_path,
input_binary, input_checkpoint_path,
output_node_names, restore_op_name,
filename_tensor_name, output_graph_path,clear_devices, "")
def load_graph(output_graph_path,ckpt_path=""):
with tf.Graph().as_default():
output_graph_def = tf.GraphDef()
with open(output_graph_path, "rb") as f:
output_graph_def.ParseFromString(f.read())
_ = tf.import_graph_def(output_graph_def, name="")
with tf.Session() as sess:
n_input = sess.graph.get_tensor_by_name("inp:0")
output = sess.graph.get_tensor_by_name("out:0")
"""
saver = tf.train.Saver()
if not ckpt_path == "":
ckpt = tf.train.get_checkpoint_state(ckpt_path)
if ckpt and ckpt.model_checkpoint_path:
saver.restore(sess, ckpt.model_checkpoint_path)
else:
print("no check")
"""
return (sess,n_input,output)