diff --git a/plaso/engine/artifacts_trie.py b/plaso/engine/artifacts_trie.py index bec64b9e16..a634e3250c 100644 --- a/plaso/engine/artifacts_trie.py +++ b/plaso/engine/artifacts_trie.py @@ -16,15 +16,15 @@ class TrieNode(object): artifacts_names (list[str]): Names of artifacts associated with this node. children (dict[str, TrieNode]): Child nodes, keyed by path segment. is_root (bool): True if this is the root node. - path_separator (str): Path separator used in the Trie. + path_separator (str): Path separator used in the Trie, default is '/'. """ - def __init__(self, path_separator=None, is_root=False): + def __init__(self, path_separator='/', is_root=False): """Initializes a trie node object. Args: - path_separator (str): the path separator used in paths stored in - the Trie, typically '/' or '\'. + path_separator (Optional[str]): the path separator used in paths stored + in the Trie, default is '/'. is_root (bool): True if this node is the root node. """ super(TrieNode, self).__init__() @@ -61,28 +61,29 @@ def AddPath(self, artifact_name, path, path_separator): path_list = self.artifacts_paths.setdefault(artifact_name, []) path_list.append(path) - node = self.root - + current_node = self.root + # Add a path separator node if this is a new separator. - if path_separator not in node.children: - node.children[path_separator] = TrieNode(path_separator=path_separator) - node = node.children[path_separator] - + if path_separator not in current_node.children: + current_node.children[path_separator] = TrieNode( + path_separator=path_separator) + current_node = current_node.children[path_separator] + # Handle the case when the input path is equal to the path_separator. if path == path_separator: - node.artifacts_names.append(artifact_name) + current_node.artifacts_names.append(artifact_name) return path_segments = path.strip(path_separator).split(path_separator) - for segment in path_segments: + for path_segment in path_segments: # Store the path_separator for each node. - if not hasattr(node, 'path_separator'): - node.path_separator = path_separator + if not hasattr(current_node, 'path_separator'): + current_node.path_separator = path_separator - if segment not in node.children: - node.children[segment] = TrieNode(path_separator) - node = node.children[segment] - node.artifacts_names.append(artifact_name) + if path_segment not in current_node.children: + current_node.children[path_segment] = TrieNode(path_separator) + current_node = current_node.children[path_segment] + current_node.artifacts_names.append(artifact_name) def GetMatchingArtifacts(self, path, path_separator): """Retrieves the artifact names that match the given path. @@ -99,7 +100,7 @@ def GetMatchingArtifacts(self, path, path_separator): return [] sub_root_node = self.root.children[path_separator] - + # Handle the case when the input path is equal to the path_separator. if path == path_separator: matching_artifacts = set() @@ -111,61 +112,66 @@ def GetMatchingArtifacts(self, path, path_separator): # Update self.artifacts_paths before starting the search. self.artifacts_paths = self._GetArtifactsPaths(sub_root_node) - def _search_trie(node, current_path, segments): - """Searches the trie for paths matching the given path segments. + self._SearchTrie( + sub_root_node, '', path_segments, path_separator, matching_artifacts) + return list(matching_artifacts) - Args: - node (TrieNode): current trie node being traversed. - current_path (str): path represented by the current node. - segments (list[str]): remaining path segments to match. - """ - if node.artifacts_names: - for artifact_name in node.artifacts_names: - for artifact_path in self.artifacts_paths.get(artifact_name, []): - if self._ComparePathIfSanitized( - current_path, path_separator, artifact_path, - node.path_separator): + def _SearchTrie( + self,node, current_path, segments, path_separator, matching_artifacts): + """Searches the trie for paths matching the given path segments. + + Args: + node (TrieNode): current trie node being traversed. + current_path (str): path represented by the current node. + segments (list[str]): remaining path segments to match. + path_separator (str): path separator. + matching_artifacts (set[str]): Set to store matching artifact names. + """ + if node.artifacts_names: + for artifact_name in node.artifacts_names: + for artifact_path in self.artifacts_paths.get(artifact_name, []): + if self._ComparePathIfSanitized( + current_path, path_separator, artifact_path, + node.path_separator): + matching_artifacts.add(artifact_name) + elif glob.has_magic(artifact_path): + if self._MatchesGlobPattern( + artifact_path, current_path, node.path_separator): matching_artifacts.add(artifact_name) - elif glob.has_magic(artifact_path): - if self._MatchesGlobPattern( - artifact_path, current_path, node.path_separator): - matching_artifacts.add(artifact_name) - - if not segments: - return - - segment = segments[0] - remaining_segments = segments[1:] - - # Handle glob characters in the current segment. - for child_segment, child_node in node.children.items(): - if ( - child_segment == segment or - # comapring the sanitized version of the path segment stored in - # the tree to the path segment from to the tool output as it - # sanitizes path segments before writting data to disk. - path_helper.PathHelper.SanitizePathSegments( - [child_segment]).pop() == segment - ): - # If the child is an exact match, continue traversal. - _search_trie(child_node, self._CustomPathJoin( + + if not segments: + return + + segment = segments[0] + remaining_segments = segments[1:] + + # Handle glob characters in the current segment. + for child_segment, child_node in node.children.items(): + # Compare the sanitized version of the path segment stored in + # the tree to the path segment from to the tool output as it + # sanitizes path segments before writing data to disk. + sanitized_child_segment = path_helper.PathHelper.SanitizePathSegments( + [child_segment]).pop() + if segment in (child_segment, sanitized_child_segment): + # If the child is an exact match, continue traversal. + self._SearchTrie(child_node, self._CustomPathJoin( + path_separator, + current_path, child_segment), remaining_segments, path_separator, + matching_artifacts) + # If the child is a glob, see if it matches. + elif glob.has_magic(child_segment): + if self._MatchesGlobPattern( + child_segment, segment, child_node.path_separator): + self._SearchTrie(child_node, self._CustomPathJoin( path_separator, - current_path, child_segment), remaining_segments) - # If the child is a glob, see if it matches. - elif glob.has_magic(child_segment): - if self._MatchesGlobPattern( - child_segment, segment, child_node.path_separator): - _search_trie(child_node, self._CustomPathJoin( - path_separator, - current_path, segment), remaining_segments) - _search_trie( - node, - self._CustomPathJoin( - path_separator, - current_path, segment), remaining_segments) - - _search_trie(sub_root_node, '', path_segments) - return list(matching_artifacts) + current_path, segment), remaining_segments, path_separator, + matching_artifacts) + self._SearchTrie( + node, + self._CustomPathJoin( + path_separator, + current_path, segment), remaining_segments, path_separator, + matching_artifacts) def _ComparePathIfSanitized( self, current_path, path_separator, artifact_path, @@ -275,20 +281,24 @@ def _MatchesGlobPattern(self, glob_pattern, path, path_separator): glob_pattern = glob_pattern.strip(path_separator).split(path_separator) path = path.strip(path_separator).split(path_separator) - i = 0 - j = 0 - while i < len(glob_pattern) and j < len(path): - if glob_pattern[i] == '**': + glob_index = 0 + path_index = 0 + while glob_index < len(glob_pattern) and path_index < len(path): + if glob_pattern[glob_index] == '**': # If ** is the last part, it matches everything remaining - if i == len(glob_pattern) - 1: + if glob_index == len(glob_pattern) - 1: return True - i += 1 # Move to the next part after ** - while j < len(path) and not fnmatch.fnmatch(path[j], glob_pattern[i]): - j += 1 # Keep advancing in the path until the next part matches - elif not fnmatch.fnmatch(path[j], glob_pattern[i]): - return False # Mismatch + # Move to the next part after ** + glob_index += 1 + # Keep advancing in the path until the next part matches + while path_index < len(path) and not fnmatch.fnmatch( + path[path_index], glob_pattern[glob_index]): + path_index += 1 + elif not fnmatch.fnmatch(path[path_index], glob_pattern[glob_index]): + # Mismatch + return False else: - i += 1 - j += 1 + glob_index += 1 + path_index += 1 - return i == len(glob_pattern) and j == len(path) + return glob_index == len(glob_pattern) and path_index == len(path) diff --git a/tests/engine/artifacts_trie.py b/tests/engine/artifacts_trie.py index add63751d3..f0d592f4bc 100644 --- a/tests/engine/artifacts_trie.py +++ b/tests/engine/artifacts_trie.py @@ -15,9 +15,7 @@ def test_initialization(self): self.assertIsNotNone(node) self.assertEqual(node.children, {}) self.assertEqual(node.artifacts_names, []) - self.assertIsNone(node.path_separator) - - # You can add more tests for TrieNode if needed, but it's a simple class. + self.assertEqual(node.path_separator, '/') class ArtifactsTrieTest(unittest.TestCase):