Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

PGAS example where state consists of a tuple of distributions? #82

Open
YSanchezAraujo opened this issue Sep 22, 2023 · 2 comments
Open

Comments

@YSanchezAraujo
Copy link

YSanchezAraujo commented Sep 22, 2023

I'm wondering if a model like the one I present below is possible? The basic problem is one where the state isn't a single distribution, but a collection of distributions, which all evolve in a Markovian manner. I don't know exactly how this works internally, so the code is based on the assumption that the state is propagated forward from initialization to transition to observation.

n_trials, n_cols = size(X)

Parameters = @NamedTuple begin
    X::Matrix
    lam_lapse_init::Float64
    sigma_set_init::Array{Float64}
    mu_init::Array{Float64}
    n_trials::Int64
    n_cols::Int64
end

mutable struct PF <: AdvancedPS.AbstractStateSpaceModel
    W::Matrix
    lam_lapse::Array
    sigma_set::Matrix
    theta::Parameters
    PF(theta::Parameters) = new(
        zeros(Float64, theta.n_trials, theta.n_cols),
        zeros(Float64, theta.n_trials),
        zeros(Float64, theta.n_trials, theta.n_cols),
        theta
    )
end


function init_step(m::PF)
    return (
        truncated(Normal(m.theta.lam_lapse_init, 0.1), lower=-10),
        truncated(Normal(m.theta.sigma_set_init[1], 0.1), lower=0.),
        truncated(Normal(m.theta.sigma_set_init[2], 0.1), lower=0.),
        truncated(Normal(m.theta.sigma_set_init[3], 0.1), lower=0.),
        MvNormal(m.theta.mu_init, 1.)
    )
end

AdvancedPS.initialization(m::PF) = init_step(m)

function transition_step(m::PF, state)
    return (
        truncated(Normal(state[1], 0.1), lower=-10), # lam_lapse
        truncated(Normal(state[2], 0.1), lower=0.), # sigma1
        truncated(Normal(state[3], 0.1), lower=0.), # sigma2
        truncated(Normal(state[4], 0.1), lower=0.), # sigma3
        MvNormal(state[5], Diagonal([state[2], state[3], state[4]])) # mu
    )
end

AdvancedPS.transition(m::PF, state) = transition_step(m, state)

function obs_density(m::PF, state, t)
    lam_lapse, _, _,_, mu = state
    lapse = logistic(lam_lapse)
    prob = (1 - lapse) * logistic(m.theta.X[t, :]'mu) + lapse * 0.5
    return Bernoulli(prob)
end  

function AdvancedPS.observation(m::PF ,state, t)
    return logpdf(obs_density(m, state, t), y[t])
end

AdvancedPS.isdone(m::PF, t) = t > m.theta.n_trials

n_particles = 20
n_samples = 200
rng = MersenneTwister(2342)

theta0 = Parameters(
    (-9, zeros(3), zeros(3), n_trials, n_cols)
    )

model = PF(theta0)
pgas = AdvancedPS.PGAS(n_particles)
chains = sample(rng, model, pgas, n_samples; progress=true);
@YSanchezAraujo YSanchezAraujo changed the title PGAS example where `state consists of a tuple of distributions? PGAS example where state consists of a tuple of distributions? Sep 22, 2023
@YSanchezAraujo
Copy link
Author

YSanchezAraujo commented Sep 22, 2023

looking at advance function in pgas

it looks like it's not possible in this formulation? In my case:


rand(init_step(model)) # 

# will just give a random element of: 

(
        truncated(Normal(m.theta.lam_lapse_init, 0.1), lower=-10),
        truncated(Normal(m.theta.sigma_set_init[1], 0.1), lower=0.),
        truncated(Normal(m.theta.sigma_set_init[2], 0.1), lower=0.),
        truncated(Normal(m.theta.sigma_set_init[3], 0.1), lower=0.),
        MvNormal(m.theta.mu_init, 1.)
    )

it seems the workout would be to allow for

rand.(init_step(model))

?

@yebai
Copy link
Member

yebai commented Oct 23, 2023

cc @FredericWantiez

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants