From 90b176c8b0ee004f509ae99192792b5d7994f0b3 Mon Sep 17 00:00:00 2001 From: William Murphy Date: Wed, 27 Mar 2024 16:49:09 -0400 Subject: [PATCH] add ability to download cached workspace (#520) * create "stale" field on workspace state A provider that downloads its workspace state directly cannot assume that this state is a valid basis for a future incremental update, and should mark the downloaded workspace as stale. Signed-off-by: Will Murphy * WIP add configs Signed-off-by: Will Murphy * lint fix Signed-off-by: Will Murphy * [wip] working on vunnel results db listing Signed-off-by: Alex Goodman * update and tests for safe_extract_tar Now that we're using it for more than one thing, make an extractor that generally prevents path traversal. Signed-off-by: Will Murphy * [wip] adding tests for fetching listing and archives Signed-off-by: Alex Goodman * [wip] add more negative tests for provider tests Signed-off-by: Alex Goodman * unit test for new workspace changes Signed-off-by: Will Murphy * replace the workspace results instead of overlaying Signed-off-by: Will Murphy * clean up hasher implementation Signed-off-by: Alex Goodman * add tests for prep workspace from listing entry Signed-off-by: Will Murphy * do not include inputs in tar test fixture Signed-off-by: Alex Goodman * vunnel fetch existing workspace working Signed-off-by: Will Murphy * add unit test for full update flow Signed-off-by: Will Murphy * update existing unit tests for new config values Signed-off-by: Will Murphy * add unit test for default behavior of new configs Signed-off-by: Will Murphy * lint fix Signed-off-by: Will Murphy * add missing annotations import Signed-off-by: Will Murphy * Use 3.9 compatible annotations Relying on the from __future__ import annotations doesn't work with the mashumaro. Signed-off-by: Will Murphy * validate that enabling import results requires host and path Signed-off-by: Will Murphy * rename listing field and add schema Signed-off-by: Alex Goodman * only require github token when downloading Signed-off-by: Alex Goodman * add zstd support Signed-off-by: Alex Goodman * add tests for zstd support Signed-off-by: Alex Goodman * add tests for _has_newer_archive Signed-off-by: Will Murphy * fix tests for zstd Signed-off-by: Alex Goodman * show stderr to log when git commands fail Signed-off-by: Alex Goodman * move import_results to common field on provider Signed-off-by: Will Murphy * add concept for distribution version Signed-off-by: Alex Goodman * single source of truth for provider schemas Signed-off-by: Alex Goodman * add distribution-version to schema, provider state, and listing entry Signed-off-by: Alex Goodman * clear workspace on different dist version Signed-off-by: Alex Goodman * fix defaulting logic and update tests Signed-off-by: Will Murphy * default distribution version and path Signed-off-by: Will Murphy * make "" and None both use default path Signed-off-by: Will Murphy --------- Signed-off-by: Will Murphy Signed-off-by: Alex Goodman Co-authored-by: Alex Goodman --- poetry.lock | 150 ++++- pyproject.toml | 2 + schema/provider-archive-listing/README.md | 17 + .../schema-1.0.0.json | 66 ++ .../schema-1.0.2.json | 80 +++ src/vunnel/cli/config.py | 58 +- src/vunnel/distribution.py | 89 +++ src/vunnel/provider.py | 149 ++++- src/vunnel/providers/alpine/__init__.py | 7 +- src/vunnel/providers/amazon/__init__.py | 7 +- src/vunnel/providers/chainguard/__init__.py | 7 +- src/vunnel/providers/debian/__init__.py | 7 +- src/vunnel/providers/github/__init__.py | 7 +- src/vunnel/providers/github/parser.py | 7 +- src/vunnel/providers/mariner/__init__.py | 8 +- src/vunnel/providers/nvd/__init__.py | 8 +- src/vunnel/providers/nvd/manager.py | 6 +- src/vunnel/providers/nvd/overrides.py | 25 +- src/vunnel/providers/oracle/__init__.py | 7 +- src/vunnel/providers/rhel/__init__.py | 7 +- src/vunnel/providers/sles/__init__.py | 8 +- src/vunnel/providers/ubuntu/__init__.py | 8 +- src/vunnel/providers/ubuntu/git.py | 18 +- src/vunnel/providers/wolfi/__init__.py | 7 +- src/vunnel/result.py | 4 +- src/vunnel/schema.py | 22 +- src/vunnel/utils/archive.py | 50 ++ src/vunnel/utils/hasher.py | 34 ++ src/vunnel/utils/http.py | 12 +- src/vunnel/workspace.py | 92 ++- tests/conftest.py | 17 +- tests/unit/cli/test_cli.py | 44 ++ tests/unit/cli/test_config.py | 86 +++ tests/unit/providers/nvd/test_overrides.py | 10 - tests/unit/test_distribution.py | 149 +++++ tests/unit/test_provider.py | 574 +++++++++++++++++- tests/unit/test_result.py | 12 +- tests/unit/test_schema.py | 9 + tests/unit/test_workspace.py | 113 +++- tests/unit/utils/test_archive.py | 75 +++ tests/unit/utils/test_hasher.py | 36 ++ 41 files changed, 1967 insertions(+), 127 deletions(-) create mode 100644 schema/provider-archive-listing/README.md create mode 100644 schema/provider-archive-listing/schema-1.0.0.json create mode 100644 schema/provider-workspace-state/schema-1.0.2.json create mode 100644 src/vunnel/distribution.py create mode 100644 src/vunnel/utils/archive.py create mode 100644 src/vunnel/utils/hasher.py create mode 100644 tests/unit/test_distribution.py create mode 100644 tests/unit/test_schema.py create mode 100644 tests/unit/utils/test_archive.py create mode 100644 tests/unit/utils/test_hasher.py diff --git a/poetry.lock b/poetry.lock index 78ee3fba..c648cfed 100644 --- a/poetry.lock +++ b/poetry.lock @@ -87,6 +87,70 @@ files = [ {file = "certifi-2024.2.2.tar.gz", hash = "sha256:0569859f95fc761b18b45ef421b1290a0f65f147e92a1e5eb3e635f9a5e4e66f"}, ] +[[package]] +name = "cffi" +version = "1.16.0" +description = "Foreign Function Interface for Python calling C code." +optional = false +python-versions = ">=3.8" +files = [ + {file = "cffi-1.16.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:6b3d6606d369fc1da4fd8c357d026317fbb9c9b75d36dc16e90e84c26854b088"}, + {file = "cffi-1.16.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:ac0f5edd2360eea2f1daa9e26a41db02dd4b0451b48f7c318e217ee092a213e9"}, + {file = "cffi-1.16.0-cp310-cp310-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:7e61e3e4fa664a8588aa25c883eab612a188c725755afff6289454d6362b9673"}, + {file = "cffi-1.16.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a72e8961a86d19bdb45851d8f1f08b041ea37d2bd8d4fd19903bc3083d80c896"}, + {file = "cffi-1.16.0-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:5b50bf3f55561dac5438f8e70bfcdfd74543fd60df5fa5f62d94e5867deca684"}, + {file = "cffi-1.16.0-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:7651c50c8c5ef7bdb41108b7b8c5a83013bfaa8a935590c5d74627c047a583c7"}, + {file = "cffi-1.16.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e4108df7fe9b707191e55f33efbcb2d81928e10cea45527879a4749cbe472614"}, + {file = "cffi-1.16.0-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:32c68ef735dbe5857c810328cb2481e24722a59a2003018885514d4c09af9743"}, + {file = "cffi-1.16.0-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:673739cb539f8cdaa07d92d02efa93c9ccf87e345b9a0b556e3ecc666718468d"}, + {file = "cffi-1.16.0-cp310-cp310-win32.whl", hash = "sha256:9f90389693731ff1f659e55c7d1640e2ec43ff725cc61b04b2f9c6d8d017df6a"}, + {file = "cffi-1.16.0-cp310-cp310-win_amd64.whl", hash = "sha256:e6024675e67af929088fda399b2094574609396b1decb609c55fa58b028a32a1"}, + {file = "cffi-1.16.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:b84834d0cf97e7d27dd5b7f3aca7b6e9263c56308ab9dc8aae9784abb774d404"}, + {file = "cffi-1.16.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:1b8ebc27c014c59692bb2664c7d13ce7a6e9a629be20e54e7271fa696ff2b417"}, + {file = "cffi-1.16.0-cp311-cp311-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:ee07e47c12890ef248766a6e55bd38ebfb2bb8edd4142d56db91b21ea68b7627"}, + {file = "cffi-1.16.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d8a9d3ebe49f084ad71f9269834ceccbf398253c9fac910c4fd7053ff1386936"}, + {file = "cffi-1.16.0-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:e70f54f1796669ef691ca07d046cd81a29cb4deb1e5f942003f401c0c4a2695d"}, + {file = "cffi-1.16.0-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:5bf44d66cdf9e893637896c7faa22298baebcd18d1ddb6d2626a6e39793a1d56"}, + {file = "cffi-1.16.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7b78010e7b97fef4bee1e896df8a4bbb6712b7f05b7ef630f9d1da00f6444d2e"}, + {file = "cffi-1.16.0-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:c6a164aa47843fb1b01e941d385aab7215563bb8816d80ff3a363a9f8448a8dc"}, + {file = "cffi-1.16.0-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:e09f3ff613345df5e8c3667da1d918f9149bd623cd9070c983c013792a9a62eb"}, + {file = "cffi-1.16.0-cp311-cp311-win32.whl", hash = "sha256:2c56b361916f390cd758a57f2e16233eb4f64bcbeee88a4881ea90fca14dc6ab"}, + {file = "cffi-1.16.0-cp311-cp311-win_amd64.whl", hash = "sha256:db8e577c19c0fda0beb7e0d4e09e0ba74b1e4c092e0e40bfa12fe05b6f6d75ba"}, + {file = "cffi-1.16.0-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:fa3a0128b152627161ce47201262d3140edb5a5c3da88d73a1b790a959126956"}, + {file = "cffi-1.16.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:68e7c44931cc171c54ccb702482e9fc723192e88d25a0e133edd7aff8fcd1f6e"}, + {file = "cffi-1.16.0-cp312-cp312-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:abd808f9c129ba2beda4cfc53bde801e5bcf9d6e0f22f095e45327c038bfe68e"}, + {file = "cffi-1.16.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:88e2b3c14bdb32e440be531ade29d3c50a1a59cd4e51b1dd8b0865c54ea5d2e2"}, + {file = "cffi-1.16.0-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:fcc8eb6d5902bb1cf6dc4f187ee3ea80a1eba0a89aba40a5cb20a5087d961357"}, + {file = "cffi-1.16.0-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:b7be2d771cdba2942e13215c4e340bfd76398e9227ad10402a8767ab1865d2e6"}, + {file = "cffi-1.16.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e715596e683d2ce000574bae5d07bd522c781a822866c20495e52520564f0969"}, + {file = "cffi-1.16.0-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:2d92b25dbf6cae33f65005baf472d2c245c050b1ce709cc4588cdcdd5495b520"}, + {file = "cffi-1.16.0-cp312-cp312-win32.whl", hash = "sha256:b2ca4e77f9f47c55c194982e10f058db063937845bb2b7a86c84a6cfe0aefa8b"}, + {file = "cffi-1.16.0-cp312-cp312-win_amd64.whl", hash = "sha256:68678abf380b42ce21a5f2abde8efee05c114c2fdb2e9eef2efdb0257fba1235"}, + {file = "cffi-1.16.0-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:0c9ef6ff37e974b73c25eecc13952c55bceed9112be2d9d938ded8e856138bcc"}, + {file = "cffi-1.16.0-cp38-cp38-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:a09582f178759ee8128d9270cd1344154fd473bb77d94ce0aeb2a93ebf0feaf0"}, + {file = "cffi-1.16.0-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e760191dd42581e023a68b758769e2da259b5d52e3103c6060ddc02c9edb8d7b"}, + {file = "cffi-1.16.0-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:80876338e19c951fdfed6198e70bc88f1c9758b94578d5a7c4c91a87af3cf31c"}, + {file = "cffi-1.16.0-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:a6a14b17d7e17fa0d207ac08642c8820f84f25ce17a442fd15e27ea18d67c59b"}, + {file = "cffi-1.16.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6602bc8dc6f3a9e02b6c22c4fc1e47aa50f8f8e6d3f78a5e16ac33ef5fefa324"}, + {file = "cffi-1.16.0-cp38-cp38-win32.whl", hash = "sha256:131fd094d1065b19540c3d72594260f118b231090295d8c34e19a7bbcf2e860a"}, + {file = "cffi-1.16.0-cp38-cp38-win_amd64.whl", hash = "sha256:31d13b0f99e0836b7ff893d37af07366ebc90b678b6664c955b54561fc36ef36"}, + {file = "cffi-1.16.0-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:582215a0e9adbe0e379761260553ba11c58943e4bbe9c36430c4ca6ac74b15ed"}, + {file = "cffi-1.16.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:b29ebffcf550f9da55bec9e02ad430c992a87e5f512cd63388abb76f1036d8d2"}, + {file = "cffi-1.16.0-cp39-cp39-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:dc9b18bf40cc75f66f40a7379f6a9513244fe33c0e8aa72e2d56b0196a7ef872"}, + {file = "cffi-1.16.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:9cb4a35b3642fc5c005a6755a5d17c6c8b6bcb6981baf81cea8bfbc8903e8ba8"}, + {file = "cffi-1.16.0-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:b86851a328eedc692acf81fb05444bdf1891747c25af7529e39ddafaf68a4f3f"}, + {file = "cffi-1.16.0-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:c0f31130ebc2d37cdd8e44605fb5fa7ad59049298b3f745c74fa74c62fbfcfc4"}, + {file = "cffi-1.16.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8f8e709127c6c77446a8c0a8c8bf3c8ee706a06cd44b1e827c3e6a2ee6b8c098"}, + {file = "cffi-1.16.0-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:748dcd1e3d3d7cd5443ef03ce8685043294ad6bd7c02a38d1bd367cfd968e000"}, + {file = "cffi-1.16.0-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:8895613bcc094d4a1b2dbe179d88d7fb4a15cee43c052e8885783fac397d91fe"}, + {file = "cffi-1.16.0-cp39-cp39-win32.whl", hash = "sha256:ed86a35631f7bfbb28e108dd96773b9d5a6ce4811cf6ea468bb6a359b256b1e4"}, + {file = "cffi-1.16.0-cp39-cp39-win_amd64.whl", hash = "sha256:3686dffb02459559c74dd3d81748269ffb0eb027c39a6fc99502de37d501faa8"}, + {file = "cffi-1.16.0.tar.gz", hash = "sha256:bcb3ef43e58665bbda2fb198698fcae6776483e0c4a631aa5647806c25e02cc0"}, +] + +[package.dependencies] +pycparser = "*" + [[package]] name = "cfgv" version = "3.4.0" @@ -664,6 +728,17 @@ files = [ {file = "iniconfig-2.0.0.tar.gz", hash = "sha256:2d91e135bf72d31a410b17c16da610a82cb55f6b0477d1a902134b24a455b8b3"}, ] +[[package]] +name = "iso8601" +version = "2.1.0" +description = "Simple module to parse ISO 8601 dates" +optional = false +python-versions = ">=3.7,<4.0" +files = [ + {file = "iso8601-2.1.0-py3-none-any.whl", hash = "sha256:aac4145c4dcb66ad8b648a02830f5e2ff6c24af20f4f482689be402db2429242"}, + {file = "iso8601-2.1.0.tar.gz", hash = "sha256:6b1d3829ee8921c4301998c909f7829fa9ed3cbdac0d3b16af2d743aed1ba8df"}, +] + [[package]] name = "jinja2" version = "3.1.3" @@ -1217,6 +1292,17 @@ files = [ [package.dependencies] wcwidth = "*" +[[package]] +name = "pycparser" +version = "2.21" +description = "C parser in Python" +optional = false +python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*" +files = [ + {file = "pycparser-2.21-py2.py3-none-any.whl", hash = "sha256:8ee45429555515e1f6b185e78100aea234072576aa43ab53aefcae078162fca9"}, + {file = "pycparser-2.21.tar.gz", hash = "sha256:e644fdec12f7872f86c58ff790da456218b10f863970249516d60a5eaca77206"}, +] + [[package]] name = "pygments" version = "2.17.2" @@ -1443,7 +1529,6 @@ files = [ {file = "PyYAML-6.0.1-cp311-cp311-win_amd64.whl", hash = "sha256:bf07ee2fef7014951eeb99f56f39c9bb4af143d8aa3c21b1677805985307da34"}, {file = "PyYAML-6.0.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:855fb52b0dc35af121542a76b9a84f8d1cd886ea97c84703eaa6d88e37a2ad28"}, {file = "PyYAML-6.0.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:40df9b996c2b73138957fe23a16a4f0ba614f4c0efce1e9406a184b6d07fa3a9"}, - {file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a08c6f0fe150303c1c6b71ebcd7213c2858041a7e01975da3a99aed1e7a378ef"}, {file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6c22bec3fbe2524cde73d7ada88f6566758a8f7227bfbf93a408a9d86bcc12a0"}, {file = "PyYAML-6.0.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:8d4e9c88387b0f5c7d5f281e55304de64cf7f9c0021a3525bd3b1c542da3b0e4"}, {file = "PyYAML-6.0.1-cp312-cp312-win32.whl", hash = "sha256:d483d2cdf104e7c9fa60c544d92981f12ad66a457afae824d146093b8c294c54"}, @@ -2190,7 +2275,68 @@ files = [ docs = ["furo", "jaraco.packaging (>=9.3)", "jaraco.tidelift (>=1.4)", "rst.linker (>=1.9)", "sphinx (<7.2.5)", "sphinx (>=3.5)", "sphinx-lint"] testing = ["big-O", "jaraco.functools", "jaraco.itertools", "more-itertools", "pytest (>=6)", "pytest-black (>=0.3.7)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=2.2)", "pytest-ignore-flaky", "pytest-mypy (>=0.9.1)", "pytest-ruff"] +[[package]] +name = "zstandard" +version = "0.22.0" +description = "Zstandard bindings for Python" +optional = false +python-versions = ">=3.8" +files = [ + {file = "zstandard-0.22.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:275df437ab03f8c033b8a2c181e51716c32d831082d93ce48002a5227ec93019"}, + {file = "zstandard-0.22.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:2ac9957bc6d2403c4772c890916bf181b2653640da98f32e04b96e4d6fb3252a"}, + {file = "zstandard-0.22.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:fe3390c538f12437b859d815040763abc728955a52ca6ff9c5d4ac707c4ad98e"}, + {file = "zstandard-0.22.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:1958100b8a1cc3f27fa21071a55cb2ed32e9e5df4c3c6e661c193437f171cba2"}, + {file = "zstandard-0.22.0-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:93e1856c8313bc688d5df069e106a4bc962eef3d13372020cc6e3ebf5e045202"}, + {file = "zstandard-0.22.0-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:1a90ba9a4c9c884bb876a14be2b1d216609385efb180393df40e5172e7ecf356"}, + {file = "zstandard-0.22.0-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:3db41c5e49ef73641d5111554e1d1d3af106410a6c1fb52cf68912ba7a343a0d"}, + {file = "zstandard-0.22.0-cp310-cp310-win32.whl", hash = "sha256:d8593f8464fb64d58e8cb0b905b272d40184eac9a18d83cf8c10749c3eafcd7e"}, + {file = "zstandard-0.22.0-cp310-cp310-win_amd64.whl", hash = "sha256:f1a4b358947a65b94e2501ce3e078bbc929b039ede4679ddb0460829b12f7375"}, + {file = "zstandard-0.22.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:589402548251056878d2e7c8859286eb91bd841af117dbe4ab000e6450987e08"}, + {file = "zstandard-0.22.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:a97079b955b00b732c6f280d5023e0eefe359045e8b83b08cf0333af9ec78f26"}, + {file = "zstandard-0.22.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:445b47bc32de69d990ad0f34da0e20f535914623d1e506e74d6bc5c9dc40bb09"}, + {file = "zstandard-0.22.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:33591d59f4956c9812f8063eff2e2c0065bc02050837f152574069f5f9f17775"}, + {file = "zstandard-0.22.0-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:888196c9c8893a1e8ff5e89b8f894e7f4f0e64a5af4d8f3c410f0319128bb2f8"}, + {file = "zstandard-0.22.0-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:53866a9d8ab363271c9e80c7c2e9441814961d47f88c9bc3b248142c32141d94"}, + {file = "zstandard-0.22.0-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:4ac59d5d6910b220141c1737b79d4a5aa9e57466e7469a012ed42ce2d3995e88"}, + {file = "zstandard-0.22.0-cp311-cp311-win32.whl", hash = "sha256:2b11ea433db22e720758cba584c9d661077121fcf60ab43351950ded20283440"}, + {file = "zstandard-0.22.0-cp311-cp311-win_amd64.whl", hash = "sha256:11f0d1aab9516a497137b41e3d3ed4bbf7b2ee2abc79e5c8b010ad286d7464bd"}, + {file = "zstandard-0.22.0-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:6c25b8eb733d4e741246151d895dd0308137532737f337411160ff69ca24f93a"}, + {file = "zstandard-0.22.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:f9b2cde1cd1b2a10246dbc143ba49d942d14fb3d2b4bccf4618d475c65464912"}, + {file = "zstandard-0.22.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a88b7df61a292603e7cd662d92565d915796b094ffb3d206579aaebac6b85d5f"}, + {file = "zstandard-0.22.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:466e6ad8caefb589ed281c076deb6f0cd330e8bc13c5035854ffb9c2014b118c"}, + {file = "zstandard-0.22.0-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:a1d67d0d53d2a138f9e29d8acdabe11310c185e36f0a848efa104d4e40b808e4"}, + {file = "zstandard-0.22.0-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:39b2853efc9403927f9065cc48c9980649462acbdf81cd4f0cb773af2fd734bc"}, + {file = "zstandard-0.22.0-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:8a1b2effa96a5f019e72874969394edd393e2fbd6414a8208fea363a22803b45"}, + {file = "zstandard-0.22.0-cp312-cp312-win32.whl", hash = "sha256:88c5b4b47a8a138338a07fc94e2ba3b1535f69247670abfe422de4e0b344aae2"}, + {file = "zstandard-0.22.0-cp312-cp312-win_amd64.whl", hash = "sha256:de20a212ef3d00d609d0b22eb7cc798d5a69035e81839f549b538eff4105d01c"}, + {file = "zstandard-0.22.0-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:d75f693bb4e92c335e0645e8845e553cd09dc91616412d1d4650da835b5449df"}, + {file = "zstandard-0.22.0-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:36a47636c3de227cd765e25a21dc5dace00539b82ddd99ee36abae38178eff9e"}, + {file = "zstandard-0.22.0-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:68953dc84b244b053c0d5f137a21ae8287ecf51b20872eccf8eaac0302d3e3b0"}, + {file = "zstandard-0.22.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:2612e9bb4977381184bb2463150336d0f7e014d6bb5d4a370f9a372d21916f69"}, + {file = "zstandard-0.22.0-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:23d2b3c2b8e7e5a6cb7922f7c27d73a9a615f0a5ab5d0e03dd533c477de23004"}, + {file = "zstandard-0.22.0-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:1d43501f5f31e22baf822720d82b5547f8a08f5386a883b32584a185675c8fbf"}, + {file = "zstandard-0.22.0-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:a493d470183ee620a3df1e6e55b3e4de8143c0ba1b16f3ded83208ea8ddfd91d"}, + {file = "zstandard-0.22.0-cp38-cp38-win32.whl", hash = "sha256:7034d381789f45576ec3f1fa0e15d741828146439228dc3f7c59856c5bcd3292"}, + {file = "zstandard-0.22.0-cp38-cp38-win_amd64.whl", hash = "sha256:d8fff0f0c1d8bc5d866762ae95bd99d53282337af1be9dc0d88506b340e74b73"}, + {file = "zstandard-0.22.0-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:2fdd53b806786bd6112d97c1f1e7841e5e4daa06810ab4b284026a1a0e484c0b"}, + {file = "zstandard-0.22.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:73a1d6bd01961e9fd447162e137ed949c01bdb830dfca487c4a14e9742dccc93"}, + {file = "zstandard-0.22.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:9501f36fac6b875c124243a379267d879262480bf85b1dbda61f5ad4d01b75a3"}, + {file = "zstandard-0.22.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:48f260e4c7294ef275744210a4010f116048e0c95857befb7462e033f09442fe"}, + {file = "zstandard-0.22.0-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:959665072bd60f45c5b6b5d711f15bdefc9849dd5da9fb6c873e35f5d34d8cfb"}, + {file = "zstandard-0.22.0-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:d22fdef58976457c65e2796e6730a3ea4a254f3ba83777ecfc8592ff8d77d303"}, + {file = "zstandard-0.22.0-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:a7ccf5825fd71d4542c8ab28d4d482aace885f5ebe4b40faaa290eed8e095a4c"}, + {file = "zstandard-0.22.0-cp39-cp39-win32.whl", hash = "sha256:f058a77ef0ece4e210bb0450e68408d4223f728b109764676e1a13537d056bb0"}, + {file = "zstandard-0.22.0-cp39-cp39-win_amd64.whl", hash = "sha256:e9e9d4e2e336c529d4c435baad846a181e39a982f823f7e4495ec0b0ec8538d2"}, + {file = "zstandard-0.22.0.tar.gz", hash = "sha256:8226a33c542bcb54cd6bd0a366067b610b41713b64c9abec1bc4533d69f51e70"}, +] + +[package.dependencies] +cffi = {version = ">=1.11", markers = "platform_python_implementation == \"PyPy\""} + +[package.extras] +cffi = ["cffi (>=1.11)"] + [metadata] lock-version = "2.0" python-versions = "^3.9" -content-hash = "243f36bc61eadc19c8f7e44bc4db0739a8d757538e9ad9e952958baee43ad62a" +content-hash = "f5ddc7374fb4e7cf321c99e49ea96315b111bed0585884a2481d085ef242eafd" diff --git a/pyproject.toml b/pyproject.toml index 4d064b4c..2a6799a8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -57,6 +57,8 @@ importlib-metadata = "^7.0.1" xsdata = {extras = ["cli", "lxml", "soap"], version = ">=22.12,<25.0"} pytest-snapshot = "^0.9.0" mashumaro = "^3.10" +iso8601 = "^2.1.0" +zstandard = "^0.22.0" [tool.poetry.group.dev.dependencies] pytest = ">=7.2.2,<9.0.0" diff --git a/schema/provider-archive-listing/README.md b/schema/provider-archive-listing/README.md new file mode 100644 index 00000000..dbcf0d6d --- /dev/null +++ b/schema/provider-archive-listing/README.md @@ -0,0 +1,17 @@ +# `ProviderState` JSON Schema + +This schema governs the `listing.json` file used when providers are configured to fetch pre-computed results (by using `import_results_enabled`). The listing file is how the provider knows what results are available, where to fetch them from, and how to validate them. + +See `src/vunnel.distribution.Listing` for the root object that represents this schema. + +## Updating the schema + +Versioning the JSON schema must be done manually by copying the existing JSON schema into a new `schema-x.y.z.json` file and manually making the necessary updates (or by using an online tool such as https://www.liquid-technologies.com/online-json-to-schema-converter). + +This schema is being versioned based off of the "SchemaVer" guidelines, which slightly diverges from Semantic Versioning to tailor for the purposes of data models. + +Given a version number format `MODEL.REVISION.ADDITION`: + +- `MODEL`: increment when you make a breaking schema change which will prevent interaction with any historical data +- `REVISION`: increment when you make a schema change which may prevent interaction with some historical data +- `ADDITION`: increment when you make a schema change that is compatible with all historical data diff --git a/schema/provider-archive-listing/schema-1.0.0.json b/schema/provider-archive-listing/schema-1.0.0.json new file mode 100644 index 00000000..33392e93 --- /dev/null +++ b/schema/provider-archive-listing/schema-1.0.0.json @@ -0,0 +1,66 @@ +{ + "$schema": "http://json-schema.org/draft-04/schema#", + "type": "object", + "properties": { + "schema": { + "type": "object", + "properties": { + "version": { + "type": "string" + }, + "url": { + "type": "string" + } + }, + "required": [ + "version", + "url" + ] + }, + "provider": { + "type": "string" + }, + "available": { + "type": "object", + "properties": { + "1": { + "type": "array", + "items": [ + { + "type": "object", + "properties": { + "distribution_checksum": { + "type": "string" + }, + "built": { + "type": "string" + }, + "checksum": { + "type": "string" + }, + "url": { + "type": "string" + }, + "version": { + "type": "integer" + } + }, + "required": [ + "built", + "checksum", + "distribution_checksum", + "url", + "version" + ] + } + ] + } + } + } + }, + "required": [ + "schema", + "available", + "provider" + ] +} diff --git a/schema/provider-workspace-state/schema-1.0.2.json b/schema/provider-workspace-state/schema-1.0.2.json new file mode 100644 index 00000000..fedcae0f --- /dev/null +++ b/schema/provider-workspace-state/schema-1.0.2.json @@ -0,0 +1,80 @@ +{ + "$schema": "http://json-schema.org/draft-04/schema#", + "type": "object", + "title": "provider-workspace-state", + "description": "describes the filesystem state of a provider workspace directory", + "properties": { + "provider": { + "type": "string" + }, + "urls": { + "type": "array", + "items": [ + { + "type": "string" + } + ] + }, + "store": { + "type": "string" + }, + "timestamp": { + "type": "string" + }, + "listing": { + "type": "object", + "properties": { + "digest": { + "type": "string" + }, + "path": { + "type": "string" + }, + "algorithm": { + "type": "string" + } + }, + "required": [ + "digest", + "path", + "algorithm" + ] + }, + "version": { + "type": "integer", + "description": "version describing the result data shape + the provider processing behavior semantics" + }, + "distribution_version": { + "type": "integer", + "description": "version describing purely the result data shape" + }, + "schema": { + "type": "object", + "properties": { + "version": { + "type": "string" + }, + "url": { + "type": "string" + } + }, + "required": [ + "version", + "url" + ] + }, + "stale": { + "type": "boolean", + "description": "set to true if the workspace is stale and cannot be used for an incremental update" + } + }, + "required": [ + "provider", + "urls", + "store", + "timestamp", + "listing", + "version", + "schema" + ] +} diff --git a/src/vunnel/cli/config.py b/src/vunnel/cli/config.py index db208f7f..d499cb5f 100644 --- a/src/vunnel/cli/config.py +++ b/src/vunnel/cli/config.py @@ -2,13 +2,41 @@ import os from dataclasses import dataclass, field, fields -from typing import Any +from typing import TYPE_CHECKING, Any + +if TYPE_CHECKING: + from collections.abc import Generator import mergedeep import yaml from mashumaro.mixins.dict import DataClassDictMixin -from vunnel import providers +from vunnel import provider, providers + + +@dataclass +class ImportResults: + """ + These are the defaults for all providers. Corresponding + fields on specific providers override these values. + + If a path is "" or None, path will be set to "providers/{provider_name}/listing.json". + If an empty path is needed, specify "/". + """ + + __default_path__ = "providers/{provider_name}/listing.json" + host: str = "" + path: str = __default_path__ + enabled: bool = False + + def __post_init__(self) -> None: + if not self.path: + self.path = self.__default_path__ + + +@dataclass +class CommonProviderConfig: + import_results: ImportResults = field(default_factory=ImportResults) @dataclass @@ -26,12 +54,32 @@ class Providers: ubuntu: providers.ubuntu.Config = field(default_factory=providers.ubuntu.Config) wolfi: providers.wolfi.Config = field(default_factory=providers.wolfi.Config) + common: CommonProviderConfig = field(default_factory=CommonProviderConfig) + + def __post_init__(self) -> None: + for name in self.provider_names(): + runtime_cfg = getattr(self, name).runtime + if runtime_cfg and isinstance(runtime_cfg, provider.RuntimeConfig): + if runtime_cfg.import_results_enabled is None: + runtime_cfg.import_results_enabled = self.common.import_results.enabled + if not runtime_cfg.import_results_host: + runtime_cfg.import_results_host = self.common.import_results.host + if not runtime_cfg.import_results_path: + runtime_cfg.import_results_path = self.common.import_results.path + def get(self, name: str) -> Any | None: - for f in fields(Providers): - if self._normalize_name(f.name) == self._normalize_name(name): - return getattr(self, f.name) + for candidate in self.provider_names(): + if self._normalize_name(candidate) == self._normalize_name(name): + return getattr(self, candidate) return None + @staticmethod + def provider_names() -> Generator[str, None, None]: + for f in fields(Providers): + if f.name == "common": + continue + yield f.name + @staticmethod def _normalize_name(name: str) -> str: return name.lower().replace("-", "_") diff --git a/src/vunnel/distribution.py b/src/vunnel/distribution.py new file mode 100644 index 00000000..f2fcbf5a --- /dev/null +++ b/src/vunnel/distribution.py @@ -0,0 +1,89 @@ +from __future__ import annotations + +import datetime +import os +from dataclasses import dataclass, field +from urllib.parse import urlparse + +import iso8601 +from mashumaro.mixins.dict import DataClassDictMixin + +from vunnel import schema as schema_def + +DB_SUFFIXES = {".tar.gz", ".tar.zst"} + + +@dataclass +class ListingEntry(DataClassDictMixin): + # the date this archive was built relative to the data enclosed in the archive + built: str + + # the URL where the vunnel provider archive is located + url: str + + # the digest of the archive referenced at the URL. + # Note: all checksums are labeled with "algorithm:value" ( e.g. sha256:1234567890abcdef1234567890abcdef) + distribution_checksum: str + + # the digest of the checksums file within the archive referenced at the URL + # Note: all checksums are labeled with "algorithm:value" ( e.g. xxhash64:1234567890abcdef) + enclosed_checksum: str + + # the provider distribution version this archive was built with (different than the provider version) + distribution_version: int = 1 + + def basename(self) -> str: + basename = os.path.basename(urlparse(self.url, allow_fragments=False).path) + if not _has_suffix(basename, suffixes=DB_SUFFIXES): + msg = f"entry url is not a db archive: {basename}" + raise RuntimeError(msg) + + return basename + + def age_in_days(self, now: datetime.datetime | None = None) -> int: + if not now: + now = datetime.datetime.now(tz=datetime.timezone.utc) + return (now - iso8601.parse_date(self.built)).days + + +@dataclass +class ListingDocument(DataClassDictMixin): + # mapping of provider versions to a list of ListingEntry objects denoting archives available for download + available: dict[int, list[ListingEntry]] + + # the provider name this document is associated with + provider: str + + # the schema information for this document + schema: schema_def.Schema = field(default_factory=schema_def.ProviderListingSchema) + + @classmethod + def new(cls, provider: str) -> ListingDocument: + return cls(available={}, provider=provider) + + def latest_entry(self, schema_version: int) -> ListingEntry | None: + if schema_version not in self.available: + return None + + if not self.available[schema_version]: + return None + + return self.available[schema_version][0] + + def add(self, entry: ListingEntry) -> None: + if not self.available.get(entry.distribution_version): + self.available[entry.distribution_version] = [] + + self.available[entry.distribution_version].append(entry) + + # keep listing entries sorted by date (rfc3339 formatted entries, which iso8601 is a superset of) + self.available[entry.distribution_version].sort( + key=lambda x: iso8601.parse_date(x.built), + reverse=True, + ) + + +def _has_suffix(el: str, suffixes: set[str] | None) -> bool: + if not suffixes: + return True + return any(el.endswith(s) for s in suffixes) diff --git a/src/vunnel/provider.py b/src/vunnel/provider.py index 0af820ed..4dde68d0 100644 --- a/src/vunnel/provider.py +++ b/src/vunnel/provider.py @@ -4,11 +4,17 @@ import datetime import enum import logging +import os +import tempfile import time from dataclasses import dataclass, field -from typing import Any +from typing import Any, Optional +from urllib.parse import urlparse -from . import result, workspace +from vunnel.utils import archive, hasher, http + +from . import distribution, result, workspace +from . import schema as schema_def from .result import ResultStatePolicy @@ -62,6 +68,10 @@ class RuntimeConfig: # the format the results should be written in result_store: result.StoreStrategy = result.StoreStrategy.FLAT_FILE + import_results_host: Optional[str] = None # noqa: UP007 - breaks mashumaro + import_results_path: Optional[str] = None # noqa: UP007 - breaks mashumaro + import_results_enabled: Optional[bool] = None # noqa: UP007 - breaks mashumaro + def __post_init__(self) -> None: if not isinstance(self.existing_input, InputStatePolicy): self.existing_input = InputStatePolicy(self.existing_input) @@ -74,6 +84,17 @@ def __post_init__(self) -> None: def skip_if_exists(self) -> bool: return self.existing_input == InputStatePolicy.KEEP + def import_url(self, provider_name: str) -> str: + path = self.import_results_path + if path is None: + path = "" + path = path.format(provider_name=provider_name) + host = self.import_results_host + if host is None: + host = "" + + return f"{host.strip('/')}/{path.strip('/')}" + def disallow_existing_input_policy(cfg: RuntimeConfig) -> None: if cfg.existing_input != InputStatePolicy.KEEP: @@ -83,20 +104,41 @@ def disallow_existing_input_policy(cfg: RuntimeConfig) -> None: class Provider(abc.ABC): - # a breaking change to the semantics or values that the provider writes out should incur a version bump here. - # this is used to determine if the provider can be run on an existing workspace or if it must be cleared first - # (regardless of the existing_input and existing_result policy is). + # a breaking change to the semantics of how the provider processes results. + # + # NOTE: this value should only be changed in classes that inherit this class. Do not change the value in this class! __version__: int = 1 + # a breaking change to the schema of the results that the provider writes out should incur a version bump here. + # + # NOTE: this value should only be changed in classes that inherit this class. Do not change the value in this class! + __distribution_version__: int = 1 + def __init__(self, root: str, runtime_cfg: RuntimeConfig = RuntimeConfig()): # noqa: B008 self.logger = logging.getLogger(self.name()) self.workspace = workspace.Workspace(root, self.name(), logger=self.logger, create=False) self.urls: list[str] = [] + if runtime_cfg.import_results_enabled: + if not runtime_cfg.import_results_host: + raise RuntimeError("enabling import results requires host") + if not runtime_cfg.import_results_path: + raise RuntimeError("enabling import results requires path") + self.runtime_cfg = runtime_cfg @classmethod def version(cls) -> int: - return cls.__version__ + return cls.__version__ + (cls.distribution_version() - 1) + + @classmethod + def distribution_version(cls) -> int: + """This version represents when a breaking change is made for interpreting purely the provider results. This + tends to be an aggregation of all schema versions involved in the provider (i.e. the provider workspace state + and results shape). This is slightly different from the `version` method which is specific to the provider, + which encapsulates at least the distribution version + any other behavioral or data differences of the + provider itself (which is valid during processing, but not strictly interpreting results).""" + workspace_version = int(schema_def.ProviderStateSchema().major_version) + return (workspace_version - 1) + cls.__distribution_version__ @classmethod @abc.abstractmethod @@ -122,13 +164,21 @@ def _update(self) -> None: last_updated = None current_state = self.read_state() - if current_state: + if current_state and not current_state.stale: last_updated = current_state.timestamp - urls, count = self.update(last_updated=last_updated) + stale = False + if self.runtime_cfg.import_results_enabled: + urls, count = self._fetch_or_use_results_archive() + stale = True + else: + urls, count = self.update(last_updated=last_updated) + if count > 0: self.workspace.record_state( + stale=stale, version=self.version(), + distribution_version=self.distribution_version(), timestamp=start, urls=urls, store=self.runtime_cfg.result_store.value, @@ -136,11 +186,66 @@ def _update(self) -> None: else: self.logger.debug("skipping recording of workspace state (no new results found)") + def _fetch_or_use_results_archive(self) -> tuple[list[str], int]: + listing_doc = self._fetch_listing_document() + latest_entry = listing_doc.latest_entry(schema_version=self.distribution_version()) + if not latest_entry: + raise RuntimeError("no listing entry found") + + if self._has_newer_archive(latest_entry=latest_entry): + self._prep_workspace_from_listing_entry(entry=latest_entry) + state = self.workspace.state() + return state.urls, state.result_count(self.workspace.path) + + def _fetch_listing_document(self) -> distribution.ListingDocument: + url = self.runtime_cfg.import_url(provider_name=self.name()) + resp = http.get(url, logger=self.logger) + resp.raise_for_status() + + return distribution.ListingDocument.from_dict(resp.json()) + + def _has_newer_archive(self, latest_entry: distribution.ListingEntry) -> bool: + if not os.path.exists(self.workspace.metadata_path): + return True + + state = self.workspace.state() + if not state: + return True + + if state.distribution_version != self.distribution_version(): + return True + + if not state.listing: + return True + + # note: the checksum is the digest of the checksums file within the archive, which is in the form "algo:value" + return f"{state.listing.algorithm}:{state.listing.digest}" != latest_entry.enclosed_checksum + + def _prep_workspace_from_listing_entry(self, entry: distribution.ListingEntry) -> None: + with tempfile.TemporaryDirectory() as temp_dir: + unarchived_path = _fetch_listing_entry_archive(dest=temp_dir, entry=entry, logger=self.logger) + + temp_ws = workspace.Workspace(unarchived_path, self.name(), logger=self.logger, create=False) + + # validate that the workspace is in a good state + temp_ws.validate_checksums() + + # then switch the existing workspace to the new one... + # move the contents of the tmp dir to the workspace destination + self.workspace.replace_results(temp_workspace=temp_ws) + def run(self) -> None: self.logger.debug(f"using {self.workspace.path!r} as workspace") current_state = self.read_state() - if current_state and current_state.version != self.version(): + if self.runtime_cfg.import_results_enabled: + if current_state and current_state.distribution_version != self.distribution_version(): + self.logger.warning( + f"provider distribution version has changed from {current_state.distribution_version} to {self.distribution_version()}", + ) + self.logger.warning("clearing workspace to ensure consistency of existing results") + self.workspace.clear() + elif current_state and current_state.version != self.version(): self.logger.warning(f"provider version has changed from {current_state.version} to {self.version()}") self.logger.warning("clearing workspace to ensure consistency of existing input and results") self.workspace.clear() @@ -152,6 +257,7 @@ def run(self) -> None: self.workspace.clear_input() self.workspace.create() + try: self._update() except Exception as e: @@ -212,3 +318,28 @@ def results_writer(self, **kwargs: Any) -> result.Writer: store_strategy=self.runtime_cfg.result_store, **kwargs, ) + + +def _fetch_listing_entry_archive(dest: str, entry: distribution.ListingEntry, logger: logging.Logger) -> str: + + archive_path = os.path.join(dest, os.path.basename(urlparse(entry.url, allow_fragments=False).path)) + + # download the URL for the archive + resp = http.get(entry.url, logger=logger, stream=True) + resp.raise_for_status() + logger.debug(f"downloading {entry.url} to {archive_path}") + with open(archive_path, "wb") as fp: + for chunk in resp.iter_content(chunk_size=None): + fp.write(chunk) + + logger.debug(f"validating checksum for {archive_path}") + hashMethod = hasher.Method.parse(entry.distribution_checksum) + actual_labeled_digest = hashMethod.digest(archive_path) + if actual_labeled_digest != entry.distribution_checksum: + raise ValueError(f"archive checksum mismatch: {actual_labeled_digest} != {entry.distribution_checksum}") + + unarchive_path = os.path.join(dest, "unarchived") + logger.debug(f"unarchiving {archive_path} to {unarchive_path}") + archive.extract(archive_path, unarchive_path) + + return unarchive_path diff --git a/src/vunnel/providers/alpine/__init__.py b/src/vunnel/providers/alpine/__init__.py index 27933653..4b7e8a6e 100644 --- a/src/vunnel/providers/alpine/__init__.py +++ b/src/vunnel/providers/alpine/__init__.py @@ -24,6 +24,10 @@ class Config: class Provider(provider.Provider): + + __schema__ = schema.OSSchema() + __distribution_version__ = int(__schema__.major_version) + def __init__(self, root: str, config: Config | None = None): if not config: config = Config() @@ -31,7 +35,6 @@ def __init__(self, root: str, config: Config | None = None): self.config = config self.logger.debug(f"config: {self.config}") - self.schema = schema.OSSchema() self.parser = Parser( workspace=self.workspace, download_timeout=self.config.request_timeout, @@ -55,7 +58,7 @@ def update(self, last_updated: datetime.datetime | None) -> tuple[list[str], int writer.write( identifier=os.path.join(namespace, vuln_id), - schema=self.schema, + schema=self.__schema__, payload=record, ) diff --git a/src/vunnel/providers/amazon/__init__.py b/src/vunnel/providers/amazon/__init__.py index dc0c0933..0cb6843a 100644 --- a/src/vunnel/providers/amazon/__init__.py +++ b/src/vunnel/providers/amazon/__init__.py @@ -28,6 +28,10 @@ def __post_init__(self) -> None: class Provider(provider.Provider): + + __schema__ = schema.OSSchema() + __distribution_version__ = int(__schema__.major_version) + def __init__(self, root: str, config: Config | None = None): if not config: config = Config() @@ -36,7 +40,6 @@ def __init__(self, root: str, config: Config | None = None): self.logger.debug(f"config: {config}") - self.schema = schema.OSSchema() self.parser = Parser( workspace=self.workspace, security_advisories=config.security_advisories, @@ -57,7 +60,7 @@ def update(self, last_updated: datetime.datetime | None) -> tuple[list[str], int writer.write( identifier=os.path.join(namespace, vuln_id), - schema=self.schema, + schema=self.__schema__, payload={"Vulnerability": vuln.json()}, ) diff --git a/src/vunnel/providers/chainguard/__init__.py b/src/vunnel/providers/chainguard/__init__.py index b5bce01f..0d7436df 100644 --- a/src/vunnel/providers/chainguard/__init__.py +++ b/src/vunnel/providers/chainguard/__init__.py @@ -23,6 +23,10 @@ class Config: class Provider(provider.Provider): + + __schema__ = schema.OSSchema() + __distribution_version__ = int(__schema__.major_version) + _url = "https://packages.cgr.dev/chainguard/security.json" _namespace = "chainguard" @@ -34,7 +38,6 @@ def __init__(self, root: str, config: Config | None = None): self.logger.debug(f"config: {config}") - self.schema = schema.OSSchema() self.parser = Parser( workspace=self.workspace, url=self._url, @@ -57,7 +60,7 @@ def update(self, last_updated: datetime.datetime | None) -> tuple[list[str], int for vuln_id, record in vuln_dict.items(): writer.write( identifier=os.path.join(f"{self._namespace.lower()}:{release.lower()}", vuln_id), - schema=self.schema, + schema=self.__schema__, payload=record, ) diff --git a/src/vunnel/providers/debian/__init__.py b/src/vunnel/providers/debian/__init__.py index b42b5b0f..aefc48db 100644 --- a/src/vunnel/providers/debian/__init__.py +++ b/src/vunnel/providers/debian/__init__.py @@ -28,6 +28,10 @@ def __post_init__(self) -> None: class Provider(provider.Provider): + + __schema__ = schema.OSSchema() + __distribution_version__ = int(__schema__.major_version) + def __init__(self, root: str, config: Config | None = None): if not config: config = Config() @@ -36,7 +40,6 @@ def __init__(self, root: str, config: Config | None = None): self.logger.debug(f"config: {config}") - self.schema = schema.OSSchema() self.parser = Parser( workspace=self.workspace, download_timeout=self.config.request_timeout, @@ -60,7 +63,7 @@ def update(self, last_updated: datetime.datetime | None) -> tuple[list[str], int vuln_id = vuln_id.lower() writer.write( identifier=os.path.join(f"debian:{relno}", vuln_id), - schema=self.schema, + schema=self.__schema__, payload=record, ) diff --git a/src/vunnel/providers/github/__init__.py b/src/vunnel/providers/github/__init__.py index b8f744f0..38d928de 100644 --- a/src/vunnel/providers/github/__init__.py +++ b/src/vunnel/providers/github/__init__.py @@ -39,6 +39,10 @@ def __str__(self) -> str: class Provider(provider.Provider): + + __schema__ = schema.GithubSecurityAdvisorySchema() + __distribution_version__ = int(__schema__.major_version) + def __init__(self, root: str, config: Config | None = None): if not config: config = Config() @@ -47,7 +51,6 @@ def __init__(self, root: str, config: Config | None = None): self.logger.debug(f"config: {config}") - self.schema = schema.GithubSecurityAdvisorySchema() self.parser = Parser( workspace=self.workspace, token=config.token, @@ -79,7 +82,7 @@ def update(self, last_updated: datetime.datetime | None) -> tuple[list[str], int writer.write( identifier=os.path.join(f"{namespace}:{ecosystem}", vuln_id), - schema=self.schema, + schema=self.__schema__, payload={"Vulnerability": {}, "Advisory": dict(advisory)}, ) diff --git a/src/vunnel/providers/github/parser.py b/src/vunnel/providers/github/parser.py index 3df57695..54d73a3f 100644 --- a/src/vunnel/providers/github/parser.py +++ b/src/vunnel/providers/github/parser.py @@ -61,10 +61,6 @@ def __init__( # noqa: PLR0913 self.download_timeout = download_timeout self.api_url = api_url self.token = token - - if not self.token: - raise ValueError("Github token must be defined") - self.timestamp = None self.cursor = None if not logger: @@ -72,6 +68,9 @@ def __init__( # noqa: PLR0913 self.logger = logger def _download(self, vuln_cursor=None): + if not self.token: + raise ValueError("Github token must be defined") + """ Download the advisories from Github via the GraphQL API, using a cursor if it was defined in the class. Advisories stay in memory until diff --git a/src/vunnel/providers/mariner/__init__.py b/src/vunnel/providers/mariner/__init__.py index 1b5a014d..d3bc2641 100644 --- a/src/vunnel/providers/mariner/__init__.py +++ b/src/vunnel/providers/mariner/__init__.py @@ -24,6 +24,10 @@ class Config: class Provider(provider.Provider): + + __schema__ = schema.OSSchema() + __distribution_version__ = int(__schema__.major_version) + def __init__(self, root: str, config: Config | None = None): if not config: config = Config() @@ -31,7 +35,7 @@ def __init__(self, root: str, config: Config | None = None): self.config = config self.logger.debug(f"config: {config}") - self.schema = schema.OSSchema() + self.parser = Parser( workspace=self.workspace, allow_versions=self.config.allow_versions, @@ -48,7 +52,7 @@ def update(self, last_updated: datetime.datetime | None) -> tuple[list[str], int for namespace, vuln_id, record in self.parser.get(): writer.write( identifier=os.path.join(namespace, vuln_id), - schema=self.schema, + schema=self.__schema__, payload=record, ) return self.parser.urls, len(writer) diff --git a/src/vunnel/providers/nvd/__init__.py b/src/vunnel/providers/nvd/__init__.py index 6ce5dca7..d91fc310 100644 --- a/src/vunnel/providers/nvd/__init__.py +++ b/src/vunnel/providers/nvd/__init__.py @@ -38,7 +38,10 @@ def __str__(self) -> str: class Provider(provider.Provider): + # this is the version for the behavior of the provider (processing) not an indication of the data shape. __version__ = 2 + __schema__ = schema.NVDSchema() + __distribution_version__ = int(__schema__.major_version) def __init__(self, root: str, config: Config | None = None): if not config: @@ -64,10 +67,9 @@ def __init__(self, root: str, config: Config | None = None): "if 'overrides_enabled' is set then 'overrides_url' must be set", ) - self.schema = schema.NVDSchema() self.manager = Manager( workspace=self.workspace, - schema=self.schema, + schema=self.__schema__, download_timeout=self.config.request_timeout, api_key=self.config.api_key, logger=self.logger, @@ -87,7 +89,7 @@ def update(self, last_updated: datetime.datetime | None) -> tuple[list[str], int ): writer.write( identifier=identifier.lower(), - schema=self.schema, + schema=self.__schema__, payload=record, ) diff --git a/src/vunnel/providers/nvd/manager.py b/src/vunnel/providers/nvd/manager.py index cad58940..28cb3431 100644 --- a/src/vunnel/providers/nvd/manager.py +++ b/src/vunnel/providers/nvd/manager.py @@ -5,13 +5,14 @@ import os from typing import TYPE_CHECKING, Any -from vunnel import result, schema +from vunnel import result from vunnel.providers.nvd.api import NvdAPI from vunnel.providers.nvd.overrides import NVDOverrides if TYPE_CHECKING: from collections.abc import Generator + from vunnel import schema as schema_def from vunnel.workspace import Workspace @@ -21,7 +22,7 @@ class Manager: def __init__( # noqa: PLR0913 self, workspace: Workspace, - schema: schema.Schema, + schema: schema_def.Schema, overrides_url: str, logger: logging.Logger | None = None, download_timeout: int = 125, @@ -67,7 +68,6 @@ def get( override_remaining_cves = override_cves - cves_processed with self._sqlite_reader() as reader: for cve in override_remaining_cves: - original_record = reader.read(cve_to_id(cve)) if not original_record: self.logger.warning(f"override for {cve} not found in original data") diff --git a/src/vunnel/providers/nvd/overrides.py b/src/vunnel/providers/nvd/overrides.py index c11faa24..aa549df6 100644 --- a/src/vunnel/providers/nvd/overrides.py +++ b/src/vunnel/providers/nvd/overrides.py @@ -3,12 +3,11 @@ import glob import logging import os -import tarfile from typing import TYPE_CHECKING, Any from orjson import loads -from vunnel.utils import http +from vunnel.utils import archive, http if TYPE_CHECKING: from vunnel.workspace import Workspace @@ -51,7 +50,7 @@ def download(self) -> None: for chunk in req.iter_content(): fp.write(chunk) - untar_file(file_path, self._extract_path) + archive.extract(file_path, self._extract_path) @property def _extract_path(self) -> str: @@ -87,23 +86,3 @@ def cves(self) -> list[str]: self.__filepaths_by_cve__ = self._build_files_by_cve() return list(self.__filepaths_by_cve__.keys()) - - -def untar_file(file_path: str, extract_path: str) -> None: - with tarfile.open(file_path, "r:gz") as tar: - - def filter_path_traversal(tarinfo: tarfile.TarInfo, path: str) -> tarfile.TarInfo | None: - # we do not expect any relative file paths that would result in the clean - # path being different from the original path - # e.g. - # expected: results/results.db - # unexpected: results/../../../../etc/passwd - # we filter (drop) any such entries - - if tarinfo.name != os.path.normpath(tarinfo.name): - return None - return tarinfo - - # note: we have a filter that drops any entries that would result in a path traversal - # which is what S202 is referring to (linter isn't smart enough to understand this) - tar.extractall(path=extract_path, filter=filter_path_traversal) # noqa: S202 diff --git a/src/vunnel/providers/oracle/__init__.py b/src/vunnel/providers/oracle/__init__.py index 6e0983cf..553f176f 100644 --- a/src/vunnel/providers/oracle/__init__.py +++ b/src/vunnel/providers/oracle/__init__.py @@ -24,6 +24,10 @@ class Config: class Provider(provider.Provider): + + __schema__ = schema.OSSchema() + __distribution_version__ = int(__schema__.major_version) + def __init__(self, root: str, config: Config | None = None): if not config: config = Config() @@ -32,7 +36,6 @@ def __init__(self, root: str, config: Config | None = None): self.logger.debug(f"config: {config}") - self.schema = schema.OSSchema() self.parser = Parser( workspace=self.workspace, config=ol_config, @@ -58,7 +61,7 @@ def update(self, last_updated: datetime.datetime | None) -> tuple[list[str], int writer.write( identifier=os.path.join(namespace, vuln_id), - schema=self.schema, + schema=self.__schema__, payload=record, ) diff --git a/src/vunnel/providers/rhel/__init__.py b/src/vunnel/providers/rhel/__init__.py index a4bbc420..860b4d35 100644 --- a/src/vunnel/providers/rhel/__init__.py +++ b/src/vunnel/providers/rhel/__init__.py @@ -27,6 +27,10 @@ class Config: class Provider(provider.Provider): + + __schema__ = schema.OSSchema() + __distribution_version__ = int(__schema__.major_version) + def __init__(self, root: str, config: Config | None = None): if not config: config = Config() @@ -35,7 +39,6 @@ def __init__(self, root: str, config: Config | None = None): self.logger.debug(f"config: {config}") - self.schema = schema.OSSchema() self.parser = Parser( workspace=self.workspace, download_timeout=self.config.request_timeout, @@ -56,7 +59,7 @@ def update(self, last_updated: datetime.datetime | None) -> tuple[list[str], int vuln_id = vuln_id.lower() writer.write( identifier=os.path.join(namespace, vuln_id), - schema=self.schema, + schema=self.__schema__, payload=record, ) diff --git a/src/vunnel/providers/sles/__init__.py b/src/vunnel/providers/sles/__init__.py index a6d83454..2efc439b 100644 --- a/src/vunnel/providers/sles/__init__.py +++ b/src/vunnel/providers/sles/__init__.py @@ -28,15 +28,19 @@ def __post_init__(self) -> None: class Provider(provider.Provider): + + __schema__ = schema.OSSchema() + __distribution_version__ = int(__schema__.major_version) + def __init__(self, root: str, config: Config | None = None): if not config: config = Config() + super().__init__(root, runtime_cfg=config.runtime) self.config = config self.logger.debug(f"config: {config}") - self.schema = schema.OSSchema() self.parser = Parser( workspace=self.workspace, allow_versions=self.config.allow_versions, @@ -59,7 +63,7 @@ def update(self, last_updated: datetime.datetime | None) -> tuple[list[str], int vuln_id = vuln_id.lower() writer.write( identifier=os.path.join(namespace, vuln_id), - schema=self.schema, + schema=self.__schema__, payload=record, ) diff --git a/src/vunnel/providers/ubuntu/__init__.py b/src/vunnel/providers/ubuntu/__init__.py index 9dcb3f60..707fde10 100644 --- a/src/vunnel/providers/ubuntu/__init__.py +++ b/src/vunnel/providers/ubuntu/__init__.py @@ -29,9 +29,12 @@ class Config: class Provider(provider.Provider): - # Bumping to version 2 because upstream changed the values of some data which requires reprocessing all of the history + # this is the version for the behavior of the provider (processing) not an indication of the data shape. __version__ = 2 + __schema__ = schema.OSSchema() + __distribution_version__ = int(__schema__.major_version) + def __init__(self, root: str, config: Config | None = None): if not config: config = Config() @@ -40,7 +43,6 @@ def __init__(self, root: str, config: Config | None = None): self.logger.debug(f"config: {config}") - self.schema = schema.OSSchema() self.parser = Parser( workspace=self.workspace, logger=self.logger, @@ -62,7 +64,7 @@ def update(self, last_updated: datetime.datetime | None) -> tuple[list[str], int vuln_id = vuln_id.lower() writer.write( identifier=os.path.join(namespace, vuln_id), - schema=self.schema, + schema=self.__schema__, payload={"Vulnerability": record}, ) diff --git a/src/vunnel/providers/ubuntu/git.py b/src/vunnel/providers/ubuntu/git.py index 23dde53c..6078cb6d 100644 --- a/src/vunnel/providers/ubuntu/git.py +++ b/src/vunnel/providers/ubuntu/git.py @@ -393,12 +393,22 @@ def _exec_cmd(self, cmd, *args, **kwargs) -> bytes: try: self.logger.trace(f"running: {cmd}") cmd_list = shlex.split(cmd) - # S603 disable exaplanation: running git commands by design + + # S603 disable explanation: running git commands by design return subprocess.check_output(cmd_list, *args, **kwargs, stderr=subprocess.PIPE) # noqa: S603 - except Exception as e: - self.logger.exception(f"error executing command: {cmd}") - if isinstance(e, subprocess.CalledProcessError) and e.stderr and self._ubuntu_server_503_message in e.stderr.decode(): + except subprocess.CalledProcessError as e: + stderr = "" + if e.stderr: + stderr = e.stderr.decode() + self.logger.exception(f"error executing command: {cmd}\nstderr:{stderr}") + + if self._ubuntu_server_503_message in stderr: raise UbuntuGitServer503Error from e raise e + + except Exception as e: + self.logger.exception(f"error executing command: {cmd}") + + raise e diff --git a/src/vunnel/providers/wolfi/__init__.py b/src/vunnel/providers/wolfi/__init__.py index 65f0f085..40e54123 100644 --- a/src/vunnel/providers/wolfi/__init__.py +++ b/src/vunnel/providers/wolfi/__init__.py @@ -24,6 +24,10 @@ class Config: class Provider(provider.Provider): + + __schema__ = schema.OSSchema() + __distribution_version__ = int(__schema__.major_version) + _url = "https://packages.wolfi.dev/os/security.json" _namespace = "wolfi" @@ -35,7 +39,6 @@ def __init__(self, root: str, config: Config | None = None): self.logger.debug(f"config: {config}") - self.schema = schema.OSSchema() self.parser = Parser( workspace=self.workspace, url=self._url, @@ -58,7 +61,7 @@ def update(self, last_updated: datetime.datetime | None) -> tuple[list[str], int for vuln_id, record in vuln_dict.items(): writer.write( identifier=os.path.join(f"{self._namespace.lower()}:{release.lower()}", vuln_id), - schema=self.schema, + schema=self.__schema__, payload=record, ) diff --git a/src/vunnel/result.py b/src/vunnel/result.py index ee447795..c2bb80b3 100644 --- a/src/vunnel/result.py +++ b/src/vunnel/result.py @@ -113,12 +113,12 @@ class SQLiteStore(Store): temp_filename = "results.db.tmp" table_name = "results" - def __init__(self, *args: Any, **kwargs: Any): + def __init__(self, *args: Any, write_location: str | None = None, **kwargs: Any): super().__init__(*args, **kwargs) self.conn = None self.engine = None self.table = None - self.write_location = kwargs.get("write_location", None) + self.write_location = write_location if self.write_location: self.filename = os.path.basename(self.write_location) self.temp_filename = f"{self.filename}.tmp" diff --git a/src/vunnel/schema.py b/src/vunnel/schema.py index 018b1ad3..18c4cfda 100644 --- a/src/vunnel/schema.py +++ b/src/vunnel/schema.py @@ -1,8 +1,12 @@ from __future__ import annotations +import os.path from dataclasses import dataclass -PROVIDER_WORKSPACE_STATE_SCHEMA_VERSION = "1.0.1" +# Note: this metadata.json file currently is not allowed to have a breaking change +PROVIDER_WORKSPACE_STATE_SCHEMA_VERSION = "1.0.2" + +PROVIDER_ARCHIVE_LISTING_SCHEMA_VERSION = "1.0.0" MATCH_EXCLUSION_SCHEMA_VERSION = "1.0.0" GITHUB_SECURITY_ADVISORY_SCHEMA_VERSION = "1.0.1" MSRC_SCHEMA_VERSION = "1.0.0" @@ -16,6 +20,22 @@ class Schema: version: str url: str + @property + def major_version(self) -> str: + return self.version.split(".")[0] + + @property + def name(self) -> str: + name = self.url.removeprefix("https://raw.githubusercontent.com/anchore/vunnel/main/schema/") + return os.path.dirname(name) + + +def ProviderListingSchema(version: str = PROVIDER_ARCHIVE_LISTING_SCHEMA_VERSION) -> Schema: + return Schema( + version=version, + url=f"https://raw.githubusercontent.com/anchore/vunnel/main/schema/provider-archive-listing/schema-{version}.json", + ) + def ProviderStateSchema(version: str = PROVIDER_WORKSPACE_STATE_SCHEMA_VERSION) -> Schema: return Schema( diff --git a/src/vunnel/utils/archive.py b/src/vunnel/utils/archive.py new file mode 100644 index 00000000..bd51b0d2 --- /dev/null +++ b/src/vunnel/utils/archive.py @@ -0,0 +1,50 @@ +from __future__ import annotations + +import os +import tarfile +import tempfile +from pathlib import Path + +import zstandard + + +def extract(path: str, destination_dir: str) -> None: + if path.endswith(".tar.zst"): + return _extract_tar_zst(path, destination_dir) + + # open for reading with transparent compression (supports gz, bz2, and xz) + with tarfile.open(path, mode="r:*") as tar: + _safe_extract_tar(tar, destination_dir) + + return None + + +def _extract_tar_zst(path: str, unarchive_path: str) -> None: + archive_path = Path(path).expanduser() + dctx = zstandard.ZstdDecompressor(max_window_size=2147483648) + + with tempfile.TemporaryFile(suffix=".tar") as ofh: + with archive_path.open("rb") as ifh: + dctx.copy_stream(ifh, ofh) + ofh.seek(0) + with tarfile.open(fileobj=ofh, mode="r") as z: + return _safe_extract_tar(z, unarchive_path) + + +def _safe_extract_tar(tar: tarfile.TarFile, destination_dir: str) -> None: + # explanation of noqa: S202 + # This function is a safe wrapper around tar.extractall. + tar.extractall(destination_dir, filter=_filter_path_traversal) # noqa: S202 + + +def _filter_path_traversal(tarinfo: tarfile.TarInfo, path: str) -> tarfile.TarInfo | None: + # drop any path that would result in a write outside the destination dir + # e.g. + # allowed: './some-dir/file.txt' + # not allowed: 'some-dir/../../../../../etc/passwd' + dest_dir = Path(os.path.abspath(path)) + write_path = Path(os.path.normpath(os.path.join(dest_dir, tarinfo.name))) + + if dest_dir in write_path.parents: + return tarinfo + return None diff --git a/src/vunnel/utils/hasher.py b/src/vunnel/utils/hasher.py new file mode 100644 index 00000000..d7a3262f --- /dev/null +++ b/src/vunnel/utils/hasher.py @@ -0,0 +1,34 @@ +from __future__ import annotations + +import hashlib +from enum import Enum + +import xxhash + + +class Method(Enum): + SHA256 = "sha256" + XXH64 = "xxh64" + + def digest(self, path: str, label: bool = True, size: int = 65536) -> str: + hasher = self.hasher() + with open(path, "rb") as f: + while b := f.read(size): + hasher.update(b) + if label: + return self.value + ":" + hasher.hexdigest() + return hasher.hexdigest() + + def hasher(self): # type: ignore[no-untyped-def] + if self == self.SHA256: + return hashlib.sha256() + if self == self.XXH64: + return xxhash.xxh64() + raise ValueError(f"unknown digest label: {self.value}") + + @staticmethod + def parse(value: str) -> Method: + try: + return Method(value.lower().replace("-", "").strip().split(":")[0]) + except ValueError: + raise ValueError(f"unknown digest label: {value}") from None diff --git a/src/vunnel/utils/http.py b/src/vunnel/utils/http.py index 41c8703f..7b73fa1c 100644 --- a/src/vunnel/utils/http.py +++ b/src/vunnel/utils/http.py @@ -1,11 +1,15 @@ -import logging +from __future__ import annotations + import random import time -from collections.abc import Callable -from typing import Any, Optional +from typing import TYPE_CHECKING, Any, Optional import requests +if TYPE_CHECKING: + import logging + from collections.abc import Callable + DEFAULT_TIMEOUT = 30 @@ -15,7 +19,7 @@ def get( # noqa: PLR0913 retries: int = 5, backoff_in_seconds: int = 3, timeout: int = DEFAULT_TIMEOUT, - status_handler: Optional[Callable[[requests.Response], None]] = None, + status_handler: Optional[Callable[[requests.Response], None]] = None, # noqa: UP007 - python 3.9 **kwargs: Any, ) -> requests.Response: """ diff --git a/src/vunnel/workspace.py b/src/vunnel/workspace.py index 7151c1c6..ca44291a 100644 --- a/src/vunnel/workspace.py +++ b/src/vunnel/workspace.py @@ -6,14 +6,15 @@ import shutil import sqlite3 from dataclasses import asdict, dataclass, field -from typing import TYPE_CHECKING, Any, Optional +from typing import TYPE_CHECKING, Optional import orjson import xxhash from mashumaro.mixins.dict import DataClassDictMixin -from vunnel import schema as schemaDef +from vunnel import schema as schema_def from vunnel import utils +from vunnel.utils import hasher if TYPE_CHECKING: from collections.abc import Generator @@ -36,8 +37,10 @@ class State(DataClassDictMixin): store: str timestamp: datetime.datetime version: int = 1 + distribution_version: int = 1 listing: Optional[File] = None # noqa:UP007 # why use Optional? mashumaro does not support this on python 3.9 - schema: schemaDef.Schema = field(default_factory=schemaDef.ProviderStateSchema) + schema: schema_def.Schema = field(default_factory=schema_def.ProviderStateSchema) + stale: bool = False @staticmethod def read(root: str) -> State: @@ -112,6 +115,14 @@ def path(self) -> str: def results_path(self) -> str: return os.path.join(self.path, "results") + @property + def metadata_path(self) -> str: + return os.path.join(self.path, METADATA_FILENAME) + + @property + def checksums_path(self) -> str: + return os.path.join(self.path, CHECKSUM_LISTING_FILENAME) + @property def input_path(self) -> str: return os.path.join(self.path, "input") @@ -162,7 +173,15 @@ def clear_input(self) -> None: shutil.rmtree(self.input_path) os.makedirs(self.input_path, exist_ok=True) - def record_state(self, version: int, timestamp: datetime.datetime, urls: list[str], store: str) -> None: + def record_state( # noqa: PLR0913 + self, + version: int, + distribution_version: int, + timestamp: datetime.datetime, + urls: list[str], + store: str, + stale: bool = False, + ) -> None: try: current_state = State.read(root=self.path) except FileNotFoundError: @@ -176,31 +195,68 @@ def record_state(self, version: int, timestamp: datetime.datetime, urls: list[st self.logger.info("recording workspace state") - state = State(provider=self.name, version=version, urls=urls, store=store, timestamp=timestamp) + state = State( + provider=self.name, + version=version, + distribution_version=distribution_version, + urls=urls, + store=store, + timestamp=timestamp, + stale=stale, + ) metadata_path = state.write(self.path, self.results_path) self.logger.debug(f"wrote workspace state to {metadata_path}") - def state(self) -> State | None: + def state(self) -> State: return State.read(self.path) + def validate_checksums(self) -> None: + state = State.read(self.path) + if not state.listing: + raise RuntimeError("no file listing found in workspace state") -def digest_path_with_hasher(path: str, hasher: Any, label: str | None, size: int = 65536) -> str: - with open(path, "rb") as f: - while b := f.read(size): - hasher.update(b) + full_path = os.path.join(self.path, state.listing.path) - if label: - return label + ":" + hasher.hexdigest() - return hasher.hexdigest() + # ensure the checksums file itself is not modified + if state.listing.digest != hasher.Method.XXH64.digest(full_path, label=False): + raise RuntimeError(f"file {full_path!r} has been modified") + # validate the checksums in the listing file + with open(full_path) as f: + for line in f.readlines(): + digest, path = line.split() + full_path = os.path.join(self.path, path) + if not os.path.exists(full_path): + raise RuntimeError(f"file {full_path!r} does not exist") -# def sha256_digest(path: str, label: bool = True) -> str: -# return digest_path_with_hasher(path, hashlib.sha256(), "sha256" if label else None) + if digest != hasher.Method.XXH64.digest(full_path, label=False): + raise RuntimeError(f"file {full_path!r} has been modified") + def overlay_existing(self, source: str, move: bool = False) -> None: + self.logger.info(f"overlaying existing workspace {source!r} to {self.path!r}") -def xxhash64_digest(path: str, label: bool = True) -> str: - return digest_path_with_hasher(path, xxhash.xxh64(), "xxh64" if label else None) + for root, _, files in os.walk(source): + for file in files: + src = os.path.join(root, file) + dst = os.path.join(self.path, os.path.relpath(src, source)) + os.makedirs(os.path.dirname(dst), exist_ok=True) + + if move: + os.rename(src, dst) + else: + shutil.copy2(src, dst) + + def replace_results(self, temp_workspace: Workspace) -> None: + self.logger.info(f"replacing results in {self.path!r} with results from {temp_workspace.path!r}") + self.clear_results() + os.rename(temp_workspace.results_path, self.results_path) + self._clear_metadata() + os.rename(temp_workspace.metadata_path, self.metadata_path) + os.rename(temp_workspace.checksums_path, self.checksums_path) + state = self.state() + state.stale = True + self.record_state(state.version, state.distribution_version, state.timestamp, state.urls, state.store, True) def write_file_listing(output_file: str, path: str) -> str: @@ -214,7 +270,7 @@ def write_file_listing(output_file: str, path: str) -> str: path_relative_to_results = os.path.relpath(full_path, path) path_relative_to_workspace = os.path.join(os.path.basename(path), path_relative_to_results) - contents = f"{xxhash64_digest(full_path, label=False)} {path_relative_to_workspace}\n" + contents = f"{hasher.Method.XXH64.digest(full_path, label=False)} {path_relative_to_workspace}\n" listing_hasher.update(contents.encode("utf-8")) f.write(contents) diff --git a/tests/conftest.py b/tests/conftest.py index e875ba48..3c16d81a 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -115,12 +115,27 @@ def assert_result_snapshots(self): pytest.fail("\n".join(message_lines), pytrace=False) +@pytest.fixture() +def validate_json_schema(): + def apply(content: str): + doc = json.loads(content) + schema_url = doc.get("schema", {}).get("url") + if not schema_url: + raise ValueError("No schema URL found in document") + + schema_path = get_schema_repo_path(schema_url) + schema = load_json_schema(schema_path) + _validate_json_schema(instance=doc, schema=schema) + + return apply + + def load_json_schema(path: str) -> dict: with open(path) as f: return json.load(f) -def get_schema_repo_path(url: str): +def get_schema_repo_path(url: str) -> str: # e.g. https://raw.githubusercontent.com/anchore/vunnel/main/schema/vulnerability/nvd/schema-{version}.json relative_path = url.removeprefix("https://raw.githubusercontent.com/anchore/vunnel/main/") if relative_path == url: diff --git a/tests/unit/cli/test_cli.py b/tests/unit/cli/test_cli.py index 5a78bd11..97651ad8 100644 --- a/tests/unit/cli/test_cli.py +++ b/tests/unit/cli/test_cli.py @@ -74,6 +74,9 @@ def test_run(mocker, monkeypatch) -> None: existing_input=provider.InputStatePolicy.KEEP, existing_results=provider.InputStatePolicy.KEEP, result_store=result.StoreStrategy.SQLITE, + import_results_host="", + import_results_path="providers/{provider_name}/listing.json", + import_results_enabled=False, ), request_timeout=125, api_key="secret", @@ -125,6 +128,9 @@ def test_config(monkeypatch) -> None: runtime: existing_input: keep existing_results: delete-before-write + import_results_enabled: false + import_results_host: '' + import_results_path: providers/{provider_name}/listing.json on_error: action: fail input: keep @@ -137,6 +143,9 @@ def test_config(monkeypatch) -> None: runtime: existing_input: keep existing_results: delete-before-write + import_results_enabled: false + import_results_host: '' + import_results_path: providers/{provider_name}/listing.json on_error: action: fail input: keep @@ -153,6 +162,9 @@ def test_config(monkeypatch) -> None: runtime: existing_input: keep existing_results: delete-before-write + import_results_enabled: false + import_results_host: '' + import_results_path: providers/{provider_name}/listing.json on_error: action: fail input: keep @@ -160,6 +172,11 @@ def test_config(monkeypatch) -> None: retry_count: 3 retry_delay: 5 result_store: sqlite + common: + import_results: + enabled: false + host: '' + path: providers/{provider_name}/listing.json debian: releases: bookworm: '12' @@ -174,6 +191,9 @@ def test_config(monkeypatch) -> None: runtime: existing_input: keep existing_results: delete-before-write + import_results_enabled: false + import_results_host: '' + import_results_path: providers/{provider_name}/listing.json on_error: action: fail input: keep @@ -187,6 +207,9 @@ def test_config(monkeypatch) -> None: runtime: existing_input: keep existing_results: delete-before-write + import_results_enabled: false + import_results_host: '' + import_results_path: providers/{provider_name}/listing.json on_error: action: fail input: keep @@ -203,6 +226,9 @@ def test_config(monkeypatch) -> None: runtime: existing_input: keep existing_results: delete-before-write + import_results_enabled: false + import_results_host: '' + import_results_path: providers/{provider_name}/listing.json on_error: action: fail input: keep @@ -218,6 +244,9 @@ def test_config(monkeypatch) -> None: runtime: existing_input: keep existing_results: keep + import_results_enabled: false + import_results_host: '' + import_results_path: providers/{provider_name}/listing.json on_error: action: fail input: keep @@ -230,6 +259,9 @@ def test_config(monkeypatch) -> None: runtime: existing_input: keep existing_results: delete-before-write + import_results_enabled: false + import_results_host: '' + import_results_path: providers/{provider_name}/listing.json on_error: action: fail input: keep @@ -244,6 +276,9 @@ def test_config(monkeypatch) -> None: runtime: existing_input: keep existing_results: delete-before-write + import_results_enabled: false + import_results_host: '' + import_results_path: providers/{provider_name}/listing.json on_error: action: fail input: keep @@ -263,6 +298,9 @@ def test_config(monkeypatch) -> None: runtime: existing_input: keep existing_results: delete-before-write + import_results_enabled: false + import_results_host: '' + import_results_path: providers/{provider_name}/listing.json on_error: action: fail input: keep @@ -280,6 +318,9 @@ def test_config(monkeypatch) -> None: runtime: existing_input: keep existing_results: keep + import_results_enabled: false + import_results_host: '' + import_results_path: providers/{provider_name}/listing.json on_error: action: fail input: keep @@ -292,6 +333,9 @@ def test_config(monkeypatch) -> None: runtime: existing_input: keep existing_results: delete-before-write + import_results_enabled: false + import_results_host: '' + import_results_path: providers/{provider_name}/listing.json on_error: action: fail input: keep diff --git a/tests/unit/cli/test_config.py b/tests/unit/cli/test_config.py index 78e153d8..884b105c 100644 --- a/tests/unit/cli/test_config.py +++ b/tests/unit/cli/test_config.py @@ -1,5 +1,6 @@ from __future__ import annotations +import pytest from vunnel import provider, providers, result from vunnel.cli import config @@ -116,3 +117,88 @@ def test_full_config(helpers): ), ), ) + + +@pytest.mark.parametrize( + "common_enabled,provider_enabled,want", + [ + # default wins if provider is None + (True, None, True), + (False, None, False), + # provider always overrides default + (True, False, False), + (False, True, True), + # if everything agrees, that's the answer + (True, True, True), + (False, False, False), + ], +) +def test_import_results_enabled(common_enabled: bool, provider_enabled: bool | None, want: bool): + cfg = config.Application( + providers=config.Providers( + common=config.CommonProviderConfig( + import_results=config.ImportResults( + enabled=common_enabled, + ) + ), + nvd=providers.nvd.Config( + runtime=provider.RuntimeConfig( + import_results_enabled=provider_enabled, + ) + ), + ) + ) + assert cfg.providers.nvd.runtime.import_results_enabled == want + + +@pytest.mark.parametrize( + "common_path,provider_path,want", + [ + ("default_value", None, "default_value"), + ("default_value", "specific_value", "specific_value"), + ("default_value", "", "default_value"), + ("default_value", "/", "/"), + ], +) +def test_import_results_path(common_path: str, provider_path: str | None, want: str): + cfg = config.Application( + providers=config.Providers( + common=config.CommonProviderConfig( + import_results=config.ImportResults( + path=common_path, + ) + ), + nvd=providers.nvd.Config( + runtime=provider.RuntimeConfig( + import_results_path=provider_path, + ) + ), + ) + ) + assert cfg.providers.nvd.runtime.import_results_path == want + + +@pytest.mark.parametrize( + "common_host,provider_host,want", + [ + ("default-host", None, "default-host"), + ("default-host", "specific-host", "specific-host"), + ("default-host", "", "default-host"), # TODO: should this be "default-host"? + ], +) +def test_import_results_host(common_host: str, provider_host: str | None, want: str): + cfg = config.Application( + providers=config.Providers( + common=config.CommonProviderConfig( + import_results=config.ImportResults( + host=common_host, + ) + ), + nvd=providers.nvd.Config( + runtime=provider.RuntimeConfig( + import_results_host=provider_host, + ) + ), + ) + ) + assert cfg.providers.nvd.runtime.import_results_host == want diff --git a/tests/unit/providers/nvd/test_overrides.py b/tests/unit/providers/nvd/test_overrides.py index 6eafabcb..091fc6be 100644 --- a/tests/unit/providers/nvd/test_overrides.py +++ b/tests/unit/providers/nvd/test_overrides.py @@ -59,13 +59,3 @@ def test_overrides_enabled(mock_requests, overrides_tar, tmpdir): assert subject.cve("CVE-2011-0022") is not None assert subject.cves() == ["CVE-2011-0022"] - - -def test_untar_file(overrides_tar, tmpdir): - overrides.untar_file(overrides_tar, tmpdir) - assert tmpdir.join("data/CVE-2011-0022.json").check(file=True) - - -def test_untar_file_path_traversal(path_traversal_tar, tmpdir): - overrides.untar_file(path_traversal_tar, tmpdir.join("somewhere", "else")) - assert tmpdir.join("somewhere/else/CVE-2011-0022.json").check(file=False) diff --git a/tests/unit/test_distribution.py b/tests/unit/test_distribution.py new file mode 100644 index 00000000..92b0d937 --- /dev/null +++ b/tests/unit/test_distribution.py @@ -0,0 +1,149 @@ +import json + +import pytest +from datetime import datetime, timezone +from src.vunnel.distribution import ListingEntry, ListingDocument + + +class TestListingEntry: + + @pytest.fixture + def entry1(self): + return ListingEntry( + built="2022-01-01T00:00:00Z", + distribution_version=1, + url="http://example.com/archive.tar.gz", + distribution_checksum="sha256:1234567890abcdef1234567890abcdef", + enclosed_checksum="xxhash64:1234567890abcdef", + ) + + @pytest.fixture + def entry2(self): + return ListingEntry( + built="2022-01-01T00:00:00Z", + distribution_version=1, + url="http://example.com/archive.tar.zst", + distribution_checksum="sha256:abcdef1234567890abcdef1234567890", + enclosed_checksum="xxhash64:abcdef1234567890", + ) + + def test_basename(self, entry1, entry2): + assert entry1.basename() == "archive.tar.gz" + assert entry2.basename() == "archive.tar.zst" + + def test_basename_invalid_url(self, entry1): + with pytest.raises(RuntimeError): + entry1.url = "http://example.com/archive.tar.unsupported" + entry1.basename() + + def test_age_in_days(self, entry1): + now = datetime.now(tz=timezone.utc) + assert entry1.age_in_days(now) == (now - datetime(2022, 1, 1, tzinfo=timezone.utc)).days + + def test_age_in_days_no_now(self, entry1): + assert entry1.age_in_days() == (datetime.now(tz=timezone.utc) - datetime(2022, 1, 1, tzinfo=timezone.utc)).days + + +class TestListingDocument: + @pytest.fixture + def document(self): + entries = [ + ListingEntry( + built="2022-01-01T00:00:00Z", + distribution_version=1, + url="http://example.com/archive1.tar.gz", + distribution_checksum="sha256:1234567890abcdef1234567890abcdef", + enclosed_checksum="xxhash64:1234567890abcdef", + ), + ListingEntry( + built="2022-01-02T00:00:00Z", + distribution_version=1, + url="http://example.com/archive2.tar.gz", + distribution_checksum="sha256:abcdef1234567890abcdef1234567890", + enclosed_checksum="xxhash64:abcdef1234567890", + ), + ] + return ListingDocument(available={1: entries}, provider="test_provider") + + @pytest.fixture + def built_document(self): + subject = ListingDocument.new(provider="nvd") + + subject.add( + ListingEntry( + built=datetime(2017, 11, 28, 23, 55, 59, 342380).strftime("%Y-%m-%dT%H:%M:%S.%f%z"), + distribution_version=3, + url="https://b-place.com/something-1.tar.gz", + distribution_checksum="sha256:123456789", + enclosed_checksum="xxh64:123456789", + ) + ) + + subject.add( + ListingEntry( + built=datetime(2016, 11, 28, 23, 55, 59, 342380).strftime("%Y-%m-%dT%H:%M:%S.%f%z"), + distribution_version=3, + url="https://a-place.com/something.tar.gz", + distribution_checksum="sha256:123456789", + enclosed_checksum="xxh64:123456789", + ) + ) + + subject.add( + ListingEntry( + built=datetime(2019, 11, 28, 23, 55, 59, 342380).strftime("%Y-%m-%dT%H:%M:%S.%f%z"), + distribution_version=3, + url="https://c-place.com/something.tar.gz", + distribution_checksum="sha256:123456789", + enclosed_checksum="xxh64:123456789", + ) + ) + + subject.add( + ListingEntry( + built=datetime(2017, 11, 28, 23, 55, 59, 342380).strftime("%Y-%m-%dT%H:%M:%S.%f%z"), + distribution_version=4, + url="https://b-place.com/something-1.tar.zst", + distribution_checksum="sha256:123456789", + enclosed_checksum="xxh64:123456789", + ) + ) + + subject.add( + ListingEntry( + built=datetime(2016, 11, 28, 23, 55, 59, 342380).strftime("%Y-%m-%dT%H:%M:%S.%f%z"), + distribution_version=4, + url="https://a-place.com/something.tar.zst", + distribution_checksum="sha256:123456789", + enclosed_checksum="xxh64:123456789", + ) + ) + + subject.add( + ListingEntry( + built=datetime(2019, 11, 28, 23, 55, 59, 342380).strftime("%Y-%m-%dT%H:%M:%S.%f%z"), + distribution_version=4, + url="https://c-place.com/something.tar.zst", + distribution_checksum="sha256:123456789", + enclosed_checksum="xxh64:123456789", + ) + ) + + return subject + + def test_latest_entry(self, document, built_document): + latest_entry = document.latest_entry(1) + assert latest_entry is not None + assert latest_entry.distribution_version == 1 + assert latest_entry.url == "http://example.com/archive1.tar.gz" + + assert "https://c-place.com/something.tar.gz" == built_document.latest_entry(3).url + assert "https://c-place.com/something.tar.zst" == built_document.latest_entry(4).url + + def test_latest_entry_no_entries(self, document): + latest_entry = document.latest_entry(2) + assert latest_entry is None + + def test_schema(self, built_document, validate_json_schema): + content = json.dumps(built_document.to_dict()) + validate_json_schema(content) diff --git a/tests/unit/test_provider.py b/tests/unit/test_provider.py index ed29a62c..7b698a3f 100644 --- a/tests/unit/test_provider.py +++ b/tests/unit/test_provider.py @@ -1,11 +1,21 @@ from __future__ import annotations -import json import os +import tarfile +import json +import logging +import random +import string +import hashlib +import shutil +from contextlib import contextmanager from unittest.mock import MagicMock, patch import pytest -from vunnel import provider, result, schema, workspace +import zstandard + +from vunnel import provider, result, schema, workspace, distribution +from vunnel.utils import hasher, archive def assert_path(path: str, exists: bool = True): @@ -31,6 +41,9 @@ def input_file(self): def assert_state_file(self, exists: bool = True): assert_path(os.path.join(self.workspace.path, "state.json"), exists) + def _fetch_or_use_results_archive(self): + return self.update() + def update(self, *args, **kwargs): self.count += 1 if self.count <= self.errors: @@ -52,11 +65,16 @@ def update(self, *args, **kwargs): return ["http://localhost:8000/dummy-input-1.json"], 1 +def get_random_string(length=10): + characters = string.ascii_letters + string.digits + return "".join(random.choice(characters) for _ in range(length)) + + @pytest.fixture() def dummy_provider(tmpdir): - def apply(populate=True, use_dir=None, **kwargs): + def apply(populate=True, use_dir=None, **kwargs) -> provider.Provider: if not use_dir: - use_dir = tmpdir + use_dir = tmpdir + get_random_string() # create a dummy provider subject = DummyProvider(root=use_dir, **kwargs) @@ -114,6 +132,76 @@ def test_clear_existing_state_from_mismatched_versions(dummy_provider): assert subject.workspace._clear_metadata.call_count == 1 +def test_clear_existing_state_from_mismatched_distribution_versions(dummy_provider): + policy = provider.RuntimeConfig( + existing_input=provider.InputStatePolicy.KEEP, + existing_results=provider.ResultStatePolicy.KEEP, + import_results_enabled=True, + import_results_path="{provider_name}/listing.json", + import_results_host="http://localhost", + ) + + subject = dummy_provider(populate=True, runtime_cfg=policy) + + # track calls without affecting behavior (get mock tracking abilities without mocking) + subject.workspace.clear_input = MagicMock(side_effect=subject.workspace.clear_input) + subject.workspace.clear_results = MagicMock(side_effect=subject.workspace.clear_results) + subject.workspace._clear_metadata = MagicMock(side_effect=subject.workspace._clear_metadata) + subject.distribution_version = MagicMock(return_value=2) + + subject.run() + + assert subject.workspace.clear_input.call_count == 1 + assert subject.workspace.clear_results.call_count == 1 + assert subject.workspace._clear_metadata.call_count == 1 + + +def test_mismatched_distribution_versions_has_no_effect_when_import_disabled(dummy_provider): + policy = provider.RuntimeConfig( + existing_input=provider.InputStatePolicy.KEEP, + existing_results=provider.ResultStatePolicy.KEEP, + import_results_enabled=False, + ) + + subject = dummy_provider(populate=True, runtime_cfg=policy) + + # track calls without affecting behavior (get mock tracking abilities without mocking) + subject.workspace.clear_input = MagicMock(side_effect=subject.workspace.clear_input) + subject.workspace.clear_results = MagicMock(side_effect=subject.workspace.clear_results) + subject.workspace._clear_metadata = MagicMock(side_effect=subject.workspace._clear_metadata) + subject.distribution_version = MagicMock(return_value=2) + + subject.run() + + assert subject.workspace.clear_input.call_count == 0 + assert subject.workspace.clear_results.call_count == 0 + assert subject.workspace._clear_metadata.call_count == 0 + + +def test_mismatched_versions_has_no_effect_when_import_enabled(dummy_provider): + policy = provider.RuntimeConfig( + existing_input=provider.InputStatePolicy.KEEP, + existing_results=provider.ResultStatePolicy.KEEP, + import_results_enabled=True, + import_results_path="{provider_name}/listing.json", + import_results_host="http://localhost", + ) + + subject = dummy_provider(populate=True, runtime_cfg=policy) + + # track calls without affecting behavior (get mock tracking abilities without mocking) + subject.workspace.clear_input = MagicMock(side_effect=subject.workspace.clear_input) + subject.workspace.clear_results = MagicMock(side_effect=subject.workspace.clear_results) + subject.workspace._clear_metadata = MagicMock(side_effect=subject.workspace._clear_metadata) + subject.version = MagicMock(return_value=2) + + subject.run() + + assert subject.workspace.clear_input.call_count == 0 + assert subject.workspace.clear_results.call_count == 0 + assert subject.workspace._clear_metadata.call_count == 0 + + def test_keep_existing_state_from_matching_versions(dummy_provider): policy = provider.RuntimeConfig( existing_input=provider.InputStatePolicy.KEEP, @@ -304,6 +392,377 @@ def test_retry_on_failure_max_attempts(dummy_provider, dummy_file): subject.assert_state_file(exists=False) +def listing_tar_entry( + tmpdir: str, + port: int, + dummy_provider_factory, + archive_name: str | None = None, + archive_checksum: str | None = None, + results_checksum: str | None = None, +) -> tuple[str, str, distribution.ListingEntry, str]: + if not archive_name: + archive_name = "results.tar.gz" + + policy = provider.RuntimeConfig( + result_store=result.StoreStrategy.SQLITE, + existing_input=provider.InputStatePolicy.KEEP, + existing_results=provider.ResultStatePolicy.KEEP, + ) + + subject = dummy_provider_factory(populate=True, runtime_cfg=policy) + subject.run() + + dest = os.path.join(tmpdir, subject.name()) + os.makedirs(dest, exist_ok=True) + + # tar up the subject.workspace.path into a tarfile + shutil.rmtree(subject.workspace.input_path, ignore_errors=True) + tarfile_path = os.path.join(dest, archive_name) + with _get_tar_writer_obj(tarfile_path) as tar: + tar.add(subject.workspace.path, arcname=subject.name()) + + if not archive_checksum: + archive_checksum = hasher.Method.XXH64.digest(tarfile_path, label=True) + + if not results_checksum: + workspace_state: workspace.State = subject.workspace.state() + results_checksum = workspace_state.listing.digest + + listing_entry = distribution.ListingEntry( + built="2021-01-01T00:00:00Z", + distribution_version=1, + url=f"http://localhost:{port}/{subject.name()}/{archive_name}", + distribution_checksum=archive_checksum, + enclosed_checksum=results_checksum, + ) + + listing_doc = distribution.ListingDocument(available={"1": [listing_entry]}, provider=subject.name()) + listing_url = f"http://localhost:{port}/{subject.name()}/listing.json" + + # write out the listing document + listing_path = os.path.join(dest, "listing.json") + with open(listing_path, "w") as f: + json.dump(listing_doc.to_dict(), f) + + return tarfile_path, listing_url, listing_entry, listing_path + + +def _get_tar_writer_obj(tarfile_path): + if tarfile_path.endswith(".tar.zst"): + return _get_tar_zst_writer_obj(tarfile_path) + + elif tarfile_path.endswith(".tar"): + return tarfile.open(tarfile_path, "w:") + + if tarfile_path.endswith(".tar.gz"): + return tarfile.open(tarfile_path, "w:gz") + + raise ValueError("unsupported tarfile extension") + + +@contextmanager +def _get_tar_zst_writer_obj(tarfile_path): + fileobj = zstandard.ZstdCompressor().stream_writer(open(tarfile_path, "wb")) + tf = None + try: + tf = tarfile.open(tarfile_path, "w|", fileobj=fileobj) + yield tf + finally: + if tf: + tf.close() + fileobj.close() + + +@pytest.mark.parametrize( + "archive_name,archive_checksum,raises_type", + ( + ("results.tar.gz", None, None), + ("results.tar.zst", None, None), + ("results.tar", None, None), + ("results.tar.gz", "sha256:1234567890abcdef", ValueError), + ), +) +@patch("requests.get") +def test_fetch_listing_entry_archive(mock_requests, tmpdir, dummy_provider, archive_name, archive_checksum, raises_type): + port = 8080 + + tarfile_path, listing_url, listing_entry, listing_path = listing_tar_entry( + tmpdir, port, dummy_provider_factory=dummy_provider, archive_name=archive_name, archive_checksum=archive_checksum + ) + + with open(tarfile_path, "rb") as f: + content = f.read() + + mock_requests.return_value.status_code = 200 + mock_requests.return_value.iter_content.return_value = [content] + + logger = logging.getLogger("test") + + if not raises_type: + unarchived_dir = provider._fetch_listing_entry_archive(entry=listing_entry, dest=tmpdir, logger=logger) + + # assert the unarchived_dir path contents is the same as the tarfile contents + compare_dir_tar(tmpdir, unarchived_dir, tarfile_path) + + args, _ = mock_requests.call_args + assert args == (listing_entry.url,) + else: + with pytest.raises(raises_type): + provider._fetch_listing_entry_archive(entry=listing_entry, dest=tmpdir, logger=logger) + + +def checksum(file_path): + hash_md5 = hashlib.md5() + with open(file_path, "rb") as f: + for chunk in iter(lambda: f.read(4096), b""): + hash_md5.update(chunk) + return hash_md5.hexdigest() + + +def compare_dir_tar(tmpdir, dir_path, tar_path): + temp_dir = os.path.join(tmpdir, "extracted") + + archive.extract(tar_path, temp_dir) + + dir_checksums = {} + tar_checksums = {} + + # walk through directory and calculate checksums + for root, dirs, files in os.walk(dir_path): + for name in files: + rel_dir = os.path.relpath(root, dir_path) + rel_file = os.path.join(rel_dir, name) + file_path = os.path.join(root, name) + dir_checksums[rel_file] = checksum(file_path) + + # walk through extracted tar contents and calculate checksums + for root, dirs, files in os.walk(temp_dir): + for name in files: + rel_dir = os.path.relpath(root, temp_dir) + rel_file = os.path.join(rel_dir, name) + file_path = os.path.join(root, name) + tar_checksums[rel_file] = checksum(file_path) + + # cleanup temporary directory + for root, dirs, files in os.walk(temp_dir, topdown=False): + for name in files: + os.remove(os.path.join(root, name)) + for name in dirs: + os.rmdir(os.path.join(root, name)) + os.rmdir(temp_dir) + + assert dir_checksums == tar_checksums, "Directory and TAR file contents differ" + + +@patch("requests.get") +def test_fetch_listing_document(mock_requests, tmpdir, dummy_provider): + port = 8080 + + policy = provider.RuntimeConfig( + result_store=result.StoreStrategy.SQLITE, + existing_input=provider.InputStatePolicy.KEEP, + existing_results=provider.ResultStatePolicy.KEEP, + import_results_enabled=True, + import_results_path="{provider_name}/listing.json", + import_results_host="http://localhost", + ) + + tarfile_path, listing_url, entry, listing_path = listing_tar_entry(tmpdir, port, dummy_provider_factory=dummy_provider) + + subject = dummy_provider(populate=False, runtime_cfg=policy) + mock_requests.return_value.status_code = 200 + mock_requests.return_value.json.return_value = json.loads(open(listing_path, "r").read()) + + doc = subject._fetch_listing_document() + + args, _ = mock_requests.call_args + assert args == ("http://localhost/dummy/listing.json",) + + +@patch("requests.get") +def test_prep_workspace_from_listing_entry(mock_requests, tmpdir, dummy_provider): + provider = dummy_provider(populate=False) + tarfile_path, listing_url, entry, listing_path = listing_tar_entry( + tmpdir=tmpdir, port=8080, dummy_provider_factory=dummy_provider + ) + + with open(tarfile_path, "rb") as f: + content = f.read() + mock_requests.return_value.status_code = 200 + mock_requests.return_value.iter_content.return_value = [content] + + with tarfile.open(tarfile_path, "r:gz") as tar: + list_of_files = tar.getnames() + + provider._prep_workspace_from_listing_entry(entry=entry) + + state = provider.workspace.state() + + assert state.stale + + provider.workspace.validate_checksums() + + for file in list_of_files: + assert os.path.exists(os.path.join(provider.workspace.path, "..", file)) + + # what this does is: + # 1. it receives a listing entry and makes a call to fetch and unarchive it + # 2. it creates a temp workspace around the unarchive path + # 3. it validates the checksums on the temp workspace + # 4. it overlays it's current workspace with the temp workspace + + +@patch("requests.get") +def test_fetch_or_use_results_archive(mock_requests, tmpdir, dummy_provider): + port = 8080 + + tarfile_path, listing_url, entry, listing_path = listing_tar_entry( + tmpdir=tmpdir, port=port, dummy_provider_factory=dummy_provider + ) + # fetch the tar file + tarfile_bytes = None + with open(tarfile_path, "rb") as f: + tarfile_bytes = f.read() + + policy = provider.RuntimeConfig( + result_store=result.StoreStrategy.SQLITE, + existing_input=provider.InputStatePolicy.KEEP, + existing_results=provider.ResultStatePolicy.KEEP, + import_results_enabled=True, + import_results_path="{provider_name}/listing.json", + import_results_host=f"http://localhost:{port}", + ) + + subject = dummy_provider(populate=False, runtime_cfg=policy) + + def handle_get_requests(url, *args, **kwargs): + listing_response = MagicMock() + listing_response.status_code = 200 + listing_response.raise_for_status.side_effect = None + listing_response.json.return_value = json.loads(open(listing_path, "r").read()) + + entry_response = MagicMock() + entry_response.status_code = 200 + entry_response.raise_for_status.side_effect = None + entry_response.iter_content.return_value = [tarfile_bytes] + + not_found_response = MagicMock() + not_found_response.status_code = 404 + not_found_response.raise_for_status.side_effect = Exception("404") + + if url == f"http://localhost:{port}/{subject.name()}/listing.json": + return listing_response + elif url == entry.url: + return entry_response + else: + return not_found_response + + mock_requests.side_effect = handle_get_requests + + urls, count = subject._fetch_or_use_results_archive() + assert urls == ["http://localhost:8000/dummy-input-1.json"] + assert count == 1 + + +@pytest.mark.parametrize( + "enabled,host,path,error_message", + [ + (True, "", "", "enablign import results requires host"), + (True, "http://example.com", "", "enablign import results requires path"), + (False, "", "", None), + ], +) +def test_validate_import_results_config(enabled: bool, host: str, path: str, error_message: str | None, dummy_provider): + runtime_config = provider.RuntimeConfig() + runtime_config.import_results_enabled = enabled + runtime_config.import_results_host = host + runtime_config.import_results_path = path + if error_message: + with pytest.raises(RuntimeError) as e: + dummy_provider(runtime_cfg=runtime_config) + assert error_message == str(e) + else: + dummy_provider(runtime_cfg=runtime_config) + + +def test_has_newer_archive_distribution_version_mismatch_true(dummy_provider): + subject = dummy_provider() + distribution_version = subject.distribution_version() + mismatched_distribution_version = distribution_version + 1 + existing_state = subject.workspace.state() + subject.workspace.record_state( + version=subject.version(), + distribution_version=mismatched_distribution_version, + timestamp=existing_state.timestamp, + store=result.StoreStrategy.FLAT_FILE.value, + urls=existing_state.urls, + ) + entry = distribution.ListingEntry( + enclosed_checksum=f"{existing_state.listing.algorithm}:{existing_state.listing.digest}", + distribution_checksum="xxh64:12341234aedf", + distribution_version=subject.distribution_version(), + built="2024-03-25T13:36:36Z", + url="http://example.com/some-example", + ) + assert subject._has_newer_archive(latest_entry=entry) + + +def test_has_newer_archive_version_mismatch_has_no_effect(dummy_provider): + subject = dummy_provider() + version = subject.version() + mismatched_version = version + 1 + existing_state = subject.workspace.state() + subject.workspace.record_state( + version=mismatched_version, + distribution_version=subject.distribution_version(), + timestamp=existing_state.timestamp, + store=result.StoreStrategy.FLAT_FILE.value, + urls=existing_state.urls, + ) + entry = distribution.ListingEntry( + enclosed_checksum=f"{existing_state.listing.algorithm}:{existing_state.listing.digest}", + distribution_checksum="xxh64:12341234aedf", + distribution_version=subject.distribution_version(), + built="2024-03-25T13:36:36Z", + url="http://example.com/some-example", + ) + assert not subject._has_newer_archive(latest_entry=entry) + + +def test_has_newer_archive_false(dummy_provider): + subject = dummy_provider(populate=True) + state = subject.workspace.state() + entry = distribution.ListingEntry( + enclosed_checksum=f"{state.listing.algorithm}:{state.listing.digest}", + distribution_checksum="xxh64:12341234aedf", + distribution_version=subject.distribution_version(), + built="2024-03-25T13:36:36Z", + url="http://example.com/some-example", + ) + assert not subject._has_newer_archive(entry) + + +@pytest.mark.parametrize( + "host,path,want", + [ + ("http://example.com/", "{provider_name}/listing.json", "http://example.com/test-provider/listing.json"), + # extra leading and trailing slashes are handled correctly: + ("http://example.com////", "///{provider_name}/listing.json", "http://example.com/test-provider/listing.json"), + ("http://example.com/", "specific-path/listing.json", "http://example.com/specific-path/listing.json"), + ("http://sub.example.com/", "v1/{provider_name}/listing.json", "http://sub.example.com/v1/test-provider/listing.json"), + ("http://sub.example.com/v1", "/{provider_name}/listing.json", "http://sub.example.com/v1/test-provider/listing.json"), + ], +) +def test_import_url(host, path, want, dummy_provider): + subject = provider.RuntimeConfig( + import_results_path=path, + import_results_enabled=True, + import_results_host=host, + ) + got = subject.import_url(provider_name="test-provider") + assert got == want + + def assert_dummy_workspace_state(ws): current_state = workspace.State.read(root=ws.path) @@ -321,3 +780,110 @@ def assert_dummy_workspace_state(ws): ) assert current_state == expected_state + + +@patch("vunnel.provider.schema_def.ProviderStateSchema") +def test_version(mock_schema): + mock_schema.return_value = MagicMock(major_version=3) + + class Impl(provider.Provider): + __distribution_version__ = 4 + __version__ = 2 + + def __init__(self): + # intentionally do not call super().__init__() + pass + + def name(self): + return "dummy" + + def update(self): + return None + + # distribution version = 4 + (3-1) = 6 + # provider version = __version__ + (distribution version - 1) = 2 + (6-1) = 7 + assert Impl().version() == 7 + + +@patch("vunnel.provider.schema_def.ProviderStateSchema") +def test_distribution_version(mock_schema): + mock_schema.return_value = MagicMock(major_version=1) + + class Impl(provider.Provider): + __distribution_version__ = 4 + + def __init__(self): + # intentionally do not call super().__init__() + pass + + def name(self): + return "dummy" + + def update(self): + return None + + assert Impl().distribution_version() == 4 + + # a breaking change to the workspace schema should reflect a change in distribution version + mock_schema.return_value = MagicMock(major_version=2) + assert Impl().distribution_version() == 5 + + # a change in the provider's distribution version should reflect a change in distribution version + Impl.__distribution_version__ = 6 + assert Impl().distribution_version() == 7 + + +def test_provider_versions(tmpdir): + from vunnel import providers + + # WARNING: changing the values of these versions has operational impact! Do not change them without + # understanding the implications! + expected = { + "alpine": 1, + "amazon": 1, + "chainguard": 1, + "debian": 1, + "github": 1, + "mariner": 1, + "nvd": 2, + "oracle": 1, + "rhel": 1, + "sles": 1, + "ubuntu": 2, + "wolfi": 1, + } + + got = {} + for name in providers.names(): + p = providers.create(name, tmpdir) + got[p.name()] = p.version() + + assert expected == got, "WARNING! CHANGES TO VERSIONS HAVE OPERATIONAL IMPACT!" + + +def test_provider_distribution_versions(tmpdir): + from vunnel import providers + + # WARNING: changing the values of these distributions has operational impact! Do not change them without + # understanding the implications! + expected = { + "alpine": 1, + "amazon": 1, + "chainguard": 1, + "debian": 1, + "github": 1, + "mariner": 1, + "nvd": 1, + "oracle": 1, + "rhel": 1, + "sles": 1, + "ubuntu": 1, + "wolfi": 1, + } + + got = {} + for name in providers.names(): + p = providers.create(name, tmpdir) + got[p.name()] = p.distribution_version() + + assert expected == got, "WARNING! CHANGES TO DISTRIBUTION VERSIONS HAVE OPERATIONAL IMPACT!" diff --git a/tests/unit/test_result.py b/tests/unit/test_result.py index a7a3005a..b2748ae1 100644 --- a/tests/unit/test_result.py +++ b/tests/unit/test_result.py @@ -29,7 +29,7 @@ def flat_file_existing_workspace(ws: workspace.Workspace) -> workspace.Workspace payload={"Vulnerability": {"dummy": "result-2"}}, ) - ws.record_state(timestamp=datetime.datetime.now(), urls=[], store=store_strategy, version=1) + ws.record_state(timestamp=datetime.datetime.now(), urls=[], store=store_strategy, version=1, distribution_version=1) return ws @@ -51,7 +51,7 @@ def sqlite_existing_workspace(ws: workspace.Workspace) -> workspace.Workspace: payload={"Vulnerability": {"dummy": "result-2"}}, ) - ws.record_state(timestamp=datetime.datetime.now(), urls=[], store=store_strategy, version=1) + ws.record_state(timestamp=datetime.datetime.now(), urls=[], store=store_strategy, version=1, distribution_version=1) return ws @@ -62,7 +62,7 @@ def sqlite_existing_workspace_with_partial_results(sqlite_existing_workspace: wo os.path.join(sqlite_existing_workspace.results_path, "results.db.tmp"), ) sqlite_existing_workspace.record_state( - timestamp=datetime.datetime.now(), urls=[], store=result.StoreStrategy.SQLITE, version=1 + timestamp=datetime.datetime.now(), urls=[], store=result.StoreStrategy.SQLITE, version=1, distribution_version=1 ) assert len(list(sqlite_existing_workspace.state().result_files(sqlite_existing_workspace.path))) == 2 return sqlite_existing_workspace @@ -118,7 +118,7 @@ def test_result_writer_flat_file( payload={"Vulnerability": {"dummy": "result-4"}}, ) - ws.record_state(timestamp=datetime.datetime.now(), urls=[], store=store_strategy, version=1) + ws.record_state(timestamp=datetime.datetime.now(), urls=[], store=store_strategy, version=1, distribution_version=1) state = ws.state() @@ -168,7 +168,7 @@ def test_result_writer_sqlite( payload={"Vulnerability": {"dummy": "result-4"}}, ) - ws.record_state(timestamp=datetime.datetime.now(), urls=[], store=store_strategy, version=1) + ws.record_state(timestamp=datetime.datetime.now(), urls=[], store=store_strategy, version=1, distribution_version=1) state = ws.state() # note: since the hash changes on each test run, just confirm the result file count, not the extra metadata @@ -218,7 +218,7 @@ def test_result_writer_sqlite_with_partial_result( payload={"Vulnerability": {"dummy": "result-4"}}, ) - ws.record_state(timestamp=datetime.datetime.now(), urls=[], store=store_strategy, version=1) + ws.record_state(timestamp=datetime.datetime.now(), urls=[], store=store_strategy, version=1, distribution_version=1) state = ws.state() # note: since the hash changes on each test run, just confirm the result file count, not the extra metadata diff --git a/tests/unit/test_schema.py b/tests/unit/test_schema.py new file mode 100644 index 00000000..db430ffc --- /dev/null +++ b/tests/unit/test_schema.py @@ -0,0 +1,9 @@ +from vunnel import schema as schema_def + + +def test_provider_workspace_schema_v1(): + # it is vital that we do not make any breaking changes to the provider workspace state schema + # until there is a mechanism to deal with the state version detection, migration, and possibly supporting + # multiple version implementations in the codebase + assert schema_def.PROVIDER_WORKSPACE_STATE_SCHEMA_VERSION.startswith("1.") + assert schema_def.ProviderStateSchema().version.startswith("1.") diff --git a/tests/unit/test_workspace.py b/tests/unit/test_workspace.py index c997feaa..52f97239 100644 --- a/tests/unit/test_workspace.py +++ b/tests/unit/test_workspace.py @@ -3,6 +3,8 @@ import datetime import os +import pytest + from vunnel import result, schema, workspace @@ -46,7 +48,7 @@ def test_clear_results(tmpdir, dummy_file): urls = ["http://localhost:8000/dummy-input-1.json"] store = result.StoreStrategy.FLAT_FILE - ws.record_state(urls=urls, store=store.value, timestamp=datetime.datetime(2021, 1, 1), version=1) + ws.record_state(urls=urls, store=store.value, timestamp=datetime.datetime(2021, 1, 1), version=1, distribution_version=1) assert_directory(ws.input_path, exists=True, empty=False) assert_directory(ws.results_path, exists=True, empty=False) @@ -71,7 +73,7 @@ def test_record_state(tmpdir, dummy_file): urls = ["http://localhost:8000/dummy-input-1.json"] store = result.StoreStrategy.FLAT_FILE - ws.record_state(urls=urls, store=store.value, timestamp=datetime.datetime(2021, 1, 1), version=1) + ws.record_state(urls=urls, store=store.value, timestamp=datetime.datetime(2021, 1, 1), version=1, distribution_version=1) current_state = workspace.State.read(root=ws.path) @@ -100,10 +102,10 @@ def test_record_state_urls_persisted_across_runs(tmpdir, dummy_file): urls = ["http://localhost:8000/dummy-input-1.json"] store = result.StoreStrategy.FLAT_FILE - ws.record_state(urls=urls, store=store.value, timestamp=datetime.datetime(2021, 1, 1), version=1) + ws.record_state(urls=urls, store=store.value, timestamp=datetime.datetime(2021, 1, 1), version=1, distribution_version=1) # this call should not clear the URLs - ws.record_state(urls=None, store=store.value, timestamp=datetime.datetime(2021, 1, 1), version=1) + ws.record_state(urls=None, store=store.value, timestamp=datetime.datetime(2021, 1, 1), version=1, distribution_version=1) current_state = workspace.State.read(root=ws.path) @@ -133,8 +135,109 @@ def test_state_schema(tmpdir, dummy_file, helpers): urls = ["http://localhost:8000/dummy-input-1.json"] store = result.StoreStrategy.FLAT_FILE - ws.record_state(urls=urls, store=store.value, timestamp=datetime.datetime(2021, 1, 1), version=1) + ws.record_state(urls=urls, store=store.value, timestamp=datetime.datetime(2021, 1, 1), version=1, distribution_version=1) ws_helper = helpers.provider_workspace_helper(name=name, create=False) assert ws_helper.metadata_schema_valid() + + +def test_replace_results(tmpdir, dummy_file): + # make 2 workspaces + ws1_dir = tmpdir.mkdir("ws1") + ws1 = workspace.Workspace(root=ws1_dir, name="old", create=True) + dummy_file(ws1.input_path, "old-input-1.json") + dummy_file(ws1.results_path, "old-00000.json") + urls = ["http://localhost:8000/old_workspace.json"] + store = result.StoreStrategy.FLAT_FILE + ws1.record_state(urls=urls, store=store.value, timestamp=datetime.datetime(2021, 1, 1), version=1, distribution_version=1) + + ws2_dir = tmpdir.mkdir("ws2") + ws2 = workspace.Workspace(root=ws2_dir, name="new", create=True) + dummy_file(ws2.input_path, "new-input-2.json") + urls = ["http://localhost:8000/new_workspace.json"] + store = result.StoreStrategy.FLAT_FILE + dummy_file(ws2.results_path, "new-00000.json") + ws2.record_state(urls=urls, store=store.value, timestamp=datetime.datetime(2024, 1, 1), version=1, distribution_version=1) + # be able to tell the contents apart + # replace the results in one with the other + # assert on the contents + ws1.replace_results(ws2) + # assert that metadata / state is replaced + assert "http://localhost:8000/new_workspace.json" in ws1.state().urls + assert "http://localhost:8000/old_workspace.json" not in ws1.state().urls + + # assert that results have been replaced + assert os.path.exists(os.path.join(ws1.results_path, "new-00000.json")) + assert not os.path.exists(os.path.join(ws1.results_path, "old-00000.json")) + + # assert that input path is not changed + assert os.path.exists(os.path.join(ws1.input_path, "old-input-1.json")) + assert not os.path.exists(os.path.join(ws1.input_path, "new-input-1.json")) + + +def test_validate_checksums_valid(tmpdir, dummy_file): + ws = workspace.Workspace(root=tmpdir, name="dummy", create=True) + urls = [] + dummy_file(ws.results_path, "dummy-result-1.json") + ws.record_state( + urls=urls, + store=result.StoreStrategy.FLAT_FILE.value, + timestamp=datetime.datetime(2021, 1, 1), + version=1, + distribution_version=1, + ) + ws.validate_checksums() + + +def test_validate_checksums_invalid(tmpdir, dummy_file): + ws = workspace.Workspace(root=tmpdir, name="dummy", create=True) + urls = [] + result_file_path = dummy_file(ws.results_path, "dummy-result-1.json") + ws.record_state( + urls=urls, + store=result.StoreStrategy.FLAT_FILE.value, + timestamp=datetime.datetime(2021, 1, 1), + version=1, + distribution_version=1, + ) + with open(result_file_path, "a") as f: + f.write("invalid") + with pytest.raises(RuntimeError) as e: + ws.validate_checksums() + assert "has been modified" in str(e) + + +def test_validate_checksums_missing(tmpdir, dummy_file): + ws = workspace.Workspace(root=tmpdir, name="dummy", create=True) + urls = [] + result_file_path = dummy_file(ws.results_path, "dummy-result-1.json") + ws.record_state( + urls=urls, + store=result.StoreStrategy.FLAT_FILE.value, + timestamp=datetime.datetime(2021, 1, 1), + version=1, + distribution_version=1, + ) + os.remove(result_file_path) + with pytest.raises(RuntimeError) as e: + ws.validate_checksums() + assert "does not exist" in str(e) + + +def test_validate_checksums_input_changed_ok(tmpdir, dummy_file): + ws = workspace.Workspace(root=tmpdir, name="dummy", create=True) + urls = [] + result_file_path = dummy_file(ws.results_path, "dummy-result-1.json") + input_file_path = dummy_file(ws.input_path, "dummy-input-1.json") + ws.record_state( + urls=urls, + store=result.StoreStrategy.FLAT_FILE.value, + timestamp=datetime.datetime(2021, 1, 1), + version=1, + distribution_version=1, + ) + with open(input_file_path, "a") as f: + f.write("this won't change the checksums") + + ws.validate_checksums() diff --git a/tests/unit/utils/test_archive.py b/tests/unit/utils/test_archive.py new file mode 100644 index 00000000..eae00f9b --- /dev/null +++ b/tests/unit/utils/test_archive.py @@ -0,0 +1,75 @@ +import tempfile +import tarfile +import zstandard +from pathlib import Path +from unittest.mock import patch, MagicMock + +from vunnel.utils import archive + +import pytest + + +@pytest.mark.parametrize( + "tar_info_name, allowed", + [ + ("file.txt", True), + ("./file.txt", True), + ("some-dir/file.txt", True), + ("../file.txt", False), + ("/file.txt", False), + ("some-dir/../../../../../../etc/passwd", False), + ], +) +def test_filter_path_traversal(tar_info_name: str, allowed: bool): + tar_info = tarfile.TarInfo(tar_info_name) + actual = archive._filter_path_traversal(tar_info, "/some/path") + if allowed: + assert actual is not None + assert actual.name == tar_info_name + else: + assert actual is None + + +def test_extract_tar_zst(): + # create a temporary directory with a file to compress + with tempfile.TemporaryDirectory() as tmp_dir: + file_path = Path(tmp_dir) / "test_file.txt" + file_path.write_text("Test content") + + # compress the file into a .tar.zst file + tar_path = Path(tmp_dir) / "test_file.tar" + zst_path = Path(tmp_dir) / "test_file.tar.zst" + with tarfile.open(tar_path, "w") as tar: + tar.add(file_path, arcname="test_file.txt") + with open(tar_path, "rb") as tar, open(zst_path, "wb") as zst: + cctx = zstandard.ZstdCompressor() + cctx.copy_stream(tar, zst) + + # extract the .tar.zst file + with tempfile.TemporaryDirectory() as extract_dir: + archive._extract_tar_zst(zst_path, extract_dir) + + # check if the file was correctly extracted + assert (Path(extract_dir) / "test_file.txt").read_text() == "Test content" + + +@patch("vunnel.utils.archive._extract_tar_zst") +@patch("vunnel.utils.archive._safe_extract_tar") +@patch("tarfile.open") +def test_extract(mock_tarfile_open, mock_safe_extract_tar, mock_extract_tar_zst): + open_mock = MagicMock() + mock_tarfile_open.return_value.__enter__.return_value = open_mock + + # call extract with a .tar.zst file + archive.extract("file.tar.zst", "dest_dir") + mock_extract_tar_zst.assert_called_once_with("file.tar.zst", "dest_dir") + mock_safe_extract_tar.assert_not_called() + mock_tarfile_open.assert_not_called() + + mock_extract_tar_zst.reset_mock() + mock_safe_extract_tar.reset_mock() + + # call extract with a .tar file + archive.extract("file.tar", "dest_dir") + mock_safe_extract_tar.assert_called_once_with(open_mock, "dest_dir") + mock_extract_tar_zst.assert_not_called() diff --git a/tests/unit/utils/test_hasher.py b/tests/unit/utils/test_hasher.py new file mode 100644 index 00000000..1d9bd171 --- /dev/null +++ b/tests/unit/utils/test_hasher.py @@ -0,0 +1,36 @@ +import hashlib + +import xxhash +import pytest + +from unittest.mock import mock_open, patch +from vunnel.utils.hasher import Method + + +@pytest.mark.parametrize( + "method,data,label,expected", + [ + (Method.SHA256, b"test data 1", True, "sha256:05e8fdb3598f91bcc3ce41a196e587b4592c8cdfc371c217274bfda2d24b1b4e"), + (Method.SHA256, b"test data 2", False, "26637da1bd793f9011a3d304372a9ec44e36cc677d2bbfba32a2f31f912358fe"), + (Method.XXH64, b"test data 1", True, "xxh64:7ccde767ab423322"), + ], +) +def test_digest(method, data, label, expected): + m = mock_open(read_data=data) + with patch("builtins.open", m): + assert method.digest("any path", label) == expected + + +@pytest.mark.parametrize( + "value,expected", + [ + ("sha256:05e8fdb3598f91bcc3ce41a196e587b4592c8cdfc371c217274bfda2d24b1b4e", Method.SHA256), + ("sha256", Method.SHA256), + ("sha-256", Method.SHA256), + ("SHA256", Method.SHA256), + ("xxh64", Method.XXH64), + ("xXh64 ", Method.XXH64), + ], +) +def test_parse(value, expected): + assert Method.parse(value) == expected