From 840ee3bb0d411667495e5ed4dbdd32ade101f239 Mon Sep 17 00:00:00 2001 From: Weston Steimel Date: Mon, 30 Jan 2023 19:45:12 +0000 Subject: [PATCH] fix: remove flawed skip_if_exists logic (#53) The current skip_if_exists logic causes provider updates to be missed because it never revisits files once they are downloaded. This only worked previously because the flag was defaulted to false and never actually set for the majority of the legacy enterprise drivers. Signed-off-by: Weston Steimel --- src/vunnel/providers/alpine/__init__.py | 2 +- src/vunnel/providers/alpine/parser.py | 111 ++++++++++----------- src/vunnel/providers/amazon/parser.py | 33 +++--- src/vunnel/providers/centos/__init__.py | 2 +- src/vunnel/providers/centos/parser.py | 88 ++++++++-------- src/vunnel/providers/debian/__init__.py | 2 +- src/vunnel/providers/debian/parser.py | 58 +++++------ src/vunnel/providers/oracle/__init__.py | 2 +- src/vunnel/providers/oracle/parser.py | 10 +- src/vunnel/providers/sles/__init__.py | 2 +- src/vunnel/providers/sles/parser.py | 14 +-- src/vunnel/providers/wolfi/__init__.py | 2 +- src/vunnel/providers/wolfi/parser.py | 75 +++++--------- tests/unit/providers/alpine/test_alpine.py | 7 +- tests/unit/providers/amazon/test_amazon.py | 7 +- tests/unit/providers/centos/test_centos.py | 18 +++- tests/unit/providers/debian/test_debian.py | 8 +- tests/unit/providers/oracle/test_oracle.py | 7 +- 18 files changed, 216 insertions(+), 232 deletions(-) diff --git a/src/vunnel/providers/alpine/__init__.py b/src/vunnel/providers/alpine/__init__.py index 9d0b96aa..e44700f7 100644 --- a/src/vunnel/providers/alpine/__init__.py +++ b/src/vunnel/providers/alpine/__init__.py @@ -41,7 +41,7 @@ def update(self, last_updated: datetime.datetime | None) -> tuple[list[str], int with self.results_writer() as writer: # TODO: tech debt: on subsequent runs, we should only write new vulns (this currently re-writes all) - for namespace, vulns in self.parser.get(skip_if_exists=self.config.runtime.skip_if_exists): + for namespace, vulns in self.parser.get(): namespace = namespace.lower() for vuln_id, record in vulns.items(): vuln_id = vuln_id.lower() diff --git a/src/vunnel/providers/alpine/parser.py b/src/vunnel/providers/alpine/parser.py index 9405790d..96e97ed9 100644 --- a/src/vunnel/providers/alpine/parser.py +++ b/src/vunnel/providers/alpine/parser.py @@ -70,7 +70,7 @@ def __init__( def urls(self) -> list[str]: return list(self._urls) - def _download(self, skip_if_exists=False): + def _download(self): """ Downloads alpine sec db files :return: @@ -81,62 +81,57 @@ def _download(self, skip_if_exists=False): if os.path.exists(os.path.join(self.secdb_dir_path, "alpine-secdb-master.tar.gz")): os.remove(os.path.join(self.secdb_dir_path, "alpine-secdb-master.tar.gz")) - if skip_if_exists and os.path.exists(self.secdb_dir_path): - self.logger.debug( - "'skip_if_exists' flag enabled and found source under {}. Skipping download".format(self.secdb_dir_path) - ) - else: - links = [] - try: - if not os.path.exists(self.secdb_dir_path): - os.makedirs(self.secdb_dir_path, exist_ok=True) - - self.logger.info("downloading alpine secdb metadata from: {}".format(self.metadata_url)) - r = requests.get(self.metadata_url, timeout=self.download_timeout) - if r.status_code == 200: - try: - self.logger.debug("HTML parsing secdb landing page content for links") - parser = SecdbLandingParser() - parser.feed(r.text) - links = parser.links - except: - self.logger.warning("unable to html parse secdb landing page content for links") - - if not links: - self.logger.debug("string parsing secdb landing page content for links") - links = re.findall(self._link_finder_regex_, r.text) - else: - r.raise_for_status() - except Exception: - self.logger.exception("error downloading or parsing alpine secdb metadata") - raise - - if links: - self.logger.debug("found release specific secdb links: {}".format(links)) + links = [] + try: + if not os.path.exists(self.secdb_dir_path): + os.makedirs(self.secdb_dir_path, exist_ok=True) + + self.logger.info("downloading alpine secdb metadata from: {}".format(self.metadata_url)) + r = requests.get(self.metadata_url, timeout=self.download_timeout) + if r.status_code == 200: + try: + self.logger.debug("HTML parsing secdb landing page content for links") + parser = SecdbLandingParser() + parser.feed(r.text) + links = parser.links + except: + self.logger.warning("unable to html parse secdb landing page content for links") + + if not links: + self.logger.debug("string parsing secdb landing page content for links") + links = re.findall(self._link_finder_regex_, r.text) else: - raise Exception("unable to find release specific secdb links") - - for link in links: - if link not in ignore_links: - try: - rel = link.strip("/") - rel_dir = os.path.join(self.secdb_dir_path, rel) - os.makedirs(rel_dir, exist_ok=True) - for db_type in self._db_types: - file_name = "{}.yaml".format(db_type) - download_url = "/".join([self.metadata_url, rel, file_name]) - self._urls.add(download_url) - self.logger.info("Downloading secdb {} {}".format(rel, db_type)) - r = requests.get(download_url, stream=True, timeout=self.download_timeout) - if r.status_code == 200: - file_path = os.path.join(rel_dir, file_name) - with open(file_path, "wb") as fp: - for chunk in r.iter_content(): - fp.write(chunk) - else: - r.raise_for_status() - except: - self.logger.exception("ignoring error processing secdb for {}".format(link)) + r.raise_for_status() + except Exception: + self.logger.exception("error downloading or parsing alpine secdb metadata") + raise + + if links: + self.logger.debug("found release specific secdb links: {}".format(links)) + else: + raise Exception("unable to find release specific secdb links") + + for link in links: + if link not in ignore_links: + try: + rel = link.strip("/") + rel_dir = os.path.join(self.secdb_dir_path, rel) + os.makedirs(rel_dir, exist_ok=True) + for db_type in self._db_types: + file_name = "{}.yaml".format(db_type) + download_url = "/".join([self.metadata_url, rel, file_name]) + self._urls.add(download_url) + self.logger.info("Downloading secdb {} {}".format(rel, db_type)) + r = requests.get(download_url, stream=True, timeout=self.download_timeout) + if r.status_code == 200: + file_path = os.path.join(rel_dir, file_name) + with open(file_path, "wb") as fp: + for chunk in r.iter_content(): + fp.write(chunk) + else: + r.raise_for_status() + except: + self.logger.exception("ignoring error processing secdb for {}".format(link)) def _load(self): """ @@ -248,13 +243,13 @@ def _normalize(self, release, dbtype_data_dict): return vuln_dict - def get(self, skip_if_exists: bool = False): + def get(self): """ Download, load and normalize alpine sec db and return a dict of releae - list of vulnerability records :return: """ # download the data - self._download(skip_if_exists) + self._download() for release, dbtype_data_dict in self._load(): # normalize the loaded data diff --git a/src/vunnel/providers/amazon/parser.py b/src/vunnel/providers/amazon/parser.py index 12e8d72c..38b19704 100644 --- a/src/vunnel/providers/amazon/parser.py +++ b/src/vunnel/providers/amazon/parser.py @@ -48,22 +48,19 @@ def __init__(self, workspace, download_timeout=125, security_advisories=None, lo self.logger = logger @utils.retry_with_backoff() - def _download_rss(self, rss_url, rss_file, skip_if_exists=False): - if skip_if_exists and os.path.exists(rss_file): - self.logger.debug(f"'skip_if_exists' flag enabled and found {rss_file}. Skipping download") - else: - try: - self.logger.info(f"downloading amazon security advisory from {rss_url}") - self.urls.append(rss_url) - r = requests.get(rss_url, timeout=self.download_timeout) - if r.status_code == 200: - with open(rss_file, "w", encoding="utf-8") as fp: - fp.write(r.text) - else: - raise Exception(f"GET {rss_url} failed with HTTP error {r.status_code}") - except Exception: - self.logger.exception("error downloading amazon linux vulnerability feeds") - raise + def _download_rss(self, rss_url, rss_file): + try: + self.logger.info(f"downloading amazon security advisory from {rss_url}") + self.urls.append(rss_url) + r = requests.get(rss_url, timeout=self.download_timeout) + if r.status_code == 200: + with open(rss_file, "w", encoding="utf-8") as fp: + fp.write(r.text) + else: + raise Exception(f"GET {rss_url} failed with HTTP error {r.status_code}") + except Exception: + self.logger.exception("error downloading amazon linux vulnerability feeds") + raise def _parse_rss(self, file_path): self.logger.debug(f"parsing RSS data from {file_path}") @@ -99,7 +96,7 @@ def _parse_rss(self, file_path): @utils.retry_with_backoff() def _get_alas_html(self, alas_url, alas_file, skip_if_exists=True): if skip_if_exists and os.path.exists(alas_file): # read alas from disk if its available - self.logger.debug(f"loading ALAS from {alas_file}") + self.logger.debug(f"loading existing ALAS from {alas_file}") with open(alas_file, encoding="utf-8") as fp: content = fp.read() return content @@ -136,7 +133,7 @@ def get(self, skip_if_exists=False): rss_file = os.path.join(self.workspace.input_path, self._rss_file_name_.format(version)) html_dir = os.path.join(self.workspace.input_path, self._html_dir_name_.format(version)) - self._download_rss(url, rss_file, skip_if_exists) + self._download_rss(url, rss_file) # parse rss for alas summaries alas_summaries = self._parse_rss(rss_file) diff --git a/src/vunnel/providers/centos/__init__.py b/src/vunnel/providers/centos/__init__.py index c0cf26ab..85f4a81d 100644 --- a/src/vunnel/providers/centos/__init__.py +++ b/src/vunnel/providers/centos/__init__.py @@ -43,7 +43,7 @@ def update(self, last_updated: datetime.datetime | None) -> tuple[list[str], int # { (CVE, namespace): {...data...}" } # TODO: tech debt: on subsequent runs, we should only write new vulns (this currently re-writes all) - vuln_dict = self.parser.get(skip_if_exists=self.config.runtime.skip_if_exists) + vuln_dict = self.parser.get() self.logger.info(f"processed {len(vuln_dict)} entries") diff --git a/src/vunnel/providers/centos/parser.py b/src/vunnel/providers/centos/parser.py index ab52a61e..167cb9f8 100644 --- a/src/vunnel/providers/centos/parser.py +++ b/src/vunnel/providers/centos/parser.py @@ -98,51 +98,47 @@ def _get_sha256(self): raise Exception("Error fetching/processing sha256") @utils.retry_with_backoff() - def _download(self, skip_if_exists=False): - - if skip_if_exists and os.path.exists(self.xml_file_path): - self.logger.debug("'skip_if_exists' flag enabled and found {}. Skipping download".format(self.xml_file_path)) + def _download(self): + download = True + + if os.path.exists(self.xml_file_path) and os.path.exists(self.xml_sha_file_path): + with open(self.xml_sha_file_path) as fp: + previous = fp.read() + previous = previous.strip() + + latest = self._get_sha256() + self.logger.debug("previous sha256: {}, latest sha256: {}".format(previous, latest)) + download = previous.lower() != latest.lower() + + if download: + try: + self.logger.info("downloading RHSA from {}".format(self._url_)) + r = requests.get(self._url_, stream=True, timeout=self.download_timeout) + if r.status_code == 200: + # compute the sha256 as the file is decompressed + sha256 = hashlib.sha256() + with open(self.xml_file_path, "wb") as extracted: + decompressor = bz2.BZ2Decompressor() + for chunk in r.iter_content(chunk_size=1024): + uncchunk = decompressor.decompress(chunk) + extracted.write(uncchunk) + sha256.update(uncchunk) + + sha256sum = str(sha256.hexdigest()).lower() + self.logger.debug("sha256 for {}: {}".format(self.xml_file_path, sha256sum)) + + # save the sha256 to another file + with open(self.xml_sha_file_path, "w") as fp: + fp.write(sha256sum) + + return sha256sum + else: + raise Exception("GET {} failed with HTTP error {}".format(self._url_, r.status_code)) + except Exception: + self.logger.exception("error downloading RHSA file") + raise Exception("error downloading RHSA file") else: - download = True - - if os.path.exists(self.xml_file_path) and os.path.exists(self.xml_sha_file_path): - with open(self.xml_sha_file_path) as fp: - previous = fp.read() - previous = previous.strip() - - latest = self._get_sha256() - self.logger.debug("previous sha256: {}, latest sha256: {}".format(previous, latest)) - download = previous.lower() != latest.lower() - - if download: - try: - self.logger.info("downloading RHSA from {}".format(self._url_)) - r = requests.get(self._url_, stream=True, timeout=self.download_timeout) - if r.status_code == 200: - # compute the sha256 as the file is decompressed - sha256 = hashlib.sha256() - with open(self.xml_file_path, "wb") as extracted: - decompressor = bz2.BZ2Decompressor() - for chunk in r.iter_content(chunk_size=1024): - uncchunk = decompressor.decompress(chunk) - extracted.write(uncchunk) - sha256.update(uncchunk) - - sha256sum = str(sha256.hexdigest()).lower() - self.logger.debug("sha256 for {}: {}".format(self.xml_file_path, sha256sum)) - - # save the sha256 to another file - with open(self.xml_sha_file_path, "w") as fp: - fp.write(sha256sum) - - return sha256sum - else: - raise Exception("GET {} failed with HTTP error {}".format(self._url_, r.status_code)) - except Exception: - self.logger.exception("error downloading RHSA file") - raise Exception("error downloading RHSA file") - else: - self.logger.info("stored csum matches server csum. Skipping download") + self.logger.info("stored csum matches server csum. Skipping download") return None @@ -150,8 +146,8 @@ def parse(self): # normalize and return results return parse(self.xml_file_path, self.config) - def get(self, skip_if_exists=False): + def get(self): # download - self._download(skip_if_exists=skip_if_exists) + self._download() return self.parse() diff --git a/src/vunnel/providers/debian/__init__.py b/src/vunnel/providers/debian/__init__.py index e9cb442f..39426f74 100644 --- a/src/vunnel/providers/debian/__init__.py +++ b/src/vunnel/providers/debian/__init__.py @@ -48,7 +48,7 @@ def update(self, last_updated: datetime.datetime | None) -> tuple[list[str], int with self.results_writer() as writer: # TODO: tech debt: on subsequent runs, we should only write new vulns (this currently re-writes all) - for relno, vuln_id, record in self.parser.get(skip_if_exists=self.config.runtime.skip_if_exists): + for relno, vuln_id, record in self.parser.get(): vuln_id = vuln_id.lower() writer.write( identifier=os.path.join(f"debian:{relno}", vuln_id), diff --git a/src/vunnel/providers/debian/parser.py b/src/vunnel/providers/debian/parser.py index 5f1f9b72..1dafe32e 100644 --- a/src/vunnel/providers/debian/parser.py +++ b/src/vunnel/providers/debian/parser.py @@ -58,50 +58,44 @@ def __init__(self, workspace, download_timeout=125, logger=None, distro_map=None self.logger = logger @utils.retry_with_backoff() - def _download_json(self, skip_if_exists=False): + def _download_json(self): """ Downloads debian json file :return: """ - if skip_if_exists and os.path.exists(self.json_file_path): - self.logger.debug(f"'skip_if_exists' flag enabled and found {self.json_file_path}. Skipping download") - else: - try: - self.logger.info(f"downloading debian security tracker data from {self._dsa_url_}") + try: + self.logger.info(f"downloading debian security tracker data from {self._dsa_url_}") - r = requests.get(self._json_url_, timeout=self.download_timeout) - if r.status_code != 200: - raise Exception(f"GET {self._json_url_} failed with HTTP error {r.status_code}") + r = requests.get(self._json_url_, timeout=self.download_timeout) + if r.status_code != 200: + raise Exception(f"GET {self._json_url_} failed with HTTP error {r.status_code}") - json.loads(r.text) # quick check if json is valid - with open(self.json_file_path, "w", encoding="utf-8") as OFH: - OFH.write(r.text) + json.loads(r.text) # quick check if json is valid + with open(self.json_file_path, "w", encoding="utf-8") as OFH: + OFH.write(r.text) - except Exception: - self.logger.exception("Error downloading debian json file") - raise + except Exception: + self.logger.exception("Error downloading debian json file") + raise @utils.retry_with_backoff() - def _download_dsa(self, skip_if_exists=False): + def _download_dsa(self): """ Downloads debian dsa file :return: """ - if skip_if_exists and os.path.exists(self.dsa_file_path): - self.logger.debug(f"'skip_if_exists' flag enabled and found {self.dsa_file_path}. Skipping download") - else: - try: - self.logger.info(f"downloading DSA from {self._dsa_url_}") - r = requests.get(self._dsa_url_, timeout=self.download_timeout) - if r.status_code != 200: - raise Exception(f"GET {self._dsa_url_} failed with HTTP error {r.status_code}") + try: + self.logger.info(f"downloading DSA from {self._dsa_url_}") + r = requests.get(self._dsa_url_, timeout=self.download_timeout) + if r.status_code != 200: + raise Exception(f"GET {self._dsa_url_} failed with HTTP error {r.status_code}") - with open(self.dsa_file_path, "w", encoding="utf-8") as OFH: - OFH.write(r.text) + with open(self.dsa_file_path, "w", encoding="utf-8") as OFH: + OFH.write(r.text) - except Exception: - self.logger.exception("error downloading debian DSA file") - raise + except Exception: + self.logger.exception("error downloading debian DSA file") + raise def _get_cve_to_dsalist(self, dsa): """ @@ -462,10 +456,10 @@ def _normalize_json(self, ns_cve_dsalist=None): return vuln_records - def get(self, skip_if_exists=False): + def get(self): # download the files - self._download_json(skip_if_exists) - self._download_dsa(skip_if_exists) + self._download_json() + self._download_dsa() # normalize dsa list first ns_cve_dsalist = self._normalize_dsa_list() diff --git a/src/vunnel/providers/oracle/__init__.py b/src/vunnel/providers/oracle/__init__.py index 784d7de6..b28274ba 100644 --- a/src/vunnel/providers/oracle/__init__.py +++ b/src/vunnel/providers/oracle/__init__.py @@ -43,7 +43,7 @@ def update(self, last_updated: datetime.datetime | None) -> tuple[list[str], int with self.results_writer() as writer: # TODO: tech debt: on subsequent runs, we should only write new vulns (this currently re-writes all) - vuln_dict = self.parser.get(skip_if_exists=self.config.runtime.skip_if_exists) + vuln_dict = self.parser.get() for (vuln_id, namespace), (_, record) in vuln_dict.items(): namespace = namespace.lower() diff --git a/src/vunnel/providers/oracle/parser.py b/src/vunnel/providers/oracle/parser.py index 6560c5f4..b45791e7 100644 --- a/src/vunnel/providers/oracle/parser.py +++ b/src/vunnel/providers/oracle/parser.py @@ -66,11 +66,7 @@ def urls(self): return [self._url_] @utils.retry_with_backoff() - def _download(self, skip_if_exists=False): - if skip_if_exists and os.path.exists(self.xml_file_path): - self.logger.debug(f"'skip_if_exists' flag enabled and found {self.xml_file_path}. Skipping download") - return - + def _download(self): try: self.logger.info(f"downloading ELSA from {self._url_}") r = requests.get(self._url_, stream=True, timeout=self.download_timeout) @@ -99,9 +95,9 @@ def _parse_oval_data(self, path: str, config: dict): filtered_results = filterer.filter(raw_results) return filtered_results - def get(self, skip_if_exists=False): + def get(self): # download - self._download(skip_if_exists=skip_if_exists) + self._download() return self._parse_oval_data(self.xml_file_path, self.config) diff --git a/src/vunnel/providers/sles/__init__.py b/src/vunnel/providers/sles/__init__.py index cdd3619c..98df4628 100644 --- a/src/vunnel/providers/sles/__init__.py +++ b/src/vunnel/providers/sles/__init__.py @@ -48,7 +48,7 @@ def update(self, last_updated: datetime.datetime | None) -> tuple[list[str], int with self.results_writer() as writer: # TODO: tech debt: on subsequent runs, we should only write new vulns (this currently re-writes all) - for namespace, vuln_id, record in self.parser.get(skip_if_exists=self.config.runtime.skip_if_exists): + for namespace, vuln_id, record in self.parser.get(): namespace = namespace.lower() vuln_id = vuln_id.lower() writer.write( diff --git a/src/vunnel/providers/sles/parser.py b/src/vunnel/providers/sles/parser.py index a0114464..018bef83 100644 --- a/src/vunnel/providers/sles/parser.py +++ b/src/vunnel/providers/sles/parser.py @@ -68,20 +68,12 @@ def __init__( Parser.logger = logger @utils.retry_with_backoff() - def _download(self, major_version: str, skip_if_exists: bool = False) -> str: + def _download(self, major_version: str) -> str: if not os.path.exists(self.oval_dir_path): self.logger.debug(f"creating workspace for OVAL source data at {self.oval_dir_path}") os.makedirs(self.oval_dir_path) oval_file_path = os.path.join(self.oval_dir_path, self.__oval_file_name__.format(major_version)) - - if skip_if_exists and os.path.exists(oval_file_path): - self.logger.debug( - "'skip_if_exists' flag enabled and found %s. Skipping download", - oval_file_path, - ) - return oval_file_path - download_url = self.__oval_url__.format(major_version) self.urls.append(download_url) @@ -328,7 +320,7 @@ def _transform_oval_vulnerabilities(cls, major_version: str, parsed_dict: dict) return results - def get(self, skip_if_exists: bool = False): + def get(self): parser_factory = OVALParserFactory( parsers=[ SLESVulnerabilityParser, @@ -343,7 +335,7 @@ def get(self, skip_if_exists: bool = False): for major_version in self.allow_versions: try: # download oval - oval_file_path = self._download(major_version, skip_if_exists) + oval_file_path = self._download(major_version) # parse oval contents parsed_dict = iter_parse_vulnerability_file( diff --git a/src/vunnel/providers/wolfi/__init__.py b/src/vunnel/providers/wolfi/__init__.py index f56ac400..fd326b06 100644 --- a/src/vunnel/providers/wolfi/__init__.py +++ b/src/vunnel/providers/wolfi/__init__.py @@ -42,7 +42,7 @@ def update(self, last_updated: datetime.datetime | None) -> tuple[list[str], int with self.results_writer() as writer: # TODO: tech debt: on subsequent runs, we should only write new vulns (this currently re-writes all) - for release, vuln_dict in self.parser.get(skip_if_exists=self.config.runtime.skip_if_exists): + for release, vuln_dict in self.parser.get(): for vuln_id, record in vuln_dict.items(): writer.write( diff --git a/src/vunnel/providers/wolfi/parser.py b/src/vunnel/providers/wolfi/parser.py index ef435b1e..474d712c 100644 --- a/src/vunnel/providers/wolfi/parser.py +++ b/src/vunnel/providers/wolfi/parser.py @@ -31,39 +31,35 @@ def __init__(self, workspace, download_timeout=125, url=None, logger=None): self.logger = logger @utils.retry_with_backoff() - def _download(self, skip_if_exists=False): + def _download(self): """ Downloads wolfi sec db files :return: """ - - if skip_if_exists and os.path.exists(self.secdb_dir_path): - self.logger.debug(f"'skip_if_exists' flag enabled and found source under {self.secdb_dir_path}. Skipping download") - else: - if not os.path.exists(self.secdb_dir_path): - os.makedirs(self.secdb_dir_path, exist_ok=True) - - for t in self._db_types: - try: - rel_dir = os.path.join(self.secdb_dir_path, t) - os.makedirs(rel_dir, exist_ok=True) - - filename = "security.json" - download_url = f"{self.metadata_url}/{t}/{filename}" - - self.urls.append(download_url) - - self.logger.info(f"downloading Wolfi secdb {download_url}") - r = requests.get(download_url, stream=True, timeout=self.download_timeout) - if r.status_code == 200: - file_path = os.path.join(rel_dir, filename) - with open(file_path, "wb") as fp: - for chunk in r.iter_content(): - fp.write(chunk) - else: - r.raise_for_status() - except: # noqa - self.logger.exception(f"ignoring error processing secdb for {t}") + if not os.path.exists(self.secdb_dir_path): + os.makedirs(self.secdb_dir_path, exist_ok=True) + + for t in self._db_types: + try: + rel_dir = os.path.join(self.secdb_dir_path, t) + os.makedirs(rel_dir, exist_ok=True) + + filename = "security.json" + download_url = f"{self.metadata_url}/{t}/{filename}" + + self.urls.append(download_url) + + self.logger.info(f"downloading Wolfi secdb {download_url}") + r = requests.get(download_url, stream=True, timeout=self.download_timeout) + if r.status_code == 200: + file_path = os.path.join(rel_dir, filename) + with open(file_path, "wb") as fp: + for chunk in r.iter_content(): + fp.write(chunk) + else: + r.raise_for_status() + except: # noqa + self.logger.exception(f"ignoring error processing secdb for {t}") def _load(self): """ @@ -135,23 +131,6 @@ def _normalize(self, release, dbtype_data_dict): vuln_record["Vulnerability"]["NamespaceName"] = namespace + ":" + str(release) vuln_record["Vulnerability"]["Link"] = "http://cve.mitre.org/cgi-bin/cvename.cgi?name=" + str(vid) vuln_record["Vulnerability"]["Severity"] = "Unknown" - - # lookup nvd record only when creating the vulnerability, no point looking it up every time - nvd_severity = None - # TODO: ALEX fix this in grype-db-builder - # if session: - # try: - # nvd_severity = nvd.get_severity( - # vid, session=session - # ) - # except Exception: - # self.logger.exception( - # "Ignoring error processing nvdv2 record" - # ) - - # use nvd severity - if nvd_severity: - vuln_record["Vulnerability"]["Severity"] = nvd_severity else: vuln_record = vuln_dict[vid] @@ -167,13 +146,13 @@ def _normalize(self, release, dbtype_data_dict): return vuln_dict - def get(self, skip_if_exists=False): + def get(self): """ Download, load and normalize wolfi sec db and return a dict of release - list of vulnerability records :return: """ # download the data - self._download(skip_if_exists) + self._download() # load the data for release, dbtype_data_dict in self._load(): diff --git a/tests/unit/providers/alpine/test_alpine.py b/tests/unit/providers/alpine/test_alpine.py index cd6ed0ee..b0555723 100644 --- a/tests/unit/providers/alpine/test_alpine.py +++ b/tests/unit/providers/alpine/test_alpine.py @@ -199,7 +199,7 @@ def disabled(*args, **kwargs): monkeypatch.setattr(parser.requests, "get", disabled) -def test_provider_schema(helpers, disable_get_requests): +def test_provider_schema(helpers, disable_get_requests, monkeypatch): workspace = helpers.provider_workspace_helper(name=Provider.name()) c = Config() @@ -212,6 +212,11 @@ def test_provider_schema(helpers, disable_get_requests): mock_data_path = helpers.local_dir("test-fixtures/input") shutil.copytree(mock_data_path, workspace.input_dir, dirs_exist_ok=True) + def mock_download(): + return + + monkeypatch.setattr(p.parser, "_download", mock_download) + p.update(None) assert 16 == workspace.num_result_entries() diff --git a/tests/unit/providers/amazon/test_amazon.py b/tests/unit/providers/amazon/test_amazon.py index f5a0b3e2..2b871158 100644 --- a/tests/unit/providers/amazon/test_amazon.py +++ b/tests/unit/providers/amazon/test_amazon.py @@ -84,7 +84,7 @@ def disabled(*args, **kwargs): monkeypatch.setattr(parser.requests, "get", disabled) -def test_provider_schema(helpers, disable_get_requests): +def test_provider_schema(helpers, disable_get_requests, monkeypatch): workspace = helpers.provider_workspace_helper(name=Provider.name()) c = Config() @@ -94,6 +94,11 @@ def test_provider_schema(helpers, disable_get_requests): mock_data_path = helpers.local_dir("test-fixtures/input") shutil.copytree(mock_data_path, workspace.input_dir, dirs_exist_ok=True) + def mock_download(self, *args, **kwargs): + pass + + monkeypatch.setattr(p.parser, "_download_rss", mock_download) + p.update(None) assert 2 == workspace.num_result_entries() diff --git a/tests/unit/providers/centos/test_centos.py b/tests/unit/providers/centos/test_centos.py index 10e4d602..981a55de 100644 --- a/tests/unit/providers/centos/test_centos.py +++ b/tests/unit/providers/centos/test_centos.py @@ -6,7 +6,7 @@ from vunnel import result, workspace from vunnel.providers import centos -from vunnel.providers.centos.parser import Parser +from vunnel.providers.centos import Parser, parser @pytest.mark.parametrize( @@ -83,6 +83,14 @@ def test_parser(tmpdir, helpers, mock_data_path, full_entry): assert vuln == full_entry +@pytest.fixture +def disable_get_requests(monkeypatch): + def disabled(*args, **kwargs): + raise RuntimeError("requests disabled but HTTP GET attempted") + + monkeypatch.setattr(parser.requests, "get", disabled) + + @pytest.mark.parametrize( "mock_data_path,expected_written_entries", [ @@ -90,7 +98,7 @@ def test_parser(tmpdir, helpers, mock_data_path, full_entry): ("test-fixtures/centos-7-entry", 1), ], ) -def test_provider_schema(helpers, mock_data_path, expected_written_entries): +def test_provider_schema(helpers, mock_data_path, expected_written_entries, disable_get_requests, monkeypatch): workspace = helpers.provider_workspace_helper(name=centos.Provider.name()) mock_data_path = helpers.local_dir(mock_data_path) @@ -101,6 +109,12 @@ def test_provider_schema(helpers, mock_data_path, expected_written_entries): config=c, ) shutil.copy(mock_data_path, p.parser.xml_file_path) + + def mock_download(): + return None + + monkeypatch.setattr(p.parser, "_download", mock_download) + p.update(None) assert expected_written_entries == workspace.num_result_entries() diff --git a/tests/unit/providers/debian/test_debian.py b/tests/unit/providers/debian/test_debian.py index a8f6494f..b87d7c49 100644 --- a/tests/unit/providers/debian/test_debian.py +++ b/tests/unit/providers/debian/test_debian.py @@ -94,7 +94,7 @@ def test_normalize_json(self, tmpdir, helpers, disable_get_requests): assert all(x.get("Vulnerability", {}).get("Description") is not None for x in vuln_dict.values()) -def test_provider_schema(helpers, disable_get_requests): +def test_provider_schema(helpers, disable_get_requests, monkeypatch): workspace = helpers.provider_workspace_helper(name=Provider.name()) c = Config() @@ -107,6 +107,12 @@ def test_provider_schema(helpers, disable_get_requests): mock_data_path = helpers.local_dir("test-fixtures/input") shutil.copytree(mock_data_path, workspace.input_dir, dirs_exist_ok=True) + def mock_download(): + return None + + monkeypatch.setattr(p.parser, "_download_json", mock_download) + monkeypatch.setattr(p.parser, "_download_dsa", mock_download) + p.update(None) assert 21 == workspace.num_result_entries() diff --git a/tests/unit/providers/oracle/test_oracle.py b/tests/unit/providers/oracle/test_oracle.py index aa5ba4ad..9f13f6a8 100644 --- a/tests/unit/providers/oracle/test_oracle.py +++ b/tests/unit/providers/oracle/test_oracle.py @@ -370,7 +370,7 @@ def disabled(*args, **kwargs): monkeypatch.setattr(parser.requests, "get", disabled) -def test_provider_schema(helpers, disable_get_requests): +def test_provider_schema(helpers, disable_get_requests, monkeypatch): workspace = helpers.provider_workspace_helper(name=Provider.name()) c = Config() @@ -380,6 +380,11 @@ def test_provider_schema(helpers, disable_get_requests): mock_data_path = helpers.local_dir("test-fixtures/mock_data") shutil.copy(mock_data_path, workspace.input_dir / "com.oracle.elsa-all.xml") + def mock_download(): + return None + + monkeypatch.setattr(p.parser, "_download", mock_download) + p.update(None) assert 2 == workspace.num_result_entries()