-
Notifications
You must be signed in to change notification settings - Fork 4
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Added metric factory #209
Merged
Merged
Added metric factory #209
Changes from all commits
Commits
Show all changes
6 commits
Select commit
Hold shift + click to select a range
0bcc1ab
Added metric factory
eacharles 366703c
added force_update to one test to get full coverage of new code
eacharles d180ef6
added base_metric_classes to interface in __init__
eacharles 5b17e3f
fix typo
eacharles 791da31
get last instance of typo
eacharles e489b3b
fix typos
eacharles File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,7 +1,17 @@ | ||
from .array_metrics import * | ||
from .metrics import * | ||
from .goodness_of_fit import * | ||
from .base_metric_classes import * | ||
|
||
# added for testing purposes | ||
|
||
from .metrics import _calculate_grid_parameters, _check_ensemble_is_not_nested, _check_ensembles_are_same_size | ||
from .metrics import _check_ensembles_contain_correct_number_of_distributions | ||
|
||
from .factory import MetricFactory | ||
|
||
|
||
create_metric = MetricFactory.create_metric | ||
print_metrics = MetricFactory.print_metrics | ||
list_metrics = MetricFactory.list_metrics | ||
update_metrics = MetricFactory.update_metrics |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,50 @@ | ||
from .base_metric_classes import BaseMetric | ||
|
||
|
||
def get_all_subclasses(cls): | ||
""" Utility function to recursively get all the subclasses of a class """ | ||
all_subclasses = [] | ||
|
||
for subclass in cls.__subclasses__(): | ||
all_subclasses.append(subclass) | ||
all_subclasses.extend(get_all_subclasses(subclass)) | ||
|
||
return all_subclasses | ||
|
||
|
||
class MetricFactory: | ||
|
||
metric_dict = {} | ||
|
||
@classmethod | ||
def update_metrics(cls): | ||
""" Update the dictionary of known metrics """ | ||
all_subclasses = get_all_subclasses(BaseMetric) | ||
cls.metric_dict = { subcls.metric_name.replace("Metric", ""): subcls for subcls in all_subclasses if subcls.metric_name } | ||
|
||
@classmethod | ||
def print_metrics(cls, force_update=False): | ||
""" List all the known metrics """ | ||
if not cls.metric_dict or force_update: | ||
cls.update_metrics() | ||
print('List of available metrics') | ||
for key, val in cls.metric_dict.items(): | ||
print(key, val, val.metric_input_type.name) | ||
|
||
@classmethod | ||
def list_metrics(cls, force_update=False): | ||
""" Get the list of all the metric names """ | ||
if not cls.metric_dict or force_update: | ||
cls.update_metrics() | ||
return list(cls.metric_dict.keys()) | ||
|
||
@classmethod | ||
def create_metric(cls, name, force_update=False, **kwargs): | ||
""" Create a metric evaluator """ | ||
if not cls.metric_dict or force_update: | ||
cls.update_metrics() | ||
try: | ||
metric_class = cls.metric_dict[name] | ||
except KeyError as msg: | ||
raise KeyError(f"{name} is not in the set of known metrics {str(cls.list_metrics())}") from msg | ||
return metric_class(**kwargs) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,51 @@ | ||
# pylint: disable=no-member | ||
# pylint: disable=protected-access | ||
|
||
""" | ||
Unit tests for PDF class | ||
""" | ||
import unittest | ||
import numpy as np | ||
import qp | ||
import qp.metrics | ||
|
||
|
||
class MetricFactoryTestCase(unittest.TestCase): | ||
"""Tests for the metric factory""" | ||
|
||
def setUp(self): | ||
""" | ||
Setup an objects that are used in multiple tests. | ||
""" | ||
qp.metrics.update_metrics() | ||
|
||
def test_print_metrics(self): | ||
"""Test printing the metrics.""" | ||
qp.metrics.print_metrics() | ||
qp.metrics.print_metrics(force_update=True) | ||
|
||
|
||
def test_list_metrics(self): | ||
"""Test printing the metrics.""" | ||
the_list = qp.metrics.list_metrics() | ||
assert the_list | ||
the_list = qp.metrics.list_metrics(force_update=True) | ||
assert the_list | ||
|
||
def test_create_metrics(self): | ||
"""Test creating all the metrics""" | ||
all_metric_names = qp.metrics.list_metrics(force_update=True) | ||
for metric_name in all_metric_names: | ||
a_metric = qp.metrics.create_metric(metric_name) | ||
assert a_metric.metric_name == metric_name | ||
a_metric = qp.metrics.create_metric("outlier", force_update=True) | ||
assert a_metric.metric_name == 'outlier' | ||
|
||
def test_bad_metric_name(self): | ||
""" Catch error on making a bad metric """ | ||
with self.assertRaises(KeyError): | ||
qp.metrics.create_metric("Bad Metric") | ||
|
||
|
||
if __name__ == "__main__": | ||
unittest.main() |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🤦 <- I'm glad there's a facepalm emoji in slack.