diff --git a/src/vunnel/providers/alpine/__init__.py b/src/vunnel/providers/alpine/__init__.py index 9d0b96aa..e44700f7 100644 --- a/src/vunnel/providers/alpine/__init__.py +++ b/src/vunnel/providers/alpine/__init__.py @@ -41,7 +41,7 @@ def update(self, last_updated: datetime.datetime | None) -> tuple[list[str], int with self.results_writer() as writer: # TODO: tech debt: on subsequent runs, we should only write new vulns (this currently re-writes all) - for namespace, vulns in self.parser.get(skip_if_exists=self.config.runtime.skip_if_exists): + for namespace, vulns in self.parser.get(): namespace = namespace.lower() for vuln_id, record in vulns.items(): vuln_id = vuln_id.lower() diff --git a/src/vunnel/providers/alpine/parser.py b/src/vunnel/providers/alpine/parser.py index 9405790d..96e97ed9 100644 --- a/src/vunnel/providers/alpine/parser.py +++ b/src/vunnel/providers/alpine/parser.py @@ -70,7 +70,7 @@ def __init__( def urls(self) -> list[str]: return list(self._urls) - def _download(self, skip_if_exists=False): + def _download(self): """ Downloads alpine sec db files :return: @@ -81,62 +81,57 @@ def _download(self, skip_if_exists=False): if os.path.exists(os.path.join(self.secdb_dir_path, "alpine-secdb-master.tar.gz")): os.remove(os.path.join(self.secdb_dir_path, "alpine-secdb-master.tar.gz")) - if skip_if_exists and os.path.exists(self.secdb_dir_path): - self.logger.debug( - "'skip_if_exists' flag enabled and found source under {}. Skipping download".format(self.secdb_dir_path) - ) - else: - links = [] - try: - if not os.path.exists(self.secdb_dir_path): - os.makedirs(self.secdb_dir_path, exist_ok=True) - - self.logger.info("downloading alpine secdb metadata from: {}".format(self.metadata_url)) - r = requests.get(self.metadata_url, timeout=self.download_timeout) - if r.status_code == 200: - try: - self.logger.debug("HTML parsing secdb landing page content for links") - parser = SecdbLandingParser() - parser.feed(r.text) - links = parser.links - except: - self.logger.warning("unable to html parse secdb landing page content for links") - - if not links: - self.logger.debug("string parsing secdb landing page content for links") - links = re.findall(self._link_finder_regex_, r.text) - else: - r.raise_for_status() - except Exception: - self.logger.exception("error downloading or parsing alpine secdb metadata") - raise - - if links: - self.logger.debug("found release specific secdb links: {}".format(links)) + links = [] + try: + if not os.path.exists(self.secdb_dir_path): + os.makedirs(self.secdb_dir_path, exist_ok=True) + + self.logger.info("downloading alpine secdb metadata from: {}".format(self.metadata_url)) + r = requests.get(self.metadata_url, timeout=self.download_timeout) + if r.status_code == 200: + try: + self.logger.debug("HTML parsing secdb landing page content for links") + parser = SecdbLandingParser() + parser.feed(r.text) + links = parser.links + except: + self.logger.warning("unable to html parse secdb landing page content for links") + + if not links: + self.logger.debug("string parsing secdb landing page content for links") + links = re.findall(self._link_finder_regex_, r.text) else: - raise Exception("unable to find release specific secdb links") - - for link in links: - if link not in ignore_links: - try: - rel = link.strip("/") - rel_dir = os.path.join(self.secdb_dir_path, rel) - os.makedirs(rel_dir, exist_ok=True) - for db_type in self._db_types: - file_name = "{}.yaml".format(db_type) - download_url = "/".join([self.metadata_url, rel, file_name]) - self._urls.add(download_url) - self.logger.info("Downloading secdb {} {}".format(rel, db_type)) - r = requests.get(download_url, stream=True, timeout=self.download_timeout) - if r.status_code == 200: - file_path = os.path.join(rel_dir, file_name) - with open(file_path, "wb") as fp: - for chunk in r.iter_content(): - fp.write(chunk) - else: - r.raise_for_status() - except: - self.logger.exception("ignoring error processing secdb for {}".format(link)) + r.raise_for_status() + except Exception: + self.logger.exception("error downloading or parsing alpine secdb metadata") + raise + + if links: + self.logger.debug("found release specific secdb links: {}".format(links)) + else: + raise Exception("unable to find release specific secdb links") + + for link in links: + if link not in ignore_links: + try: + rel = link.strip("/") + rel_dir = os.path.join(self.secdb_dir_path, rel) + os.makedirs(rel_dir, exist_ok=True) + for db_type in self._db_types: + file_name = "{}.yaml".format(db_type) + download_url = "/".join([self.metadata_url, rel, file_name]) + self._urls.add(download_url) + self.logger.info("Downloading secdb {} {}".format(rel, db_type)) + r = requests.get(download_url, stream=True, timeout=self.download_timeout) + if r.status_code == 200: + file_path = os.path.join(rel_dir, file_name) + with open(file_path, "wb") as fp: + for chunk in r.iter_content(): + fp.write(chunk) + else: + r.raise_for_status() + except: + self.logger.exception("ignoring error processing secdb for {}".format(link)) def _load(self): """ @@ -248,13 +243,13 @@ def _normalize(self, release, dbtype_data_dict): return vuln_dict - def get(self, skip_if_exists: bool = False): + def get(self): """ Download, load and normalize alpine sec db and return a dict of releae - list of vulnerability records :return: """ # download the data - self._download(skip_if_exists) + self._download() for release, dbtype_data_dict in self._load(): # normalize the loaded data diff --git a/src/vunnel/providers/amazon/parser.py b/src/vunnel/providers/amazon/parser.py index 12e8d72c..38b19704 100644 --- a/src/vunnel/providers/amazon/parser.py +++ b/src/vunnel/providers/amazon/parser.py @@ -48,22 +48,19 @@ def __init__(self, workspace, download_timeout=125, security_advisories=None, lo self.logger = logger @utils.retry_with_backoff() - def _download_rss(self, rss_url, rss_file, skip_if_exists=False): - if skip_if_exists and os.path.exists(rss_file): - self.logger.debug(f"'skip_if_exists' flag enabled and found {rss_file}. Skipping download") - else: - try: - self.logger.info(f"downloading amazon security advisory from {rss_url}") - self.urls.append(rss_url) - r = requests.get(rss_url, timeout=self.download_timeout) - if r.status_code == 200: - with open(rss_file, "w", encoding="utf-8") as fp: - fp.write(r.text) - else: - raise Exception(f"GET {rss_url} failed with HTTP error {r.status_code}") - except Exception: - self.logger.exception("error downloading amazon linux vulnerability feeds") - raise + def _download_rss(self, rss_url, rss_file): + try: + self.logger.info(f"downloading amazon security advisory from {rss_url}") + self.urls.append(rss_url) + r = requests.get(rss_url, timeout=self.download_timeout) + if r.status_code == 200: + with open(rss_file, "w", encoding="utf-8") as fp: + fp.write(r.text) + else: + raise Exception(f"GET {rss_url} failed with HTTP error {r.status_code}") + except Exception: + self.logger.exception("error downloading amazon linux vulnerability feeds") + raise def _parse_rss(self, file_path): self.logger.debug(f"parsing RSS data from {file_path}") @@ -99,7 +96,7 @@ def _parse_rss(self, file_path): @utils.retry_with_backoff() def _get_alas_html(self, alas_url, alas_file, skip_if_exists=True): if skip_if_exists and os.path.exists(alas_file): # read alas from disk if its available - self.logger.debug(f"loading ALAS from {alas_file}") + self.logger.debug(f"loading existing ALAS from {alas_file}") with open(alas_file, encoding="utf-8") as fp: content = fp.read() return content @@ -136,7 +133,7 @@ def get(self, skip_if_exists=False): rss_file = os.path.join(self.workspace.input_path, self._rss_file_name_.format(version)) html_dir = os.path.join(self.workspace.input_path, self._html_dir_name_.format(version)) - self._download_rss(url, rss_file, skip_if_exists) + self._download_rss(url, rss_file) # parse rss for alas summaries alas_summaries = self._parse_rss(rss_file) diff --git a/src/vunnel/providers/centos/__init__.py b/src/vunnel/providers/centos/__init__.py index c0cf26ab..85f4a81d 100644 --- a/src/vunnel/providers/centos/__init__.py +++ b/src/vunnel/providers/centos/__init__.py @@ -43,7 +43,7 @@ def update(self, last_updated: datetime.datetime | None) -> tuple[list[str], int # { (CVE, namespace): {...data...}" } # TODO: tech debt: on subsequent runs, we should only write new vulns (this currently re-writes all) - vuln_dict = self.parser.get(skip_if_exists=self.config.runtime.skip_if_exists) + vuln_dict = self.parser.get() self.logger.info(f"processed {len(vuln_dict)} entries") diff --git a/src/vunnel/providers/centos/parser.py b/src/vunnel/providers/centos/parser.py index ab52a61e..167cb9f8 100644 --- a/src/vunnel/providers/centos/parser.py +++ b/src/vunnel/providers/centos/parser.py @@ -98,51 +98,47 @@ def _get_sha256(self): raise Exception("Error fetching/processing sha256") @utils.retry_with_backoff() - def _download(self, skip_if_exists=False): - - if skip_if_exists and os.path.exists(self.xml_file_path): - self.logger.debug("'skip_if_exists' flag enabled and found {}. Skipping download".format(self.xml_file_path)) + def _download(self): + download = True + + if os.path.exists(self.xml_file_path) and os.path.exists(self.xml_sha_file_path): + with open(self.xml_sha_file_path) as fp: + previous = fp.read() + previous = previous.strip() + + latest = self._get_sha256() + self.logger.debug("previous sha256: {}, latest sha256: {}".format(previous, latest)) + download = previous.lower() != latest.lower() + + if download: + try: + self.logger.info("downloading RHSA from {}".format(self._url_)) + r = requests.get(self._url_, stream=True, timeout=self.download_timeout) + if r.status_code == 200: + # compute the sha256 as the file is decompressed + sha256 = hashlib.sha256() + with open(self.xml_file_path, "wb") as extracted: + decompressor = bz2.BZ2Decompressor() + for chunk in r.iter_content(chunk_size=1024): + uncchunk = decompressor.decompress(chunk) + extracted.write(uncchunk) + sha256.update(uncchunk) + + sha256sum = str(sha256.hexdigest()).lower() + self.logger.debug("sha256 for {}: {}".format(self.xml_file_path, sha256sum)) + + # save the sha256 to another file + with open(self.xml_sha_file_path, "w") as fp: + fp.write(sha256sum) + + return sha256sum + else: + raise Exception("GET {} failed with HTTP error {}".format(self._url_, r.status_code)) + except Exception: + self.logger.exception("error downloading RHSA file") + raise Exception("error downloading RHSA file") else: - download = True - - if os.path.exists(self.xml_file_path) and os.path.exists(self.xml_sha_file_path): - with open(self.xml_sha_file_path) as fp: - previous = fp.read() - previous = previous.strip() - - latest = self._get_sha256() - self.logger.debug("previous sha256: {}, latest sha256: {}".format(previous, latest)) - download = previous.lower() != latest.lower() - - if download: - try: - self.logger.info("downloading RHSA from {}".format(self._url_)) - r = requests.get(self._url_, stream=True, timeout=self.download_timeout) - if r.status_code == 200: - # compute the sha256 as the file is decompressed - sha256 = hashlib.sha256() - with open(self.xml_file_path, "wb") as extracted: - decompressor = bz2.BZ2Decompressor() - for chunk in r.iter_content(chunk_size=1024): - uncchunk = decompressor.decompress(chunk) - extracted.write(uncchunk) - sha256.update(uncchunk) - - sha256sum = str(sha256.hexdigest()).lower() - self.logger.debug("sha256 for {}: {}".format(self.xml_file_path, sha256sum)) - - # save the sha256 to another file - with open(self.xml_sha_file_path, "w") as fp: - fp.write(sha256sum) - - return sha256sum - else: - raise Exception("GET {} failed with HTTP error {}".format(self._url_, r.status_code)) - except Exception: - self.logger.exception("error downloading RHSA file") - raise Exception("error downloading RHSA file") - else: - self.logger.info("stored csum matches server csum. Skipping download") + self.logger.info("stored csum matches server csum. Skipping download") return None @@ -150,8 +146,8 @@ def parse(self): # normalize and return results return parse(self.xml_file_path, self.config) - def get(self, skip_if_exists=False): + def get(self): # download - self._download(skip_if_exists=skip_if_exists) + self._download() return self.parse() diff --git a/src/vunnel/providers/debian/__init__.py b/src/vunnel/providers/debian/__init__.py index e9cb442f..39426f74 100644 --- a/src/vunnel/providers/debian/__init__.py +++ b/src/vunnel/providers/debian/__init__.py @@ -48,7 +48,7 @@ def update(self, last_updated: datetime.datetime | None) -> tuple[list[str], int with self.results_writer() as writer: # TODO: tech debt: on subsequent runs, we should only write new vulns (this currently re-writes all) - for relno, vuln_id, record in self.parser.get(skip_if_exists=self.config.runtime.skip_if_exists): + for relno, vuln_id, record in self.parser.get(): vuln_id = vuln_id.lower() writer.write( identifier=os.path.join(f"debian:{relno}", vuln_id), diff --git a/src/vunnel/providers/debian/parser.py b/src/vunnel/providers/debian/parser.py index 5f1f9b72..1dafe32e 100644 --- a/src/vunnel/providers/debian/parser.py +++ b/src/vunnel/providers/debian/parser.py @@ -58,50 +58,44 @@ def __init__(self, workspace, download_timeout=125, logger=None, distro_map=None self.logger = logger @utils.retry_with_backoff() - def _download_json(self, skip_if_exists=False): + def _download_json(self): """ Downloads debian json file :return: """ - if skip_if_exists and os.path.exists(self.json_file_path): - self.logger.debug(f"'skip_if_exists' flag enabled and found {self.json_file_path}. Skipping download") - else: - try: - self.logger.info(f"downloading debian security tracker data from {self._dsa_url_}") + try: + self.logger.info(f"downloading debian security tracker data from {self._dsa_url_}") - r = requests.get(self._json_url_, timeout=self.download_timeout) - if r.status_code != 200: - raise Exception(f"GET {self._json_url_} failed with HTTP error {r.status_code}") + r = requests.get(self._json_url_, timeout=self.download_timeout) + if r.status_code != 200: + raise Exception(f"GET {self._json_url_} failed with HTTP error {r.status_code}") - json.loads(r.text) # quick check if json is valid - with open(self.json_file_path, "w", encoding="utf-8") as OFH: - OFH.write(r.text) + json.loads(r.text) # quick check if json is valid + with open(self.json_file_path, "w", encoding="utf-8") as OFH: + OFH.write(r.text) - except Exception: - self.logger.exception("Error downloading debian json file") - raise + except Exception: + self.logger.exception("Error downloading debian json file") + raise @utils.retry_with_backoff() - def _download_dsa(self, skip_if_exists=False): + def _download_dsa(self): """ Downloads debian dsa file :return: """ - if skip_if_exists and os.path.exists(self.dsa_file_path): - self.logger.debug(f"'skip_if_exists' flag enabled and found {self.dsa_file_path}. Skipping download") - else: - try: - self.logger.info(f"downloading DSA from {self._dsa_url_}") - r = requests.get(self._dsa_url_, timeout=self.download_timeout) - if r.status_code != 200: - raise Exception(f"GET {self._dsa_url_} failed with HTTP error {r.status_code}") + try: + self.logger.info(f"downloading DSA from {self._dsa_url_}") + r = requests.get(self._dsa_url_, timeout=self.download_timeout) + if r.status_code != 200: + raise Exception(f"GET {self._dsa_url_} failed with HTTP error {r.status_code}") - with open(self.dsa_file_path, "w", encoding="utf-8") as OFH: - OFH.write(r.text) + with open(self.dsa_file_path, "w", encoding="utf-8") as OFH: + OFH.write(r.text) - except Exception: - self.logger.exception("error downloading debian DSA file") - raise + except Exception: + self.logger.exception("error downloading debian DSA file") + raise def _get_cve_to_dsalist(self, dsa): """ @@ -462,10 +456,10 @@ def _normalize_json(self, ns_cve_dsalist=None): return vuln_records - def get(self, skip_if_exists=False): + def get(self): # download the files - self._download_json(skip_if_exists) - self._download_dsa(skip_if_exists) + self._download_json() + self._download_dsa() # normalize dsa list first ns_cve_dsalist = self._normalize_dsa_list() diff --git a/src/vunnel/providers/oracle/__init__.py b/src/vunnel/providers/oracle/__init__.py index 784d7de6..b28274ba 100644 --- a/src/vunnel/providers/oracle/__init__.py +++ b/src/vunnel/providers/oracle/__init__.py @@ -43,7 +43,7 @@ def update(self, last_updated: datetime.datetime | None) -> tuple[list[str], int with self.results_writer() as writer: # TODO: tech debt: on subsequent runs, we should only write new vulns (this currently re-writes all) - vuln_dict = self.parser.get(skip_if_exists=self.config.runtime.skip_if_exists) + vuln_dict = self.parser.get() for (vuln_id, namespace), (_, record) in vuln_dict.items(): namespace = namespace.lower() diff --git a/src/vunnel/providers/oracle/parser.py b/src/vunnel/providers/oracle/parser.py index 6560c5f4..b45791e7 100644 --- a/src/vunnel/providers/oracle/parser.py +++ b/src/vunnel/providers/oracle/parser.py @@ -66,11 +66,7 @@ def urls(self): return [self._url_] @utils.retry_with_backoff() - def _download(self, skip_if_exists=False): - if skip_if_exists and os.path.exists(self.xml_file_path): - self.logger.debug(f"'skip_if_exists' flag enabled and found {self.xml_file_path}. Skipping download") - return - + def _download(self): try: self.logger.info(f"downloading ELSA from {self._url_}") r = requests.get(self._url_, stream=True, timeout=self.download_timeout) @@ -99,9 +95,9 @@ def _parse_oval_data(self, path: str, config: dict): filtered_results = filterer.filter(raw_results) return filtered_results - def get(self, skip_if_exists=False): + def get(self): # download - self._download(skip_if_exists=skip_if_exists) + self._download() return self._parse_oval_data(self.xml_file_path, self.config) diff --git a/src/vunnel/providers/sles/__init__.py b/src/vunnel/providers/sles/__init__.py index cdd3619c..98df4628 100644 --- a/src/vunnel/providers/sles/__init__.py +++ b/src/vunnel/providers/sles/__init__.py @@ -48,7 +48,7 @@ def update(self, last_updated: datetime.datetime | None) -> tuple[list[str], int with self.results_writer() as writer: # TODO: tech debt: on subsequent runs, we should only write new vulns (this currently re-writes all) - for namespace, vuln_id, record in self.parser.get(skip_if_exists=self.config.runtime.skip_if_exists): + for namespace, vuln_id, record in self.parser.get(): namespace = namespace.lower() vuln_id = vuln_id.lower() writer.write( diff --git a/src/vunnel/providers/sles/parser.py b/src/vunnel/providers/sles/parser.py index a0114464..018bef83 100644 --- a/src/vunnel/providers/sles/parser.py +++ b/src/vunnel/providers/sles/parser.py @@ -68,20 +68,12 @@ def __init__( Parser.logger = logger @utils.retry_with_backoff() - def _download(self, major_version: str, skip_if_exists: bool = False) -> str: + def _download(self, major_version: str) -> str: if not os.path.exists(self.oval_dir_path): self.logger.debug(f"creating workspace for OVAL source data at {self.oval_dir_path}") os.makedirs(self.oval_dir_path) oval_file_path = os.path.join(self.oval_dir_path, self.__oval_file_name__.format(major_version)) - - if skip_if_exists and os.path.exists(oval_file_path): - self.logger.debug( - "'skip_if_exists' flag enabled and found %s. Skipping download", - oval_file_path, - ) - return oval_file_path - download_url = self.__oval_url__.format(major_version) self.urls.append(download_url) @@ -328,7 +320,7 @@ def _transform_oval_vulnerabilities(cls, major_version: str, parsed_dict: dict) return results - def get(self, skip_if_exists: bool = False): + def get(self): parser_factory = OVALParserFactory( parsers=[ SLESVulnerabilityParser, @@ -343,7 +335,7 @@ def get(self, skip_if_exists: bool = False): for major_version in self.allow_versions: try: # download oval - oval_file_path = self._download(major_version, skip_if_exists) + oval_file_path = self._download(major_version) # parse oval contents parsed_dict = iter_parse_vulnerability_file( diff --git a/src/vunnel/providers/wolfi/__init__.py b/src/vunnel/providers/wolfi/__init__.py index f56ac400..fd326b06 100644 --- a/src/vunnel/providers/wolfi/__init__.py +++ b/src/vunnel/providers/wolfi/__init__.py @@ -42,7 +42,7 @@ def update(self, last_updated: datetime.datetime | None) -> tuple[list[str], int with self.results_writer() as writer: # TODO: tech debt: on subsequent runs, we should only write new vulns (this currently re-writes all) - for release, vuln_dict in self.parser.get(skip_if_exists=self.config.runtime.skip_if_exists): + for release, vuln_dict in self.parser.get(): for vuln_id, record in vuln_dict.items(): writer.write( diff --git a/src/vunnel/providers/wolfi/parser.py b/src/vunnel/providers/wolfi/parser.py index ef435b1e..474d712c 100644 --- a/src/vunnel/providers/wolfi/parser.py +++ b/src/vunnel/providers/wolfi/parser.py @@ -31,39 +31,35 @@ def __init__(self, workspace, download_timeout=125, url=None, logger=None): self.logger = logger @utils.retry_with_backoff() - def _download(self, skip_if_exists=False): + def _download(self): """ Downloads wolfi sec db files :return: """ - - if skip_if_exists and os.path.exists(self.secdb_dir_path): - self.logger.debug(f"'skip_if_exists' flag enabled and found source under {self.secdb_dir_path}. Skipping download") - else: - if not os.path.exists(self.secdb_dir_path): - os.makedirs(self.secdb_dir_path, exist_ok=True) - - for t in self._db_types: - try: - rel_dir = os.path.join(self.secdb_dir_path, t) - os.makedirs(rel_dir, exist_ok=True) - - filename = "security.json" - download_url = f"{self.metadata_url}/{t}/{filename}" - - self.urls.append(download_url) - - self.logger.info(f"downloading Wolfi secdb {download_url}") - r = requests.get(download_url, stream=True, timeout=self.download_timeout) - if r.status_code == 200: - file_path = os.path.join(rel_dir, filename) - with open(file_path, "wb") as fp: - for chunk in r.iter_content(): - fp.write(chunk) - else: - r.raise_for_status() - except: # noqa - self.logger.exception(f"ignoring error processing secdb for {t}") + if not os.path.exists(self.secdb_dir_path): + os.makedirs(self.secdb_dir_path, exist_ok=True) + + for t in self._db_types: + try: + rel_dir = os.path.join(self.secdb_dir_path, t) + os.makedirs(rel_dir, exist_ok=True) + + filename = "security.json" + download_url = f"{self.metadata_url}/{t}/{filename}" + + self.urls.append(download_url) + + self.logger.info(f"downloading Wolfi secdb {download_url}") + r = requests.get(download_url, stream=True, timeout=self.download_timeout) + if r.status_code == 200: + file_path = os.path.join(rel_dir, filename) + with open(file_path, "wb") as fp: + for chunk in r.iter_content(): + fp.write(chunk) + else: + r.raise_for_status() + except: # noqa + self.logger.exception(f"ignoring error processing secdb for {t}") def _load(self): """ @@ -135,23 +131,6 @@ def _normalize(self, release, dbtype_data_dict): vuln_record["Vulnerability"]["NamespaceName"] = namespace + ":" + str(release) vuln_record["Vulnerability"]["Link"] = "http://cve.mitre.org/cgi-bin/cvename.cgi?name=" + str(vid) vuln_record["Vulnerability"]["Severity"] = "Unknown" - - # lookup nvd record only when creating the vulnerability, no point looking it up every time - nvd_severity = None - # TODO: ALEX fix this in grype-db-builder - # if session: - # try: - # nvd_severity = nvd.get_severity( - # vid, session=session - # ) - # except Exception: - # self.logger.exception( - # "Ignoring error processing nvdv2 record" - # ) - - # use nvd severity - if nvd_severity: - vuln_record["Vulnerability"]["Severity"] = nvd_severity else: vuln_record = vuln_dict[vid] @@ -167,13 +146,13 @@ def _normalize(self, release, dbtype_data_dict): return vuln_dict - def get(self, skip_if_exists=False): + def get(self): """ Download, load and normalize wolfi sec db and return a dict of release - list of vulnerability records :return: """ # download the data - self._download(skip_if_exists) + self._download() # load the data for release, dbtype_data_dict in self._load(): diff --git a/tests/unit/providers/alpine/test_alpine.py b/tests/unit/providers/alpine/test_alpine.py index cd6ed0ee..b0555723 100644 --- a/tests/unit/providers/alpine/test_alpine.py +++ b/tests/unit/providers/alpine/test_alpine.py @@ -199,7 +199,7 @@ def disabled(*args, **kwargs): monkeypatch.setattr(parser.requests, "get", disabled) -def test_provider_schema(helpers, disable_get_requests): +def test_provider_schema(helpers, disable_get_requests, monkeypatch): workspace = helpers.provider_workspace_helper(name=Provider.name()) c = Config() @@ -212,6 +212,11 @@ def test_provider_schema(helpers, disable_get_requests): mock_data_path = helpers.local_dir("test-fixtures/input") shutil.copytree(mock_data_path, workspace.input_dir, dirs_exist_ok=True) + def mock_download(): + return + + monkeypatch.setattr(p.parser, "_download", mock_download) + p.update(None) assert 16 == workspace.num_result_entries() diff --git a/tests/unit/providers/amazon/test_amazon.py b/tests/unit/providers/amazon/test_amazon.py index f5a0b3e2..2b871158 100644 --- a/tests/unit/providers/amazon/test_amazon.py +++ b/tests/unit/providers/amazon/test_amazon.py @@ -84,7 +84,7 @@ def disabled(*args, **kwargs): monkeypatch.setattr(parser.requests, "get", disabled) -def test_provider_schema(helpers, disable_get_requests): +def test_provider_schema(helpers, disable_get_requests, monkeypatch): workspace = helpers.provider_workspace_helper(name=Provider.name()) c = Config() @@ -94,6 +94,11 @@ def test_provider_schema(helpers, disable_get_requests): mock_data_path = helpers.local_dir("test-fixtures/input") shutil.copytree(mock_data_path, workspace.input_dir, dirs_exist_ok=True) + def mock_download(self, *args, **kwargs): + pass + + monkeypatch.setattr(p.parser, "_download_rss", mock_download) + p.update(None) assert 2 == workspace.num_result_entries() diff --git a/tests/unit/providers/centos/test_centos.py b/tests/unit/providers/centos/test_centos.py index 10e4d602..981a55de 100644 --- a/tests/unit/providers/centos/test_centos.py +++ b/tests/unit/providers/centos/test_centos.py @@ -6,7 +6,7 @@ from vunnel import result, workspace from vunnel.providers import centos -from vunnel.providers.centos.parser import Parser +from vunnel.providers.centos import Parser, parser @pytest.mark.parametrize( @@ -83,6 +83,14 @@ def test_parser(tmpdir, helpers, mock_data_path, full_entry): assert vuln == full_entry +@pytest.fixture +def disable_get_requests(monkeypatch): + def disabled(*args, **kwargs): + raise RuntimeError("requests disabled but HTTP GET attempted") + + monkeypatch.setattr(parser.requests, "get", disabled) + + @pytest.mark.parametrize( "mock_data_path,expected_written_entries", [ @@ -90,7 +98,7 @@ def test_parser(tmpdir, helpers, mock_data_path, full_entry): ("test-fixtures/centos-7-entry", 1), ], ) -def test_provider_schema(helpers, mock_data_path, expected_written_entries): +def test_provider_schema(helpers, mock_data_path, expected_written_entries, disable_get_requests, monkeypatch): workspace = helpers.provider_workspace_helper(name=centos.Provider.name()) mock_data_path = helpers.local_dir(mock_data_path) @@ -101,6 +109,12 @@ def test_provider_schema(helpers, mock_data_path, expected_written_entries): config=c, ) shutil.copy(mock_data_path, p.parser.xml_file_path) + + def mock_download(): + return None + + monkeypatch.setattr(p.parser, "_download", mock_download) + p.update(None) assert expected_written_entries == workspace.num_result_entries() diff --git a/tests/unit/providers/debian/test_debian.py b/tests/unit/providers/debian/test_debian.py index a8f6494f..b87d7c49 100644 --- a/tests/unit/providers/debian/test_debian.py +++ b/tests/unit/providers/debian/test_debian.py @@ -94,7 +94,7 @@ def test_normalize_json(self, tmpdir, helpers, disable_get_requests): assert all(x.get("Vulnerability", {}).get("Description") is not None for x in vuln_dict.values()) -def test_provider_schema(helpers, disable_get_requests): +def test_provider_schema(helpers, disable_get_requests, monkeypatch): workspace = helpers.provider_workspace_helper(name=Provider.name()) c = Config() @@ -107,6 +107,12 @@ def test_provider_schema(helpers, disable_get_requests): mock_data_path = helpers.local_dir("test-fixtures/input") shutil.copytree(mock_data_path, workspace.input_dir, dirs_exist_ok=True) + def mock_download(): + return None + + monkeypatch.setattr(p.parser, "_download_json", mock_download) + monkeypatch.setattr(p.parser, "_download_dsa", mock_download) + p.update(None) assert 21 == workspace.num_result_entries() diff --git a/tests/unit/providers/oracle/test_oracle.py b/tests/unit/providers/oracle/test_oracle.py index aa5ba4ad..9f13f6a8 100644 --- a/tests/unit/providers/oracle/test_oracle.py +++ b/tests/unit/providers/oracle/test_oracle.py @@ -370,7 +370,7 @@ def disabled(*args, **kwargs): monkeypatch.setattr(parser.requests, "get", disabled) -def test_provider_schema(helpers, disable_get_requests): +def test_provider_schema(helpers, disable_get_requests, monkeypatch): workspace = helpers.provider_workspace_helper(name=Provider.name()) c = Config() @@ -380,6 +380,11 @@ def test_provider_schema(helpers, disable_get_requests): mock_data_path = helpers.local_dir("test-fixtures/mock_data") shutil.copy(mock_data_path, workspace.input_dir / "com.oracle.elsa-all.xml") + def mock_download(): + return None + + monkeypatch.setattr(p.parser, "_download", mock_download) + p.update(None) assert 2 == workspace.num_result_entries()