From b59df612d59ed0b9a57e2d1c6b5e788d850f22ee Mon Sep 17 00:00:00 2001 From: qfl3x Date: Tue, 19 Apr 2022 14:41:09 +0100 Subject: [PATCH] Attempt to fix the compatibility issue. --- examples/semisupervised_gcn.jl | 25 +++++++++++++------------ 1 file changed, 13 insertions(+), 12 deletions(-) diff --git a/examples/semisupervised_gcn.jl b/examples/semisupervised_gcn.jl index 2030f3af9..899731be5 100644 --- a/examples/semisupervised_gcn.jl +++ b/examples/semisupervised_gcn.jl @@ -42,22 +42,23 @@ end target_dim = 7 # target dimension end -## Loss: cross entropy with first layer L2 regularization +## Loss: cross entropy with first layer L2 regularization l2norm(x) = sum(abs2, x) -function model_loss(model, λ, batch) + +function model_loss(model, λ, batch, batch_size::Int) loss = 0.f0 - for (x, y) in batch + for (x, y) in [[batch[1][i], batch[2][:,:,i]] for i = 1:batch_size] loss += logitcrossentropy(model(x), y) loss += λ*sum(l2norm, Flux.params(model[1])) end return loss end -function accuracy(model, batch::AbstractVector) - return mean(mean(onecold(softmax(cpu(model(x)))) .== onecold(cpu(y))) for (x, y) in batch) +function accuracy(model, batch::Tuple{AbstractVector, AbstractArray}, batch_size::Int) + return mean(mean(onecold(softmax(cpu(model(x)))) .== onecold(cpu(y))) for (x,y) in [[batch[1][i], batch[2][:,:,i]] for i = 1:batch_size]) end -accuracy(model, loader::DataLoader, device) = mean(accuracy(model, batch |> device) for batch in loader) +accuracy(model, loader::DataLoader, device, batch_size::Int) = mean(accuracy(model, batch |> device, batch_size) for batch in loader) function train(; kws...) # load hyperparamters @@ -75,7 +76,7 @@ function train(; kws...) # load Cora from Planetoid dataset train_loader, test_loader = load_data(:cora, args.batch_size) - + # build model model = Chain( GCNConv(args.input_dim=>args.hidden_dim, relu), @@ -86,7 +87,7 @@ function train(; kws...) # ADAM optimizer opt = ADAM(args.η) - + # parameters ps = Flux.params(model) @@ -96,13 +97,12 @@ function train(; kws...) for epoch = 1:args.epochs @info "Epoch $(epoch)" progress = Progress(length(train_loader)) - for batch in train_loader loss, back = Flux.pullback(ps) do - model_loss(model, args.λ, batch |> device) + model_loss(model, args.λ, batch |> device, args.batch_size) end - train_acc = accuracy(model, train_loader, device) - test_acc = accuracy(model, test_loader, device) + train_acc = accuracy(model, train_loader, device, args.batch_size) + test_acc = accuracy(model, test_loader, device, args.batch_size) grad = back(1f0) Flux.Optimise.update!(opt, ps, grad) @@ -115,6 +115,7 @@ function train(; kws...) train_steps += 1 end + end return model, args