Skip to content

Commit

Permalink
Pass seed through in dha multiprocessing example.
Browse files Browse the repository at this point in the history
Fixes Github issue #104.
  • Loading branch information
riastradh-probcomp committed Jul 21, 2016
1 parent 6bd5d4f commit 2de0192
Showing 1 changed file with 10 additions and 3 deletions.
13 changes: 10 additions & 3 deletions examples/dha_example_multiprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import argparse
from multiprocessing import Pool
import os
import random
#
import numpy
#
Expand All @@ -44,6 +45,9 @@
num_chains = args.num_chains
num_transitions = args.num_transitions
#
rng = random.Random(gen_seed)
get_next_seed = lambda: rng.randint(1, 2**31 - 1)
#
pkl_filename = 'dha_example_num_transitions_%s.pkl.gz' % num_transitions


Expand Down Expand Up @@ -75,8 +79,10 @@ def determine_unobserved_Y(num_rows, M_c, condition_tuples):

# run the chains
engine = MultiprocessingEngine.MultiprocessingEngine()
X_L_list, X_D_list = engine.initialize(M_c, M_r, T, n_chains=num_chains)
X_L_list, X_D_list = engine.analyze(M_c, T, X_L_list, X_D_list)
X_L_list, X_D_list = engine.initialize(
M_c, M_r, T, n_chains=num_chains, seed=get_next_seed())
X_L_list, X_D_list = engine.analyze(
M_c, T, X_L_list, X_D_list, seed=get_next_seed())

# save the progress
to_pickle = dict(X_L_list=X_L_list, X_D_list=X_D_list)
Expand Down Expand Up @@ -115,7 +121,8 @@ def determine_unobserved_Y(num_rows, M_c, condition_tuples):
impute_names = [col_names[impute_col]]
Q = determine_Q(M_c, impute_names, num_rows, impute_row=impute_row)
#
imputed = engine.impute(M_c, X_L_list, X_D_list, Y, Q, 1000)
imputed = engine.impute(
M_c, X_L_list, X_D_list, Y, Q, seed=get_next_seed(), n=1000)
imputed_list.append(imputed)
print
print actual_values
Expand Down

0 comments on commit 2de0192

Please sign in to comment.