Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
CarloLucibello committed Dec 31, 2024
1 parent e59e280 commit 274922d
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 16 deletions.
19 changes: 12 additions & 7 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ using Flux: OneHotArray, OneHotMatrix, OneHotVector
using Test
using Random, Statistics, LinearAlgebra
using IterTools: ncycle
import Optimisers

using Zygote
const gradient = Flux.gradient # both Flux & Zygote export this on 0.15
Expand All @@ -21,18 +22,24 @@ using Functors: fmapstructure_with_path
# ENV["FLUX_TEST_DISTRIBUTED_NCCL"] = "true"
# ENV["FLUX_TEST_ENZYME"] = "false"

const FLUX_TEST_ENZYME = get(ENV, "FLUX_TEST_ENZYME", VERSION < v"1.12-" ? "true" : "false") == "true"
if FLUX_TEST_ENZYME
Pkg.add("Enzyme")
using Enzyme: Enzyme
end

include("test_utils.jl") # for test_gradients

Random.seed!(0)

include("testsuite/normalization.jl")

function flux_testsuite(dev)
@testset "Flux Test Suite" begin
@testset "Normalization" begin
normalization_testsuite(dev)
end
@testset "Flux Test Suite" begin
@testset "Normalization" begin
normalization_testsuite(dev)
end
end
end

@testset verbose=true "Flux.jl" begin
Expand Down Expand Up @@ -157,10 +164,8 @@ end
@info "Skipping Distributed tests, set FLUX_TEST_DISTRIBUTED_MPI or FLUX_TEST_DISTRIBUTED_NCCL=true to run them."
end

if get(ENV, "FLUX_TEST_ENZYME", VERSION < v"1.12-" ? "true" : "false") == "true"
Pkg.add("Enzyme")
if FLUX_TEST_ENZYME
@testset "Enzyme" begin
import Enzyme
include("ext_enzyme/enzyme.jl")
end
else
Expand Down
11 changes: 2 additions & 9 deletions test/train.jl
Original file line number Diff line number Diff line change
@@ -1,18 +1,11 @@
using Flux
# using Flux.Train
import Optimisers

using Test
using Random
import Enzyme

function train_enzyme!(fn, model, args...; kwargs...)
Flux.train!(fn, Enzyme.Duplicated(model, Enzyme.make_zero(model)), args...; kwargs...)
end

for (trainfn!, name) in ((Flux.train!, "Zygote"), (train_enzyme!, "Enzyme"))

if (name == "Enzyme" && get(ENV, "FLUX_TEST_ENZYME", "true") == "false")
if name == "Enzyme" && FLUX_TEST_ENZYME
continue
end

Expand Down Expand Up @@ -50,7 +43,7 @@ end
for (trainfn!, name) in ((Flux.train!, "Zygote"), (train_enzyme!, "Enzyme"))
# TODO reinstate Enzyme
name == "Enzyme" && continue
# if (name == "Enzyme" && get(ENV, "FLUX_TEST_ENZYME", "true") == "false")
# if name == "Enzyme" && FLUX_TEST_ENZYME
# continue
# end

Expand Down

0 comments on commit 274922d

Please sign in to comment.