Skip to content
This repository has been archived by the owner on Apr 28, 2023. It is now read-only.

Commit

Permalink
Add support for strided tensors
Browse files Browse the repository at this point in the history
This commit is to start support for strided tensors. I made changes
to percolate a vector in TensorInfo down to emitCudaKernel to allow
codegen to cast strided tensors. This required changes to an unit test
to expect the correct cast.
  • Loading branch information
Protonu Basu committed Jun 8, 2018
1 parent cc4b1eb commit ff1ed36
Show file tree
Hide file tree
Showing 8 changed files with 260 additions and 15 deletions.
Binary file added .test_tc_mapper_output.txt.swp
Binary file not shown.
5 changes: 4 additions & 1 deletion tc/core/cuda/cuda_tc_executor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -93,13 +93,16 @@ CudaCompilationResult CudaBackend::compileWithTcMapper(
auto parameters = mappedScop->scop().getParameterValues();
auto specializedName = specializeKernelName(tcName, parameters);

auto inputsInfo = makeTensorInfoVector(inputs);

// This updates the launch bounds with the actual result from compilation
// with tightening of launch_bounds. What you get is not necessarily what
// you asked for, the autotuner should adapt to that.
std::string source;
Grid grid;
Block block;
std::tie(source, grid, block) = mappedScop->codegen(specializedName);
std::tie(source, grid, block) =
mappedScop->codegen(specializedName, inputsInfo);
LOG_IF(INFO, FLAGS_dump_cuda) << "generatedCuda: " << source << "\n"
<< "grid: " << grid << " block: " << block;

Expand Down
26 changes: 19 additions & 7 deletions tc/core/polyhedral/cuda/codegen.cc
Original file line number Diff line number Diff line change
Expand Up @@ -183,15 +183,23 @@ void emitTensorView(
stringstream& ss,
Halide::OutputImageParam p,
const map<string, Halide::Expr>& paramValues,
bool constInput = false) {
bool constInput = false,
const TensorInfo* tinfo = NULL) {
WS ws;
stringstream ssViewType;
for (int i = 1; i < p.dimensions(); ++i) { // Skip the outermost dimension
Halide::Expr extent = p.parameter().extent_constraint(i);
extent = Halide::Internal::substitute(paramValues, extent);
CHECK(extent.defined())
<< "Undefined extent on input/output tensor. Forward bounds inference should have set these\n";
ssViewType << "[" << extent << "]";
// TODO: Handle non-unit stride in the innermost dimension
if (tinfo && tinfo->strides.size() == p.dimensions() &&
tinfo->strides[p.dimensions() - 1] == 1 &&
tinfo->strides[i - 1] != (tinfo->shape[i] * tinfo->strides[i])) {
ssViewType << "[" << tinfo->strides[i - 1] << "]";
} else {
ssViewType << "[" << extent << "]";
}
}
ss << ws.tab();
ss << (constInput ? "const " : "") << p.type() << " (*" << p.name() << ")"
Expand All @@ -216,9 +224,12 @@ void emitTensorViews(
void emitTensorViews(
stringstream& ss,
const vector<Halide::ImageParam>& params,
const map<string, Halide::Expr>& paramValues) {
for (auto p : params) {
emitTensorView(ss, p, paramValues, true);
const map<string, Halide::Expr>& paramValues,
const std::vector<TensorInfo>& inputsInfo = std::vector<TensorInfo>{}) {
for (size_t i = 0; i < params.size(); ++i) {
inputsInfo.size()
? emitTensorView(ss, params[i], paramValues, true, &inputsInfo[i])
: emitTensorView(ss, params[i], paramValues, true);
}
}

Expand Down Expand Up @@ -738,7 +749,8 @@ std::unordered_set<isl::id, isl::IslIdIslHash> gatherReadOnlySet(

string emitCudaKernel(
const std::string& specializedName,
const MappedScop& mscop) {
const MappedScop& mscop,
const std::vector<TensorInfo>& inputsInfo) {
// Expecting a schedule with domain root and context first child.
CHECK(mscop.schedule()->elemAs<detail::ScheduleTreeElemDomain>());
CHECK(
Expand All @@ -755,7 +767,7 @@ string emitCudaKernel(
emitKernelSignature(ss, specializedName, scop);
emitThreadIdInit(ss, mscop);
emitTensorViews(ss, scop.halide.outputs, paramValues);
emitTensorViews(ss, scop.halide.inputs, paramValues);
emitTensorViews(ss, scop.halide.inputs, paramValues, inputsInfo);
emitTmpDecl(ss, scop);
emitPromotedArrayViewsHalide(ss, scop);
NodeInfoMapType nodeInfoMap;
Expand Down
3 changes: 2 additions & 1 deletion tc/core/polyhedral/cuda/codegen.h
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,8 @@ struct CodegenStatementContext : CodegenContext {

std::string emitCudaKernel(
const std::string& specializedName,
const MappedScop& scop);
const MappedScop& scop,
const std::vector<TensorInfo>& inputsInfo = std::vector<TensorInfo>{});

} // namespace polyhedral
} // namespace tc
7 changes: 4 additions & 3 deletions tc/core/polyhedral/cuda/mapped_scop.cc
Original file line number Diff line number Diff line change
Expand Up @@ -910,7 +910,8 @@ std::unique_ptr<MappedScop> makeSpecializedMappedScop(
// the context of the original scop as top-level
// context node in schedule tree.
std::tuple<std::string, tc::Grid, tc::Block> MappedScop::codegen(
const std::string& specializedName) const {
const std::string& specializedName,
const std::vector<TensorInfo>& inputsInfo) const {
validate(schedule());

auto mappedScopForCodegen = makeSpecializedMappedScop(*this);
Expand All @@ -927,8 +928,8 @@ std::tuple<std::string, tc::Grid, tc::Block> MappedScop::codegen(
code << code::cuda::cubBlockReduce;
}
code << "extern \"C\" {" << std::endl
<< emitCudaKernel(specializedName, *mappedScopForCodegen) << "}"
<< std::endl;
<< emitCudaKernel(specializedName, *mappedScopForCodegen, inputsInfo)
<< "}" << std::endl;

return std::make_tuple(
code.str(),
Expand Down
4 changes: 3 additions & 1 deletion tc/core/polyhedral/cuda/mapped_scop.h
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,9 @@ class MappedScop {
// Generate CUDA code at the current state of transformation provided a
// name for the generated function.
std::tuple<std::string, tc::Grid, tc::Block> codegen(
const std::string& specializedName) const;
const std::string& specializedName,
const std::vector<TensorInfo>& inputsInfo =
std::vector<TensorInfo>{}) const;

// Accessors..
// Const accessor to schedule of underlying Scop.
Expand Down
4 changes: 2 additions & 2 deletions test/cuda/test_tc_mapper.cc
Original file line number Diff line number Diff line change
Expand Up @@ -326,8 +326,8 @@ def tensoraddstrided(float(N, M) I0_view, float(N, M) I1_view) -> (O) {
auto res = Check(TC, name, options, inputs, checkFun);
// This test should be modified when strided tensors are handled
std::string expected =
"const float32 (*I0_view)[64] = "
"reinterpret_cast<const float32 (*)[64]>(pI0_view)";
"const float32 (*I0_view)[128] = "
"reinterpret_cast<const float32 (*)[128]>(pI0_view)";
ASSERT_NE(std::string::npos, res.second.find(expected))
<< "In resulting code:\n"
<< res.second << "\nfound unexpected: " << expected;
Expand Down
226 changes: 226 additions & 0 deletions test_tc_mapper_output.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,226 @@
Note: Google Test filter = *Strided*
[==========] Running 1 test from 1 test case.
[----------] Global test environment set-up.
[----------] 1 test from TcCudaMapperTest
[ RUN ] TcCudaMapperTest.TensorAddStrided
WARNING:
Reduction without initialization. If O is not pre-initialized before calling the TC function, consider using the !-suffixed reduction operator +=! instead of +=:

def tensoraddstrided(float(N, M) I0_view, float(N, M) I1_view) -> (O) {
O(n, m) += I0_view(n, m) + I1_view(n, m)
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~... <--- HERE
}

WARNING:
Reduction without initialization. If O is not pre-initialized before calling the TC function, consider using the !-suffixed reduction operator +=! instead of +=:

def tensoraddstrided(float(N, M) I0_view, float(N, M) I1_view) -> (O) {
O(n, m) += I0_view(n, m) + I1_view(n, m)
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~... <--- HERE
}

I0607 13:02:54.070823 21973 cuda_tc_executor.cc:82] tc::CudaMappingOptions::makeNaiveMappingOptions()
.outerScheduleFusionStrategy(tc::FusionStrategy::Preserve3Coincident)
.outerScheduleAllowSkewing(false)
.outerSchedulePositiveOrthant(true)
.intraTileScheduleFusionStrategy(tc::FusionStrategy::Preserve3Coincident)
.intraTileScheduleAllowSkewing(false)
.intraTileSchedulePositiveOrthant(true)
.fixParametersBeforeScheduling(false)
.tile(32, 32, 32)
.unroll(1)
.tileImperfectlyNested(false)
.matchLibraryCalls(false)
.mapToThreads(32, 8)
.mapToBlocks(256, 256)
.useSharedMemory(false)
.usePrivateMemory(false)
.unrollCopyShared(false)
.useReadOnlyCache(false);
I0607 13:02:54.072165 21973 cuda_tc_executor.cc:83] original schedule:
domain(
[M, N] -> { S_0[O_s1_n, O_s1_m] : 0 <= O_s1_n < N and 0 <= O_s1_m < M })
band(n(1) permutable(0) coincident(0) unroll(0)
-----------------------------------------------------------------------
| [M, N] -> { S_0[O_s1_n, O_s1_m] -> [(O_s1_n)] }
-----------------------------------------------------------------------
band(n(1) permutable(0) coincident(0) unroll(0)
-----------------------------------------------------------------------
| [M, N] -> { S_0[O_s1_n, O_s1_m] -> [(O_s1_m)] }
-----------------------------------------------------------------------
I0607 13:02:54.075304 21973 scop.cc:400] After scheduling:
domain(
[M, N] -> { S_0[O_s1_n, O_s1_m] : 0 <= O_s1_n < N and 0 <= O_s1_m < M })
band(n(2) permutable(1) coincident(1, 1) unroll(0, 0)
-----------------------------------------------------------------------
| [M, N] -> { S_0[O_s1_n, O_s1_m] -> [(O_s1_n)] }
-----------------------------------------------------------------------
| [M, N] -> { S_0[O_s1_n, O_s1_m] -> [(O_s1_m)] }
-----------------------------------------------------------------------
I0607 13:02:54.075870 21973 scop.cc:454] After tiling outer:
domain(
[M, N] -> { S_0[O_s1_n, O_s1_m] : 0 <= O_s1_n < N and 0 <= O_s1_m < M })
band(n(2) permutable(1) coincident(1, 1) unroll(0, 0)
-----------------------------------------------------------------------
| [M, N] -> { S_0[O_s1_n, O_s1_m] -> [(floor((O_s1_n)/32))] }
-----------------------------------------------------------------------
| [M, N] -> { S_0[O_s1_n, O_s1_m] -> [(floor((O_s1_m)/32))] }
-----------------------------------------------------------------------
band(n(2) permutable(1) coincident(1, 1) unroll(0, 0)
-----------------------------------------------------------------------
| [M, N] -> { S_0[O_s1_n, O_s1_m] -> [(O_s1_n - 32*floor((O_s1_n)/32))] }
-----------------------------------------------------------------------
| [M, N] -> { S_0[O_s1_n, O_s1_m] -> [(O_s1_m - 32*floor((O_s1_m)/32))] }
-----------------------------------------------------------------------
I0607 13:02:54.078128 21973 mapped_scop.cc:1021] After mapping to threads:
domain(
[M, N] -> { S_0[O_s1_n, O_s1_m] : M = 64 and N = 64 and 0 <= O_s1_n <= 63 and 0 <= O_s1_m <= 63 })
context([M, N, t1, t0, t2, b2, b1, b0] -> { [] : t2 = 0 and b2 = 0 and 0 <= t1 <= 7 and 0 <= t0 <= 31 and 0 <= b1 <= 255 and 0 <= b0 <= 255 })
band(n(2) permutable(1) coincident(1, 1) unroll(0, 0)
-----------------------------------------------------------------------
| [M, N] -> { S_0[O_s1_n, O_s1_m] -> [(floor((O_s1_n)/32))] }
-----------------------------------------------------------------------
| [M, N] -> { S_0[O_s1_n, O_s1_m] -> [(floor((O_s1_m)/32))] }
-----------------------------------------------------------------------
mapping_filter(ids(t1, t0, )
[M, N, t0, t1] -> { S_0[O_s1_n, O_s1_m] : (-t1 + O_s1_n) mod 8 = 0 and (-t0 + O_s1_m) mod 32 = 0 and 0 <= t0 <= 31 and 0 <= t1 <= 7 })
band(n(2) permutable(1) coincident(1, 1) unroll(0, 0)
-----------------------------------------------------------------------
| [M, N] -> { S_0[O_s1_n, O_s1_m] -> [(O_s1_n - 32*floor((O_s1_n)/32))] }
-----------------------------------------------------------------------
| [M, N] -> { S_0[O_s1_n, O_s1_m] -> [(O_s1_m - 32*floor((O_s1_m)/32))] }
-----------------------------------------------------------------------
thread_specific()
I0607 13:02:54.079393 21973 schedule_transforms.cc:391] Resizing scales to 2 entries: 32 32 32
I0607 13:02:54.079439 21973 mapped_scop.cc:1029] After mapping to blocks:
domain(
[M, N] -> { S_0[O_s1_n, O_s1_m] : M = 64 and N = 64 and 0 <= O_s1_n <= 63 and 0 <= O_s1_m <= 63 })
context([M, N, t1, t0, t2, b2, b1, b0] -> { [] : t2 = 0 and b2 = 0 and 0 <= t1 <= 7 and 0 <= t0 <= 31 and 0 <= b1 <= 255 and 0 <= b0 <= 255 })
mapping_filter(ids(b1, b0, )
[M, N, b0, b1] -> { S_0[O_s1_n, O_s1_m] : -31 - 32b1 + O_s1_m <= 8192*floor((O_s1_m)/8192) <= -32b1 + O_s1_m and -31 - 32b0 + O_s1_n <= 8192*floor((O_s1_n)/8192) <= -32b0 + O_s1_n })
band(n(2) permutable(1) coincident(1, 1) unroll(0, 0)
-----------------------------------------------------------------------
| [M, N] -> { S_0[O_s1_n, O_s1_m] -> [(32*floor((O_s1_n)/32))] }
-----------------------------------------------------------------------
| [M, N] -> { S_0[O_s1_n, O_s1_m] -> [(32*floor((O_s1_m)/32))] }
-----------------------------------------------------------------------
mapping_filter(ids(t1, t0, )
[M, N, t0, t1] -> { S_0[O_s1_n, O_s1_m] : (-t1 + O_s1_n) mod 8 = 0 and (-t0 + O_s1_m) mod 32 = 0 and 0 <= t0 <= 31 and 0 <= t1 <= 7 })
band(n(2) permutable(1) coincident(1, 1) unroll(0, 0)
-----------------------------------------------------------------------
| [M, N] -> { S_0[O_s1_n, O_s1_m] -> [(O_s1_n - 32*floor((O_s1_n)/32))] }
-----------------------------------------------------------------------
| [M, N] -> { S_0[O_s1_n, O_s1_m] -> [(O_s1_m - 32*floor((O_s1_m)/32))] }
-----------------------------------------------------------------------
thread_specific()
I0607 13:02:54.079643 21973 mapped_scop.cc:1083] After outerBlockInnerThread strategy:
domain(
[M, N] -> { S_0[O_s1_n, O_s1_m] : M = 64 and N = 64 and 0 <= O_s1_n <= 63 and 0 <= O_s1_m <= 63 })
context([M, N, t1, t0, t2, b2, b1, b0] -> { [] : t2 = 0 and b2 = 0 and 0 <= t1 <= 7 and 0 <= t0 <= 31 and 0 <= b1 <= 255 and 0 <= b0 <= 255 })
mapping_filter(ids(b1, b0, )
[M, N, b0, b1] -> { S_0[O_s1_n, O_s1_m] : -31 - 32b1 + O_s1_m <= 8192*floor((O_s1_m)/8192) <= -32b1 + O_s1_m and -31 - 32b0 + O_s1_n <= 8192*floor((O_s1_n)/8192) <= -32b0 + O_s1_n })
band(n(2) permutable(1) coincident(1, 1) unroll(0, 0)
-----------------------------------------------------------------------
| [M, N] -> { S_0[O_s1_n, O_s1_m] -> [(32*floor((O_s1_n)/32))] }
-----------------------------------------------------------------------
| [M, N] -> { S_0[O_s1_n, O_s1_m] -> [(32*floor((O_s1_m)/32))] }
-----------------------------------------------------------------------
mapping_filter(ids(t1, t0, )
[M, N, t0, t1] -> { S_0[O_s1_n, O_s1_m] : (-t1 + O_s1_n) mod 8 = 0 and (-t0 + O_s1_m) mod 32 = 0 and 0 <= t0 <= 31 and 0 <= t1 <= 7 })
band(n(2) permutable(1) coincident(1, 1) unroll(0, 0)
-----------------------------------------------------------------------
| [M, N] -> { S_0[O_s1_n, O_s1_m] -> [(O_s1_n - 32*floor((O_s1_n)/32))] }
-----------------------------------------------------------------------
| [M, N] -> { S_0[O_s1_n, O_s1_m] -> [(O_s1_m - 32*floor((O_s1_m)/32))] }
-----------------------------------------------------------------------
thread_specific()
I0607 13:02:54.079829 21973 cuda_tc_executor.cc:90] Mapped schedule:
domain(
[M, N] -> { S_0[O_s1_n, O_s1_m] : M = 64 and N = 64 and 0 <= O_s1_n <= 63 and 0 <= O_s1_m <= 63 })
context([M, N, t1, t0, t2, b2, b1, b0] -> { [] : t2 = 0 and b2 = 0 and 0 <= t1 <= 7 and 0 <= t0 <= 31 and 0 <= b1 <= 255 and 0 <= b0 <= 255 })
mapping_filter(ids(b1, b0, )
[M, N, b0, b1] -> { S_0[O_s1_n, O_s1_m] : -31 - 32b1 + O_s1_m <= 8192*floor((O_s1_m)/8192) <= -32b1 + O_s1_m and -31 - 32b0 + O_s1_n <= 8192*floor((O_s1_n)/8192) <= -32b0 + O_s1_n })
band(n(2) permutable(1) coincident(1, 1) unroll(0, 0)
-----------------------------------------------------------------------
| [M, N] -> { S_0[O_s1_n, O_s1_m] -> [(32*floor((O_s1_n)/32))] }
-----------------------------------------------------------------------
| [M, N] -> { S_0[O_s1_n, O_s1_m] -> [(32*floor((O_s1_m)/32))] }
-----------------------------------------------------------------------
mapping_filter(ids(t1, t0, )
[M, N, t0, t1] -> { S_0[O_s1_n, O_s1_m] : (-t1 + O_s1_n) mod 8 = 0 and (-t0 + O_s1_m) mod 32 = 0 and 0 <= t0 <= 31 and 0 <= t1 <= 7 })
band(n(2) permutable(1) coincident(1, 1) unroll(0, 0)
-----------------------------------------------------------------------
| [M, N] -> { S_0[O_s1_n, O_s1_m] -> [(O_s1_n - 32*floor((O_s1_n)/32))] }
-----------------------------------------------------------------------
| [M, N] -> { S_0[O_s1_n, O_s1_m] -> [(O_s1_m - 32*floor((O_s1_m)/32))] }
-----------------------------------------------------------------------
thread_specific()
I0607 13:02:54.091660 21973 mapped_scop.cc:900] Codegen with tightened bounds [blocks:CudaDim(2, 2, 1) @0x7ffefab63f90, threads:CudaDim(32, 8, 1) @0x7ffefab63fd0] for tree:
domain(
[M, N] -> { S_0[O_s1_n, O_s1_m] : M = 64 and N = 64 and 0 <= O_s1_n <= 63 and 0 <= O_s1_m <= 63 })
context([M, N, t1, t0, t2, b2, b1, b0] -> { [] : M = 64 and N = 64 and t2 = 0 and b2 = 0 and 0 <= t1 <= 7 and 0 <= t0 <= 31 and 0 <= b1 <= 1 and 0 <= b0 <= 1 })
mapping_filter(ids(b1, b0, )
[M, N, b0, b1] -> { S_0[O_s1_n, O_s1_m] : -31 - 32b1 + O_s1_m <= 8192*floor((O_s1_m)/8192) <= -32b1 + O_s1_m and -31 - 32b0 + O_s1_n <= 8192*floor((O_s1_n)/8192) <= -32b0 + O_s1_n })
band(n(2) permutable(1) coincident(1, 1) unroll(0, 0)
-----------------------------------------------------------------------
| [M, N] -> { S_0[O_s1_n, O_s1_m] -> [(32*floor((O_s1_n)/32))] }
-----------------------------------------------------------------------
| [M, N] -> { S_0[O_s1_n, O_s1_m] -> [(32*floor((O_s1_m)/32))] }
-----------------------------------------------------------------------
mapping_filter(ids(t1, t0, )
[M, N, t0, t1] -> { S_0[O_s1_n, O_s1_m] : (-t1 + O_s1_n) mod 8 = 0 and (-t0 + O_s1_m) mod 32 = 0 and 0 <= t0 <= 31 and 0 <= t1 <= 7 })
band(n(2) permutable(1) coincident(1, 1) unroll(0, 0)
-----------------------------------------------------------------------
| [M, N] -> { S_0[O_s1_n, O_s1_m] -> [(O_s1_n - 32*floor((O_s1_n)/32))] }
-----------------------------------------------------------------------
| [M, N] -> { S_0[O_s1_n, O_s1_m] -> [(O_s1_m - 32*floor((O_s1_m)/32))] }
-----------------------------------------------------------------------
thread_specific()
I0607 13:02:54.130249 21973 cuda_rtc.cc:58] NVRTC function source:

template<typename T> inline __device__ T floord(T n, T d) {
return n < 0 ? - (-n + d - 1)/d : n / d;
}
#define if_then_else(cond,a,b) ((cond) ? (a) : (b))

// Halide type handling
typedef int int32;
typedef long int64;
typedef float float32;
typedef double float64;

#define inff __int_as_float(0x7f800000)
#define inf __longlong_as_double(0x7ff0000000000000LL)

// Before CUDA 9, syncwarp is a noop since warps are always synchronized.
#if __CUDACC_VER_MAJOR__ < 9
__device__ void __syncwarp(unsigned mask = 0xFFFFFFFF) {}
#endif

extern "C" {
__global__ void tensoraddstrided_64_64(int32 M, int32 N, float32* pO, const float32* pI0_view, const float32* pI1_view) {
int b0 = blockIdx.x; int b1 = blockIdx.y; int b2 = blockIdx.z;
int t0 = threadIdx.x; int t1 = threadIdx.y; int t2 = threadIdx.z;
float32 (*O)[64] = reinterpret_cast<float32 (*)[64]>(pO);
const float32 (*I0_view)[128] = reinterpret_cast<const float32 (*)[128]>(pI0_view);
const float32 (*I1_view)[128] = reinterpret_cast<const float32 (*)[128]>(pI1_view);
for (int c2 = t1; c2 <= 31; c2 += 8) {
O[(32 * b0 + c2)][(t0 + 32 * b1)] = (O[(32 * b0 + c2)][(t0 + 32 * b1)] + (I0_view[(32 * b0 + c2)][(t0 + 32 * b1)] + I1_view[(32 * b0 + c2)][(t0 + 32 * b1)]));
}
}
}
I0607 13:02:54.348301 21973 cuda_tc_executor.cc:64] [COMPILE] Compiling with host JIT compiler took: 218ms
WARNING:
Reduction without initialization. If O is not pre-initialized before calling the TC function, consider using the !-suffixed reduction operator +=! instead of +=:

def tensoraddstrided(float(N, M) I0_view, float(N, M) I1_view) -> (O) {
O(n, m) += I0_view(n, m) + I1_view(n, m)
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~... <--- HERE
}

[ OK ] TcCudaMapperTest.TensorAddStrided (297 ms)
[----------] 1 test from TcCudaMapperTest (297 ms total)

[----------] Global test environment tear-down
[==========] 1 test from 1 test case ran. (298 ms total)
[ PASSED ] 1 test.

0 comments on commit ff1ed36

Please sign in to comment.