Skip to content

Commit

Permalink
[pre-commit.ci] auto fixes from pre-commit.com hooks
Browse files Browse the repository at this point in the history
for more information, see https://pre-commit.ci
  • Loading branch information
pre-commit-ci[bot] authored and zachjweiner committed Nov 4, 2023
1 parent cf5c9d4 commit ab63bf4
Showing 1 changed file with 13 additions and 9 deletions.
22 changes: 13 additions & 9 deletions src/emcee/ensemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,17 +35,20 @@ def infer_dict_mapping(state):

def array_to_dict(ary, param_slice_shape):
return {
key: ary[:, slc].reshape((-1,)+shape)
key: ary[:, slc].reshape((-1,) + shape)
for key, (slc, shape) in param_slice_shape.items()
}


def array_to_list_of_dicts(ary, param_slice_shape):
# reshape adds a small amount of overhead; don't do it unless necessary
return [{
key: ary_i[slc].reshape(shape) if len(shape) > 1 else ary_i[slc]
for key, (slc, shape) in param_slice_shape.items()
} for ary_i in ary]
return [
{
key: ary_i[slc].reshape(shape) if len(shape) > 1 else ary_i[slc]
for key, (slc, shape) in param_slice_shape.items()
}
for ary_i in ary
]


def collapse_and_hstack(values, nwalkers=None):
Expand Down Expand Up @@ -199,15 +202,15 @@ def __init__(
if isinstance(parameter_names, Sequence):
if len(parameter_names) != ndim:
raise ValueError(
f"`parameter_names` does not specify {ndim} names")
f"`parameter_names` does not specify {ndim} names"
)
parameter_names = dict(zip(parameter_names, range(ndim)))

indices = np.arange(ndim)

try:
index_map = {
key: indices[slc]
for key, slc in parameter_names.items()
key: indices[slc] for key, slc in parameter_names.items()
}
indexed = collapse_and_hstack(index_map.values())
except IndexError as err:
Expand Down Expand Up @@ -330,7 +333,8 @@ def sample(
_state = {key: val[0] for key, val in initial_state.items()}
self.param_slice_shape = infer_dict_mapping(_state)
initial_state = collapse_and_hstack(
initial_state.values(), self.nwalkers)
initial_state.values(), self.nwalkers
)

state = State(initial_state, copy=True)
state_shape = np.shape(state.coords)
Expand Down

0 comments on commit ab63bf4

Please sign in to comment.