Skip to content

Commit

Permalink
Merge branch 'layerwise' of https://github.com/Xtra-Computing/FedTree
Browse files Browse the repository at this point in the history
…into layerwise
  • Loading branch information
QinbinLi committed Jan 26, 2023
2 parents 3ed5354 + efd8caf commit 82da2f3
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 14 deletions.
26 changes: 13 additions & 13 deletions python/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,22 +3,22 @@
from shutil import copyfile
from sys import platform

# dirname = path.dirname(path.abspath(__file__))
dirname = path.dirname(path.abspath(__file__))

# if platform == "linux" or platform == "linux2":
# lib_path = path.abspath(path.join(dirname, '../build/lib/libFedTree.so'))
# elif platform == "win32":
# lib_path = path.abspath(path.join(dirname, '../build/bin/Debug/libFedTree.dll'))
# elif platform == "darwin":
# lib_path = path.abspath(path.join(dirname, '../build/lib/libFedTree.dylib'))
# else:
# print("OS not supported!")
# exit()
if platform == "linux" or platform == "linux2":
lib_path = path.abspath(path.join(dirname, '../build/lib/libFedTree.so'))
elif platform == "win32":
lib_path = path.abspath(path.join(dirname, '../build/bin/Debug/libFedTree.dll'))
elif platform == "darwin":
lib_path = path.abspath(path.join(dirname, '../build/lib/libFedTree.dylib'))
else:
print("OS not supported!")
exit()

# if not path.exists(path.join(dirname, "fedtree", path.basename(lib_path))):
# copyfile(lib_path, path.join(dirname, "fedtree", path.basename(lib_path)))
if not path.exists(path.join(dirname, "fedtree", path.basename(lib_path))):
copyfile(lib_path, path.join(dirname, "fedtree", path.basename(lib_path)))

lib_path = "./fedtree/libFedTree.so"
# lib_path = "./fedtree/libFedTree.so"


setuptools.setup(name="fedtree",
Expand Down
1 change: 0 additions & 1 deletion src/FedTree/Tree/gbdt.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,6 @@ void GBDT::predict_raw(const GBDTParam &model_param, const DataSet &dataSet, Syn

int total_num_node = num_iter * num_class * num_node;
y_predict.resize(n_instances * num_class);
std::cout<<"num_class in predict_raw:"<<num_class<<std::endl;
SyncArray<Tree::TreeNode> model(total_num_node);
auto model_data = model.host_data();
int tree_cnt = 0;
Expand Down

0 comments on commit 82da2f3

Please sign in to comment.