Skip to content

Commit

Permalink
fix: remove flawed skip_if_exists logic (#53)
Browse files Browse the repository at this point in the history
The current skip_if_exists logic causes provider updates to be missed because it never revisits files once they are downloaded. This only worked previously because the flag was defaulted to false and never actually set for the majority of the legacy enterprise drivers.

Signed-off-by: Weston Steimel <[email protected]>
  • Loading branch information
westonsteimel authored Jan 30, 2023
1 parent 15d178c commit 840ee3b
Show file tree
Hide file tree
Showing 18 changed files with 216 additions and 232 deletions.
2 changes: 1 addition & 1 deletion src/vunnel/providers/alpine/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def update(self, last_updated: datetime.datetime | None) -> tuple[list[str], int

with self.results_writer() as writer:
# TODO: tech debt: on subsequent runs, we should only write new vulns (this currently re-writes all)
for namespace, vulns in self.parser.get(skip_if_exists=self.config.runtime.skip_if_exists):
for namespace, vulns in self.parser.get():
namespace = namespace.lower()
for vuln_id, record in vulns.items():
vuln_id = vuln_id.lower()
Expand Down
111 changes: 53 additions & 58 deletions src/vunnel/providers/alpine/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def __init__(
def urls(self) -> list[str]:
return list(self._urls)

def _download(self, skip_if_exists=False):
def _download(self):
"""
Downloads alpine sec db files
:return:
Expand All @@ -81,62 +81,57 @@ def _download(self, skip_if_exists=False):
if os.path.exists(os.path.join(self.secdb_dir_path, "alpine-secdb-master.tar.gz")):
os.remove(os.path.join(self.secdb_dir_path, "alpine-secdb-master.tar.gz"))

if skip_if_exists and os.path.exists(self.secdb_dir_path):
self.logger.debug(
"'skip_if_exists' flag enabled and found source under {}. Skipping download".format(self.secdb_dir_path)
)
else:
links = []
try:
if not os.path.exists(self.secdb_dir_path):
os.makedirs(self.secdb_dir_path, exist_ok=True)

self.logger.info("downloading alpine secdb metadata from: {}".format(self.metadata_url))
r = requests.get(self.metadata_url, timeout=self.download_timeout)
if r.status_code == 200:
try:
self.logger.debug("HTML parsing secdb landing page content for links")
parser = SecdbLandingParser()
parser.feed(r.text)
links = parser.links
except:
self.logger.warning("unable to html parse secdb landing page content for links")

if not links:
self.logger.debug("string parsing secdb landing page content for links")
links = re.findall(self._link_finder_regex_, r.text)
else:
r.raise_for_status()
except Exception:
self.logger.exception("error downloading or parsing alpine secdb metadata")
raise

if links:
self.logger.debug("found release specific secdb links: {}".format(links))
links = []
try:
if not os.path.exists(self.secdb_dir_path):
os.makedirs(self.secdb_dir_path, exist_ok=True)

self.logger.info("downloading alpine secdb metadata from: {}".format(self.metadata_url))
r = requests.get(self.metadata_url, timeout=self.download_timeout)
if r.status_code == 200:
try:
self.logger.debug("HTML parsing secdb landing page content for links")
parser = SecdbLandingParser()
parser.feed(r.text)
links = parser.links
except:
self.logger.warning("unable to html parse secdb landing page content for links")

if not links:
self.logger.debug("string parsing secdb landing page content for links")
links = re.findall(self._link_finder_regex_, r.text)
else:
raise Exception("unable to find release specific secdb links")

for link in links:
if link not in ignore_links:
try:
rel = link.strip("/")
rel_dir = os.path.join(self.secdb_dir_path, rel)
os.makedirs(rel_dir, exist_ok=True)
for db_type in self._db_types:
file_name = "{}.yaml".format(db_type)
download_url = "/".join([self.metadata_url, rel, file_name])
self._urls.add(download_url)
self.logger.info("Downloading secdb {} {}".format(rel, db_type))
r = requests.get(download_url, stream=True, timeout=self.download_timeout)
if r.status_code == 200:
file_path = os.path.join(rel_dir, file_name)
with open(file_path, "wb") as fp:
for chunk in r.iter_content():
fp.write(chunk)
else:
r.raise_for_status()
except:
self.logger.exception("ignoring error processing secdb for {}".format(link))
r.raise_for_status()
except Exception:
self.logger.exception("error downloading or parsing alpine secdb metadata")
raise

if links:
self.logger.debug("found release specific secdb links: {}".format(links))
else:
raise Exception("unable to find release specific secdb links")

for link in links:
if link not in ignore_links:
try:
rel = link.strip("/")
rel_dir = os.path.join(self.secdb_dir_path, rel)
os.makedirs(rel_dir, exist_ok=True)
for db_type in self._db_types:
file_name = "{}.yaml".format(db_type)
download_url = "/".join([self.metadata_url, rel, file_name])
self._urls.add(download_url)
self.logger.info("Downloading secdb {} {}".format(rel, db_type))
r = requests.get(download_url, stream=True, timeout=self.download_timeout)
if r.status_code == 200:
file_path = os.path.join(rel_dir, file_name)
with open(file_path, "wb") as fp:
for chunk in r.iter_content():
fp.write(chunk)
else:
r.raise_for_status()
except:
self.logger.exception("ignoring error processing secdb for {}".format(link))

def _load(self):
"""
Expand Down Expand Up @@ -248,13 +243,13 @@ def _normalize(self, release, dbtype_data_dict):

return vuln_dict

def get(self, skip_if_exists: bool = False):
def get(self):
"""
Download, load and normalize alpine sec db and return a dict of releae - list of vulnerability records
:return:
"""
# download the data
self._download(skip_if_exists)
self._download()

for release, dbtype_data_dict in self._load():
# normalize the loaded data
Expand Down
33 changes: 15 additions & 18 deletions src/vunnel/providers/amazon/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,22 +48,19 @@ def __init__(self, workspace, download_timeout=125, security_advisories=None, lo
self.logger = logger

@utils.retry_with_backoff()
def _download_rss(self, rss_url, rss_file, skip_if_exists=False):
if skip_if_exists and os.path.exists(rss_file):
self.logger.debug(f"'skip_if_exists' flag enabled and found {rss_file}. Skipping download")
else:
try:
self.logger.info(f"downloading amazon security advisory from {rss_url}")
self.urls.append(rss_url)
r = requests.get(rss_url, timeout=self.download_timeout)
if r.status_code == 200:
with open(rss_file, "w", encoding="utf-8") as fp:
fp.write(r.text)
else:
raise Exception(f"GET {rss_url} failed with HTTP error {r.status_code}")
except Exception:
self.logger.exception("error downloading amazon linux vulnerability feeds")
raise
def _download_rss(self, rss_url, rss_file):
try:
self.logger.info(f"downloading amazon security advisory from {rss_url}")
self.urls.append(rss_url)
r = requests.get(rss_url, timeout=self.download_timeout)
if r.status_code == 200:
with open(rss_file, "w", encoding="utf-8") as fp:
fp.write(r.text)
else:
raise Exception(f"GET {rss_url} failed with HTTP error {r.status_code}")
except Exception:
self.logger.exception("error downloading amazon linux vulnerability feeds")
raise

def _parse_rss(self, file_path):
self.logger.debug(f"parsing RSS data from {file_path}")
Expand Down Expand Up @@ -99,7 +96,7 @@ def _parse_rss(self, file_path):
@utils.retry_with_backoff()
def _get_alas_html(self, alas_url, alas_file, skip_if_exists=True):
if skip_if_exists and os.path.exists(alas_file): # read alas from disk if its available
self.logger.debug(f"loading ALAS from {alas_file}")
self.logger.debug(f"loading existing ALAS from {alas_file}")
with open(alas_file, encoding="utf-8") as fp:
content = fp.read()
return content
Expand Down Expand Up @@ -136,7 +133,7 @@ def get(self, skip_if_exists=False):
rss_file = os.path.join(self.workspace.input_path, self._rss_file_name_.format(version))
html_dir = os.path.join(self.workspace.input_path, self._html_dir_name_.format(version))

self._download_rss(url, rss_file, skip_if_exists)
self._download_rss(url, rss_file)

# parse rss for alas summaries
alas_summaries = self._parse_rss(rss_file)
Expand Down
2 changes: 1 addition & 1 deletion src/vunnel/providers/centos/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def update(self, last_updated: datetime.datetime | None) -> tuple[list[str], int
# { (CVE, namespace): {...data...}" }

# TODO: tech debt: on subsequent runs, we should only write new vulns (this currently re-writes all)
vuln_dict = self.parser.get(skip_if_exists=self.config.runtime.skip_if_exists)
vuln_dict = self.parser.get()

self.logger.info(f"processed {len(vuln_dict)} entries")

Expand Down
88 changes: 42 additions & 46 deletions src/vunnel/providers/centos/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,60 +98,56 @@ def _get_sha256(self):
raise Exception("Error fetching/processing sha256")

@utils.retry_with_backoff()
def _download(self, skip_if_exists=False):

if skip_if_exists and os.path.exists(self.xml_file_path):
self.logger.debug("'skip_if_exists' flag enabled and found {}. Skipping download".format(self.xml_file_path))
def _download(self):
download = True

if os.path.exists(self.xml_file_path) and os.path.exists(self.xml_sha_file_path):
with open(self.xml_sha_file_path) as fp:
previous = fp.read()
previous = previous.strip()

latest = self._get_sha256()
self.logger.debug("previous sha256: {}, latest sha256: {}".format(previous, latest))
download = previous.lower() != latest.lower()

if download:
try:
self.logger.info("downloading RHSA from {}".format(self._url_))
r = requests.get(self._url_, stream=True, timeout=self.download_timeout)
if r.status_code == 200:
# compute the sha256 as the file is decompressed
sha256 = hashlib.sha256()
with open(self.xml_file_path, "wb") as extracted:
decompressor = bz2.BZ2Decompressor()
for chunk in r.iter_content(chunk_size=1024):
uncchunk = decompressor.decompress(chunk)
extracted.write(uncchunk)
sha256.update(uncchunk)

sha256sum = str(sha256.hexdigest()).lower()
self.logger.debug("sha256 for {}: {}".format(self.xml_file_path, sha256sum))

# save the sha256 to another file
with open(self.xml_sha_file_path, "w") as fp:
fp.write(sha256sum)

return sha256sum
else:
raise Exception("GET {} failed with HTTP error {}".format(self._url_, r.status_code))
except Exception:
self.logger.exception("error downloading RHSA file")
raise Exception("error downloading RHSA file")
else:
download = True

if os.path.exists(self.xml_file_path) and os.path.exists(self.xml_sha_file_path):
with open(self.xml_sha_file_path) as fp:
previous = fp.read()
previous = previous.strip()

latest = self._get_sha256()
self.logger.debug("previous sha256: {}, latest sha256: {}".format(previous, latest))
download = previous.lower() != latest.lower()

if download:
try:
self.logger.info("downloading RHSA from {}".format(self._url_))
r = requests.get(self._url_, stream=True, timeout=self.download_timeout)
if r.status_code == 200:
# compute the sha256 as the file is decompressed
sha256 = hashlib.sha256()
with open(self.xml_file_path, "wb") as extracted:
decompressor = bz2.BZ2Decompressor()
for chunk in r.iter_content(chunk_size=1024):
uncchunk = decompressor.decompress(chunk)
extracted.write(uncchunk)
sha256.update(uncchunk)

sha256sum = str(sha256.hexdigest()).lower()
self.logger.debug("sha256 for {}: {}".format(self.xml_file_path, sha256sum))

# save the sha256 to another file
with open(self.xml_sha_file_path, "w") as fp:
fp.write(sha256sum)

return sha256sum
else:
raise Exception("GET {} failed with HTTP error {}".format(self._url_, r.status_code))
except Exception:
self.logger.exception("error downloading RHSA file")
raise Exception("error downloading RHSA file")
else:
self.logger.info("stored csum matches server csum. Skipping download")
self.logger.info("stored csum matches server csum. Skipping download")

return None

def parse(self):
# normalize and return results
return parse(self.xml_file_path, self.config)

def get(self, skip_if_exists=False):
def get(self):
# download
self._download(skip_if_exists=skip_if_exists)
self._download()

return self.parse()
2 changes: 1 addition & 1 deletion src/vunnel/providers/debian/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def update(self, last_updated: datetime.datetime | None) -> tuple[list[str], int

with self.results_writer() as writer:
# TODO: tech debt: on subsequent runs, we should only write new vulns (this currently re-writes all)
for relno, vuln_id, record in self.parser.get(skip_if_exists=self.config.runtime.skip_if_exists):
for relno, vuln_id, record in self.parser.get():
vuln_id = vuln_id.lower()
writer.write(
identifier=os.path.join(f"debian:{relno}", vuln_id),
Expand Down
Loading

0 comments on commit 840ee3b

Please sign in to comment.