diff --git a/stanio/__init__.py b/stanio/__init__.py index aeaaf2e..bcda030 100644 --- a/stanio/__init__.py +++ b/stanio/__init__.py @@ -1,7 +1,14 @@ from .csv import read_csv -from .json import write_stan_json +from .json import write_stan_json, dump_stan_json from .reshape import Variable, parse_header, stan_variables -__all__ = ["read_csv", "write_stan_json", "Variable", "parse_header", "stan_variables"] +__all__ = [ + "read_csv", + "write_stan_json", + "dump_stan_json", + "Variable", + "parse_header", + "stan_variables", +] -__version__ = "0.2.0" +__version__ = "0.3.0" diff --git a/stanio/json.py b/stanio/json.py index ce60799..39a1cd4 100644 --- a/stanio/json.py +++ b/stanio/json.py @@ -36,6 +36,25 @@ def process_value(val: Any) -> Any: return val +def dump_stan_json(data: Mapping[str, Any]) -> str: + """ + Convert a mapping of strings to data to a JSON string. + + Values can be any numeric type, a boolean (converted to int), + or any collection compatible with :func:`numpy.asarray`, e.g a + :class:`pandas.Series`. + + Produces a string compatible with the + `Json Format for Cmdstan + `__ + + :param data: A mapping from strings to values. This can be a dictionary + or something more exotic like an :class:`xarray.Dataset`. This will be + copied before type conversion, not modified + """ + return json.dumps(process_dictionary(data)) + + def write_stan_json(path: str, data: Mapping[str, Any]) -> None: """ Dump a mapping of strings to data to a JSON file. diff --git a/test/test_json.py b/test/test_json.py index ebe8821..9499b04 100644 --- a/test/test_json.py +++ b/test/test_json.py @@ -8,7 +8,7 @@ import pandas as pd import pytest -from stanio.json import write_stan_json +from stanio.json import dump_stan_json, write_stan_json @pytest.fixture(scope="module") @@ -163,3 +163,11 @@ def test_tuples(TMPDIR) -> None: file_tuple = os.path.join(TMPDIR, "tuple.json") write_stan_json(file_tuple, dict_tuples) compare_before_after(file_tuple, dict_tuples, dict_tuple_exp) + + +def test_write_vs_dump(TMPDIR): + dict_list = {"a": [1.0, 2.0, 3.0]} + file_write = os.path.join(TMPDIR, "write.json") + write_stan_json(file_write, dict_list) + with open(file_write) as fd: + assert fd.read() == dump_stan_json(dict_list)