diff --git a/README.md b/README.md index b190395..e7f63bd 100644 --- a/README.md +++ b/README.md @@ -33,10 +33,21 @@ class APITestCase(TestCase): self.assertMatchSnapshot(my_gpg_response, 'gpg_response') ``` -If you want to update the snapshots automatically you can use the `nosetests --snapshot-update`. +You'll also need to let your test runner know about snapshottest, +to summarize the snapshot results and handle removing unused snapshots: +* If your code calls `unittest.main()`, replace that with `snapshottest.main()` +* If you run `python -m unittest ...`, switch to `python -m snapshottest ...` +* If you use nose, snapshottest automatically loads a nose plugin + that handles this for you +* Or if you have a custom unittest TestRunner, add + `snapshottest.unittest.SnapshotTestRunnerMixin` (see its docstring for more info) + +To generate new snapshots, add `--snapshot-update` to your usual test command line +(e.g., `python -m snapshottest ... --snapshot-update` or `nosetests --snapshot-update`). Check the [Unittest example](https://github.com/syrusakbary/snapshottest/tree/master/examples/unittest). + ## Usage with pytest ```python diff --git a/examples/unittest/test_demo.py b/examples/unittest/test_demo.py index 268e833..4e37be7 100644 --- a/examples/unittest/test_demo.py +++ b/examples/unittest/test_demo.py @@ -1,4 +1,3 @@ -import unittest import snapshottest @@ -8,6 +7,12 @@ def api_client_get(url): } +# Use snapshottest.TestCase in place of unittest.TestCase +# where you want to run snapshot tests. +# +# (You can also mix it into any subclass of unittest.TestCase: +# class TestDemo(snapshottest.TestCase, MyCustomTestCase): +# ...) class TestDemo(snapshottest.TestCase): def setUp(self): pass @@ -18,4 +23,6 @@ def test_api_me(self): if __name__ == "__main__": - unittest.main() + # Replace unittest.main() with snapshottest's version: + # unittest.main() + snapshottest.main() diff --git a/snapshottest/__init__.py b/snapshottest/__init__.py index 8db737d..13adfb2 100644 --- a/snapshottest/__init__.py +++ b/snapshottest/__init__.py @@ -1,7 +1,7 @@ from .snapshot import Snapshot from .generic_repr import GenericRepr from .module import assert_match_snapshot -from .unittest import TestCase +from .unittest import TestCase, main -__all__ = ["Snapshot", "GenericRepr", "assert_match_snapshot", "TestCase"] +__all__ = ["Snapshot", "GenericRepr", "assert_match_snapshot", "TestCase", "main"] diff --git a/snapshottest/__main__.py b/snapshottest/__main__.py new file mode 100644 index 0000000..776d18a --- /dev/null +++ b/snapshottest/__main__.py @@ -0,0 +1,20 @@ +"""Main entry point (for unittest with snapshottest support)""" + +# This is here to support invoking snapshottest-augmented unittest via +# `python -m snapshottest ...` (paralleling unittest's own `python -m unittest ...`). +# It's copied almost directly from unittest.__main__. + +import sys + +if sys.argv[0].endswith("__main__.py"): + import os.path + + # We change sys.argv[0] to make help message more useful + # use executable without path, unquoted + executable = os.path.basename(sys.executable) + sys.argv[0] = executable + " -m snapshottest" + del os + +from .unittest import main + +main(module=None) diff --git a/snapshottest/unittest.py b/snapshottest/unittest.py index 535b24a..8021b7f 100644 --- a/snapshottest/unittest.py +++ b/snapshottest/unittest.py @@ -1,9 +1,10 @@ -import unittest import inspect +import sys +import unittest -from .module import SnapshotModule, SnapshotTest from .diff import PrettyDiff -from .reporting import diff_report +from .module import SnapshotModule, SnapshotTest +from .reporting import diff_report, reporting_lines class UnitTestSnapshotTest(SnapshotTest): @@ -36,6 +37,10 @@ def test_name(self): # Inspired by https://gist.github.com/twolfson/13f5f5784f67fd49b245 class TestCase(unittest.TestCase): + # Whether snapshots should be updated, for all unittest-derived frameworks. + # Set (perhaps circuitously) in runner init from the --snapshot-update + # command line option. (.unittest.TestCase.snapshot_should_update is the + # equivalent of pytest's config.option.snapshot_update.) snapshot_should_update = False @classmethod @@ -99,3 +104,140 @@ def assert_match_snapshot(self, value, name=""): self._snapshot.assert_match(value, name=name) assertMatchSnapshot = assert_match_snapshot + + +def output_snapshottest_summary(stream=None, testing_cli=None): + """ + Outputs a summary of snapshot tests for the session, if any. + + Call at the end of a test session to write results summary + to stream (default sys.stderr). If no snapshot tests were run, + outputs nothing. + + testing_cli (default from sys.argv) should be the string command + line that invokes the tests, and is used to explain how to update + snapshots. + + (This is the equivalent of .pytest.SnapshotSession.display, + for unittest-derived frameworks.) + """ + # TODO: Call this to replace near-duplicate code in .django and .nose. + + if not SnapshotModule.has_snapshots(): + return + + if stream is None: + # This follows unittest.TextTestRunner, which by default uses sys.stderr + # for test status and summary info (not sys.stdout). + stream = sys.stderr + if testing_cli is None: + # We can't really recover the exact command line formatted for the user's shell + # (quoting, etc.), but this should be close enough to get the point across. + testing_cli = " ".join(sys.argv) + + separator1 = "=" * 70 + separator2 = "-" * 70 + + print(separator1, file=stream) + print("SnapshotTest summary", file=stream) + print(separator2, file=stream) + for line in reporting_lines(testing_cli): + print(line, file=stream) + print(separator1, file=stream) + + +def finalize_snapshots(): + """ + Call at the end of a unittest session to delete unused snapshots. + + (This deletes the data needed for SnapshotModule.total_unvisited_snapshots. + Complete any reporting before calling this function.) + """ + # TODO: this is duplicated in four places (with varying "should_update" conditions). + # Move it into shared code for snapshot sessions (which is currently implemented + # as classmethods on SnapshotModule). + if TestCase.snapshot_should_update: + for module in SnapshotModule.get_modules(): + module.delete_unvisited() + module.save() + + +class SnapshotTestRunnerMixin: + """ + A mixin for a unittest TestRunner that adds snapshottest session handling. + + Note: a TestRunner is not responsible for command line options. If you are + adding snapshottest support to other unittest-derived frameworks, you must + also arrange to set snapshottest.unittest.TestCase.snapshot_should_update + when the user requests --snapshot-update. + """ + + def run(self, test): + result = super().run(test) + self.report_snapshottest_summary() + finalize_snapshots() + return result + + def report_snapshottest_summary(self): + """Report a summary of snapshottest results for the session""" + if hasattr(self, "stream"): + # Mixed into a unittest.TextTestRunner or similar (with an output stream) + output_snapshottest_summary(self.stream) + else: + # Mixed into some sort of graphical frontend, probably + raise NotImplementedError( + "Non-text TestRunner with SnapshotTestRunnerMixin" + " must implement report_snapshottest_summary" + ) + + +class SnapshotTextTestRunner(SnapshotTestRunnerMixin, unittest.TextTestRunner): + """ + Version of unittest.TextTestRunner that adds snapshottest session handling. + """ + + pass + + +class SnapshotTestProgram(unittest.TestProgram): + """ + Augmented implementation of unittest.main that adds --snapshot-update + command line option, and that ensures testRunner includes snapshottest + session handling. + """ + + def __init__(self, *args, testRunner=None, **kwargs): + # (For simplicity, we only allow testRunner as a kwarg.) + if testRunner is None: + testRunner = SnapshotTextTestRunner + # Verify the testRunner includes snapshot session handling. + # "The testRunner argument can either be a test runner class + # or an already created instance of it." + if not issubclass(testRunner, SnapshotTestRunnerMixin) and not isinstance( + testRunner, SnapshotTestRunnerMixin + ): + raise TypeError( + "snapshottest testRunner must include SnapshotTestRunnerMixin" + ) + + self._snapshot_update = False + super().__init__(*args, testRunner=testRunner, **kwargs) + + def _getParentArgParser(self): + # (Yes, this is hooking a private method. Sorry. + # unittest.TestProgram isn't really designed to be extended.) + parser = super()._getParentArgParser() + parser.add_argument( + "--snapshot-update", + dest="_snapshot_update", + action="store_true", + help="Update snapshottest snapshots", + ) + return parser + + def runTests(self): + TestCase.snapshot_should_update = self._snapshot_update + super().runTests() + + +main = SnapshotTestProgram