Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

update deep_walk and unit test #1447

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions global_helpers/global_helpers_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import string
import sys
import unittest
from collections.abc import Sequence

from panther_core.enriched_event import PantherEvent
from panther_core.immutable import ImmutableCaseInsensitiveDict, ImmutableList
Expand Down Expand Up @@ -1420,7 +1421,11 @@ def test_deep_walk_success_random(self):
"""
for _ in range(1000):
data, keys, expected = self.generate_random_test_case_success()
self.assertEqual(p_b_h.deep_walk(data, *keys, default=""), expected)
result = p_b_h.deep_walk(data, *keys, default="")
if isinstance(result, Sequence) and not isinstance(result, str):
self.assertEqual(set(result), {expected})
else:
self.assertEqual(result, expected)

def test_deep_walk_default_random(self):
"""
Expand Down Expand Up @@ -1497,7 +1502,7 @@ def test_deep_walk_manual(self):
p_b_h.deep_walk(
event, "key", "very_nested", "outer_key", "nested_key2", "nested_key3", default=""
),
"value2",
["value2", "value2"],
)
self.assertEqual(
p_b_h.deep_walk(event, "key", "very_nested", "outer_key2", "nested_key4", default=""),
Expand Down
9 changes: 3 additions & 6 deletions global_helpers/panther_base_helpers.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import re
from base64 import b64decode
from binascii import Error as AsciiError
from collections import OrderedDict
from collections.abc import Mapping
from datetime import datetime
from fnmatch import fnmatch
Expand Down Expand Up @@ -73,7 +72,7 @@ def _empty_list(sub_obj: Any):
return default if _empty_list(obj) else obj

current_key = keys[0]
found: OrderedDict = OrderedDict()
found_list: list[Any] = []

if isinstance(obj, Mapping):
next_key = obj.get(current_key, None)
Expand All @@ -87,12 +86,10 @@ def _empty_list(sub_obj: Any):
value = deep_walk(item, *keys, default=default, return_val=return_val)
if value is not None:
if isinstance(value, Sequence) and not isinstance(value, str):
for sub_item in value:
found[sub_item] = None
found_list.extend(value)
else:
found[value] = None
found_list.append(value)

found_list: list[Any] = list(found.keys())
if not found_list:
return default
return {
Expand Down
Loading