Skip to content

Commit

Permalink
Add EmbeddingBag (#2031)
Browse files Browse the repository at this point in the history
* embedding bag

* doc fix

* Apply suggestions from code review

Co-authored-by: Carlo Lucibello <[email protected]>

* Remove references to `Statistics`

Statistics is imported by Flux so we can just call `mean` rather than `Statistics.mean`.

* non mutating bag and onehot changes

* better docs and todo

* input/offset docs

* doctest

* Apply suggestions from code review

Co-authored-by: Kyle Daruwalla <[email protected]>
Co-authored-by: Michael Abbott <[email protected]>

* reduce docs

* broadcast to map

* remove extra doc example line

* add _splitat

* rename input/offset

* minor docs

* Apply suggestions from code review

* Update test/layers/basic.jl

* Update test/layers/basic.jl

* Update test/layers/basic.jl

* typo

* docstring

* Apply suggestions from code review

---------

Co-authored-by: Carlo Lucibello <[email protected]>
Co-authored-by: Kyle Daruwalla <[email protected]>
Co-authored-by: Michael Abbott <[email protected]>
  • Loading branch information
4 people authored Apr 18, 2023
1 parent ccf87bb commit dfea43c
Show file tree
Hide file tree
Showing 4 changed files with 225 additions and 1 deletion.
1 change: 1 addition & 0 deletions docs/src/models/layers.md
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@ These layers accept an index, and return a vector (or several indices, and sever

```@docs
Flux.Embedding
Flux.EmbeddingBag
```

## [Dataflow Layers, or Containers](@id man-dataflow-layers)
Expand Down
148 changes: 148 additions & 0 deletions src/layers/basic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -716,3 +716,151 @@ Embedding((in, out)::Pair{<:Integer, <:Integer}; init = randn32) = Embedding(ini
function Base.show(io::IO, m::Embedding)
print(io, "Embedding(", size(m.weight, 2), " => ", size(m.weight, 1), ")")
end


"""
_splitat(data::AbstractVector, at::AbstractVector{Int})
Partitions `data` into a vector of views.
Each index `i in at` specifies that a view starts with `data[i]`.
These indices must be strictly increasing, and start at `1`.
The resulting views do not overlap, and are never empty.
The last view always ends with `data[end]`.
### Example
```jldoctest
julia> Flux._splitat(collect('A':'Z'), [1, 3, 4, 13])
4-element Vector{SubArray{Char, 1, Vector{Char}, Tuple{UnitRange{Int64}}, true}}:
['A', 'B']
['C']
['D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L']
['M', 'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z']
```
"""
function _splitat(data::AbstractVector, at::AbstractVector{<:Integer})
at[begin] == firstindex(data) || throw(ArgumentError("The first element in `at` must be 1."))
at[end] <= lastindex(data) || throw(ArgumentError("The last element in `at` must be at most the length of `data`."))
issorted(at, lt = <=) || throw(ArgumentError("`at` must be monotonically increasing with no duplicates."))
iplus = vcat(at, lastindex(data)+1)
return [view(data, iplus[n]:(iplus[n+1]-1)) for n in eachindex(at)]
end

"""
EmbeddingBag(in => out, reduction=mean; init=Flux.randn32)
A lookup table that stores embeddings of dimension `out` for a vocabulary of size `in`.
Differs from [`Embedding`](@ref) in that, instead of acting on a single vocabulary index,
it always acts a vector of indices which it calls a "bag".
Their individual embedding vectors are reduced to one, using `mean` or some other function.
Instead of acting on one "bag", such as `x::Vector{Int}`, the layer can also act on several:
* Acting on a vector of "bags", it produces a matrix whose columns are the reduced vectors.
More generally on `x::Array{Vector{Int}}`, its output is of size `(out, size(x)...)`.
* Any higher-rank array of integers is interpreted as a collection of "bags" each along the first dimension.
Thus the output is `mapslices(e, x; dims=1)` when `e::EmbeddingBag` and `x::Array{Int,N}`.
This method is more efficient, but requires that all "bags" have the same length.
* A vector of "bags" may also be produced by splitting a vector of indices at specified points.
For this case the layer takes two inputs, both vectors of integers. See details below.
The "bag" may equivalently be represented as a `OneHotMatrix`. A collection of these,
or one higher-rank `OneHotArray`, again produce a stack of embeddings. See details below.
# Examples
```jldoctest
julia> vocab_size = 26; # embed into 3 dimensions, with non-random vectors:
julia> eb = EmbeddingBag(vocab_size => 3, init=Flux.identity_init(gain=100))
EmbeddingBag(26 => 3) # 78 parameters
julia> eb([2]) # one bag of 1 item
3-element Vector{Float32}:
0.0
100.0
0.0
julia> eb([3,3,1]) # one bag of 3 items, one mean embedding
3-element Vector{Float32}:
33.333332
0.0
66.666664
julia> eb([[3,1,3], [2,1]]) # two bags
3×2 Matrix{Float32}:
33.3333 50.0
0.0 50.0
66.6667 0.0
julia> eb([1 1 1 1; 1 2 3 4]) # 4 bags each of 2 items, eachcol([1 1 1 1; 1 2 3 4])
3×4 Matrix{Float32}:
100.0 50.0 50.0 50.0
0.0 50.0 0.0 0.0
0.0 0.0 50.0 0.0
julia> eb(rand(1:26, 10, 5, 5)) |> size # 25 bags each of 10 items
(3, 5, 5)
```
Another way to specify "many bags of many items" is to provide a vector `data` (each in `1:in`)
and a vector `at` stating where to split that up into "bags".
The first bag starts with `data[at[1]]`, the second at `data[at[2]]`, and so on,
with no overlaps and nothing left out (thus it requires `at[1]==1`).
```jldoctest
julia> data = [11, 1, 12, 2, 13, 3, 14];
julia> Flux._splitat(data, [1, 4]) |> println # internal function, makes data[1:3], data[4:end]
[[11, 1, 12], [2, 13, 3, 14]]
julia> eb(data, [1, 4]) # two bags, of 3 and 4 items
3×2 Matrix{Float32}:
33.3333 0.0
0.0 25.0
0.0 25.0
```
Finally, each bag may also be also be represented as a [`OneHotMatrix`](@ref OneHotArrays.onehotbatch).
```jldoctest
julia> eb(Flux.onehotbatch("bba", 'a':'z')) # same as [2,2,1], one bag of 3 items
3-element Vector{Float32}:
33.333332
66.666664
0.0
julia> eb([Flux.onehotbatch("bba", 'a':'z'), Flux.onehotbatch("cc", 'a':'z')]) # two bags
3×2 Matrix{Float32}:
33.3333 0.0
66.6667 0.0
0.0 100.0
```
"""
struct EmbeddingBag{F, W<:AbstractMatrix}
weight::W
reduction::F
end

@functor EmbeddingBag

EmbeddingBag((in, out)::Pair{<:Integer, <:Integer}, reduction::Function = mean; init = randn32) = EmbeddingBag(init(out, in), reduction)
EmbeddingBag(weight::AbstractMatrix) = EmbeddingBag(weight, mean)

(m::EmbeddingBag)(data::AbstractVector, at::AbstractVector) = m(_splitat(data, at))
(m::EmbeddingBag)(inds::AbstractArray{<:Integer}) = dropdims(m.reduction(Embedding(m.weight)(inds), dims=2), dims=2)
(m::EmbeddingBag)(ind::Integer) = error("EmbeddingBag expects an array of indices, not just one")

(m::EmbeddingBag)(hot::AbstractArray{Bool}) = dropdims(m.reduction(Embedding(m.weight)(hot), dims=2), dims=2)
(m::EmbeddingBag)(hot::AbstractVector{Bool}) = error("EmbeddingBag not defined for a one-hot vector")

# These two could be stack(m, bags), but no AD support yet. (Gradient for weight quite inefficient here.)
(m::EmbeddingBag)(bags::AbstractVector{<:AbstractVector}) = reduce(hcat, m.(bags))
(m::EmbeddingBag)(bags::AbstractArray{<:AbstractVector}) = reshape(m(vec(bags)), :, size(bags)...)

(m::EmbeddingBag)(bags::AbstractArray{<:AbstractMatrix{Bool}}) = reshape(reduce(hcat, m.(vec(bags))), :, size(bags)...)

function Base.show(io::IO, m::EmbeddingBag)
print(io, "EmbeddingBag(", size(m.weight, 2), " => ", size(m.weight, 1), ")")
end
2 changes: 1 addition & 1 deletion src/layers/show.jl
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ _show_children(p::Parallel) = (p.connection, p.layers...)
_show_children(f::PairwiseFusion) = (f.connection, f.layers...)

for T in [
:Conv, :ConvTranspose, :CrossCor, :Dense, :Scale, :Bilinear, :Embedding,
:Conv, :ConvTranspose, :CrossCor, :Dense, :Scale, :Bilinear, :Embedding, :EmbeddingBag,
:BatchNorm, :LayerNorm, :InstanceNorm, :GroupNorm,
]
@eval function Base.show(io::IO, m::MIME"text/plain", x::$T)
Expand Down
75 changes: 75 additions & 0 deletions test/layers/basic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -338,6 +338,81 @@ import Flux: activations
y3 = m(x3)
@test size(y3) == (embed_size, 3, 4)
end

@testset "EmbeddingBag" begin

# test _splitat
data = [1, 2, 3, 4, 5, 6, 7, 8, 9]
offsets_good = [1, 3, 6]
offsets_each = [1,2,3,4,5,6,7,8,9]
offsets_just_one = [1]
offsets_all_but_last = [1, 9]

@test Flux._splitat(data, offsets_good) == [[1, 2], [3, 4, 5], [6, 7, 8, 9]]
@test Flux._splitat(data, offsets_each) == [[1], [2], [3], [4], [5], [6], [7], [8], [9]]
@test Flux._splitat(data, offsets_just_one) == [[1,2,3,4,5,6,7,8,9]]
@test Flux._splitat(data, offsets_all_but_last) == [[1,2,3,4,5,6,7,8], [9]]

offsets_non_monotonic = [1, 2, 2, 5]
offsets_non_sorted = [1, 5, 2]
offsets_non_one = [2, 3, 5]
offsets_too_large = [1, 5, 11]

@test_throws ArgumentError Flux._splitat(data, offsets_non_monotonic)
@test_throws ArgumentError Flux._splitat(data, offsets_non_sorted)
@test_throws ArgumentError Flux._splitat(data, offsets_non_one)
@test_throws ArgumentError Flux._splitat(data, offsets_too_large)

@testset for reduction in [sum, Statistics.mean, maximum]
vocab_size, embed_size = 10, 4
emb_bag = Flux.EmbeddingBag(vocab_size => embed_size, reduction)
emb = Flux.Embedding(emb_bag.weight)
@test size(emb_bag.weight) == (embed_size, vocab_size)
@test_throws ErrorException emb_bag(2)

# single bag (input as a vector)
x = rand(1:vocab_size, 3)
y = emb_bag(x)
z = vec(reduction(emb(x), dims=2))
@test y isa Vector{Float32}
@test y z

# PyTorch style `input`/`offset` bagging
@test emb_bag([1,3,2,4,5,7], [1,3,5]) emb_bag([[1,3], [2,4], [5,7]])
@test emb_bag([1,3,2,4,5,7], [1,3,5]) emb_bag([1 2 5; 3 4 7])
@test_throws ArgumentError emb_bag([1,2,3,4,5,6], [2, 4])
@test_throws ArgumentError emb_bag([1,2,3,4,5,6], [1, 12])

# docstring example
@test emb_bag([1,2,3,4,5,6,7,8,9,10], [1,5,6,8]) emb_bag([[1,2,3,4], [5], [6,7], [8,9,10]])

# multiple bags (input as a vector of vectors)
x = [rand(1:vocab_size, 3) for _ in 1:4]
y = emb_bag(x)
z = reduce(hcat, reduction.(emb.(x), dims=2))
@test y isa Matrix{Float32}
@test y z

# multiple bags (input as a matrix)
x = rand(1:vocab_size, (3, 5))
xvec = collect(eachcol(x))
y = emb_bag(x)
z = reduce(hcat, reduction.(emb.(xvec), dims=2))
@test y emb_bag(xvec)
@test y z

# a one-hot matrix is a bag, but a one-hot vector is not.
@test_throws ErrorException emb_bag(Flux.OneHotVector(3, vocab_size))

i2 = rand(1:vocab_size, 3)
x2 = Flux.OneHotMatrix(i2, vocab_size)
y2 = emb_bag(x2)
z2 = emb(i2)
@test y2 isa Vector{Float32}
@test y2 vec(reduction(z2, dims=2))
@test_throws DimensionMismatch emb_bag(Flux.OneHotMatrix(1:5, 1000))
end
end
end

@testset "second derivatives" begin
Expand Down

0 comments on commit dfea43c

Please sign in to comment.