-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmobilenet.jl
36 lines (26 loc) · 1.19 KB
/
mobilenet.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
using Pkg # hideall
Pkg.activate("./Project.toml")
Pkg.instantiate()
include("./src/setup.jl");
artifacts = "./Artifacts.toml"
ensure_artifact_installed("mobilenet", artifacts)
mobilenet = artifact_hash("mobilenet", artifacts)
modelpath = joinpath(artifact_path(mobilenet), "mobilenet.bson")
model = BSON.load(modelpath, @__MODULE__)[:m];
ensure_artifact_installed("vww", artifacts)
vwwdata = artifact_hash("vww", artifacts)
dataroot = joinpath(artifact_path(vwwdata), "vww-hackathon")
valdata = VisualWakeWords(dataroot; subset = :val)
valaug = map_augmentation(ImageToTensor(), valdata)
valloader = DataLoader(BatchView(valaug; batchsize = 32), nothing; buffered = true)
accfn(ŷ::AbstractArray, y::AbstractArray) = mean((ŷ .> 0) .== y)
accfn(data, model) = mean(accfn(model(x), y) for (x, y) in data)
accfn(valloader, model)
model_scaled, scalings = prepare_bitstream_model(model)
@show total_scaling = prod(prod.(scalings))
model_scaled
simulation_length = 1000
add_conversion_error!(model_scaled, simulation_length);
model_rescaled = Chain(model_scaled, x -> x .* total_scaling)
accfn(valloader, model_rescaled)
# This file was generated using Literate.jl, https://github.com/fredrikekre/Literate.jl