Skip to content

Commit

Permalink
Add dump_stan_json function
Browse files Browse the repository at this point in the history
  • Loading branch information
WardBrian committed Aug 25, 2023
1 parent 897f4ac commit 5ffb2ab
Show file tree
Hide file tree
Showing 3 changed files with 38 additions and 4 deletions.
13 changes: 10 additions & 3 deletions stanio/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,14 @@
from .csv import read_csv
from .json import write_stan_json
from .json import write_stan_json, dump_stan_json
from .reshape import Variable, parse_header, stan_variables

__all__ = ["read_csv", "write_stan_json", "Variable", "parse_header", "stan_variables"]
__all__ = [
"read_csv",
"write_stan_json",
"dump_stan_json",
"Variable",
"parse_header",
"stan_variables",
]

__version__ = "0.2.0"
__version__ = "0.3.0"
19 changes: 19 additions & 0 deletions stanio/json.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,25 @@ def process_value(val: Any) -> Any:
return val


def dump_stan_json(data: Mapping[str, Any]) -> str:
"""
Convert a mapping of strings to data to a JSON string.
Values can be any numeric type, a boolean (converted to int),
or any collection compatible with :func:`numpy.asarray`, e.g a
:class:`pandas.Series`.
Produces a string compatible with the
`Json Format for Cmdstan
<https://mc-stan.org/docs/cmdstan-guide/json.html>`__
:param data: A mapping from strings to values. This can be a dictionary
or something more exotic like an :class:`xarray.Dataset`. This will be
copied before type conversion, not modified
"""
return json.dumps(process_dictionary(data))


def write_stan_json(path: str, data: Mapping[str, Any]) -> None:
"""
Dump a mapping of strings to data to a JSON file.
Expand Down
10 changes: 9 additions & 1 deletion test/test_json.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import pandas as pd
import pytest

from stanio.json import write_stan_json
from stanio.json import dump_stan_json, write_stan_json


@pytest.fixture(scope="module")
Expand Down Expand Up @@ -163,3 +163,11 @@ def test_tuples(TMPDIR) -> None:
file_tuple = os.path.join(TMPDIR, "tuple.json")
write_stan_json(file_tuple, dict_tuples)
compare_before_after(file_tuple, dict_tuples, dict_tuple_exp)


def test_write_vs_dump(TMPDIR):
dict_list = {"a": [1.0, 2.0, 3.0]}
file_write = os.path.join(TMPDIR, "write.json")
write_stan_json(file_write, dict_list)
with open(file_write) as fd:
assert fd.read() == dump_stan_json(dict_list)

0 comments on commit 5ffb2ab

Please sign in to comment.