Skip to content

Commit

Permalink
Fix xgb string data contains slash (#2869)
Browse files Browse the repository at this point in the history
* fix xgb string data contains slash

* fix ci
  • Loading branch information
typhoonzero authored Aug 28, 2020
1 parent 6cd7b24 commit a181bcb
Show file tree
Hide file tree
Showing 4 changed files with 13 additions and 9 deletions.
6 changes: 3 additions & 3 deletions python/runtime/local/xgboost_submitter/predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from runtime.feature.derivation import get_ordered_field_descs
from runtime.feature.field_desc import DataType
from runtime.model.model import Model
from runtime.xgboost.dataset import xgb_dataset
from runtime.xgboost.dataset import DMATRIX_FILE_SEP, xgb_dataset


def pred(datasource, select, result_table, pred_label_name, model):
Expand Down Expand Up @@ -148,8 +148,8 @@ def _store_predict_result(preds, result_table, result_column_names,
break

row = [
item for i, item in enumerate(line.strip().split("/"))
if i != train_label_idx
item for i, item in enumerate(line.strip().split(
DMATRIX_FILE_SEP)) if i != train_label_idx
]
row.append(str(preds[line_no]))
w.write(row)
Expand Down
7 changes: 5 additions & 2 deletions python/runtime/xgboost/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@
from scipy.sparse import vstack
from sklearn.datasets import load_svmlight_file, load_svmlight_files

DMATRIX_FILE_SEP = "\t"


def xgb_dataset(datasource,
fn,
Expand Down Expand Up @@ -130,7 +132,8 @@ def dump_dmatrix(filename,
feature_metas)

if raw_data_fid is not None:
raw_data_fid.write("/".join([str(r) for r in row]) + "\n")
raw_data_fid.write(
DMATRIX_FILE_SEP.join([str(r) for r in row]) + "\n")

if transform_fn:
features = transform_fn(features)
Expand Down Expand Up @@ -163,7 +166,7 @@ def dump_dmatrix(filename,
if has_label:
row_data = [str(label)] + row_data

f.write("\t".join(row_data) + "\n")
f.write(DMATRIX_FILE_SEP.join(row_data) + "\n")
row_id += 1
# batch_size == None means use all data in generator
if batch_size is None:
Expand Down
4 changes: 2 additions & 2 deletions python/runtime/xgboost/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from runtime import db
from runtime.dbapi.paiio import PaiIOConnection
from runtime.model.metadata import load_metadata
from runtime.xgboost.dataset import xgb_dataset
from runtime.xgboost.dataset import DMATRIX_FILE_SEP, xgb_dataset

SKLEARN_METRICS = [
'accuracy_score',
Expand Down Expand Up @@ -118,7 +118,7 @@ def evaluate_and_store_result(bst, dpred, feature_file_id, validation_metrics,

y_test_list = []
for line in feature_file_read:
row = [i for i in line.strip().split("\t")]
row = [i for i in line.strip().split(DMATRIX_FILE_SEP)]
# DMatrix store label in the first column
if label_meta["dtype"] == "float32":
label = float(row[0])
Expand Down
5 changes: 3 additions & 2 deletions python/runtime/xgboost/predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from runtime import db
from runtime.dbapi.paiio import PaiIOConnection
from runtime.model.metadata import load_metadata
from runtime.xgboost.dataset import xgb_dataset
from runtime.xgboost.dataset import DMATRIX_FILE_SEP, xgb_dataset

DEFAULT_PREDICT_BATCH_SIZE = 10000

Expand Down Expand Up @@ -123,7 +123,8 @@ def predict_and_store_result(bst, dpred, feature_file_id, model_params,
# FIXME(typhoonzero): how to output columns that are not used
# as features, like ids?
row = [
item for i, item in enumerate(line.strip().split("/"))
item
for i, item in enumerate(line.strip().split(DMATRIX_FILE_SEP))
if i != train_label_index
]
row.append(str(preds[line_no]))
Expand Down

0 comments on commit a181bcb

Please sign in to comment.