Skip to content

Commit

Permalink
Add parallel keyword to GMM (#114)
Browse files Browse the repository at this point in the history
  • Loading branch information
fhagemann authored Dec 11, 2024
1 parent 9fd2dd0 commit 0151911
Showing 1 changed file with 10 additions and 10 deletions.
20 changes: 10 additions & 10 deletions src/train.jl
Original file line number Diff line number Diff line change
Expand Up @@ -28,19 +28,19 @@ GMM(x::Vector{T}) where T <: AbstractFloat = GMM(reshape(x, length(x), 1)) # st

## constructors based on data or matrix
function GMM(n::Int, x::DataOrMatrix{T}; method::Symbol=:kmeans, kind=:diag,
nInit::Int=50, nIter::Int=10, nFinal::Int=nIter, sparse=0) where T <: AbstractFloat
nInit::Int=50, nIter::Int=10, nFinal::Int=nIter, sparse=0, parallel::Bool=true) where T <: AbstractFloat
if n < 2
return GMM(x, kind=kind)
elseif method == :split
return GMM2(n, x, kind=kind, nIter=nIter, nFinal=nFinal, sparse=sparse)
return GMM2(n, x, kind=kind, nIter=nIter, nFinal=nFinal, sparse=sparse, parallel=parallel)
elseif method == :kmeans
return GMMk(n, x, kind=kind, nInit=nInit, nIter=nIter, sparse=sparse)
return GMMk(n, x, kind=kind, nInit=nInit, nIter=nIter, sparse=sparse, parallel=parallel)
else
error("Unknown method ", method)
end
end
## a 1-dimensional Gaussian can be initialized with a vector, skip kind=
GMM(n::Int, x::Vector{T}; method::Symbol=:kmeans, nInit::Int=50, nIter::Int=10, nFinal::Int=nIter, sparse=0) where T <: AbstractFloat = GMM(n, reshape(x, length(x), 1); method=method, kind=:diag, nInit=nInit, nIter=nIter, nFinal=nFinal, sparse=sparse)
GMM(n::Int, x::Vector{T}; method::Symbol=:kmeans, nInit::Int=50, nIter::Int=10, nFinal::Int=nIter, sparse=0, parallel::Bool=true) where T <: AbstractFloat = GMM(n, reshape(x, length(x), 1); method=method, kind=:diag, nInit=nInit, nIter=nIter, nFinal=nFinal, sparse=sparse, parallel=parallel)

## we sometimes end up with pathological gmms
function sanitycheck!(gmm::GMM)
Expand Down Expand Up @@ -73,7 +73,7 @@ end


## initialize GMM using Clustering.kmeans (which uses a method similar to kmeans++)
function GMMk(n::Int, x::DataOrMatrix{T}; kind=:diag, nInit::Int=50, nIter::Int=10, sparse=0) where T <: AbstractFloat
function GMMk(n::Int, x::DataOrMatrix{T}; kind=:diag, nInit::Int=50, nIter::Int=10, sparse=0, parallel::Bool=true) where T <: AbstractFloat
nₓ, d = size(x)
hist = [History(@sprintf("Initializing GMM, %d Gaussians %s covariance %d dimensions using %d data points", n, diag, d, nₓ))]
@info(last(hist).s)
Expand Down Expand Up @@ -141,22 +141,22 @@ function GMMk(n::Int, x::DataOrMatrix{T}; kind=:diag, nInit::Int=50, nIter::Int=
@info(last(hist).s)
gmm = GMM(w, μ, Σ, hist, nxx)
sanitycheck!(gmm)
em!(gmm, x; nIter=nIter, sparse=sparse)
em!(gmm, x; nIter=nIter, sparse=sparse, parallel=parallel)
return gmm
end

## Train a GMM by consecutively splitting all means. n most be a power of 2
## This kind of initialization is deterministic, but doesn't work particularily well, its seems
## We start with one Gaussian, and consecutively split.
function GMM2(n::Int, x::DataOrMatrix; kind=:diag, nIter::Int=10, nFinal::Int=nIter, sparse=0)
function GMM2(n::Int, x::DataOrMatrix; kind=:diag, nIter::Int=10, nFinal::Int=nIter, sparse=0, parallel::Bool=true)
log2n = round(Int,log2(n))
2^log2n == n || error("n must be power of 2")
gmm = GMM(x, kind=kind)
tll = [avll(gmm, x)]
@info("0: avll = ", tll[1])
for i in 1:log2n
gmm = gmmsplit(gmm)
avll = em!(gmm, x; nIter=(i==log2n ? nFinal : nIter), sparse=sparse)
avll = em!(gmm, x; nIter=(i==log2n ? nFinal : nIter), sparse=sparse, parallel=parallel)
@info(i, avll)
append!(tll, avll)
end
Expand Down Expand Up @@ -235,7 +235,7 @@ end
# the log-likelihood history, per data frame per dimension
## Note: 0 iterations is allowed, this just computes the average log likelihood
## of the data and stores this in the history.
function em!(gmm::GMM, x::DataOrMatrix; nIter::Int = 10, varfloor::Float64=1e-3, sparse=0, debug=1)
function em!(gmm::GMM, x::DataOrMatrix; nIter::Int = 10, varfloor::Float64=1e-3, sparse=0, parallel::Bool=true, debug=1)
size(x,2)==gmm.d || error("Inconsistent size gmm and x")
d = gmm.d # dim
ng = gmm.n # n gaussians
Expand All @@ -247,7 +247,7 @@ function em!(gmm::GMM, x::DataOrMatrix; nIter::Int = 10, varfloor::Float64=1e-3,

for i in 1:nIter
## E-step
nₓ, ll[i], N, F, S = stats(gmm, x, parallel=true)
nₓ, ll[i], N, F, S = stats(gmm, x, parallel=parallel)
## M-step
gmm.w = N / nₓ
gmm.μ = F ./ N
Expand Down

0 comments on commit 0151911

Please sign in to comment.