diff --git a/setup.cfg b/setup.cfg index b410ffb8..c509d9bf 100644 --- a/setup.cfg +++ b/setup.cfg @@ -46,7 +46,7 @@ install_requires = pyyaml pydot importlib_resources; python_version<'3.7' - pandas>=1.1 + pandas>=1.1, <2.0 amply>=0.1.4 networkx flatten_dict diff --git a/src/otoole/exceptions.py b/src/otoole/exceptions.py index 736a3494..9db13768 100644 --- a/src/otoole/exceptions.py +++ b/src/otoole/exceptions.py @@ -72,18 +72,20 @@ def __str__(self): return f"{self.name} -> {self.message}" -class OtooleExcelNameMismatchError(OtooleException): - """Name mismatch between config and excel tabs.""" +class OtooleNameMismatchError(OtooleException): + """Names not consistent between read in data and config file""" def __init__( - self, excel_name: str, message: str = "Excel tab name not found in config file" + self, + name: str, + message: str = "Name not consistent between data and config file", ) -> None: - self.excel_name = excel_name + self.name = name self.message = message super().__init__(self.message) def __str__(self): - return f"{self.excel_name} -> {self.message}" + return f"{self.name} -> {self.message}" class OtooleDeprecationError(OtooleException): diff --git a/src/otoole/input.py b/src/otoole/input.py index fab09792..ebdd13e6 100644 --- a/src/otoole/input.py +++ b/src/otoole/input.py @@ -32,10 +32,12 @@ import logging from abc import ABC, abstractmethod -from typing import Any, Dict, Optional, TextIO, Tuple, Union +from typing import Any, Dict, List, Optional, TextIO, Tuple, Union import pandas as pd +from otoole.exceptions import OtooleNameMismatchError + logger = logging.getLogger(__name__) @@ -124,10 +126,9 @@ class Strategy(ABC): def __init__(self, user_config: Dict[str, Dict]): self.user_config = user_config - - # self.input_config = { - # x: y for x, y in self.user_config.items() if y["type"] in ["param", 'set'] - # } + self.input_config = { + x: y for x, y in self.user_config.items() if y["type"] in ["param", "set"] + } self.results_config = { x: y for x, y in self.user_config.items() if y["type"] == "result" } @@ -418,7 +419,7 @@ def _get_missing_input_dataframes( Arguments: ---------- input_data: Dict[str, pd.DataFrame] - Data read in from the excel notebook + Internal datastore config_type: str Type of value. Must be "set", "param", or "result" @@ -450,6 +451,39 @@ def _get_missing_input_dataframes( return input_data + def _compare_read_to_expected( + self, names: List[str], short_names: bool = False + ) -> None: + """Compares input data definitions to config file definitions + + Arguments: + --------- + names: List[str] + Parameter and set names read in + map_names: bool = False + If should be checking short_names from config file + + Raises: + ------- + OtooleNameMismatchError + If the info in the data and config file do not match + """ + user_config = self.input_config + if short_names: + expected = [] + for name in user_config: + try: + expected.append(user_config[name]["short_name"]) + except KeyError: + expected.append(name) + else: + expected = [x for x in user_config] + + errors = list(set(expected).symmetric_difference(set(names))) + if errors: + logger.debug(f"data and config name errors are: {errors}") + raise OtooleNameMismatchError(name=errors[0]) + @abstractmethod def read( self, filepath: Union[str, TextIO], **kwargs diff --git a/src/otoole/read_strategies.py b/src/otoole/read_strategies.py index a74b4166..d8f6fca5 100644 --- a/src/otoole/read_strategies.py +++ b/src/otoole/read_strategies.py @@ -6,7 +6,7 @@ from amply import Amply from flatten_dict import flatten -from otoole.exceptions import OtooleDeprecationError, OtooleExcelNameMismatchError +from otoole.exceptions import OtooleDeprecationError from otoole.input import ReadStrategy from otoole.preprocess.longify_data import check_datatypes, check_set_datatype from otoole.utils import create_name_mappings @@ -121,8 +121,7 @@ def read( excel_to_csv = create_name_mappings(config, map_full_to_short=False) xl = pd.ExcelFile(filepath, engine="openpyxl") - - self._check_input_sheet_names(xl.sheet_names) + self._compare_read_to_expected(names=xl.sheet_names, short_names=True) input_data = {} @@ -155,33 +154,6 @@ def read( return input_data, default_values - def _check_input_sheet_names(self, sheet_names: List[str]) -> None: - """Checks that excel sheet names are in the config file. - - Arguments: - --------- - sheet_names: list[str] - Sheet names from the excel file - - Raises: - ------- - OtooleExcelNameMismatchError - If the sheet name is not found in the config files parameter or - 'short_name' parameter - """ - user_config = self.user_config - csv_to_excel = create_name_mappings(user_config) - config_param_names = [] - for name in user_config: - try: - config_param_names.append(csv_to_excel[name]) - except KeyError: - config_param_names.append(name) - - for sheet_name in sheet_names: - if sheet_name not in config_param_names: - raise OtooleExcelNameMismatchError(excel_name=sheet_name) - class ReadCsv(_ReadTabular): """Read in a folder of CSV files""" @@ -193,6 +165,10 @@ def read( input_data = {} self._check_for_default_values_csv(filepath) + self._compare_read_to_expected( + names=[f.split(".csv")[0] for f in os.listdir(filepath)] + ) + default_values = self._read_default_values(self.user_config) for parameter, details in self.user_config.items(): diff --git a/tests/fixtures/combined_inputs.xlsx b/tests/fixtures/combined_inputs.xlsx index a301dcfc..0495d5b4 100644 Binary files a/tests/fixtures/combined_inputs.xlsx and b/tests/fixtures/combined_inputs.xlsx differ diff --git a/tests/fixtures/config.yaml b/tests/fixtures/config.yaml index b57b8732..6f8ffdcc 100644 --- a/tests/fixtures/config.yaml +++ b/tests/fixtures/config.yaml @@ -84,11 +84,6 @@ DiscountRate: type: param dtype: float default: 0.05 -DiscountRateIdv: - indices: [REGION,TECHNOLOGY] - type: param - dtype: float - default: 0.05 DiscountRateStorage: indices: [REGION,STORAGE] type: param diff --git a/tests/fixtures/config_r.yaml b/tests/fixtures/config_r.yaml index e80e6101..8a56c2e0 100644 --- a/tests/fixtures/config_r.yaml +++ b/tests/fixtures/config_r.yaml @@ -202,6 +202,3 @@ UseByTechnology: dtype: float default: 0 calculated: False -_REGION: - type: set - dtype: str diff --git a/tests/test_input.py b/tests/test_input.py index 9bc69cea..0f2493a8 100644 --- a/tests/test_input.py +++ b/tests/test_input.py @@ -4,6 +4,7 @@ from pandas.testing import assert_frame_equal from pytest import fixture, mark, raises +from otoole.exceptions import OtooleNameMismatchError from otoole.input import ReadStrategy, WriteStrategy @@ -40,6 +41,31 @@ def simple_input_data(region, year, technology): } +@fixture +def simple_user_config(): + return { + "AccumulatedAnnualDemand": { + "indices": ["REGION", "FUEL", "YEAR"], + "type": "param", + "dtype": "float", + "default": 0, + "short_name": "AAD", + }, + "REGION": { + "dtype": "str", + "type": "set", + }, + "FUEL": { + "dtype": "str", + "type": "set", + }, + "YEAR": { + "dtype": "int", + "type": "set", + }, + } + + # To instantiate abstract class WriteStrategy class DummyWriteStrategy(WriteStrategy): def _header(self) -> Union[TextIO, Any]: @@ -279,6 +305,14 @@ class TestReadStrategy: ), ("set", "REGION", pd.DataFrame(columns=["VALUE"])), ) + compare_read_to_expected_data = [ + [["AccumulatedAnnualDemand", "REGION", "FUEL", "YEAR"], False], + [["AAD", "REGION", "FUEL", "YEAR"], True], + ] + compare_read_to_expected_data_exception = [ + ["AccumulatedAnnualDemand", "REGION", "FUEL"], + ["AccumulatedAnnualDemand", "REGION", "FUEL", "YEAR", "Extra"], + ] @mark.parametrize( "config_type, test_value, expected", @@ -299,3 +333,22 @@ def test_get_missing_input_dataframes_excpetion(self, user_config): reader = DummyReadStrategy(user_config) with raises(ValueError): reader._get_missing_input_dataframes(input_data, config_type="not_valid") + + @mark.parametrize( + "expected, short_name", + compare_read_to_expected_data, + ids=["full_name", "short_name"], + ) + def test_compare_read_to_expected(self, simple_user_config, expected, short_name): + reader = DummyReadStrategy(simple_user_config) + reader._compare_read_to_expected(names=expected, short_names=short_name) + + @mark.parametrize( + "expected", + compare_read_to_expected_data_exception, + ids=["missing_value", "extra_value"], + ) + def test_compare_read_to_expected_exception(self, simple_user_config, expected): + reader = DummyReadStrategy(simple_user_config) + with raises(OtooleNameMismatchError): + reader._compare_read_to_expected(names=expected) diff --git a/tests/test_read_strategies.py b/tests/test_read_strategies.py index dd106e9c..ea93b0f7 100644 --- a/tests/test_read_strategies.py +++ b/tests/test_read_strategies.py @@ -945,22 +945,6 @@ def test_read_excel_yearsplit(self, user_config): assert (actual_data == expected).all() - def test_read_excel_discount_rate(self, user_config): - """Tests that parameters not in excel are saved in datastore""" - - spreadsheet = os.path.join("tests", "fixtures", "combined_inputs.xlsx") - xl = pd.ExcelFile(spreadsheet, engine="openpyxl") - - # checks that fixture does not contian discount rate data - assert "DiscountRateIdv" not in xl.sheet_names - - reader = ReadExcel(user_config=user_config) - actual, _ = reader.read(spreadsheet) - - # checks that discount rate has data after reading in excel data - assert "DiscountRateIdv" in actual - assert actual["DiscountRateIdv"].empty - def test_narrow_parameters(self, user_config): data = [ ["IW0016", 0.238356164, 0.238356164, 0.238356164], diff --git a/tests/test_utils.py b/tests/test_utils.py index e4a445dc..8fac9aa3 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -4,12 +4,7 @@ import pytest import yaml -from otoole.exceptions import ( - OtooleDeprecationError, - OtooleExcelNameLengthError, - OtooleExcelNameMismatchError, -) -from otoole.read_strategies import ReadExcel +from otoole.exceptions import OtooleDeprecationError, OtooleExcelNameLengthError from otoole.utils import ( UniqueKeyLoader, create_name_mappings, @@ -71,13 +66,6 @@ def test_create_name_mappings_reversed(self, user_config): assert actual == expected -def test_excel_name_mismatch_error(user_config): - read_excel = ReadExcel(user_config=user_config) - sheet_names = ["AccumulatedAnnualDemand", "MismatchSheet"] - with pytest.raises(OtooleExcelNameMismatchError): - read_excel._check_input_sheet_names(sheet_names=sheet_names) - - user_config_name_errors = ["user_config_long_param_name", "user_config_long_short_name"]