Skip to content

Commit

Permalink
applying reviewer comments
Browse files Browse the repository at this point in the history
  • Loading branch information
sa3eed3ed committed Jan 22, 2025
1 parent 01e9e98 commit 5d77c1b
Show file tree
Hide file tree
Showing 2 changed files with 95 additions and 87 deletions.
178 changes: 94 additions & 84 deletions plaso/engine/artifacts_trie.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,15 @@ class TrieNode(object):
artifacts_names (list[str]): Names of artifacts associated with this node.
children (dict[str, TrieNode]): Child nodes, keyed by path segment.
is_root (bool): True if this is the root node.
path_separator (str): Path separator used in the Trie.
path_separator (str): Path separator used in the Trie, default is '/'.
"""

def __init__(self, path_separator=None, is_root=False):
def __init__(self, path_separator='/', is_root=False):
"""Initializes a trie node object.
Args:
path_separator (str): the path separator used in paths stored in
the Trie, typically '/' or '\'.
path_separator (Optional[str]): the path separator used in paths stored
in the Trie, default is '/'.
is_root (bool): True if this node is the root node.
"""
super(TrieNode, self).__init__()
Expand Down Expand Up @@ -61,28 +61,29 @@ def AddPath(self, artifact_name, path, path_separator):
path_list = self.artifacts_paths.setdefault(artifact_name, [])
path_list.append(path)

node = self.root
current_node = self.root

# Add a path separator node if this is a new separator.
if path_separator not in node.children:
node.children[path_separator] = TrieNode(path_separator=path_separator)
node = node.children[path_separator]

if path_separator not in current_node.children:
current_node.children[path_separator] = TrieNode(
path_separator=path_separator)
current_node = current_node.children[path_separator]

# Handle the case when the input path is equal to the path_separator.
if path == path_separator:
node.artifacts_names.append(artifact_name)
current_node.artifacts_names.append(artifact_name)
return

path_segments = path.strip(path_separator).split(path_separator)
for segment in path_segments:
for path_segment in path_segments:
# Store the path_separator for each node.
if not hasattr(node, 'path_separator'):
node.path_separator = path_separator
if not hasattr(current_node, 'path_separator'):
current_node.path_separator = path_separator

if segment not in node.children:
node.children[segment] = TrieNode(path_separator)
node = node.children[segment]
node.artifacts_names.append(artifact_name)
if path_segment not in current_node.children:
current_node.children[path_segment] = TrieNode(path_separator)
current_node = current_node.children[path_segment]
current_node.artifacts_names.append(artifact_name)

def GetMatchingArtifacts(self, path, path_separator):
"""Retrieves the artifact names that match the given path.
Expand All @@ -99,7 +100,7 @@ def GetMatchingArtifacts(self, path, path_separator):
return []

sub_root_node = self.root.children[path_separator]

# Handle the case when the input path is equal to the path_separator.
if path == path_separator:
matching_artifacts = set()
Expand All @@ -111,61 +112,66 @@ def GetMatchingArtifacts(self, path, path_separator):
# Update self.artifacts_paths before starting the search.
self.artifacts_paths = self._GetArtifactsPaths(sub_root_node)

def _search_trie(node, current_path, segments):
"""Searches the trie for paths matching the given path segments.
self._SearchTrie(
sub_root_node, '', path_segments, path_separator, matching_artifacts)
return list(matching_artifacts)

Args:
node (TrieNode): current trie node being traversed.
current_path (str): path represented by the current node.
segments (list[str]): remaining path segments to match.
"""
if node.artifacts_names:
for artifact_name in node.artifacts_names:
for artifact_path in self.artifacts_paths.get(artifact_name, []):
if self._ComparePathIfSanitized(
current_path, path_separator, artifact_path,
node.path_separator):
def _SearchTrie(
self,node, current_path, segments, path_separator, matching_artifacts):
"""Searches the trie for paths matching the given path segments.
Args:
node (TrieNode): current trie node being traversed.
current_path (str): path represented by the current node.
segments (list[str]): remaining path segments to match.
path_separator (str): path separator.
matching_artifacts (set[str]): Set to store matching artifact names.
"""
if node.artifacts_names:
for artifact_name in node.artifacts_names:
for artifact_path in self.artifacts_paths.get(artifact_name, []):
if self._ComparePathIfSanitized(
current_path, path_separator, artifact_path,
node.path_separator):
matching_artifacts.add(artifact_name)
elif glob.has_magic(artifact_path):
if self._MatchesGlobPattern(
artifact_path, current_path, node.path_separator):
matching_artifacts.add(artifact_name)
elif glob.has_magic(artifact_path):
if self._MatchesGlobPattern(
artifact_path, current_path, node.path_separator):
matching_artifacts.add(artifact_name)

if not segments:
return

segment = segments[0]
remaining_segments = segments[1:]

# Handle glob characters in the current segment.
for child_segment, child_node in node.children.items():
if (
child_segment == segment or
# comapring the sanitized version of the path segment stored in
# the tree to the path segment from to the tool output as it
# sanitizes path segments before writting data to disk.
path_helper.PathHelper.SanitizePathSegments(
[child_segment]).pop() == segment
):
# If the child is an exact match, continue traversal.
_search_trie(child_node, self._CustomPathJoin(

if not segments:
return

segment = segments[0]
remaining_segments = segments[1:]

# Handle glob characters in the current segment.
for child_segment, child_node in node.children.items():
# Compare the sanitized version of the path segment stored in
# the tree to the path segment from to the tool output as it
# sanitizes path segments before writing data to disk.
sanitized_child_segment = path_helper.PathHelper.SanitizePathSegments(
[child_segment]).pop()
if segment in (child_segment, sanitized_child_segment):
# If the child is an exact match, continue traversal.
self._SearchTrie(child_node, self._CustomPathJoin(
path_separator,
current_path, child_segment), remaining_segments, path_separator,
matching_artifacts)
# If the child is a glob, see if it matches.
elif glob.has_magic(child_segment):
if self._MatchesGlobPattern(
child_segment, segment, child_node.path_separator):
self._SearchTrie(child_node, self._CustomPathJoin(
path_separator,
current_path, child_segment), remaining_segments)
# If the child is a glob, see if it matches.
elif glob.has_magic(child_segment):
if self._MatchesGlobPattern(
child_segment, segment, child_node.path_separator):
_search_trie(child_node, self._CustomPathJoin(
path_separator,
current_path, segment), remaining_segments)
_search_trie(
node,
self._CustomPathJoin(
path_separator,
current_path, segment), remaining_segments)

_search_trie(sub_root_node, '', path_segments)
return list(matching_artifacts)
current_path, segment), remaining_segments, path_separator,
matching_artifacts)
self._SearchTrie(
node,
self._CustomPathJoin(
path_separator,
current_path, segment), remaining_segments, path_separator,
matching_artifacts)

def _ComparePathIfSanitized(
self, current_path, path_separator, artifact_path,
Expand Down Expand Up @@ -275,20 +281,24 @@ def _MatchesGlobPattern(self, glob_pattern, path, path_separator):
glob_pattern = glob_pattern.strip(path_separator).split(path_separator)
path = path.strip(path_separator).split(path_separator)

i = 0
j = 0
while i < len(glob_pattern) and j < len(path):
if glob_pattern[i] == '**':
glob_index = 0
path_index = 0
while glob_index < len(glob_pattern) and path_index < len(path):
if glob_pattern[glob_index] == '**':
# If ** is the last part, it matches everything remaining
if i == len(glob_pattern) - 1:
if glob_index == len(glob_pattern) - 1:
return True
i += 1 # Move to the next part after **
while j < len(path) and not fnmatch.fnmatch(path[j], glob_pattern[i]):
j += 1 # Keep advancing in the path until the next part matches
elif not fnmatch.fnmatch(path[j], glob_pattern[i]):
return False # Mismatch
# Move to the next part after **
glob_index += 1
# Keep advancing in the path until the next part matches
while path_index < len(path) and not fnmatch.fnmatch(
path[path_index], glob_pattern[glob_index]):
path_index += 1
elif not fnmatch.fnmatch(path[path_index], glob_pattern[glob_index]):
# Mismatch
return False
else:
i += 1
j += 1
glob_index += 1
path_index += 1

return i == len(glob_pattern) and j == len(path)
return glob_index == len(glob_pattern) and path_index == len(path)
4 changes: 1 addition & 3 deletions tests/engine/artifacts_trie.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,7 @@ def test_initialization(self):
self.assertIsNotNone(node)
self.assertEqual(node.children, {})
self.assertEqual(node.artifacts_names, [])
self.assertIsNone(node.path_separator)

# You can add more tests for TrieNode if needed, but it's a simple class.
self.assertEqual(node.path_separator, '/')


class ArtifactsTrieTest(unittest.TestCase):
Expand Down

0 comments on commit 5d77c1b

Please sign in to comment.