diff --git a/poetry.lock b/poetry.lock index 803ed416..c37eb7da 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1443,7 +1443,6 @@ files = [ {file = "PyYAML-6.0.1-cp311-cp311-win_amd64.whl", hash = "sha256:bf07ee2fef7014951eeb99f56f39c9bb4af143d8aa3c21b1677805985307da34"}, {file = "PyYAML-6.0.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:855fb52b0dc35af121542a76b9a84f8d1cd886ea97c84703eaa6d88e37a2ad28"}, {file = "PyYAML-6.0.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:40df9b996c2b73138957fe23a16a4f0ba614f4c0efce1e9406a184b6d07fa3a9"}, - {file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a08c6f0fe150303c1c6b71ebcd7213c2858041a7e01975da3a99aed1e7a378ef"}, {file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6c22bec3fbe2524cde73d7ada88f6566758a8f7227bfbf93a408a9d86bcc12a0"}, {file = "PyYAML-6.0.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:8d4e9c88387b0f5c7d5f281e55304de64cf7f9c0021a3525bd3b1c542da3b0e4"}, {file = "PyYAML-6.0.1-cp312-cp312-win32.whl", hash = "sha256:d483d2cdf104e7c9fa60c544d92981f12ad66a457afae824d146093b8c294c54"}, diff --git a/src/vunnel/providers/nvd/__init__.py b/src/vunnel/providers/nvd/__init__.py index 76057352..6ce5dca7 100644 --- a/src/vunnel/providers/nvd/__init__.py +++ b/src/vunnel/providers/nvd/__init__.py @@ -21,6 +21,8 @@ class Config: ) request_timeout: int = 125 api_key: Optional[str] = "env:NVD_API_KEY" # noqa: UP007 + overrides_url: str = "https://github.com/anchore/nvd-data-overrides/archive/refs/heads/main.tar.gz" + overrides_enabled: bool = False def __post_init__(self) -> None: if self.api_key and self.api_key.startswith("env:"): @@ -36,6 +38,8 @@ def __str__(self) -> str: class Provider(provider.Provider): + __version__ = 2 + def __init__(self, root: str, config: Config | None = None): if not config: config = Config() @@ -50,12 +54,25 @@ def __init__(self, root: str, config: Config | None = None): "(otherwise incremental updates will fail)", ) + if self.config.runtime.result_store != result.StoreStrategy.SQLITE: + raise ValueError( + f"only 'SQLITE' is supported for 'runtime.result_store' but got '{self.config.runtime.result_store}'", + ) + + if self.config.overrides_enabled and not self.config.overrides_url: + raise ValueError( + "if 'overrides_enabled' is set then 'overrides_url' must be set", + ) + self.schema = schema.NVDSchema() self.manager = Manager( workspace=self.workspace, + schema=self.schema, download_timeout=self.config.request_timeout, api_key=self.config.api_key, logger=self.logger, + overrides_enabled=self.config.overrides_enabled, + overrides_url=self.config.overrides_url, ) @classmethod diff --git a/src/vunnel/providers/nvd/manager.py b/src/vunnel/providers/nvd/manager.py index 19c85b78..cad58940 100644 --- a/src/vunnel/providers/nvd/manager.py +++ b/src/vunnel/providers/nvd/manager.py @@ -5,7 +5,9 @@ import os from typing import TYPE_CHECKING, Any -from .api import NvdAPI +from vunnel import result, schema +from vunnel.providers.nvd.api import NvdAPI +from vunnel.providers.nvd.overrides import NVDOverrides if TYPE_CHECKING: from collections.abc import Generator @@ -14,12 +16,17 @@ class Manager: - def __init__( + __nvd_input_db__ = "nvd-input.db" + + def __init__( # noqa: PLR0913 self, workspace: Workspace, + schema: schema.Schema, + overrides_url: str, logger: logging.Logger | None = None, download_timeout: int = 125, api_key: str | None = None, + overrides_enabled: bool = False, ) -> None: self.workspace = workspace @@ -28,19 +35,87 @@ def __init__( self.logger = logger self.api = NvdAPI(api_key=api_key, logger=logger, timeout=download_timeout) + + self.overrides = NVDOverrides( + enabled=overrides_enabled, + url=overrides_url, + workspace=workspace, + logger=logger, + download_timeout=download_timeout, + ) + self.urls = [self.api._cve_api_url_] # noqa: SLF001 + self.schema = schema def get( self, last_updated: datetime.datetime | None, skip_if_exists: bool = False, ) -> Generator[tuple[str, dict[str, Any]], Any, None]: - if skip_if_exists and self._can_update_incrementally(last_updated): - yield from self._download_updates(last_updated) # type: ignore # noqa: PGH003 - else: - yield from self._download_all() + self.overrides.download() + + cves_processed = set() + for record_id, record in self._download_nvd_input(last_updated, skip_if_exists): + cves_processed.add(id_to_cve(record_id)) + yield record_id, record + + if self.overrides.enabled: + self.urls.append(self.overrides.url) + self.logger.debug("applying NVD data overrides...") + + override_cves = {cve.lower() for cve in self.overrides.cves()} + override_remaining_cves = override_cves - cves_processed + with self._sqlite_reader() as reader: + for cve in override_remaining_cves: + + original_record = reader.read(cve_to_id(cve)) + if not original_record: + self.logger.warning(f"override for {cve} not found in original data") + continue + + original_record = original_record["item"] + if not original_record: + self.logger.warning(f"missing original data for {cve}") + continue + + yield cve_to_id(cve), self._apply_override(cve, original_record) + + self.logger.debug(f"applied overrides for {len(override_remaining_cves)} CVEs") + + self.logger.debug("overrides are not enabled, skipping...") + + def _download_nvd_input( + self, + last_updated: datetime.datetime | None, + skip_if_exists: bool = False, + ) -> Generator[tuple[str, dict[str, Any]], Any, None]: + with self._nvd_input_writer() as writer: + if skip_if_exists and self._can_update_incrementally(last_updated): + yield from self._download_updates(last_updated, writer) # type: ignore # noqa: PGH003 + else: + yield from self._download_all(writer) + + def _nvd_input_writer(self) -> result.Writer: + return result.Writer( + workspace=self.workspace, + result_state_policy=result.ResultStatePolicy.KEEP, + logger=self.logger, + store_strategy=result.StoreStrategy.SQLITE, + write_location=self._input_nvd_path, + ) + + def _sqlite_reader(self) -> result.SQLiteReader: + return result.SQLiteReader(sqlite_db_path=self._input_nvd_path) + + @property + def _input_nvd_path(self) -> str: + return os.path.join(self.workspace.input_path, self.__nvd_input_db__) def _can_update_incrementally(self, last_updated: datetime.datetime | None) -> bool: + input_db_path = os.path.join(self.workspace.input_path, self.__nvd_input_db__) + if not os.path.exists(input_db_path): + return False + if not last_updated: return False @@ -55,15 +130,19 @@ def _can_update_incrementally(self, last_updated: datetime.datetime | None) -> b return True - def _download_all(self) -> Generator[tuple[str, dict[str, Any]], Any, None]: + def _download_all(self, writer: result.Writer) -> Generator[tuple[str, dict[str, Any]], Any, None]: self.logger.info("downloading all CVEs") # TODO: should we delete all existing state in this case first? for response in self.api.cve(): - yield from self._unwrap_records(response) + yield from self._unwrap_records(response, writer) - def _download_updates(self, last_updated: datetime.datetime) -> Generator[tuple[str, dict[str, Any]], Any, None]: + def _download_updates( + self, + last_updated: datetime.datetime, + writer: result.Writer, + ) -> Generator[tuple[str, dict[str, Any]], Any, None]: self.logger.debug(f"downloading CVEs changed since {last_updated.isoformat()}") # get the list of CVEs that have been updated since the last sync @@ -74,10 +153,45 @@ def _download_updates(self, last_updated: datetime.datetime) -> Generator[tuple[ if total_results: self.logger.debug(f"discovered {total_results} updated CVEs") - yield from self._unwrap_records(response) + yield from self._unwrap_records(response, writer) - def _unwrap_records(self, response: dict[str, Any]) -> Generator[tuple[str, dict[str, Any]], Any, None]: + def _unwrap_records( + self, + response: dict[str, Any], + writer: result.Writer, + ) -> Generator[tuple[str, dict[str, Any]], Any, None]: for vuln in response["vulnerabilities"]: cve_id = vuln["cve"]["id"] - year = cve_id.split("-")[1] - yield os.path.join(year, cve_id), vuln + record_id = cve_to_id(cve_id) + + # keep input for future overrides + writer.write(record_id, self.schema, vuln) + + # apply overrides to output + yield record_id, self._apply_override(cve_id=cve_id, record=vuln) + + def _apply_override(self, cve_id: str, record: dict[str, Any]) -> dict[str, Any]: + override = self.overrides.cve(cve_id) + if override: + self.logger.debug(f"applying override for {cve_id}") + # ignore empty overrides + if override is None or "cve" not in override: + return record + # explicitly only support CPE configurations for now and always override the + # original record configurations. Can figure out more complicated scenarios + # later if needed + if "configurations" not in override["cve"]: + return record + + record["cve"]["configurations"] = override["cve"]["configurations"] + + return record + + +def cve_to_id(cve: str) -> str: + year = cve.split("-")[1] + return os.path.join(year, cve) + + +def id_to_cve(cve_id: str) -> str: + return cve_id.split("/")[1] diff --git a/src/vunnel/providers/nvd/overrides.py b/src/vunnel/providers/nvd/overrides.py new file mode 100644 index 00000000..c11faa24 --- /dev/null +++ b/src/vunnel/providers/nvd/overrides.py @@ -0,0 +1,109 @@ +from __future__ import annotations + +import glob +import logging +import os +import tarfile +from typing import TYPE_CHECKING, Any + +from orjson import loads + +from vunnel.utils import http + +if TYPE_CHECKING: + from vunnel.workspace import Workspace + + +class NVDOverrides: + __file_name__ = "nvd-overrides.tar.gz" + __extract_name__ = "nvd-overrides" + + def __init__( # noqa: PLR0913 + self, + enabled: bool, + url: str, + workspace: Workspace, + logger: logging.Logger | None = None, + download_timeout: int = 125, + ): + self.enabled = enabled + self.__url__ = url + self.workspace = workspace + self.download_timeout = download_timeout + if not logger: + logger = logging.getLogger(self.__class__.__name__) + self.logger = logger + self.__filepaths_by_cve__: dict[str, str] | None = None + + @property + def url(self) -> str: + return self.__url__ + + def download(self) -> None: + if not self.enabled: + self.logger.debug("overrides are not enabled, skipping download...") + return + + req = http.get(self.__url__, self.logger, stream=True, timeout=self.download_timeout) + + file_path = os.path.join(self.workspace.input_path, self.__file_name__) + with open(file_path, "wb") as fp: + for chunk in req.iter_content(): + fp.write(chunk) + + untar_file(file_path, self._extract_path) + + @property + def _extract_path(self) -> str: + return os.path.join(self.workspace.input_path, self.__extract_name__) + + def _build_files_by_cve(self) -> dict[str, Any]: + filepaths_by_cve__: dict[str, str] = {} + for path in glob.glob(os.path.join(self._extract_path, "**/data/**/", "CVE-*.json"), recursive=True): + cve_id = os.path.basename(path).removesuffix(".json").upper() + filepaths_by_cve__[cve_id] = path + + return filepaths_by_cve__ + + def cve(self, cve_id: str) -> dict[str, Any] | None: + if not self.enabled: + return None + + if self.__filepaths_by_cve__ is None: + self.__filepaths_by_cve__ = self._build_files_by_cve() + + # TODO: implement in-memory index + path = self.__filepaths_by_cve__.get(cve_id.upper()) + if path and os.path.exists(path): + with open(path) as f: + return loads(f.read()) + return None + + def cves(self) -> list[str]: + if not self.enabled: + return [] + + if self.__filepaths_by_cve__ is None: + self.__filepaths_by_cve__ = self._build_files_by_cve() + + return list(self.__filepaths_by_cve__.keys()) + + +def untar_file(file_path: str, extract_path: str) -> None: + with tarfile.open(file_path, "r:gz") as tar: + + def filter_path_traversal(tarinfo: tarfile.TarInfo, path: str) -> tarfile.TarInfo | None: + # we do not expect any relative file paths that would result in the clean + # path being different from the original path + # e.g. + # expected: results/results.db + # unexpected: results/../../../../etc/passwd + # we filter (drop) any such entries + + if tarinfo.name != os.path.normpath(tarinfo.name): + return None + return tarinfo + + # note: we have a filter that drops any entries that would result in a path traversal + # which is what S202 is referring to (linter isn't smart enough to understand this) + tar.extractall(path=extract_path, filter=filter_path_traversal) # noqa: S202 diff --git a/src/vunnel/result.py b/src/vunnel/result.py index 3fdd4ba7..ee447795 100644 --- a/src/vunnel/result.py +++ b/src/vunnel/result.py @@ -42,6 +42,7 @@ def __init__( result_state_policy: ResultStatePolicy, skip_duplicates: bool = False, logger: logging.Logger | None = None, + **kwargs: dict[str, Any], ): self.workspace = workspace self.result_state_policy = result_state_policy @@ -117,6 +118,10 @@ def __init__(self, *args: Any, **kwargs: Any): self.conn = None self.engine = None self.table = None + self.write_location = kwargs.get("write_location", None) + if self.write_location: + self.filename = os.path.basename(self.write_location) + self.temp_filename = f"{self.filename}.tmp" @db.event.listens_for(db.engine.Engine, "connect") def set_sqlite_pragma(dbapi_connection, connection_record): # type: ignore[no-untyped-def] @@ -133,13 +138,21 @@ def connection(self) -> tuple[db.engine.Connection, db.Table]: self.table = self._create_table() return self.conn, self.table + @property + def write_dir(self) -> str: + if self.write_location: + return os.path.dirname(self.write_location) + return self.workspace.results_path + @property def db_file_path(self) -> str: - return os.path.join(self.workspace.results_path, self.filename) + if self.write_location: + return self.write_location + return os.path.join(self.write_dir, self.filename) @property def temp_db_file_path(self) -> str: - return os.path.join(self.workspace.results_path, self.temp_filename) + return os.path.join(self.write_dir, self.temp_filename) def _create_table(self) -> db.Table: metadata = db.MetaData() @@ -171,6 +184,15 @@ def store(self, identifier: str, record: Envelope) -> None: conn.execute(statement) + def read(self, identifier: str) -> Envelope: + conn, table = self.connection() + with conn.begin(): + result = conn.execute(table.select().where(table.c.id == identifier)).first() + if not result: + raise KeyError(f"no result found for identifier: {identifier!r}") + + return Envelope(**orjson.loads(result.record)) + def prepare(self) -> None: if os.path.exists(self.temp_db_file_path): self.logger.warning("removing unexpected partial result state") @@ -188,7 +210,7 @@ def close(self, successful: bool) -> None: self.engine = None self.table = None - if successful: + if successful and os.path.exists(self.temp_db_file_path): os.rename(self.temp_db_file_path, self.db_file_path) elif os.path.exists(self.temp_db_file_path): os.remove(self.temp_db_file_path) @@ -202,6 +224,7 @@ def __init__( # noqa: PLR0913 logger: logging.Logger | None = None, skip_duplicates: bool = False, store_strategy: StoreStrategy = StoreStrategy.FLAT_FILE, + write_location: str | None = None, ): self.workspace = workspace self.skip_duplicates = skip_duplicates @@ -216,6 +239,7 @@ def __init__( # noqa: PLR0913 result_state_policy=result_state_policy, skip_duplicates=skip_duplicates, logger=logger, + write_location=write_location, ) def __enter__(self) -> Writer: @@ -242,3 +266,46 @@ def write(self, identifier: str, schema: Schema, payload: Any) -> None: self.store.store(identifier, envelope) self.wrote += 1 + + +class SQLiteReader: + def __init__(self, sqlite_db_path: str, table_name: str = "results"): + self.db_path = sqlite_db_path + self.table_name = table_name + self.conn = None + self.engine = None + self.table = None + + def read(self, identifier: str) -> dict[str, Any] | None: + conn, table = self.connection() + with conn.begin(): + result = conn.execute(table.select().where(table.c.id == identifier.lower())).first() + if not result: + return None + + return orjson.loads(result.record) + + def connection(self) -> tuple[db.engine.Connection, db.Table]: + if not self.conn: + self.engine = db.create_engine(f"sqlite:///{self.db_path}") + self.conn = self.engine.connect() # type: ignore[attr-defined] + metadata = db.MetaData(bind=self.engine) + self.table = db.Table(self.table_name, metadata, autoload=True, autoload_with=self.engine) + return self.conn, self.table + + def __enter__(self) -> SQLiteReader: + return self + + def __exit__( + self, + exc_type: type[BaseException] | None, + exc_val: BaseException | None, + exc_tb: TracebackType | None, + ) -> None: + if self.conn: + self.conn.close() + self.engine.dispose() + + self.conn = None + self.engine = None + self.table = None diff --git a/tests/unit/cli/test-fixtures/full.yaml b/tests/unit/cli/test-fixtures/full.yaml index 1f4a9365..66cef0f9 100644 --- a/tests/unit/cli/test-fixtures/full.yaml +++ b/tests/unit/cli/test-fixtures/full.yaml @@ -37,6 +37,8 @@ providers: nvd: runtime: *runtime request_timeout: 20 + overrides_enabled: true + overrides_url: https://github.com/anchore/nvd-data-overrides/SOMEWHEREELSE/main.tar.gz oracle: runtime: *runtime request_timeout: 20 diff --git a/tests/unit/cli/test_cli.py b/tests/unit/cli/test_cli.py index 3055898c..5a78bd11 100644 --- a/tests/unit/cli/test_cli.py +++ b/tests/unit/cli/test_cli.py @@ -212,6 +212,8 @@ def test_config(monkeypatch) -> None: result_store: sqlite nvd: api_key: secret + overrides_enabled: false + overrides_url: https://github.com/anchore/nvd-data-overrides/archive/refs/heads/main.tar.gz request_timeout: 125 runtime: existing_input: keep diff --git a/tests/unit/cli/test_config.py b/tests/unit/cli/test_config.py index e240fb37..78e153d8 100644 --- a/tests/unit/cli/test_config.py +++ b/tests/unit/cli/test_config.py @@ -82,6 +82,8 @@ def test_full_config(helpers): nvd=providers.nvd.Config( runtime=runtime_cfg, request_timeout=20, + overrides_enabled=True, + overrides_url="https://github.com/anchore/nvd-data-overrides/SOMEWHEREELSE/main.tar.gz", ), oracle=providers.oracle.Config( runtime=runtime_cfg, diff --git a/tests/unit/providers/nvd/test_manager.py b/tests/unit/providers/nvd/test_manager.py index eda92458..42d532e9 100644 --- a/tests/unit/providers/nvd/test_manager.py +++ b/tests/unit/providers/nvd/test_manager.py @@ -3,7 +3,7 @@ import json import pytest -from vunnel import workspace +from vunnel import workspace, schema from vunnel.providers.nvd import manager @@ -26,7 +26,11 @@ def test_parser(tmpdir, helpers, mock_data_path, mocker): identity = f"{year}/{cve_id}" expected_vulns.append((identity, v)) - subject = manager.Manager(workspace=workspace.Workspace(tmpdir, "test", create=True)) + subject = manager.Manager( + workspace=workspace.Workspace(tmpdir, "test", create=True), + schema=schema.NVDSchema(), + overrides_url="http://example.com", + ) subject.api.cve = mocker.Mock(return_value=[json_dict]) actual_vulns = list(subject.get(None)) diff --git a/tests/unit/providers/nvd/test_nvd.py b/tests/unit/providers/nvd/test_nvd.py index 71c5db0f..14aca52f 100644 --- a/tests/unit/providers/nvd/test_nvd.py +++ b/tests/unit/providers/nvd/test_nvd.py @@ -6,7 +6,6 @@ import pytest from vunnel import provider, result from vunnel.providers import nvd -from vunnel.providers.nvd import api as nvd_api @pytest.mark.parametrize( @@ -19,7 +18,55 @@ ) def test_incremental_update_with_existing_results(policy, should_raise): def make(): - nvd.Provider("/tmp/doesntmatter", nvd.Config(runtime=provider.RuntimeConfig(existing_results=policy))) + nvd.Provider( + "/tmp/doesntmatter", + nvd.Config(runtime=provider.RuntimeConfig(existing_results=policy, result_store=result.StoreStrategy.SQLITE)), + ) + + if should_raise: + with pytest.raises(Exception): + make() + else: + make() + + +@pytest.mark.parametrize( + ("store", "should_raise"), + ( + (result.StoreStrategy.FLAT_FILE, True), + (result.StoreStrategy.SQLITE, False), + ), +) +def test_require_sqlite_store(store, should_raise): + def make(): + nvd.Provider("/tmp/doesntmatter", nvd.Config(runtime=provider.RuntimeConfig(result_store=store))) + + if should_raise: + with pytest.raises(Exception): + make() + else: + make() + + +@pytest.mark.parametrize( + ("overrides_enabled", "overrides_url", "should_raise"), + ( + (True, "something", False), + (False, "something", False), + (True, "", True), + (False, "", False), + ), +) +def test_require_override_configuration(overrides_enabled, overrides_url, should_raise): + def make(): + nvd.Provider( + "/tmp/doesntmatter", + nvd.Config( + overrides_enabled=overrides_enabled, + overrides_url=overrides_url, + runtime=provider.RuntimeConfig(result_store=result.StoreStrategy.SQLITE), + ), + ) if should_raise: with pytest.raises(Exception): @@ -42,8 +89,8 @@ def test_provider_schema(helpers, mock_data_path, expected_written_entries, disa json_dict = json.load(f) c = nvd.Config() - c.runtime.result_store = result.StoreStrategy.FLAT_FILE p = nvd.Provider(root=workspace.root, config=c) + p.config.runtime.result_store = result.StoreStrategy.FLAT_FILE p.manager.api.cve = mocker.Mock(return_value=[json_dict]) p.update(None) @@ -68,12 +115,12 @@ def test_provider_via_snapshot(helpers, mock_data_path, disable_get_requests, mo ) c = nvd.Config() - # keep all of the default values for the result store, but override the strategy - c.runtime.result_store = result.StoreStrategy.FLAT_FILE p = nvd.Provider( root=workspace.root, config=c, ) + # keep all of the default values for the result store, but override the strategy + p.config.runtime.result_store = result.StoreStrategy.FLAT_FILE mock_data_path = helpers.local_dir(mock_data_path) diff --git a/tests/unit/providers/nvd/test_overrides.py b/tests/unit/providers/nvd/test_overrides.py new file mode 100644 index 00000000..6eafabcb --- /dev/null +++ b/tests/unit/providers/nvd/test_overrides.py @@ -0,0 +1,71 @@ +from __future__ import annotations + +import tarfile +from unittest.mock import patch, MagicMock + +import pytest +from vunnel import workspace +from vunnel.providers.nvd import overrides + + +@pytest.fixture +def overrides_tar(tmpdir): + tar = tmpdir.join("overrides.tar.gz") + + with tarfile.open(tar, "w:gz") as f: + f.add("tests/unit/providers/nvd/test-fixtures/single-entry.json", arcname="data/CVE-2011-0022.json") + + return tar + + +@pytest.fixture +def path_traversal_tar(tmpdir): + tar = tmpdir.join("overrides.tar.gz") + + with tarfile.open(tar, "w:gz") as f: + f.add("tests/unit/providers/nvd/test-fixtures/single-entry.json", arcname="data/../../CVE-2011-0022.json") + + return tar + + +@patch("requests.get") +def test_overrides_disabled(mock_requests, tmpdir): + subject = overrides.NVDOverrides( + enabled=False, + url="http://localhost:8080/failed", + workspace=workspace.Workspace(tmpdir, "test", create=True), + ) + subject.__filepaths_by_cve__ = {"CVE-2020-0000": '{"fail": true}'} + + # ensure requests.get is not called + subject.download() + mock_requests.get.assert_not_called() + + # ensure cve returns None + assert subject.cve("CVE-2020-0000") is None + assert subject.cves() == [] + + +@patch("requests.get") +def test_overrides_enabled(mock_requests, overrides_tar, tmpdir): + mock_requests.return_value = MagicMock(status_code=200, iter_content=lambda: [open(overrides_tar, "rb").read()]) + subject = overrides.NVDOverrides( + enabled=True, + url="http://localhost:8080/failed", + workspace=workspace.Workspace(tmpdir, "test", create=True), + ) + + subject.download() + + assert subject.cve("CVE-2011-0022") is not None + assert subject.cves() == ["CVE-2011-0022"] + + +def test_untar_file(overrides_tar, tmpdir): + overrides.untar_file(overrides_tar, tmpdir) + assert tmpdir.join("data/CVE-2011-0022.json").check(file=True) + + +def test_untar_file_path_traversal(path_traversal_tar, tmpdir): + overrides.untar_file(path_traversal_tar, tmpdir.join("somewhere", "else")) + assert tmpdir.join("somewhere/else/CVE-2011-0022.json").check(file=False)