Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(data-masking): add custom mask functionalities #5837

Open
wants to merge 13 commits into
base: develop
Choose a base branch
from
136 changes: 126 additions & 10 deletions aws_lambda_powertools/utilities/data_masking/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import functools
import logging
import warnings
from copy import deepcopy
from typing import TYPE_CHECKING, Any, Callable, Mapping, Sequence, overload

from jsonpath_ng.ext import parse
Expand Down Expand Up @@ -94,15 +95,53 @@ def erase(self, data: tuple, fields: list[str]) -> tuple[str]: ...
@overload
def erase(self, data: dict, fields: list[str]) -> dict: ...

def erase(self, data: Sequence | Mapping, fields: list[str] | None = None) -> str | list[str] | tuple[str] | dict:
return self._apply_action(data=data, fields=fields, action=self.provider.erase)
@overload
def erase(self, data: dict[Any, Any], *, masking_rules: dict[str, object]) -> dict[Any, Any]: ...
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure if we need the arg * here. Can you try to remove this method signature and see if mypy complains?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Tried to remove, but increases the number of mypy errors

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Lets keep this conversation opened until we find a solution for this.


@overload
def erase(
self,
data: dict,
fields: list[str],
dynamic_mask: bool | None = None,
custom_mask: str | None = None,
regex_pattern: str | None = None,
mask_format: str | None = None,
) -> dict: ...

def erase(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we will need to remove all the overloads and keep only this function implementation. Now we have different use cases that I don't know if we will can cover all of them and make mypy happy.

self,
data: Sequence | Mapping,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
data: Sequence | Mapping,
data: Any,

fields: list[str] | None = None,
dynamic_mask: bool | None = None,
custom_mask: str | None = None,
regex_pattern: str | None = None,
mask_format: str | None = None,
masking_rules: dict | None = None,
) -> str | list[str] | tuple[str] | dict:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
) -> str | list[str] | tuple[str] | dict:
) -> Any:

if masking_rules:
return self._apply_masking_rules(data, masking_rules)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
return self._apply_masking_rules(data, masking_rules)
return self._apply_masking_rules(data=data, masking_rules=masking_rules)

else:
return self._apply_action(
data=data,
fields=fields,
action=self.provider.erase,
dynamic_mask=dynamic_mask,
custom_mask=custom_mask,
regex_pattern=regex_pattern,
mask_format=mask_format,
)

def _apply_action(
self,
data,
fields: list[str] | None,
action: Callable,
provider_options: dict | None = None,
dynamic_mask: bool | None = None,
custom_mask: str | None = None,
regex_pattern: str | None = None,
mask_format: str | None = None,
**encryption_context: str,
):
"""
Expand Down Expand Up @@ -136,18 +175,34 @@ def _apply_action(
fields=fields,
action=action,
provider_options=provider_options,
dynamic_mask=dynamic_mask,
custom_mask=custom_mask,
regex_pattern=regex_pattern,
mask_format=mask_format,
**encryption_context,
)
else:
logger.debug(f"Running action {action.__name__} with the entire data")
return action(data=data, provider_options=provider_options, **encryption_context)
return action(
data=data,
provider_options=provider_options,
dynamic_mask=dynamic_mask,
custom_mask=custom_mask,
regex_pattern=regex_pattern,
mask_format=mask_format,
**encryption_context,
)

def _apply_action_to_fields(
self,
data: dict | str,
fields: list,
action: Callable,
provider_options: dict | None = None,
dynamic_mask: bool | None = None,
custom_mask: str | None = None,
regex_pattern: str | None = None,
mask_format: str | None = None,
**encryption_context: str,
) -> dict | str:
"""
Expand Down Expand Up @@ -194,6 +249,8 @@ def _apply_action_to_fields(
new_dict = {'a': {'b': {'c': '*****'}}, 'x': {'y': '*****'}}
```
"""
if not fields:
raise ValueError("Fields parameter cannot be empty")

data_parsed: dict = self._normalize_data_to_parse(fields, data)

Expand All @@ -204,6 +261,10 @@ def _apply_action_to_fields(
self._call_action,
action=action,
provider_options=provider_options,
dynamic_mask=dynamic_mask,
custom_mask=custom_mask,
regex_pattern=regex_pattern,
mask_format=mask_format,
**encryption_context, # type: ignore[arg-type]
)

Expand All @@ -225,12 +286,6 @@ def _apply_action_to_fields(
# For in-place updates, json_parse accepts a callback function
# that receives 3 args: field_value, fields, field_name
# We create a partial callback to pre-populate known provider options (action, provider opts, enc ctx)
update_callback = functools.partial(
self._call_action,
action=action,
provider_options=provider_options,
**encryption_context, # type: ignore[arg-type]
)

json_parse.update(
data_parsed,
Expand All @@ -239,13 +294,66 @@ def _apply_action_to_fields(

return data_parsed

def _apply_masking_rules(self, data: dict, masking_rules: dict) -> dict:
"""
Apply masking rules to data, supporting both simple field names and complex path expressions.
Args:
data: The dictionary containing data to mask
masking_rules: Dictionary mapping field names or path expressions to masking rules
Returns:
dict: The masked data dictionary
"""
result = deepcopy(data)

for path, rule in masking_rules.items():
try:
jsonpath_expr = parse(f"$.{path}")
matches = jsonpath_expr.find(result)

if not matches:
warnings.warn(f"No matches found for path: {path}", stacklevel=2)
continue

for match in matches:
try:
value = match.value
if value is not None:
masked_value = self.provider.erase(str(value), **rule)
match.full_path.update(result, masked_value)

except Exception as e:
warnings.warn(f"Error masking value for path {path}: {str(e)}", stacklevel=2)
continue

except Exception as e:
warnings.warn(f"Error processing path {path}: {str(e)}", stacklevel=2)
continue

return result

def _mask_nested_field(self, data: dict, field_path: str, mask_function):
keys = field_path.split(".")
current = data
for key in keys[:-1]:
current = current.get(key, {})
if not isinstance(current, dict):
return
if keys[-1] in current:
current[keys[-1]] = mask_function(current[keys[-1]])

@staticmethod
def _call_action(
field_value: Any,
fields: dict[str, Any],
field_name: str,
action: Callable,
provider_options: dict[str, Any] | None = None,
dynamic_mask: bool | None = None,
custom_mask: str | None = None,
regex_pattern: str | None = None,
mask_format: str | None = None,
**encryption_context,
) -> None:
"""
Expand All @@ -263,7 +371,15 @@ def _call_action(
Returns:
- fields[field_name]: Returns the processed field value
"""
fields[field_name] = action(field_value, provider_options=provider_options, **encryption_context)
fields[field_name] = action(
field_value,
provider_options=provider_options,
dynamic_mask=dynamic_mask,
custom_mask=custom_mask,
regex_pattern=regex_pattern,
mask_format=mask_format,
**encryption_context,
)
return fields[field_name]

def _normalize_data_to_parse(self, fields: list, data: str | dict) -> dict:
Expand Down
157 changes: 142 additions & 15 deletions aws_lambda_powertools/utilities/data_masking/provider/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,14 @@

import functools
import json
from typing import Any, Callable, Iterable
import re
from typing import Any, Callable

from aws_lambda_powertools.utilities.data_masking.constants import DATA_MASKING_STRING

PRESERVE_CHARS = set("-_. ")
_regex_cache = {}


class BaseProvider:
"""
Expand Down Expand Up @@ -63,19 +67,142 @@ def decrypt(self, data, provider_options: dict | None = None, **encryption_conte
"""
raise NotImplementedError("Subclasses must implement decrypt()")

def erase(self, data, **kwargs) -> Iterable[str]:
"""
This method irreversibly erases data.
def erase(
self,
data: Any,
dynamic_mask: bool | None = None,
custom_mask: str | None = None,
regex_pattern: str | None = None,
mask_format: str | None = None,
masking_rules: dict | None = None,
**kwargs,
) -> Any:

result = DATA_MASKING_STRING

if not any([dynamic_mask, custom_mask, regex_pattern, mask_format, masking_rules]):
if isinstance(data, (str, int, float, dict, bytes)):
return DATA_MASKING_STRING
elif isinstance(data, (list, tuple, set)):
return type(data)([DATA_MASKING_STRING] * len(data))
else:
return DATA_MASKING_STRING

if isinstance(data, (str, int, float)):
result = self._mask_primitive(str(data), dynamic_mask, custom_mask, regex_pattern, mask_format, **kwargs)
elif isinstance(data, dict):
result = self._mask_dict(
data,
dynamic_mask,
custom_mask,
regex_pattern,
mask_format,
masking_rules,
**kwargs,
)
elif isinstance(data, (list, tuple, set)):
result = self._mask_iterable(
data,
dynamic_mask,
custom_mask,
regex_pattern,
mask_format,
masking_rules,
**kwargs,
)

return result

def _mask_primitive(
self,
data: str,
dynamic_mask: bool | None,
custom_mask: str | None,
regex_pattern: str | None,
mask_format: str | None,
**kwargs,
) -> str:
if regex_pattern and mask_format:
return self._regex_mask(data, regex_pattern, mask_format)
elif custom_mask:
return self._pattern_mask(data, custom_mask)
elif dynamic_mask:
return self._custom_erase(data, **kwargs)
else:
return DATA_MASKING_STRING

If the data to be erased is of type `str`, `dict`, or `bytes`,
this method will return an erased string, i.e. "*****".
def _mask_dict(
self,
data: dict,
dynamic_mask: bool | None,
custom_mask: str | None,
regex_pattern: str | None,
mask_format: str | None,
masking_rules: dict | None,
**kwargs,
) -> dict:
if masking_rules:
return self._apply_masking_rules(data, masking_rules)
else:
return {
k: self.erase(
v,
dynamic_mask=dynamic_mask,
custom_mask=custom_mask,
regex_pattern=regex_pattern,
mask_format=mask_format,
masking_rules=masking_rules,
**kwargs,
)
for k, v in data.items()
}

def _mask_iterable(
self,
data: list | tuple | set,
dynamic_mask: bool | None,
custom_mask: str | None,
regex_pattern: str | None,
mask_format: str | None,
masking_rules: dict | None,
**kwargs,
) -> list | tuple | set:
masked_data = [
self.erase(
item,
dynamic_mask=dynamic_mask,
custom_mask=custom_mask,
regex_pattern=regex_pattern,
mask_format=mask_format,
masking_rules=masking_rules,
**kwargs,
)
for item in data
]
return type(data)(masked_data)

def _apply_masking_rules(self, data: dict, masking_rules: dict) -> Any:
"""Apply masking rules to dictionary data."""
return {
key: self.erase(str(value), **masking_rules[key]) if key in masking_rules else str(value)
for key, value in data.items()
}

If the data to be erased is of an iterable type like `list`, `tuple`,
or `set`, this method will return a new object of the same type as the
input data but with each element replaced by the string "*****".
"""
if isinstance(data, (str, dict, bytes)):
return DATA_MASKING_STRING
elif isinstance(data, (list, tuple, set)):
return type(data)([DATA_MASKING_STRING] * len(data))
return DATA_MASKING_STRING
def _pattern_mask(self, data: str, pattern: str) -> str:
"""Apply pattern masking to string data."""
return pattern[: len(data)] if len(pattern) >= len(data) else pattern

def _regex_mask(self, data: str, regex_pattern: str, mask_format: str) -> str:
"""Apply regex masking to string data."""
try:
if regex_pattern not in _regex_cache:
_regex_cache[regex_pattern] = re.compile(regex_pattern)
return _regex_cache[regex_pattern].sub(mask_format, data)
except re.error:
return data

def _custom_erase(self, data: str, **kwargs) -> str:
if not data:
return ""

return "".join("*" if char not in PRESERVE_CHARS else char for char in data)
Loading
Loading