Skip to content

Commit

Permalink
Merge pull request #157 from trevorb1/issue-151
Browse files Browse the repository at this point in the history
Adds checks for input data and config file names
  • Loading branch information
trevorb1 authored Apr 19, 2023
2 parents 7952254 + b85ed28 commit a10d336
Show file tree
Hide file tree
Showing 10 changed files with 108 additions and 79 deletions.
2 changes: 1 addition & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ install_requires =
pyyaml
pydot
importlib_resources; python_version<'3.7'
pandas>=1.1
pandas>=1.1, <2.0
amply>=0.1.4
networkx
flatten_dict
Expand Down
12 changes: 7 additions & 5 deletions src/otoole/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,18 +72,20 @@ def __str__(self):
return f"{self.name} -> {self.message}"


class OtooleExcelNameMismatchError(OtooleException):
"""Name mismatch between config and excel tabs."""
class OtooleNameMismatchError(OtooleException):
"""Names not consistent between read in data and config file"""

def __init__(
self, excel_name: str, message: str = "Excel tab name not found in config file"
self,
name: str,
message: str = "Name not consistent between data and config file",
) -> None:
self.excel_name = excel_name
self.name = name
self.message = message
super().__init__(self.message)

def __str__(self):
return f"{self.excel_name} -> {self.message}"
return f"{self.name} -> {self.message}"


class OtooleDeprecationError(OtooleException):
Expand Down
46 changes: 40 additions & 6 deletions src/otoole/input.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,12 @@

import logging
from abc import ABC, abstractmethod
from typing import Any, Dict, Optional, TextIO, Tuple, Union
from typing import Any, Dict, List, Optional, TextIO, Tuple, Union

import pandas as pd

from otoole.exceptions import OtooleNameMismatchError

logger = logging.getLogger(__name__)


Expand Down Expand Up @@ -124,10 +126,9 @@ class Strategy(ABC):
def __init__(self, user_config: Dict[str, Dict]):

self.user_config = user_config

# self.input_config = {
# x: y for x, y in self.user_config.items() if y["type"] in ["param", 'set']
# }
self.input_config = {
x: y for x, y in self.user_config.items() if y["type"] in ["param", "set"]
}
self.results_config = {
x: y for x, y in self.user_config.items() if y["type"] == "result"
}
Expand Down Expand Up @@ -418,7 +419,7 @@ def _get_missing_input_dataframes(
Arguments:
----------
input_data: Dict[str, pd.DataFrame]
Data read in from the excel notebook
Internal datastore
config_type: str
Type of value. Must be "set", "param", or "result"
Expand Down Expand Up @@ -450,6 +451,39 @@ def _get_missing_input_dataframes(

return input_data

def _compare_read_to_expected(
self, names: List[str], short_names: bool = False
) -> None:
"""Compares input data definitions to config file definitions
Arguments:
---------
names: List[str]
Parameter and set names read in
map_names: bool = False
If should be checking short_names from config file
Raises:
-------
OtooleNameMismatchError
If the info in the data and config file do not match
"""
user_config = self.input_config
if short_names:
expected = []
for name in user_config:
try:
expected.append(user_config[name]["short_name"])
except KeyError:
expected.append(name)
else:
expected = [x for x in user_config]

errors = list(set(expected).symmetric_difference(set(names)))
if errors:
logger.debug(f"data and config name errors are: {errors}")
raise OtooleNameMismatchError(name=errors[0])

@abstractmethod
def read(
self, filepath: Union[str, TextIO], **kwargs
Expand Down
36 changes: 6 additions & 30 deletions src/otoole/read_strategies.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from amply import Amply
from flatten_dict import flatten

from otoole.exceptions import OtooleDeprecationError, OtooleExcelNameMismatchError
from otoole.exceptions import OtooleDeprecationError
from otoole.input import ReadStrategy
from otoole.preprocess.longify_data import check_datatypes, check_set_datatype
from otoole.utils import create_name_mappings
Expand Down Expand Up @@ -121,8 +121,7 @@ def read(
excel_to_csv = create_name_mappings(config, map_full_to_short=False)

xl = pd.ExcelFile(filepath, engine="openpyxl")

self._check_input_sheet_names(xl.sheet_names)
self._compare_read_to_expected(names=xl.sheet_names, short_names=True)

input_data = {}

Expand Down Expand Up @@ -155,33 +154,6 @@ def read(

return input_data, default_values

def _check_input_sheet_names(self, sheet_names: List[str]) -> None:
"""Checks that excel sheet names are in the config file.
Arguments:
---------
sheet_names: list[str]
Sheet names from the excel file
Raises:
-------
OtooleExcelNameMismatchError
If the sheet name is not found in the config files parameter or
'short_name' parameter
"""
user_config = self.user_config
csv_to_excel = create_name_mappings(user_config)
config_param_names = []
for name in user_config:
try:
config_param_names.append(csv_to_excel[name])
except KeyError:
config_param_names.append(name)

for sheet_name in sheet_names:
if sheet_name not in config_param_names:
raise OtooleExcelNameMismatchError(excel_name=sheet_name)


class ReadCsv(_ReadTabular):
"""Read in a folder of CSV files"""
Expand All @@ -193,6 +165,10 @@ def read(
input_data = {}

self._check_for_default_values_csv(filepath)
self._compare_read_to_expected(
names=[f.split(".csv")[0] for f in os.listdir(filepath)]
)

default_values = self._read_default_values(self.user_config)

for parameter, details in self.user_config.items():
Expand Down
Binary file modified tests/fixtures/combined_inputs.xlsx
Binary file not shown.
5 changes: 0 additions & 5 deletions tests/fixtures/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -84,11 +84,6 @@ DiscountRate:
type: param
dtype: float
default: 0.05
DiscountRateIdv:
indices: [REGION,TECHNOLOGY]
type: param
dtype: float
default: 0.05
DiscountRateStorage:
indices: [REGION,STORAGE]
type: param
Expand Down
3 changes: 0 additions & 3 deletions tests/fixtures/config_r.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,3 @@ UseByTechnology:
dtype: float
default: 0
calculated: False
_REGION:
type: set
dtype: str
53 changes: 53 additions & 0 deletions tests/test_input.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from pandas.testing import assert_frame_equal
from pytest import fixture, mark, raises

from otoole.exceptions import OtooleNameMismatchError
from otoole.input import ReadStrategy, WriteStrategy


Expand Down Expand Up @@ -40,6 +41,31 @@ def simple_input_data(region, year, technology):
}


@fixture
def simple_user_config():
return {
"AccumulatedAnnualDemand": {
"indices": ["REGION", "FUEL", "YEAR"],
"type": "param",
"dtype": "float",
"default": 0,
"short_name": "AAD",
},
"REGION": {
"dtype": "str",
"type": "set",
},
"FUEL": {
"dtype": "str",
"type": "set",
},
"YEAR": {
"dtype": "int",
"type": "set",
},
}


# To instantiate abstract class WriteStrategy
class DummyWriteStrategy(WriteStrategy):
def _header(self) -> Union[TextIO, Any]:
Expand Down Expand Up @@ -279,6 +305,14 @@ class TestReadStrategy:
),
("set", "REGION", pd.DataFrame(columns=["VALUE"])),
)
compare_read_to_expected_data = [
[["AccumulatedAnnualDemand", "REGION", "FUEL", "YEAR"], False],
[["AAD", "REGION", "FUEL", "YEAR"], True],
]
compare_read_to_expected_data_exception = [
["AccumulatedAnnualDemand", "REGION", "FUEL"],
["AccumulatedAnnualDemand", "REGION", "FUEL", "YEAR", "Extra"],
]

@mark.parametrize(
"config_type, test_value, expected",
Expand All @@ -299,3 +333,22 @@ def test_get_missing_input_dataframes_excpetion(self, user_config):
reader = DummyReadStrategy(user_config)
with raises(ValueError):
reader._get_missing_input_dataframes(input_data, config_type="not_valid")

@mark.parametrize(
"expected, short_name",
compare_read_to_expected_data,
ids=["full_name", "short_name"],
)
def test_compare_read_to_expected(self, simple_user_config, expected, short_name):
reader = DummyReadStrategy(simple_user_config)
reader._compare_read_to_expected(names=expected, short_names=short_name)

@mark.parametrize(
"expected",
compare_read_to_expected_data_exception,
ids=["missing_value", "extra_value"],
)
def test_compare_read_to_expected_exception(self, simple_user_config, expected):
reader = DummyReadStrategy(simple_user_config)
with raises(OtooleNameMismatchError):
reader._compare_read_to_expected(names=expected)
16 changes: 0 additions & 16 deletions tests/test_read_strategies.py
Original file line number Diff line number Diff line change
Expand Up @@ -945,22 +945,6 @@ def test_read_excel_yearsplit(self, user_config):

assert (actual_data == expected).all()

def test_read_excel_discount_rate(self, user_config):
"""Tests that parameters not in excel are saved in datastore"""

spreadsheet = os.path.join("tests", "fixtures", "combined_inputs.xlsx")
xl = pd.ExcelFile(spreadsheet, engine="openpyxl")

# checks that fixture does not contian discount rate data
assert "DiscountRateIdv" not in xl.sheet_names

reader = ReadExcel(user_config=user_config)
actual, _ = reader.read(spreadsheet)

# checks that discount rate has data after reading in excel data
assert "DiscountRateIdv" in actual
assert actual["DiscountRateIdv"].empty

def test_narrow_parameters(self, user_config):
data = [
["IW0016", 0.238356164, 0.238356164, 0.238356164],
Expand Down
14 changes: 1 addition & 13 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,7 @@
import pytest
import yaml

from otoole.exceptions import (
OtooleDeprecationError,
OtooleExcelNameLengthError,
OtooleExcelNameMismatchError,
)
from otoole.read_strategies import ReadExcel
from otoole.exceptions import OtooleDeprecationError, OtooleExcelNameLengthError
from otoole.utils import (
UniqueKeyLoader,
create_name_mappings,
Expand Down Expand Up @@ -71,13 +66,6 @@ def test_create_name_mappings_reversed(self, user_config):
assert actual == expected


def test_excel_name_mismatch_error(user_config):
read_excel = ReadExcel(user_config=user_config)
sheet_names = ["AccumulatedAnnualDemand", "MismatchSheet"]
with pytest.raises(OtooleExcelNameMismatchError):
read_excel._check_input_sheet_names(sheet_names=sheet_names)


user_config_name_errors = ["user_config_long_param_name", "user_config_long_short_name"]


Expand Down

0 comments on commit a10d336

Please sign in to comment.