Skip to content

Commit

Permalink
add fix bug
Browse files Browse the repository at this point in the history
  • Loading branch information
lgqfhwy committed Dec 20, 2023
1 parent e9fc1cc commit 7fc27b8
Show file tree
Hide file tree
Showing 2 changed files with 2 additions and 38 deletions.
7 changes: 1 addition & 6 deletions easy_rec/python/input/input.py
Original file line number Diff line number Diff line change
Expand Up @@ -441,13 +441,8 @@ def _parse_tag_feature(self, fc, parsed_dict, field_dict):
if fc.HasField('kv_separator'):
indices = parsed_dict[feature_name].indices
tmp_kvs = parsed_dict[feature_name].values
print('tmp_kvs = ', tmp_kvs)
tmp_kvs = tf.Print(tmp_kvs, [tmp_kvs], message='print_tag_value=')
tmp_kvs = tf.string_split(tmp_kvs, fc.kv_separator, skip_empty=False)
tmp_value = tmp_kvs.values
tmp_value = tf.Print(
tmp_value, [tmp_value], message='print_tag_2_value=')
tmp_kvs = tf.reshape(tmp_value, [-1, 2])
tmp_kvs = tf.reshape(tmp_kvs.values, [-1, 2])
tmp_ks, tmp_vs = tmp_kvs[:, 0], tmp_kvs[:, 1]

check_list = [
Expand Down
33 changes: 1 addition & 32 deletions easy_rec/python/input/odps_input_v3.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
import logging
import sys

import numpy as np
import tensorflow as tf

from easy_rec.python.input.input import Input
Expand Down Expand Up @@ -72,52 +71,22 @@ def _odps_read(self):
batch_num = int(total_records_num / self._data_config.batch_size)
res_num = total_records_num - batch_num * self._data_config.batch_size
batch_defaults = [
np.array([x] * self._data_config.batch_size) for x in record_defaults
[x] * self._data_config.batch_size for x in record_defaults
]
for batch_id in range(batch_num):
batch_data_np = [x.copy() for x in batch_defaults]
for row_id, one_data in enumerate(
reader.read(self._data_config.batch_size)):
for col_id in range(len(record_defaults)):
print('first_batch_num, col_id = ', col_id, ', type = ',
type(one_data[col_id]), ', data = #', one_data[col_id], '#')
if isinstance(one_data[col_id], bytes):
one_data[col_id] = one_data[col_id].decode('utf-8')
print('second_batch_num, col_id = ', col_id, ', type = ',
type(one_data[col_id]), ', data = #', one_data[col_id], '#')
first_split = one_data[col_id].split(',')
second_split = first_split[0].split(':')
print('first_split = ', first_split, ', second_split = ',
second_split)
if one_data[col_id] not in ['', 'NULL', None]:
batch_data_np[col_id][row_id] = one_data[col_id]
print('here0, batch_data_np = ', batch_data_np)
one_data[col_id] = one_data[col_id].decode('utf-8')
batch_data_np[col_id][row_id] = one_data[col_id]
print('here1, batch_data_np = ', batch_data_np)
print('batch_data_np = ', batch_data_np)
yield tuple(batch_data_np)
if res_num > 0:
batch_data_np = [x[:res_num] for x in batch_defaults]
for row_id, one_data in enumerate(reader.read(res_num)):
for col_id in range(len(record_defaults)):
print('first_res_num, col_id = ', col_id, ', type = ',
type(one_data[col_id]), ', data = #', one_data[col_id], '#')
if isinstance(one_data[col_id], bytes):
one_data[col_id] = one_data[col_id].decode('utf-8')
print('second_res_num, col_id = ', col_id, ', type = ',
type(one_data[col_id]), ', data = #', one_data[col_id], '#')
first_split = one_data[col_id].split(',')
second_split = first_split[0].split(':')
print('first_split = ', first_split, ', second_split = ',
second_split)
if one_data[col_id] not in ['', 'NULL', None]:
batch_data_np[col_id][row_id] = one_data[col_id]
print('here00, batch_data_np = ', batch_data_np)
one_data[col_id] = one_data[col_id].decode('utf-8')
batch_data_np[col_id][row_id] = one_data[col_id]
print('here11, batch_data_np = ', batch_data_np)
print('batch_data_np = ', batch_data_np)
yield tuple(batch_data_np)
reader.close()
logging.info('finish epoch[%d]' % self._num_epoch)
Expand Down

0 comments on commit 7fc27b8

Please sign in to comment.