Skip to content

Commit

Permalink
Merge pull request #154 from dhairyagandhi96/dg/infer_test
Browse files Browse the repository at this point in the history
Refactor tests and make NNPACK optional
  • Loading branch information
DhairyaLGandhi authored Dec 24, 2019
2 parents 480754a + 20e8e47 commit 398bfdd
Show file tree
Hide file tree
Showing 6 changed files with 45 additions and 30 deletions.
16 changes: 12 additions & 4 deletions deps/build.jl
Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,18 @@ end

# If we have a download, and we are unsatisfied (or the version we're
# trying to install is not itself installed) then load it up!
if unsatisfied || !isinstalled(dl_info...; prefix=prefix)
# Download and install binaries
# Download and install binaries
use_nnpack = get(ENV, "NNLIB_USE_NNPACK", "false") == "true"
os_support = Sys.islinux() || Sys.isapple()
if use_nnpack && os_support
if unsatisfied || !isinstalled(dl_info...; prefix=prefix)
install(dl_info...; prefix=prefix, force=true, verbose=verbose)
end
# Write out a deps.jl file that will contain mappings for our products
write_deps_file(joinpath(@__DIR__, "deps.jl"), products, verbose=verbose)
else
open(joinpath(@__DIR__, "deps.jl"), "w") do io
write(io, "check_deps() = false")
end
end

# Write out a deps.jl file that will contain mappings for our products
write_deps_file(joinpath(@__DIR__, "deps.jl"), products, verbose=verbose)
7 changes: 4 additions & 3 deletions src/NNlib.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,11 @@ using Requires
include("dim_helpers.jl")

# NNPACK support
if Sys.islinux() || Sys.isapple()
include("nnpack/NNPACK.jl")
include(joinpath(@__DIR__, "..", "deps", "deps.jl"))
if check_deps() == nothing
include("nnpack/NNPACK.jl")
else
is_nnpack_available() = false
is_nnpack_available() = false
end

include("activation.jl")
Expand Down
3 changes: 1 addition & 2 deletions src/nnpack/NNPACK.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ const depsjl_path = joinpath(dirname(@__FILE__), "..", "..", "deps", "deps.jl")
if !isfile(depsjl_path)
error("NNPACK not installed properly, run Pkg.build(\"NNlib\"), restart Julia and try again")
end
include(depsjl_path)

const shared_threadpool_dict = Dict{UInt64, Base.RefValue}()

Expand All @@ -18,7 +17,7 @@ const shared_threadpool_dict = Dict{UInt64, Base.RefValue}()
Checks if the current hardware is supported by NNPACK.
"""
function is_nnpack_available()
check_deps()
check_deps() isa Nothing || return false
status = nnp_initialize()
if status == nnp_status_unsupported_hardware
return false
Expand Down
30 changes: 18 additions & 12 deletions test/conv.jl
Original file line number Diff line number Diff line change
Expand Up @@ -274,9 +274,13 @@ conv_answer_dict = Dict(
# A "drop channels and batch dimension" helper
ddims(x) = dropdims(x, dims=(rank+1, rank+2))

for conv in (NNlib.conv, NNlib.conv_im2col, NNlib.conv_direct, NNlib.conv_nnpack)
if conv == NNlib.conv_nnpack && !NNlib.nnpack_supported_operation(DenseConvDims(x, w))
continue
convs = [NNlib.conv, NNlib.conv_im2col, NNlib.conv_direct,]
NNlib.is_nnpack_available() && push!(convs, NNlib.conv_nnpack)
for conv in convs
if NNlib.is_nnpack_available()
if conv == NNlib.conv_nnpack && !NNlib.nnpack_supported_operation(DenseConvDims(x, w))
continue
end
end
@testset "$(conv)" begin
cdims = DenseConvDims(x, w)
Expand Down Expand Up @@ -352,12 +356,11 @@ conv_answer_dict = Dict(
end
end
end
end

if get(ENV,"NNLIB_TEST_FUZZING","false") == "true"
# @info("Skipping Convolutional fuzzing tests, set NNLIB_TEST_FUZZING=true to run them")
@testset "fuzzing" begin
if get(ENV,"NNLIB_TEST_FUZZING","false") != "true"
@info("Skipping Convolutional fuzzing tests, set NNLIB_TEST_FUZZING=true to run them")
return
end
@info("Starting Convolutional fuzzing tests; this can take a few minutes...")
# Now that we're fairly certain things are working, let's fuzz things a little bit:
for x_size in (
Expand Down Expand Up @@ -441,9 +444,10 @@ conv_answer_dict = Dict(
end
println()
end
else
@info "Skipping Convolutional fuzzing tests, set NNLIB_TEST_FUZZING=true to run them"
end


@testset "Depthwise Convolution" begin
# Start with some easy-to-debug cases that we have worked through and _know_ work
for rank in (1,) #2,3)
Expand Down Expand Up @@ -552,12 +556,11 @@ end
end
end
end
end


if get(ENV,"NNLIB_TEST_FUZZING","false") == "true"
@testset "fuzzing" begin
if get(ENV,"NNLIB_TEST_FUZZING","false") != "true"
@info("Skipping Depthwise Convolutional fuzzing tests, set NNLIB_TEST_FUZZING=true to run them")
return
end
@info("Starting Depthwise Convolutional fuzzing tests; this can take a few minutes...")
# Now that we're fairly certain things are working, let's fuzz things a little bit:
for x_size in (
Expand Down Expand Up @@ -641,8 +644,11 @@ end
end
println()
end
else
@info "Skipping Depthwise Convolutional fuzzing tests, set NNLIB_TEST_FUZZING=true to run them"
end


@testset "conv_wrapper" begin
x = rand(10, 10, 3, 10)
w = rand(2, 2, 3, 16)
Expand Down
5 changes: 2 additions & 3 deletions test/inference.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
using NNlib, Test
using NNlib: conv_direct, conv_im2col
import NNlib: conv_direct, conv_im2col

@testset "Conv Inference" begin
x = rand(10, 10, 3, 2)
Expand All @@ -9,6 +8,6 @@ using NNlib: conv_direct, conv_im2col
NNlib.is_nnpack_available() && push!(impl, NNlib.conv_nnpack)

for T in impl
@inferred T(x, w, DenseConvDims(x, w))
@test T(x, w, DenseConvDims(x, w)) isa AbstractArray{K,4} where K
end
end
14 changes: 8 additions & 6 deletions test/pooling.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
using NNlib, Test
#using NNlib, Test

maxpool_answer_dict = Dict(
1 => Dict(
Expand Down Expand Up @@ -298,11 +298,13 @@ for rank in (1, 2, 3)
end
end

x = rand(10, 10, 3, 10)
@test size(maxpool(x, (2, 2))) == (5, 5, 3, 10)
@test size(maxpool(x, (2, 2); pad = (1, 1), stride = (2, 2))) == (6, 6, 3, 10)
@test size(meanpool(x, (2, 2))) == (5, 5, 3, 10)
@test size(meanpool(x, (2, 2); pad = (1, 1), stride = (2, 2))) == (6, 6, 3, 10)
@testset "Pooling - Check Sizes" begin
x = rand(10, 10, 3, 10)
@test size(maxpool(x, (2, 2))) == (5, 5, 3, 10)
@test size(maxpool(x, (2, 2); pad = (1, 1), stride = (2, 2))) == (6, 6, 3, 10)
@test size(meanpool(x, (2, 2))) == (5, 5, 3, 10)
@test size(meanpool(x, (2, 2); pad = (1, 1), stride = (2, 2))) == (6, 6, 3, 10)
end

# Add another test for 2d maxpool that uses an odd-length size:
@testset "Issue #133" begin
Expand Down

0 comments on commit 398bfdd

Please sign in to comment.