forked from mengzaiqiao/CAN
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmodel.py
99 lines (78 loc) · 4.09 KB
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
from layers import GraphConvolution, GraphConvolutionSparse, InnerDecoder, Dense
import tensorflow as tf
flags = tf.flags
FLAGS = flags.FLAGS
class Model(object):
def __init__(self, **kwargs):
allowed_kwargs = {'name', 'logging'}
for kwarg in kwargs.keys():
assert kwarg in allowed_kwargs, 'Invalid keyword argument: ' + kwarg
for kwarg in kwargs.keys():
assert kwarg in allowed_kwargs, 'Invalid keyword argument: ' + kwarg
name = kwargs.get('name')
if not name:
name = self.__class__.__name__.lower()
self.name = name
logging = kwargs.get('logging', False)
self.logging = logging
self.vars = {}
def _build(self):
raise NotImplementedError
def build(self):
""" Wrapper for _build() """
with tf.variable_scope(self.name):
self._build()
variables = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope=self.name)
self.vars = {var.name: var for var in variables}
def fit(self):
pass
def predict(self):
pass
class CAN(Model):
def __init__(self, placeholders, num_features, num_nodes, features_nonzero, **kwargs):
super(CAN, self).__init__(**kwargs)
self.inputs = placeholders['features']
self.input_dim = num_features
self.features_nonzero = features_nonzero
self.n_samples = num_nodes
self.adj = placeholders['adj']
self.dropout = placeholders['dropout']
self.build()
def _build(self):
self.hidden1 = GraphConvolutionSparse(input_dim=self.input_dim,
output_dim=FLAGS.hidden1,
adj=self.adj,
features_nonzero=self.features_nonzero,
act=tf.nn.relu,
dropout=self.dropout,
logging=self.logging)(self.inputs)
self.hidden2 = Dense(input_dim=self.n_samples,
output_dim=FLAGS.hidden1,
act=tf.nn.tanh,
sparse_inputs=True,
dropout=self.dropout)(tf.sparse_transpose(self.inputs))
self.z_u_mean = GraphConvolution(input_dim=FLAGS.hidden1,
output_dim=FLAGS.hidden2,
adj=self.adj,
act=lambda x: x,
dropout=self.dropout,
logging=self.logging)(self.hidden1)
self.z_u_log_std = GraphConvolution(input_dim=FLAGS.hidden1,
output_dim=FLAGS.hidden2,
adj=self.adj,
act=lambda x: x,
dropout=self.dropout,
logging=self.logging)(self.hidden1)
self.z_a_mean = Dense(input_dim=FLAGS.hidden1,
output_dim=FLAGS.hidden2,
act=lambda x: x,
dropout=self.dropout)(self.hidden2)
self.z_a_log_std = Dense(input_dim=FLAGS.hidden1,
output_dim=FLAGS.hidden2,
act=lambda x: x,
dropout=self.dropout)(self.hidden2)
self.z_u = self.z_u_mean + tf.random_normal([self.n_samples, FLAGS.hidden2]) * tf.exp(self.z_u_log_std)
self.z_a = self.z_a_mean + tf.random_normal([self.input_dim, FLAGS.hidden2]) * tf.exp(self.z_a_log_std)
self.reconstructions = InnerDecoder(input_dim=FLAGS.hidden2,
act=lambda x: x,
logging=self.logging)((self.z_u, self.z_a))