Skip to content

Commit

Permalink
Merge pull request #226 from ReactiveBayes/logmean
Browse files Browse the repository at this point in the history
Implement log mean for tensordirichlet
  • Loading branch information
bvdmitri authored Jan 20, 2025
2 parents 7a5d361 + a5c0164 commit 9f06eb4
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 0 deletions.
2 changes: 2 additions & 0 deletions src/distributions/tensor_dirichlet.jl
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,8 @@ getlogbasemeasure(::Type{TensorDirichlet}, conditioner) = (x) -> zero(Float64)
getsufficientstatistics(::Type{TensorDirichlet}, conditioner) = (x -> vmap(log, x),)

BayesBase.mean(dist::TensorDirichlet) = dist.a ./ dist.α0
BayesBase.mean(::BroadcastFunction{typeof(log)}, dist::TensorDirichlet) = digamma.(dist.a) .- digamma.(dist.α0)

function BayesBase.cov(dist::TensorDirichlet{T}) where {T}
s = size(dist.a)
news = (first(s), first(s), Base.tail(s)...)
Expand Down
22 changes: 22 additions & 0 deletions test/distributions/tensor_dirichlet_test.jl
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,28 @@ end
end
end

@testitem "TensorDirichlet: logmean" begin
include("distributions_setuptests.jl")

for rank in (3, 5)
for d in (2, 5, 10)
for _ in 1:10
alpha = rand([d for _ in 1:rank]...)

distribution = TensorDirichlet(alpha)
mat_of_dir = Dirichlet.(eachslice(alpha, dims = Tuple(2:rank)))

temp = mean.(Base.Broadcast.BroadcastFunction(log), mat_of_dir)
mat_mean = similar(alpha)
for i in CartesianIndices(Base.tail(size(alpha)))
mat_mean[:, i] = temp[i]
end
@test mean(Base.Broadcast.BroadcastFunction(log), distribution) mat_mean
end
end
end
end

@testitem "TensorDirichlet: std" begin
include("distributions_setuptests.jl")

Expand Down

0 comments on commit 9f06eb4

Please sign in to comment.