-
Notifications
You must be signed in to change notification settings - Fork 359
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[ModelZoo] Support Co_Action Network #344
base: main
Are you sure you want to change the base?
Changes from 1 commit
32ce9e7
d6d5be5
b438c64
c5df688
4e4b400
b613472
0cc1389
6421ff4
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,9 @@ | ||
export PATH="~/anaconda4/bin:$PATH" | ||
wget http://snap.stanford.edu/data/amazon/productGraph/categoryFiles/reviews_Books.json.gz | ||
wget http://snap.stanford.edu/data/amazon/productGraph/categoryFiles/meta_Books.json.gz | ||
gunzip reviews_Books.json.gz | ||
gunzip meta_Books.json.gz | ||
python script/process_data.py meta_Books.json reviews_Books_5.json | ||
python script/local_aggretor.py | ||
python script/split_by_user.py | ||
python script/generate_voc.py |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,35 @@ | ||
import tensorflow as tf | ||
|
||
def dice(_x, axis=-1, epsilon=0.000000001, name=''): | ||
with tf.variable_scope(name, reuse=tf.AUTO_REUSE): | ||
alphas = tf.get_variable('alpha'+name, _x.get_shape()[-1], | ||
initializer=tf.constant_initializer(0.0), | ||
dtype=tf.float32) | ||
input_shape = list(_x.get_shape()) | ||
|
||
reduction_axes = list(range(len(input_shape))) | ||
del reduction_axes[axis] | ||
broadcast_shape = [1] * len(input_shape) | ||
broadcast_shape[axis] = input_shape[axis] | ||
|
||
# case: train mode (uses stats of the current batch) | ||
mean = tf.reduce_mean(_x, axis=reduction_axes) | ||
brodcast_mean = tf.reshape(mean, broadcast_shape) | ||
std = tf.reduce_mean(tf.square(_x - brodcast_mean) + epsilon, axis=reduction_axes) | ||
std = tf.sqrt(std) | ||
brodcast_std = tf.reshape(std, broadcast_shape) | ||
x_normed = (_x - brodcast_mean) / (brodcast_std + epsilon) | ||
# x_normed = tf.layers.batch_normalization(_x, center=False, scale=False) | ||
x_p = tf.sigmoid(x_normed) | ||
|
||
|
||
return alphas * (1.0 - x_p) * _x + x_p * _x | ||
|
||
def parametric_relu(_x): | ||
alphas = tf.get_variable('alpha', _x.get_shape()[-1], | ||
initializer=tf.constant_initializer(0.0), | ||
dtype=tf.float32) | ||
pos = tf.nn.relu(_x) | ||
neg = alphas * (_x - abs(_x)) * 0.5 | ||
|
||
return pos + neg |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,14 @@ | ||
|
||
ckpt = tf.train.get_checkpoint_state("./ckpt_path/").model_checkpoint_path | ||
saver = tf.train.import_meta_graph(ckpt+'.meta') | ||
variables = tf.trainable_variables() | ||
total_parameters = 0 | ||
for variable in variables: | ||
shape = variable.get_shape() | ||
variable_parameters = 1 | ||
for dim in shape: | ||
# print(dim) | ||
variable_parameters *= dim.value | ||
# print(variable_parameters) | ||
total_parameters += variable_parameters | ||
print(total_parameters) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,228 @@ | ||
import numpy | ||
import json | ||
#import cPickle as pkl | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 同上 |
||
import _pickle as cPickle | ||
import random | ||
|
||
import gzip | ||
|
||
import shuffle | ||
|
||
def unicode_to_utf8(d): | ||
return dict((key.encode("UTF-8"), value) for (key,value) in d.items()) | ||
def dict_unicode_to_utf8(d): | ||
print('d={}'.format(d)) | ||
return dict(((key[0].encode("UTF-8"), key[1].encode("UTF-8")), value) for (key,value) in d.items()) | ||
|
||
def load_dict(filename): | ||
try: | ||
with open(filename, 'rb') as f: | ||
return unicode_to_utf8(json.load(f)) | ||
except: | ||
try: | ||
with open(filename, 'rb') as f: | ||
return unicode_to_utf8(cPickle.load(f)) | ||
except: | ||
with open(filename, 'rb') as f: | ||
return dict_unicode_to_utf8(cPickle.load(f)) | ||
|
||
|
||
def fopen(filename, mode='r'): | ||
if filename.endswith('.gz'): | ||
return gzip.open(filename, mode) | ||
return open(filename, mode) | ||
|
||
|
||
class DataIterator: | ||
|
||
def __init__(self, source, | ||
uid_voc, | ||
mid_voc, | ||
cat_voc, | ||
batch_size=128, | ||
maxlen=100, | ||
skip_empty=False, | ||
shuffle_each_epoch=False, | ||
sort_by_length=True, | ||
max_batch_size=20, | ||
minlen=None, | ||
label_type=1): | ||
if shuffle_each_epoch: | ||
self.source_orig = source | ||
self.source = shuffle.main(self.source_orig, temporary=True) | ||
else: | ||
self.source = fopen(source, 'r') | ||
self.source_dicts = [] | ||
#for source_dict in [uid_voc, mid_voc, cat_voc, cat_voc, cat_voc]:# 'item_carte_voc.pkl', 'cate_carte_voc.pkl']: | ||
for source_dict in [uid_voc, mid_voc, cat_voc, '/home/test/modelzoo/CAN/data/item_carte_voc.pkl', '/home/test/modelzoo/CAN/data/cate_carte_voc.pkl']: | ||
self.source_dicts.append(load_dict(source_dict)) | ||
|
||
f_meta = open("/home/test/modelzoo/CAN/data/item-info", "r") | ||
meta_map = {} | ||
for line in f_meta: | ||
arr = line.strip().split("\t") | ||
if arr[0] not in meta_map: | ||
meta_map[arr[0]] = arr[1] | ||
self.meta_id_map ={} | ||
for key in meta_map: | ||
val = meta_map[key] | ||
if key in self.source_dicts[1]: | ||
mid_idx = self.source_dicts[1][key] | ||
else: | ||
mid_idx = 0 | ||
if val in self.source_dicts[2]: | ||
cat_idx = self.source_dicts[2][val] | ||
else: | ||
cat_idx = 0 | ||
self.meta_id_map[mid_idx] = cat_idx | ||
|
||
f_review = open("/home/test/modelzoo/CAN/data/reviews-info", "r") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这个路径不要用绝对路径 |
||
self.mid_list_for_random = [] | ||
for line in f_review: | ||
arr = line.strip().split("\t") | ||
tmp_idx = 0 | ||
if arr[1] in self.source_dicts[1]: | ||
tmp_idx = self.source_dicts[1][arr[1]] | ||
self.mid_list_for_random.append(tmp_idx) | ||
|
||
self.batch_size = batch_size | ||
self.maxlen = maxlen | ||
self.minlen = minlen | ||
self.skip_empty = skip_empty | ||
|
||
self.n_uid = len(self.source_dicts[0]) | ||
self.n_mid = len(self.source_dicts[1]) | ||
self.n_cat = len(self.source_dicts[2]) | ||
self.n_carte = [len(self.source_dicts[3]), len(self.source_dicts[4])] | ||
print("n_uid=%d, n_mid=%d, n_cat=%d" % (self.n_uid, self.n_mid, self.n_cat)) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 多余的print |
||
|
||
self.shuffle = shuffle_each_epoch | ||
self.sort_by_length = sort_by_length | ||
|
||
self.source_buffer = [] | ||
self.k = batch_size * max_batch_size | ||
|
||
self.end_of_data = False | ||
self.label_type = label_type | ||
|
||
def get_n(self): | ||
return self.n_uid, self.n_mid, self.n_cat, self.n_carte | ||
|
||
def __iter__(self): | ||
return self | ||
|
||
def reset(self): | ||
if self.shuffle: | ||
self.source= shuffle.main(self.source_orig, temporary=True) | ||
else: | ||
self.source.seek(0) | ||
|
||
def __next__(self): | ||
if self.end_of_data: | ||
self.end_of_data = False | ||
self.reset() | ||
raise StopIteration | ||
|
||
source = [] | ||
target = [] | ||
|
||
if len(self.source_buffer) == 0: | ||
for k_ in range(self.k): | ||
ss = self.source.readline() | ||
if ss == "": | ||
break | ||
self.source_buffer.append(ss.strip("\n").split("\t")) | ||
|
||
# sort by history behavior length | ||
if self.sort_by_length: | ||
his_length = numpy.array([len(s[4].split("")) for s in self.source_buffer]) | ||
tidx = his_length.argsort() | ||
|
||
_sbuf = [self.source_buffer[i] for i in tidx] | ||
self.source_buffer = _sbuf | ||
else: | ||
self.source_buffer.reverse() | ||
|
||
if len(self.source_buffer) == 0: | ||
self.end_of_data = False | ||
self.reset() | ||
raise StopIteration | ||
|
||
try: | ||
|
||
# actual work here | ||
while True: | ||
|
||
# read from source file and map to word index | ||
try: | ||
ss = self.source_buffer.pop() | ||
except IndexError: | ||
break | ||
|
||
uid = self.source_dicts[0][ss[1]] if ss[1] in self.source_dicts[0] else 0 | ||
mid = self.source_dicts[1][ss[2]] if ss[2] in self.source_dicts[1] else 0 | ||
cat = self.source_dicts[2][ss[3]] if ss[3] in self.source_dicts[2] else 0 | ||
|
||
tmp = [] | ||
item_carte = [] | ||
for fea in ss[4].split(""): | ||
m = self.source_dicts[1][fea] if fea in self.source_dicts[1] else 0 | ||
tmp.append(m) | ||
i_c = self.source_dicts[3][(ss[2], fea)] if (ss[2], fea) in self.source_dicts[3] else 0 | ||
item_carte.append(i_c) | ||
mid_list = tmp | ||
|
||
tmp1 = [] | ||
cate_carte = [] | ||
for fea in ss[5].split(""): | ||
c = self.source_dicts[2][fea] if fea in self.source_dicts[2] else 0 | ||
tmp1.append(c) | ||
c_c = self.source_dicts[4][(ss[3], fea)] if (ss[3], fea) in self.source_dicts[4] else 0 | ||
cate_carte.append(c_c) | ||
cat_list = tmp1 | ||
|
||
# read from source file and map to word index | ||
|
||
if self.minlen != None: | ||
if len(mid_list) <= self.minlen: | ||
continue | ||
if self.skip_empty and (not mid_list): | ||
continue | ||
|
||
noclk_mid_list = [] | ||
noclk_cat_list = [] | ||
for pos_mid in mid_list: | ||
noclk_tmp_mid = [] | ||
noclk_tmp_cat = [] | ||
noclk_index = 0 | ||
while True: | ||
noclk_mid_indx = random.randint(0, len(self.mid_list_for_random)-1) | ||
noclk_mid = self.mid_list_for_random[noclk_mid_indx] | ||
if noclk_mid == pos_mid: | ||
continue | ||
noclk_tmp_mid.append(noclk_mid) | ||
noclk_tmp_cat.append(self.meta_id_map[noclk_mid]) | ||
noclk_index += 1 | ||
if noclk_index >= 5: | ||
break | ||
noclk_mid_list.append(noclk_tmp_mid) | ||
noclk_cat_list.append(noclk_tmp_cat) | ||
carte_list = [item_carte, cate_carte] | ||
source.append([uid, mid, cat, mid_list, cat_list, noclk_mid_list, noclk_cat_list, carte_list]) | ||
if self.label_type == 1: | ||
target.append([float(ss[0])]) | ||
else: | ||
target.append([float(ss[0]), 1-float(ss[0])]) | ||
|
||
if len(source) >= self.batch_size or len(target) >= self.batch_size: | ||
break | ||
except IOError: | ||
self.end_of_data = True | ||
|
||
# all sentence pairs in maxibatch filtered out because of length | ||
if len(source) == 0 or len(target) == 0: | ||
source, target = self.next() | ||
|
||
return source, target | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,91 @@ | ||
import pickle as pk | ||
|
||
f_train = open("/home/test/modelzoo/DIEN/data/local_train_splitByUser", "r") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 同上,不要使用绝对路径 |
||
uid_dict = {} | ||
mid_dict = {} | ||
cat_dict = {} | ||
item_carte_dict = {} | ||
cate_carte_dict = {} | ||
|
||
iddd = 0 | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这个iddd是有用的变量吗? |
||
for line in f_train: | ||
arr = line.strip("\n").split("\t") | ||
clk = arr[0] | ||
uid = arr[1] | ||
mid = arr[2] | ||
cat = arr[3] | ||
mid_list = arr[4] | ||
cat_list = arr[5] | ||
if uid not in uid_dict: | ||
uid_dict[uid] = 0 | ||
uid_dict[uid] += 1 | ||
if mid not in mid_dict: | ||
mid_dict[mid] = 0 | ||
mid_dict[mid] += 1 | ||
if cat not in cat_dict: | ||
cat_dict[cat] = 0 | ||
cat_dict[cat] += 1 | ||
if len(mid_list) == 0: | ||
continue | ||
for m in mid_list.split(""): | ||
if m not in mid_dict: | ||
mid_dict[m] = 0 | ||
mid_dict[m] += 1 | ||
if (mid, m) not in item_carte_dict: | ||
item_carte_dict[(mid, m)] = 0 | ||
item_carte_dict[(mid, m)] += 1 | ||
#print iddd | ||
iddd+=1 | ||
for c in cat_list.split(""): | ||
if c not in cat_dict: | ||
cat_dict[c] = 0 | ||
cat_dict[c] += 1 | ||
if (cat, c) not in cate_carte_dict: | ||
cate_carte_dict[(cat, c)] = 0 | ||
cate_carte_dict[(cat, c)] += 1 | ||
|
||
sorted_uid_dict = sorted(uid_dict.items(), key=lambda x:x[1], reverse=True) | ||
sorted_mid_dict = sorted(mid_dict.items(), key=lambda x:x[1], reverse=True) | ||
sorted_cat_dict = sorted(cat_dict.items(), key=lambda x:x[1], reverse=True) | ||
sorted_item_carte_dict = sorted(item_carte_dict.items(), key=lambda x:x[1], reverse=True) | ||
sorted_cate_carte_dict = sorted(cate_carte_dict.items(), key=lambda x:x[1], reverse=True) | ||
|
||
uid_voc = {} | ||
index = 0 | ||
for key, value in sorted_uid_dict: | ||
uid_voc[key] = index | ||
index += 1 | ||
|
||
mid_voc = {} | ||
mid_voc["default_mid"] = 0 | ||
index = 1 | ||
for key, value in sorted_mid_dict: | ||
mid_voc[key] = index | ||
index += 1 | ||
|
||
cat_voc = {} | ||
cat_voc["default_cat"] = 0 | ||
index = 1 | ||
for key, value in sorted_cat_dict: | ||
cat_voc[key] = index | ||
index += 1 | ||
|
||
item_carte_voc = {} | ||
item_carte_voc["default_item_carte"] = 0 | ||
index = 1 | ||
for key, value in sorted_item_carte_dict: | ||
item_carte_voc[key] = index | ||
index += 1 | ||
|
||
cate_carte_voc = {} | ||
cate_carte_voc["default_cate_carte"] = 0 | ||
index = 1 | ||
for key, value in sorted_cate_carte_dict: | ||
cate_carte_voc[key] = index | ||
index += 1 | ||
|
||
pk.dump(uid_voc, open("uid_voc.pkl", "wb")) | ||
pk.dump(mid_voc, open("mid_voc.pkl", "wb")) | ||
pk.dump(cat_voc, open("cat_voc.pkl", "wb")) | ||
pk.dump(item_carte_voc, open("item_carte_voc.pkl", "wb")) | ||
pk.dump(cate_carte_voc, open("cate_carte_voc.pkl", "wb")) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
去掉无用的print或是注释