Skip to content

Commit

Permalink
Fix test_basic
Browse files Browse the repository at this point in the history
  • Loading branch information
hcho3 committed Feb 14, 2024
1 parent 19093a8 commit c6d7ba8
Show file tree
Hide file tree
Showing 12 changed files with 42 additions and 4,774 deletions.
19 changes: 17 additions & 2 deletions src/compiler/codegen/condition_node.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,17 @@ namespace {
namespace ast = tl2cgen::compiler::detail::ast;
namespace codegen = tl2cgen::compiler::detail::codegen;

std::string GetFabsCFunc(std::string const& threshold_type) {
if (threshold_type == "float") {
return "fabsf";
} else if (threshold_type == "double") {
return "fabs";
} else {
TL2CGEN_LOG(FATAL) << "Unrecognized type: " << threshold_type;
return "";
}
}

inline std::string ExtractNumericalCondition(ast::NumericalConditionNode const* node) {
std::string const threshold_type = codegen::GetThresholdCType(node);
std::string result;
Expand Down Expand Up @@ -73,6 +84,9 @@ inline std::vector<std::uint64_t> GetCategoricalBitmap(
}

inline std::string ExtractCategoricalCondition(ast::CategoricalConditionNode const* node) {
std::string const threshold_ctype_str = codegen::GetThresholdCType(node);
std::string const fabs = GetFabsCFunc(threshold_ctype_str);

std::string result;
std::vector<std::uint64_t> bitmap = GetCategoricalBitmap(node->category_list_);
TL2CGEN_CHECK_GE(bitmap.size(), 1);
Expand All @@ -99,8 +113,9 @@ inline std::string ExtractCategoricalCondition(ast::CategoricalConditionNode con

oss << fmt::format(
"((data[{split_index}].fvalue >= 0) && "
"(fabsf(data[{split_index}].fvalue) <= (float)(1U << FLT_MANT_DIG)) && (",
"split_index"_a = node->split_index_);
"({fabs}(data[{split_index}].fvalue) <= ({threshold_ctype})(1U << FLT_MANT_DIG)) && (",
"split_index"_a = node->split_index_, "threshold_ctype"_a = threshold_ctype_str,
"fabs"_a = fabs);
oss << "(tmp >= 0 && tmp < 64 && (( (uint64_t)" << bitmap[0] << "U >> tmp) & 1) )";
for (std::size_t i = 1; i < bitmap.size(); ++i) {
oss << " || (tmp >= " << (i * 64) << " && tmp < " << ((i + 1) * 64) << " && (( (uint64_t)"
Expand Down
1 change: 1 addition & 0 deletions src/compiler/codegen/function_node.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ using namespace fmt::literals;
namespace tl2cgen::compiler::detail::codegen {

void HandleFunctionNode(ast::FunctionNode const* node, CodeCollection& gencode) {
gencode.PushFragment("unsigned int tmp;");
for (ast::ASTNode* child : node->children_) {
GenerateCodeFromAST(child, gencode);
}
Expand Down
Loading

0 comments on commit c6d7ba8

Please sign in to comment.