Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added groupwiseconv and modified depthwise conv for common interface #948

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/Flux.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ using Zygote: Params, @adjoint, gradient, pullback, @nograd
export gradient

export Chain, Dense, Maxout, RNN, LSTM, GRU, Conv, CrossCor, ConvTranspose, MaxPool, MeanPool,
DepthwiseConv, Dropout, AlphaDropout, LayerNorm, BatchNorm, InstanceNorm, GroupNorm,
DepthwiseConv, GroupwiseConv, Dropout, AlphaDropout, LayerNorm, BatchNorm, InstanceNorm, GroupNorm,
SkipConnection, params, fmap, cpu, gpu, f32, f64

include("optimise/Optimise.jl")
Expand Down
112 changes: 101 additions & 11 deletions src/layers/conv.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
using NNlib: conv, ∇conv_data, depthwiseconv
using NNlib: conv, ∇conv_data

expand(N, i::Tuple) = i
expand(N, i::Integer) = ntuple(_ -> i, N)
Expand Down Expand Up @@ -51,6 +51,7 @@ function (c::Conv)(x::AbstractArray)
# TODO: breaks gpu broadcast :(
# ndims(x) == ndims(c.weight)-1 && return squeezebatch(c(reshape(x, size(x)..., 1)))
σ, b = c.σ, reshape(c.bias, map(_->1, c.stride)..., :, 1)
@show b
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ignore this. I will remove it.

cdims = DenseConvDims(x, c.weight; stride=c.stride, padding=c.pad, dilation=c.dilation)
σ.(conv(x, c.weight, cdims) .+ b)
end
Expand Down Expand Up @@ -160,41 +161,52 @@ struct DepthwiseConv{N,M,F,A,V}
stride::NTuple{N,Int}
pad::NTuple{M,Int}
dilation::NTuple{N,Int}
groupcount::Int
end

# TODO groupcount should be inferred.
function DepthwiseConv(w::AbstractArray{T,N}, b::AbstractVector{T}, σ = identity;
stride = 1, pad = 0, dilation = 1) where {T,N}
stride = 1, pad = 0, dilation = 1, groupcount = 1) where {T,N}
stride = expand(Val(N-2), stride)
pad = expand(Val(2*(N-2)), pad)
dilation = expand(Val(N-2), dilation)
return DepthwiseConv(σ, w, b, stride, pad, dilation)
return DepthwiseConv(σ, w, b, stride, pad, dilation, groupcount)
end

function DepthwiseConv(k::NTuple{N,Integer}, ch::Pair{<:Integer,<:Integer}, σ = identity;
init = glorot_uniform, stride = 1, pad = 0, dilation = 1) where N
@assert ch[2] % ch[1] == 0 "Output channels must be integer multiple of input channels"
init = glorot_uniform, stride = 1, pad = 0, dilation = 1, groupcount=1) where N
@assert ch[2] % groupcount == 0 "Output channels must be integer multiple of input channels"
@assert ch[1] % groupcount == 0 "Input channels must be interger multiples of groupcount"
return DepthwiseConv(
init(k..., div(ch[2], ch[1]), ch[1]),
init(k..., div(ch[1], groupcount), ch[2]),
zeros(ch[2]),
σ;
stride = stride,
pad = pad,
dilation = dilation
dilation = dilation,
groupcount = groupcount
)
end

@functor DepthwiseConv

# TODO may not necessary
function depthwiseconv(x, w, ddims::DenseConvDims)
ddims = DenseConvDims(ddims)
return conv(x, w, ddims)
end

function (c::DepthwiseConv)(x)
σ, b = c.σ, reshape(c.bias, map(_->1, c.stride)..., :, 1)
cdims = DepthwiseConvDims(x, c.weight; stride=c.stride, padding=c.pad, dilation=c.dilation)
σ.(depthwiseconv(x, c.weight, cdims) .+ b)
cdims = DenseConvDims(x, c.weight; stride=c.stride, padding=c.pad, dilation=c.dilation, groupcount=c.groupcount)
σ.(conv(x, c.weight, cdims) .+ b)
end

function Base.show(io::IO, l::DepthwiseConv)
print(io, "DepthwiseConv(", size(l.weight)[1:end-2])
print(io, ", ", size(l.weight)[end], "=>", prod(size(l.weight)[end-1:end]))
print(io, "DepthwiseConv(", size(l.weight, ndims(l.weight)-2))
print(io, ", ", size(l.weight, ndims(l.weight)-1)*l.groupcount, "=>", size(l.weight, ndims(l.weight)))
l.σ == identity || print(io, ", ", l.σ)
l.groupcount == 1 || print(io, ", groupcount = ", l.groupcount)
print(io, ")")
end

Expand All @@ -204,6 +216,84 @@ end
(a::DepthwiseConv{<:Any,<:Any,W})(x::AbstractArray{<:Real}) where {T <: Union{Float32,Float64}, W <: AbstractArray{T}} =
a(T.(x))


"""
GroupwiseConv(size, in=>out)
GroupwiseConv(size, in=>out, relu)

Groupwise convolutional layer. `size` should be a tuple like `(2, 2)`.
`in` and `out` specify the number of input and output channels respectively.
Note that `out` must be an integer multiple of `in`.

Data should be stored in WHCN order. In other words, a 100×100 RGB image would
be a `100×100×3` array, and a batch of 50 would be a `100×100×3×50` array.

Takes the keyword arguments `pad`, `stride` and `dilation`.
"""
struct GroupwiseConv{N,M,F,A,V}
σ::F
weight::A
bias::V
stride::NTuple{N,Int}
pad::NTuple{M,Int}
dilation::NTuple{N,Int}
groupcount::Int
end

# TODO groupcount should be mandatory
function GroupwiseConv(w::AbstractArray{T,N}, b::AbstractVector{T}, σ = identity;
stride = 1, pad = 0, dilation = 1, groupcount = 1) where {T,N}
stride = expand(Val(N-2), stride)
pad = expand(Val(2*(N-2)), pad)
dilation = expand(Val(N-2), dilation)
return GroupwiseConv(σ, w, b, stride, pad, dilation, groupcount)
end

function GroupwiseConv(k::NTuple{N,Integer}, ch::Pair{<:Integer,<:Integer}, σ = identity;
init = glorot_uniform, stride = 1, pad = 0, dilation = 1, groupcount=1) where N
@assert ch[2] % groupcount == 0 "Output channels must be integer multiple of input channels"
@assert ch[1] % groupcount == 0 "Input channels must be interger multiples of groupcount"
return GroupwiseConv(
init(k..., div(ch[1], groupcount), ch[2]),
zeros(ch[2]),
σ;
stride = stride,
pad = pad,
dilation = dilation,
groupcount = groupcount
)
end

@functor GroupwiseConv

# TODO may not necessary
function groupwiseconv(x, w, ddims::DenseConvDims)
ddims = DenseConvDims(ddims)
return conv(x, w, ddims)
end

function (c::GroupwiseConv)(x)
σ, b = c.σ, reshape(c.bias, map(_->1, c.stride)..., :, 1)
@info b, c.bias
cdims = DenseConvDims(x, c.weight; stride=c.stride, padding=c.pad, dilation=c.dilation, groupcount=c.groupcount)
σ.(conv(x, c.weight, cdims) .+ b)
end

function Base.show(io::IO, l::GroupwiseConv)
print(io, "GroupwiseConv(", size(l.weight, ndims(l.weight)-2))
print(io, ", ", size(l.weight, ndims(l.weight)-1)*l.groupcount, "=>", size(l.weight, ndims(l.weight)))
l.σ == identity || print(io, ", ", l.σ)
l.groupcount == 1 || print(io, ", groupcount = ", l.groupcount)
print(io, ")")
end

(a::GroupwiseConv{<:Any,<:Any,W})(x::AbstractArray{T}) where {T <: Union{Float32,Float64}, W <: AbstractArray{T}} =
invoke(a, Tuple{AbstractArray}, x)

(a::GroupwiseConv{<:Any,<:Any,W})(x::AbstractArray{<:Real}) where {T <: Union{Float32,Float64}, W <: AbstractArray{T}} =
a(T.(x))


"""
CrossCor(size, in=>out)
CrossCor(size, in=>out, relu)
Expand Down