-
Notifications
You must be signed in to change notification settings - Fork 31
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
MappedDistribution #12
Comments
Here's my current setup: struct For{F,T,D,X}
f :: F
θ :: T
end where...
Some example use cases: # T = NTuple{2,Int}
x ~ For(10,3) do i,j
Bernoulli(j/i)
end # T = Base.Generator{Base.OneTo{Int64},Base.var"#174#175"{Array{Float64,2}}}
y ~ For(eachrow(X)) do xrow
Normal(xrow' * β, 1)
end We'll have different methods for Also, I currently have the following restrictions:
Currently this targets "array-like" results, but in principle |
I don't think we need a restriction on function logpdf(dist::For, x::AbstractArray)
@assert size(dist.θ) == size(x)
return sum(1:length(dist.θ)) do i
logpdf(dist.f(dist.θ[i]), x[i])
end
end
rand(dist::For) = rand.(dist.f.(dist.θ)) Whether |
eltype(dist::For) = mapreduce(i -> eltype(dist.f(dist.θ[i])), promote_type, 1:length(dist.θ)) |
Note that the above is a dynamically sized distribution. We can also get free specialization and inlining for small, fixed-size distributions when using |
I think for Tracker sum(logpdf.(dist.f.(dist.θ), x)) will be faster than sum(1:length(dist.θ)) do i
logpdf(dist.f(dist.θ[i]), x[i])
end So if either |
The most obvious reason for this is type stability, though there may be ways around that. In addition, the vast majority of models will satisfy this anyway, and it often opens up opportunities for optimization. For example, in cases where One thing I've found a bit tricky is make useful type information available without much computational cost. Unfortunately in Julia, we can't just ask a function about its codomain, so instantiating a eltype(dist::For) = mapreduce(i -> eltype(dist.f(dist.θ[i])), promote_type, 1:length(dist.θ)) is appealing, but would require O(n) instantiation cost. Above I suggested Finally, we had some recent discussion on Discourse about the best approach for parallelism, which will be important for many cases. |
Cleaning this up a bit in Soss, here's the current state: There's also x ~ Normal() |> iid(N) |
Thanks for the PR @cscherrer and sorry for the late review; I was busy the last few weeks. I will review your PR asap. |
Sometimes it is useful to be able to define a multivariate distribution on iid variables by generating distributions on the fly which use different distribution parameters each variable according to a certain rule/function. Defining an efficient
logpdf
and adjoint can give significant computational savings. This is similar to the SossFor
combinator.The text was updated successfully, but these errors were encountered: