Skip to content

Commit

Permalink
Compiler is now complete
Browse files Browse the repository at this point in the history
  • Loading branch information
hcho3 committed Feb 14, 2024
1 parent 08f5813 commit 8e8fd49
Show file tree
Hide file tree
Showing 11 changed files with 152 additions and 203 deletions.
34 changes: 26 additions & 8 deletions include/tl2cgen/detail/compiler/codegen/codegen.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,16 +7,27 @@
#ifndef TL2CGEN_DETAIL_COMPILER_CODEGEN_CODEGEN_H_
#define TL2CGEN_DETAIL_COMPILER_CODEGEN_CODEGEN_H_

#include <filesystem>
#include <map>
#include <ostream>
#include <string>
#include <vector>

namespace tl2cgen::compiler::detail {
/* Forward declarations */
namespace treelite {

namespace ast {
class Model;

} // namespace treelite

namespace tl2cgen::compiler {

struct CompilerParam;

} // namespace tl2cgen::compiler

namespace tl2cgen::compiler::detail::ast {

// Forward declarations
class ASTNode;
class MainNode;
class FunctionNode;
Expand All @@ -26,13 +37,16 @@ class TranslationUnitNode;
class QuantizerNode;
class ModelMeta;

} // namespace ast
} // namespace tl2cgen::compiler::detail::ast

namespace codegen {
namespace tl2cgen::compiler::detail::codegen {

class CodeCollection; // forward declaration

void GenerateCodeFromAST(ast::ASTNode const* node, CodeCollection& gencode);
void WriteCodeToDisk(std::filesystem::path const& dirpath, CodeCollection const& collection);
void WriteBuildRecipeToDisk(std::filesystem::path const& dirpath,
std::string const& native_lib_name, CodeCollection const& collection);

// Codegen implementation for each AST node type
void HandleMainNode(ast::MainNode const* node, CodeCollection& gencode);
Expand Down Expand Up @@ -70,6 +84,9 @@ class SourceFile {
void ChangeIndent(int n_tabs_delta); // Add or remove indent
void PushFragment(std::string content);
friend std::ostream& operator<<(std::ostream&, CodeCollection const&);
friend void WriteCodeToDisk(std::filesystem::path const& dirpath, CodeCollection const&);
friend void WriteBuildRecipeToDisk(
std::filesystem::path const&, std::string const&, CodeCollection const&);
friend class CodeCollection;
};

Expand All @@ -88,11 +105,12 @@ class CodeCollection {
void PushFragment(std::string content);

friend std::ostream& operator<<(std::ostream&, CodeCollection const&);
friend void WriteCodeToDisk(std::filesystem::path const&, CodeCollection const&);
friend void WriteBuildRecipeToDisk(
std::filesystem::path const&, std::string const&, CodeCollection const&);
};
std::ostream& operator<<(std::ostream& os, CodeCollection const& collection);

} // namespace codegen

} // namespace tl2cgen::compiler::detail
} // namespace tl2cgen::compiler::detail::codegen

#endif // TL2CGEN_DETAIL_COMPILER_CODEGEN_CODEGEN_H_
15 changes: 0 additions & 15 deletions include/tl2cgen/detail/filesystem.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,21 +19,6 @@ namespace tl2cgen::detail::filesystem {
*/
void CreateDirectoryIfNotExist(std::filesystem::path const& dirpath);

/*!
* \brief Write a sequence of strings to a text file, with newline character (\n) inserted between
* strings. This function is suitable for creating multi-line text files.
* \param path Path to text file
* \param content A sequence of strings to be written.
*/
void WriteToFile(std::filesystem::path const& path, std::string const& content);

/*!
* \brief Write a sequence of bytes to a text file
* \param path Path to text file
* \param content A sequence of bytes to be written.
*/
void WriteToFile(std::filesystem::path const& path, std::vector<char> const& content);

} // namespace tl2cgen::detail::filesystem

#endif // TL2CGEN_DETAIL_FILESYSTEM_H_
14 changes: 4 additions & 10 deletions python/tl2cgen/shortcuts.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Convenience functions"""

import pathlib
import shutil
from tempfile import TemporaryDirectory
Expand All @@ -17,12 +18,11 @@ def export_lib(
toolchain: str,
libpath: Union[str, pathlib.Path],
params: Optional[Dict[str, Any]] = None,
compiler: str = "ast_native",
*,
nthread: Optional[int] = None,
verbose: bool = False,
options: Optional[List[str]] = None,
):
): # pylint: disable=too-many-arguments
"""
Convenience function: Generate prediction code and immediately turn it
into a dynamic shared library. A temporary directory will be created to
Expand All @@ -41,9 +41,6 @@ def export_lib(
Parameters to be passed to the compiler. See
:py:doc:`this page </compiler_param>` for the list of compiler
parameters.
compiler :
Kind of C code generator to use. Currently, there are two possible values:
{"ast_native", "failsafe"}
nthread :
Number of threads to use in creating the shared library.
Defaults to the number of cores in the system.
Expand Down Expand Up @@ -78,7 +75,7 @@ def export_lib(
long_build_time_warning = not (params and "parallel_comp" in params)

with TemporaryDirectory() as tempdir:
generate_c_code(model, tempdir, params, compiler, verbose=verbose)
generate_c_code(model, tempdir, params, verbose=verbose)
temp_libpath = create_shared(
toolchain,
tempdir,
Expand All @@ -98,7 +95,6 @@ def export_srcpkg(
pkgpath: Union[str, pathlib.Path],
libname: str,
params: Optional[Dict[str, Any]] = None,
compiler: str = "ast_native",
*,
verbose: bool = False,
options: Optional[List[str]] = None,
Expand All @@ -123,8 +119,6 @@ def export_srcpkg(
Parameters to be passed to the compiler. See
:py:doc:`this page </compiler_param>` for the list of compiler
parameters.
compiler :
Name of compiler to use in C code generation
verbose :
Whether to produce extra messages
options :
Expand Down Expand Up @@ -168,7 +162,7 @@ def export_srcpkg(
if params is None:
params = {}
params["native_lib_name"] = target
generate_c_code(model, dirpath, params, compiler, verbose=verbose)
generate_c_code(model, dirpath, params, verbose=verbose)
if toolchain == "cmake":
generate_cmakelists(dirpath, options)
else:
Expand Down
1 change: 0 additions & 1 deletion src/c_api/c_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,6 @@ int TL2cgenGenerateCode(

std::filesystem::path dirpath_
= std::filesystem::weakly_canonical(std::filesystem::u8path(std::string(dirpath)));
detail::filesystem::CreateDirectoryIfNotExist(dirpath_);

/* Compile model */
auto param = compiler::CompilerParam::ParseFromJSON(compiler_params_json_str);
Expand Down
47 changes: 47 additions & 0 deletions src/compiler/codegen/codegen.cc
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,15 @@
* \author Hyunsu Cho
*/

#include <rapidjson/ostreamwrapper.h>
#include <rapidjson/prettywriter.h>
#include <tl2cgen/detail/compiler/ast/ast.h>
#include <tl2cgen/detail/compiler/codegen/codegen.h>
#include <tl2cgen/detail/compiler/codegen/format_util.h>
#include <tl2cgen/detail/filesystem.h>
#include <tl2cgen/logging.h>

#include <fstream>
#include <string>
#include <type_traits>
#include <variant>
Expand Down Expand Up @@ -56,6 +60,49 @@ void GenerateCodeFromAST(ast::ASTNode const* node, CodeCollection& gencode) {
}
}

void WriteCodeToDisk(std::filesystem::path const& dirpath, CodeCollection const& collection) {
namespace fs = tl2cgen::detail::filesystem;
for (auto const& [file_name, source_file] : collection.sources_) {
std::ofstream of(dirpath / file_name);
for (auto const& fragment : source_file.fragments_) {
of << IndentMultiLineString(fragment.content_, fragment.indent_) << "\n";
}
of << "\n";
}
}

void WriteBuildRecipeToDisk(std::filesystem::path const& dirpath,
std::string const& native_lib_name, CodeCollection const& collection) {
std::ofstream ofs(dirpath / "recipe.json");
rapidjson::OStreamWrapper ofs_wrapped(ofs);
rapidjson::PrettyWriter<rapidjson::OStreamWrapper> writer(ofs_wrapped);
writer.SetFormatOptions(rapidjson::PrettyFormatOptions::kFormatSingleLineArray);

writer.StartObject();
writer.Key("target");
writer.String(native_lib_name);
writer.Key("sources");
writer.StartArray();
for (auto const& [file_name, source_file] : collection.sources_) {
if (file_name.compare(file_name.length() - 2, 2, ".c") == 0) {
std::size_t line_count = 0;
for (auto const& fragment : source_file.fragments_) {
line_count += std::count(fragment.content_.begin(), fragment.content_.end(), '\n');
}
writer.StartObject();
writer.Key("name");
std::string name = file_name.substr(0, file_name.length() - 2);
writer.String(name);
writer.Key("length");
writer.Uint64(line_count);
writer.EndObject();
}
}
writer.EndArray();
writer.EndObject();
ofs << "\n"; // Add newline at the end, for convention's sake
}

std::string GetThresholdCType(ast::ASTNode const* node) {
return GetThresholdCType(*node->meta_);
}
Expand Down
12 changes: 6 additions & 6 deletions src/compiler/codegen/main_node.cc
Original file line number Diff line number Diff line change
Expand Up @@ -70,8 +70,8 @@ union Entry {{
{extern_array_is_categorical}
{dllexport}void predict(union Entry* data, int pred_margin, {threshold_type}* result);
void postprocess({threshold_type}* result);
{dllexport}void predict(union Entry* data, int pred_margin, {leaf_output_type}* result);
void postprocess({leaf_output_type}* result);
)TL2CGENTEMPLATE";

char const* const main_start_template =
Expand All @@ -80,17 +80,17 @@ char const* const main_start_template =
{array_is_categorical}
void predict(union Entry* data, int pred_margin, {threshold_type}* result) {{
void predict(union Entry* data, int pred_margin, {leaf_output_type}* result) {{
)TL2CGENTEMPLATE";

void HandleMainNode(ast::MainNode const* node, CodeCollection& gencode) {
auto threshold_ctype_str = GetThresholdCType(node);
auto leaf_output_ctype_str = GetLeafOutputCType(node);
std::int32_t const num_target = node->meta_->num_target_;
std::vector<std::int32_t>& num_class = node->meta_->num_class_;
std::int32_t const max_num_class = *std::max_element(num_class.begin(), num_class.end());

gencode.SwitchToSourceFile("header.h");
gencode.PushFragment(fmt::format(header_template, "threshold_type"_a = threshold_ctype_str,
gencode.PushFragment(fmt::format(header_template, "leaf_output_type"_a = leaf_output_ctype_str,
"dllexport"_a = DLLEXPORT_KEYWORD, "num_target"_a = num_target,
"max_num_class"_a = max_num_class,
"extern_array_is_categorical"_a
Expand All @@ -100,7 +100,7 @@ void HandleMainNode(ast::MainNode const* node, CodeCollection& gencode) {
gencode.SwitchToSourceFile("main.c");
gencode.PushFragment(fmt::format(main_start_template,
"array_is_categorical"_a = RenderIsCategoricalArray(node->meta_->is_categorical_),
"threshold_type"_a = threshold_ctype_str));
"leaf_output_type"_a = leaf_output_ctype_str));
gencode.ChangeIndent(1);
TL2CGEN_CHECK_EQ(node->children_.size(), 1);
GenerateCodeFromAST(node->children_[0], gencode);
Expand Down
Loading

0 comments on commit 8e8fd49

Please sign in to comment.