From 18c5e1d04dc1db8eb112b522f783aedd3f6d3ee8 Mon Sep 17 00:00:00 2001 From: Asaf Cohen Date: Wed, 29 May 2024 16:21:55 +0300 Subject: [PATCH 01/23] expose opa port in pdp makefile run --- Makefile | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/Makefile b/Makefile index 4cf8e62b..52b1c8f5 100644 --- a/Makefile +++ b/Makefile @@ -24,7 +24,7 @@ build-arm64: prepare @docker buildx build --platform linux/arm64 -t permitio/pdp-v2:$(VERSION) . --load run: run-prepare - @docker run -p 7766:7000 --env PDP_API_KEY=$(API_KEY) --env PDP_DEBUG=true permitio/pdp-v2:$(VERSION) + @docker run -it -p 7766:7000 -p 8181:8181 --env PDP_API_KEY=$(API_KEY) --env PDP_DEBUG=true permitio/pdp-v2:$(VERSION) run-on-background: run-prepare - @docker run -d -p 7766:7000 --env PDP_API_KEY=$(API_KEY) --env PDP_DEBUG=true permitio/pdp-v2:$(VERSION) + @docker run -d -p 7766:7000 -p 8181:8181 --env PDP_API_KEY=$(API_KEY) --env PDP_DEBUG=true permitio/pdp-v2:$(VERSION) From 297abfe8df6307f038f5118841c5e5f97cc19fd9 Mon Sep 17 00:00:00 2001 From: Asaf Cohen Date: Wed, 29 May 2024 16:22:17 +0300 Subject: [PATCH 02/23] fix annoying error message when building pdp locally --- horizon/state.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/horizon/state.py b/horizon/state.py index 2ba862f5..364ab0be 100644 --- a/horizon/state.py +++ b/horizon/state.py @@ -125,7 +125,7 @@ def _get_pdp_version(cls) -> Optional[str]: if os.path.exists(PDP_VERSION_FILENAME): with open(PDP_VERSION_FILENAME) as f: return f.read().strip() - return None + return "0.0.0" @classmethod def _get_pdp_runtime(cls) -> dict: From 5ac4912efa0cdeb2cd7a24b43421b66deb06e46d Mon Sep 17 00:00:00 2001 From: Asaf Cohen Date: Wed, 29 May 2024 16:23:25 +0300 Subject: [PATCH 03/23] basic compile client and filtering endpoint --- horizon/enforcer/api.py | 29 +++++++++ horizon/enforcer/data_filtering/__init__.py | 0 .../enforcer/data_filtering/compile_client.py | 59 +++++++++++++++++++ 3 files changed, 88 insertions(+) create mode 100644 horizon/enforcer/data_filtering/__init__.py create mode 100644 horizon/enforcer/data_filtering/compile_client.py diff --git a/horizon/enforcer/api.py b/horizon/enforcer/api.py index 24b1dc8e..a0f4b394 100644 --- a/horizon/enforcer/api.py +++ b/horizon/enforcer/api.py @@ -19,6 +19,7 @@ from horizon.authentication import enforce_pdp_token from horizon.config import sidecar_config +from horizon.enforcer.data_filtering.compile_client import OpaCompileClient from horizon.enforcer.schemas import ( AuthorizationQuery, AuthorizationResult, @@ -676,4 +677,32 @@ async def is_allowed_kong(request: Request, query: KongAuthorizationQuery): ) return result + @router.post( + "/filter_resources", + response_model=AuthorizationResult, + status_code=status.HTTP_200_OK, + response_model_exclude_none=True, + # TODO: restore authz + # dependencies=[Depends(enforce_pdp_token)], + ) + async def filter_resources( + request: Request, + input: AuthorizationQuery, + x_permit_sdk_language: Optional[str] = Depends(notify_seen_sdk), + ): + headers = transform_headers(request) + client = OpaCompileClient(headers=headers) + COMPILE_ROOT_RULE_REFERENCE = "data.example.rbac4.allow" + query = f"{COMPILE_ROOT_RULE_REFERENCE} == true" + response = await client.compile_query( + query=query, + input=input, + unknowns=[ + "input.resource.key", + "input.resource.tenant", + "input.resource.attributes", + ], + ) + return response + return router diff --git a/horizon/enforcer/data_filtering/__init__.py b/horizon/enforcer/data_filtering/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/horizon/enforcer/data_filtering/compile_client.py b/horizon/enforcer/data_filtering/compile_client.py new file mode 100644 index 00000000..c2d9d280 --- /dev/null +++ b/horizon/enforcer/data_filtering/compile_client.py @@ -0,0 +1,59 @@ +import json + +import aiohttp +from fastapi import HTTPException, Response, status +from opal_client.config import opal_client_config +from opal_client.logger import logger + +from horizon.enforcer.schemas import AuthorizationQuery + + +class OpaCompileClient: + def __init__(self, headers: dict): + self._base_url = f"{opal_client_config.POLICY_STORE_URL}" + self._headers = headers + self._client = aiohttp.ClientSession( + base_url=self._base_url, headers=self._headers + ) + + async def compile_query( + self, query: str, input: AuthorizationQuery, unknowns: list[str] + ): + input = {**input.dict(), "use_debugger": False} + data = { + "query": query, + # we don't want debug rules when we try to reduce the policy into a partial policy + "input": input, + "unknowns": unknowns, + } + try: + logger.debug("Compiling OPA query: {}", data) + async with self._client as session: + async with session.post( + "/v1/compile", + data=json.dumps(data), + raise_for_status=True, + ) as response: + content = await response.text() + return Response( + content=content, + status_code=response.status, + headers=dict(response.headers), + media_type="application/json", + ) + except aiohttp.ClientResponseError as e: + exc = HTTPException( + status.HTTP_502_BAD_GATEWAY, # 502 indicates server got an error from another server + detail="OPA request failed (url: {url}, status: {status}, message: {message})".format( + url=self._base_url, status=e.status, message=e.message + ), + ) + except aiohttp.ClientError as e: + exc = HTTPException( + status.HTTP_502_BAD_GATEWAY, + detail="OPA request failed (url: {url}, error: {error}".format( + url=self._base_url, error=str(e) + ), + ) + logger.warning(exc.detail) + raise exc From 23cc9ae6b3192ab27be8bb475032bf851fd42977 Mon Sep 17 00:00:00 2001 From: Asaf Cohen Date: Thu, 30 May 2024 12:31:50 +0300 Subject: [PATCH 04/23] wip --- horizon/enforcer/data_filtering/rego_ast.py | 309 ++++++++++++++++++++ horizon/enforcer/data_filtering/schemas.py | 0 2 files changed, 309 insertions(+) create mode 100644 horizon/enforcer/data_filtering/rego_ast.py create mode 100644 horizon/enforcer/data_filtering/schemas.py diff --git a/horizon/enforcer/data_filtering/rego_ast.py b/horizon/enforcer/data_filtering/rego_ast.py new file mode 100644 index 00000000..86c085b2 --- /dev/null +++ b/horizon/enforcer/data_filtering/rego_ast.py @@ -0,0 +1,309 @@ +# Rego policy structure +# +# Rego policies are defined using a relatively small set of types: +# modules, package and import declarations, rules, expressions, and terms. +# At their core, policies consist of rules that are defined by one or more expressions over documents available to the policy engine. +# The expressions are defined by intrinsic values (terms) such as strings, objects, variables, etc. +# Rego policies are typically defined in text files and then parsed and compiled by the policy engine at runtime. +# +# The parsing stage takes the text or string representation of the policy and converts it into an abstract syntax tree (AST) +# that consists of the types mentioned above. The AST is organized as follows: +# Module +# | +# +--- Package (Reference) +# | +# +--- Imports +# | | +# | +--- Import (Term) +# | +# +--- Rules +# | +# +--- Rule +# | +# +--- Head +# | | +# | +--- Name (Variable) +# | | +# | +--- Key (Term) +# | | +# | +--- Value (Term) +# | +# +--- Body +# | +# +--- Expression (Term | Terms | Variable Declaration) +# | +# +--- Term + + +import json +from typing import Any, Optional + + +def indent_string(s: str, indent_char: str = "\t", indent_level: int = 1): + indent = indent_char * indent_level + return ["{}{}".format(indent, row) for row in s.splitlines()] + + +class QuerySet: + """ + A queryset is a result of partial evaluation, creating a residual policy consisting + of multiple queries (each query consists of multiple rego expressions). + + The query essentially outlines a set of conditions for the residual policy to be true. + All the expressions of the query must be true (logical AND) in order for the query to evaluate to TRUE. + + You can roughly translate the query set into an SQL WHERE statement. + + Between each query of the queryset - there is a logical OR. + """ + + def __init__(self, queries: list["Query"]): + self.queries = queries + + @classmethod + def parse(cls, queries: list): + """ + example data: + # queryset + [ + # query (an array of expressions) + [ + # expression (an array of terms) + { + "index": 0, + "terms": [ + ... + ] + } + ], + ... + ] + """ + return cls([Query.parse(q) for q in queries]) + + def __repr__(self): + queries_str = "\n".join([indent_string(repr(r)) for r in self.queries]) + return "QuerySet([\n{}\n])\n".format(queries_str) + + +class Query: + """ + A residual query is a result of partial evaluation. + The query essentially outlines a set of conditions for the residual policy to be true. + All the expressions of the query must be true (logical AND) in order for the query to evaluate to TRUE. + """ + + def __init__(self, expressions: list["Expression"]): + self.expressions = expressions + + @classmethod + def parse(cls, expressions: list): + """ + example data: + # query (an array of expressions) + [ + # expression (an array of terms) + { + "index": 0, + "terms": [ + ... + ] + } + ] + """ + return cls([Expression.parse(e) for e in expressions]) + + def __repr__(self): + exprs_str = "\n".join([indent_string(repr(e)) for e in self.expressions]) + return "Query([\n{}\n])\n".format(exprs_str) + + +class Expression: + """ + An expression roughly translate into one line of rego code. + Typically a rego rule consists of multiple expressions. + + An expression is comprised of multiple terms (typically 3), where the first is an operator and the rest are operands. + """ + + def __init__(self, terms: list["Term"]): + self.terms = terms + + @classmethod + def parse(cls, data: dict): + """ + example data: + # expression + { + "index": 0, + # terms + "terms": [ + # first term is typically an operator (e.g: ==, !=, >, <, etc) + # the operator will typically be a *reference* to a built in function. + # for example the "equals" (or "==") operator (within OPA operators are called builtins) is actually the builtin function "eq()". + { + "type": "ref", + "value": [ + { + "type": "var", + "value": "eq" + } + ] + }, + # rest of terms are typically operands + { + "type": "ref", + "value": [ + { + "type": "var", + "value": "input" + }, + { + "type": "string", + "value": "resource" + }, + { + "type": "string", + "value": "tenant" + } + ] + }, + { + "type": "string", + "value": "default" + } + ] + } + """ + terms = data["terms"] + if isinstance(terms, dict): + return cls([Term.parse(terms)]) + return cls([Term.parse(t) for t in terms]) + + @property + def operator(self): + """ + returns the term that is the operator of the expression (typically the first term) + """ + return self.terms[0] + + @property + def operands(self): + """ + returns the terms that are the operands of the expression + """ + return self.terms[1:] + + def __repr__(self): + operands_str = ",".join([repr(o) for o in self.operands]) + return "Expression({}, [{}])".format(repr(self.operator), operands_str) + + +class Term: + def __init__(self, value: Any): + self.value = value + + @classmethod + def parse(cls, data): + if data["type"] == "null": + data["value"] = None + return cls(TERMS_BY_TYPE[data["type"]].parse(data["value"])) + + def __repr__(self): + return repr(self.value) + + +class NullTerm: + def __init__(self): + self.value = None + + @classmethod + def parse(cls): + return cls() + + def __repr__(self): + return json.dumps(self.value) # null + + +class BooleanTerm: + def __init__(self, value: bool): + self.value = value + + @classmethod + def parse(cls, data: bool): + return cls(data) + + def __repr__(self): + return json.dumps(self.value) + + +class NumberTerm: + def __init__(self, value: int | float): + self.value = value + + @classmethod + def parse(cls, data: int | float): + return cls(data) + + def __repr__(self): + return json.dumps(self.value) + + +class StringTerm: + def __init__(self, value: str): + self.value = value + + @classmethod + def parse(cls, data: str): + return cls(data) + + def __repr__(self): + return json.dumps(self.value) + + +class VarTerm: + def __init__(self, value: str): + self.value = value + + @classmethod + def parse(cls, variable_name: str): + return cls(variable_name) + + def __repr__(self): + return self.value + + +class RefTerm: + def __init__(self, ref: str): + self.ref = ref + + @classmethod + def parse(cls, terms: list[dict]): + assert len(terms) > 0 + var_term = VarTerm.parse(terms[0]["value"]) + string_terms = [ + StringTerm.parse(t["value"]) for t in terms[1:] + ] # might be empty + terms = [var_term] + string_terms + ref_parts = [t.value for t in terms] + return cls(".".join(ref_parts)) + + def __repr__(self): + return "Ref({})".format(self.ref) + + +TERMS_BY_TYPE = { + "null": NullTerm, + "boolean": BooleanTerm, + "number": NumberTerm, + "string": StringTerm, + "var": VarTerm, + "ref": RefTerm, + # "array": ArrayTerm, + # "set": SetTerm, + # "object": ObjectTerm, + # "arraycomprehension": ArrayComprehensionTerm, + # "setcomprehension": SetComprehensionTerm, + # "objectcomprehension": ObjectComprehensionTerm, + # "call": CallTerm, +} diff --git a/horizon/enforcer/data_filtering/schemas.py b/horizon/enforcer/data_filtering/schemas.py new file mode 100644 index 00000000..e69de29b From cbdc464a00df49ec986a341bb5b3077ce404bbdc Mon Sep 17 00:00:00 2001 From: Asaf Cohen Date: Sun, 2 Jun 2024 17:09:04 +0300 Subject: [PATCH 05/23] pydantic foundation for parsing the json (needs expansion later) --- horizon/enforcer/data_filtering/schemas.py | 75 ++++ .../enforcer/data_filtering/tests/__init__.py | 0 .../tests/test_compile_parsing.py | 344 ++++++++++++++++++ 3 files changed, 419 insertions(+) create mode 100644 horizon/enforcer/data_filtering/tests/__init__.py create mode 100644 horizon/enforcer/data_filtering/tests/test_compile_parsing.py diff --git a/horizon/enforcer/data_filtering/schemas.py b/horizon/enforcer/data_filtering/schemas.py index e69de29b..b08d39fe 100644 --- a/horizon/enforcer/data_filtering/schemas.py +++ b/horizon/enforcer/data_filtering/schemas.py @@ -0,0 +1,75 @@ +from __future__ import annotations + +from typing import Any, List, Optional + +from pydantic import BaseModel, Field + + +class BaseSchema(BaseModel): + class Config: + orm_mode = True + allow_population_by_field_name = True + + +class CompileResponse(BaseSchema): + result: "CompileResponseComponents" + + +class CompileResponseComponents(BaseSchema): + queries: Optional["CRQuerySet"] = None + support: Optional["CRSupportBlock"] = None + + +class CRQuerySet(BaseSchema): + __root__: List["CRQuery"] + + +class CRQuery(BaseSchema): + __root__: List["CRExpression"] + + +class CRExpression(BaseSchema): + index: int + terms: List["CRTerm"] + + +class CRTerm(BaseSchema): + type: str + value: Any + + +class CRSupportBlock(BaseSchema): + __root__: List["CRSupportModule"] + + +class CRSupportModule(BaseSchema): + package: "CRSupportModulePackage" + rules: List["CRRegoRule"] + + +class CRSupportModulePackage(BaseSchema): + path: List["CRTerm"] + + +class CRRegoRule(BaseSchema): + body: List["CRExpression"] + head: "CRRuleHead" + + +class CRRuleHead(BaseSchema): + name: str + key: "CRTerm" + ref: List["CRTerm"] + + +CompileResponse.update_forward_refs() +CompileResponseComponents.update_forward_refs() +CRQuerySet.update_forward_refs() +CRQuery.update_forward_refs() +CRExpression.update_forward_refs() +CRTerm.update_forward_refs() +CRSupportBlock.update_forward_refs() +CRSupportModule.update_forward_refs() +CRSupportModulePackage.update_forward_refs() +CRRegoRule.update_forward_refs() +CRRuleHead.update_forward_refs() diff --git a/horizon/enforcer/data_filtering/tests/__init__.py b/horizon/enforcer/data_filtering/tests/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/horizon/enforcer/data_filtering/tests/test_compile_parsing.py b/horizon/enforcer/data_filtering/tests/test_compile_parsing.py new file mode 100644 index 00000000..19223bb4 --- /dev/null +++ b/horizon/enforcer/data_filtering/tests/test_compile_parsing.py @@ -0,0 +1,344 @@ +import json + +from horizon.enforcer.data_filtering.schemas import CRTerm, CompileResponse + + +COMPILE_RESPONE_RBAC_NO_SUPPORT_BLOCK = """{ + "result": { + "queries": [ + [ + { + "index": 0, + "terms": [ + { + "type": "ref", + "value": [ + { + "type": "var", + "value": "eq" + } + ] + }, + { + "type": "ref", + "value": [ + { + "type": "var", + "value": "input" + }, + { + "type": "string", + "value": "resource" + }, + { + "type": "string", + "value": "tenant" + } + ] + }, + { + "type": "string", + "value": "default" + } + ] + } + ], + [ + { + "index": 0, + "terms": [ + { + "type": "ref", + "value": [ + { + "type": "var", + "value": "eq" + } + ] + }, + { + "type": "ref", + "value": [ + { + "type": "var", + "value": "input" + }, + { + "type": "string", + "value": "resource" + }, + { + "type": "string", + "value": "tenant" + } + ] + }, + { + "type": "string", + "value": "second" + } + ] + } + ] + ] + } +} +""" + + +COMPILE_RESPONE_RBAC_WITH_SUPPORT_BLOCK = """{ + "result": { + "queries": [ + [ + { + "index": 0, + "terms": [ + { + "type": "ref", + "value": [ + { + "type": "var", + "value": "gt" + } + ] + }, + { + "type": "call", + "value": [ + { + "type": "ref", + "value": [ + { + "type": "var", + "value": "count" + } + ] + }, + { + "type": "ref", + "value": [ + { + "type": "var", + "value": "data" + }, + { + "type": "string", + "value": "partial" + }, + { + "type": "string", + "value": "example" + }, + { + "type": "string", + "value": "rbac4" + }, + { + "type": "string", + "value": "allowed" + } + ] + } + ] + }, + { + "type": "number", + "value": 0 + } + ] + } + ] + ], + "support": [ + { + "package": { + "path": [ + { + "type": "var", + "value": "data" + }, + { + "type": "string", + "value": "partial" + }, + { + "type": "string", + "value": "example" + }, + { + "type": "string", + "value": "rbac4" + } + ] + }, + "rules": [ + { + "body": [ + { + "index": 0, + "terms": [ + { + "type": "ref", + "value": [ + { + "type": "var", + "value": "eq" + } + ] + }, + { + "type": "ref", + "value": [ + { + "type": "var", + "value": "input" + }, + { + "type": "string", + "value": "resource" + }, + { + "type": "string", + "value": "tenant" + } + ] + }, + { + "type": "string", + "value": "default" + } + ] + } + ], + "head": { + "name": "allowed", + "key": { + "type": "string", + "value": "user 'asaf' has role 'editor' in tenant 'default' which grants permission 'task:read'" + }, + "ref": [ + { + "type": "var", + "value": "allowed" + } + ] + } + }, + { + "body": [ + { + "index": 0, + "terms": [ + { + "type": "ref", + "value": [ + { + "type": "var", + "value": "eq" + } + ] + }, + { + "type": "ref", + "value": [ + { + "type": "var", + "value": "input" + }, + { + "type": "string", + "value": "resource" + }, + { + "type": "string", + "value": "tenant" + } + ] + }, + { + "type": "string", + "value": "second" + } + ] + } + ], + "head": { + "name": "allowed", + "key": { + "type": "string", + "value": "user 'asaf' has role 'viewer' in tenant 'second' which grants permission 'task:read'" + }, + "ref": [ + { + "type": "var", + "value": "allowed" + } + ] + } + } + ] + } + ] + } +} +""" + + +def test_parse_compile_response_rbac_no_support_block(): + response = json.loads(COMPILE_RESPONE_RBAC_NO_SUPPORT_BLOCK) + res = CompileResponse(**response) + assert res.result.queries is not None + assert len(res.result.queries.__root__) == 2 + + first_query = res.result.queries.__root__[0] + len(first_query.__root__) == 1 + + first_query_expression = first_query.__root__[0] + assert first_query_expression.index == 0 + assert len(first_query_expression.terms) == 3 + + assert first_query_expression.terms[0].type == "ref" + assert first_query_expression.terms[1].type == "ref" + assert first_query_expression.terms[2].type == "string" + assert first_query_expression.terms[2].value == "default" + + second_query = res.result.queries.__root__[1] + len(second_query.__root__) == 1 + + second_query_expression = second_query.__root__[0] + assert second_query_expression.index == 0 + assert len(second_query_expression.terms) == 3 + + assert second_query_expression.terms[0].type == "ref" + assert second_query_expression.terms[1].type == "ref" + assert second_query_expression.terms[2].type == "string" + assert second_query_expression.terms[2].value == "second" + + +def test_parse_compile_response_rbac_with_support_block(): + response = json.loads(COMPILE_RESPONE_RBAC_WITH_SUPPORT_BLOCK) + res = CompileResponse(**response) + assert res.result.queries is not None + assert len(res.result.queries.__root__) == 1 # 1 query + + first_query = res.result.queries.__root__[0] + len(first_query.__root__) == 1 # 1 expression + + first_query_expression = first_query.__root__[0] + assert first_query_expression.index == 0 + assert len(first_query_expression.terms) == 3 + + assert first_query_expression.terms[0].type == "ref" # operator + len(first_query_expression.terms[0].value) == 1 + op = CRTerm(**first_query_expression.terms[0].value[0]) + assert op.type == "var" + assert op.value == "gt" + + assert first_query_expression.terms[1].type == "call" + len(first_query_expression.terms[1].value) == 2 + + assert first_query_expression.terms[2].type == "number" + assert first_query_expression.terms[2].value == 0 + + assert res.result.support is not None + assert len(res.result.support.__root__) == 1 # 1 support module From 3d7172cf10c7352c0b6415670b1371e73ade5a0f Mon Sep 17 00:00:00 2001 From: Asaf Cohen Date: Sun, 2 Jun 2024 22:37:47 +0300 Subject: [PATCH 06/23] rewrite term parsers --- horizon/enforcer/data_filtering/rego_ast.py | 178 ++++++++++-------- .../data_filtering/tests/test_ast_parser.py | 116 ++++++++++++ 2 files changed, 219 insertions(+), 75 deletions(-) create mode 100644 horizon/enforcer/data_filtering/tests/test_ast_parser.py diff --git a/horizon/enforcer/data_filtering/rego_ast.py b/horizon/enforcer/data_filtering/rego_ast.py index 86c085b2..17c7dc00 100644 --- a/horizon/enforcer/data_filtering/rego_ast.py +++ b/horizon/enforcer/data_filtering/rego_ast.py @@ -36,7 +36,10 @@ import json -from typing import Any, Optional +from types import NoneType +from typing import Any, Generic, Optional, TypeVar + +from horizon.enforcer.data_filtering.schemas import CRTerm def indent_string(s: str, indent_char: str = "\t", indent_level: int = 1): @@ -177,8 +180,8 @@ def parse(cls, data: dict): """ terms = data["terms"] if isinstance(terms, dict): - return cls([Term.parse(terms)]) - return cls([Term.parse(t) for t in terms]) + return cls([TermParser.parse(terms)]) + return cls([TermParser.parse(t) for t in terms]) @property def operator(self): @@ -199,111 +202,136 @@ def __repr__(self): return "Expression({}, [{}])".format(repr(self.operator), operands_str) -class Term: - def __init__(self, value: Any): +T = TypeVar("T") + + +class Term(Generic[T]): + def __init__(self, value: T): self.value = value @classmethod - def parse(cls, data): - if data["type"] == "null": - data["value"] = None - return cls(TERMS_BY_TYPE[data["type"]].parse(data["value"])) + def parse(cls, data: T): + return cls(data) def __repr__(self): - return repr(self.value) + return json.dumps(self.value) -class NullTerm: - def __init__(self): - self.value = None +class NullTerm(Term[NoneType]): + pass - @classmethod - def parse(cls): - return cls() - def __repr__(self): - return json.dumps(self.value) # null +class BooleanTerm(Term[bool]): + pass -class BooleanTerm: - def __init__(self, value: bool): - self.value = value +class NumberTerm(Term[int | float]): + pass - @classmethod - def parse(cls, data: bool): - return cls(data) +class StringTerm(Term[str]): + pass + + +class VarTerm(Term[str]): def __repr__(self): - return json.dumps(self.value) + return self.value -class NumberTerm: - def __init__(self, value: int | float): - self.value = value +class Ref: + def __init__(self, ref_parts: list[str]): + self._parts = ref_parts - @classmethod - def parse(cls, data: int | float): - return cls(data) + @property + def parts(self): + return self._parts - def __repr__(self): - return json.dumps(self.value) + @property + def as_string(self): + return str(self) + def __str__(self): + return ".".join(self._parts) -class StringTerm: - def __init__(self, value: str): - self.value = value +class RefTerm(Term[Ref]): @classmethod - def parse(cls, data: str): - return cls(data) + def parse(cls, terms: list[dict]): + assert len(terms) > 0 + parsed_terms: list[Term] = [TermParser.parse(CRTerm(**t)) for t in terms] + var_term = parsed_terms[0] + # TODO: support more types of refs + assert isinstance( + var_term, VarTerm + ), "first sub-term inside ref is not a variable" + string_terms = parsed_terms[1:] # might be empty + # TODO: support more types of refs + assert all( + isinstance(t, StringTerm) for t in string_terms + ), "ref parts are not string terms" + ref_parts = [t.value for t in parsed_terms] + return cls(Ref(ref_parts)) def __repr__(self): - return json.dumps(self.value) + return "Ref({})".format(self.value.as_string) -class VarTerm: - def __init__(self, value: str): - self.value = value +class Call: + """ + represents a function call + """ - @classmethod - def parse(cls, variable_name: str): - return cls(variable_name) + def __init__(self, func: Term, args: list[Term]): + self._func = func + self._args = args - def __repr__(self): - return self.value + @property + def func(self): + return self._func + @property + def args(self): + return self._args + + def __str__(self): + return "{}({})".format(self.func, ", ".join([str(arg) for arg in self.args])) -class RefTerm: - def __init__(self, ref: str): - self.ref = ref +class CallTerm(Term[Call]): @classmethod def parse(cls, terms: list[dict]): assert len(terms) > 0 - var_term = VarTerm.parse(terms[0]["value"]) - string_terms = [ - StringTerm.parse(t["value"]) for t in terms[1:] - ] # might be empty - terms = [var_term] + string_terms - ref_parts = [t.value for t in terms] - return cls(".".join(ref_parts)) + parsed_terms: list[Term] = [TermParser.parse(CRTerm(**t)) for t in terms] + func_term = parsed_terms[0] + # TODO: support more types of refs + assert isinstance(func_term, RefTerm), "first sub-term inside call is not a ref" + arg_terms = parsed_terms[1:] # might be empty + return cls(Call(func_term, arg_terms)) def __repr__(self): - return "Ref({})".format(self.ref) - - -TERMS_BY_TYPE = { - "null": NullTerm, - "boolean": BooleanTerm, - "number": NumberTerm, - "string": StringTerm, - "var": VarTerm, - "ref": RefTerm, - # "array": ArrayTerm, - # "set": SetTerm, - # "object": ObjectTerm, - # "arraycomprehension": ArrayComprehensionTerm, - # "setcomprehension": SetComprehensionTerm, - # "objectcomprehension": ObjectComprehensionTerm, - # "call": CallTerm, -} + return "call:{}".format(str(self.value)) + + +class TermParser: + TERMS_BY_TYPE: dict[str, Term] = { + "null": NullTerm, + "boolean": BooleanTerm, + "number": NumberTerm, + "string": StringTerm, + "var": VarTerm, + "ref": RefTerm, + "call": CallTerm, + # "array": ArrayTerm, + # "set": SetTerm, + # "object": ObjectTerm, + # "arraycomprehension": ArrayComprehensionTerm, + # "setcomprehension": SetComprehensionTerm, + # "objectcomprehension": ObjectComprehensionTerm, + } + + @classmethod + def parse(cls, data: CRTerm) -> Term: + if data.type == "null": + data.value = None + klass = cls.TERMS_BY_TYPE[data.type] + return klass.parse(data.value) diff --git a/horizon/enforcer/data_filtering/tests/test_ast_parser.py b/horizon/enforcer/data_filtering/tests/test_ast_parser.py new file mode 100644 index 00000000..e67a584a --- /dev/null +++ b/horizon/enforcer/data_filtering/tests/test_ast_parser.py @@ -0,0 +1,116 @@ +from horizon.enforcer.data_filtering.rego_ast import ( + BooleanTerm, + Call, + CallTerm, + NullTerm, + Ref, + RefTerm, + Term, + TermParser, + NumberTerm, + StringTerm, + VarTerm, +) +from horizon.enforcer.data_filtering.schemas import CRTerm + + +def test_parse_null_term(): + t = CRTerm(**{"type": "null"}) + term: Term = TermParser.parse(t) + assert isinstance(term, NullTerm) + assert term.value == None + + +def test_parse_boolean_term(): + for val in [True, False]: + t = CRTerm(**{"type": "boolean", "value": val}) + term: Term = TermParser.parse(t) + assert isinstance(term, BooleanTerm) + assert term.value == val + isinstance(term.value, bool) + + +def test_parse_number_term(): + for val in [0, 2, 3.14]: + t = CRTerm(**{"type": "number", "value": val}) + term: Term = TermParser.parse(t) + assert isinstance(term, NumberTerm) + assert term.value == val + assert isinstance(term.value, int) or isinstance(term.value, float) + + +def test_parse_string_term(): + for val in ["hello", "world", ""]: + t = CRTerm(**{"type": "string", "value": val}) + term: Term = TermParser.parse(t) + assert isinstance(term, StringTerm) + assert term.value == val + assert isinstance(term.value, str) + + +def test_parse_var_term(): + t = CRTerm(**{"type": "var", "value": "eq"}) + term: Term = TermParser.parse(t) + assert isinstance(term, VarTerm) + assert term.value == "eq" + assert isinstance(term.value, str) + + +def test_parse_simple_ref_term(): + simple_ref_term = { + "type": "ref", + "value": [ + {"type": "var", "value": "allowed"}, + ], + } + t = CRTerm(**simple_ref_term) + term: Term = TermParser.parse(t) + assert isinstance(term, RefTerm) + assert isinstance(term.value, Ref) + assert len(term.value.parts) == 1 + assert term.value.as_string == "allowed" + + +def test_parse_complex_ref_term(): + complex_ref_term = { + "type": "ref", + "value": [ + {"type": "var", "value": "input"}, + {"type": "string", "value": "resource"}, + {"type": "string", "value": "tenant"}, + ], + } + t = CRTerm(**complex_ref_term) + term: Term = TermParser.parse(t) + assert isinstance(term, RefTerm) + assert isinstance(term.value, Ref) + assert len(term.value.parts) == 3 + assert term.value.as_string == "input.resource.tenant" + + +def test_parse_call_term(): + call_term = { + "type": "call", + "value": [ + {"type": "ref", "value": [{"type": "var", "value": "count"}]}, + { + "type": "ref", + "value": [ + {"type": "var", "value": "data"}, + {"type": "string", "value": "partial"}, + {"type": "string", "value": "example"}, + {"type": "string", "value": "rbac4"}, + {"type": "string", "value": "allowed"}, + ], + }, + ], + } + t = CRTerm(**call_term) + term: Term = TermParser.parse(t) + assert isinstance(term, CallTerm) + assert isinstance(term.value, Call) + assert isinstance(term.value.func.value, Ref) + assert term.value.func.value.as_string == "count" + assert len(term.value.args) == 1 + assert isinstance(term.value.args[0].value, Ref) + assert term.value.args[0].value.as_string == "data.partial.example.rbac4.allowed" From c0729c78796785e40a179cd6d0027313be3a2244 Mon Sep 17 00:00:00 2001 From: Asaf Cohen Date: Sun, 2 Jun 2024 23:27:40 +0300 Subject: [PATCH 07/23] basic ast parsing of rbac policies --- horizon/enforcer/data_filtering/rego_ast.py | 89 ++++++-- horizon/enforcer/data_filtering/schemas.py | 4 +- .../data_filtering/tests/test_ast_parser.py | 212 +++++++++++++++++- 3 files changed, 282 insertions(+), 23 deletions(-) diff --git a/horizon/enforcer/data_filtering/rego_ast.py b/horizon/enforcer/data_filtering/rego_ast.py index 17c7dc00..069f31a2 100644 --- a/horizon/enforcer/data_filtering/rego_ast.py +++ b/horizon/enforcer/data_filtering/rego_ast.py @@ -37,14 +37,20 @@ import json from types import NoneType -from typing import Any, Generic, Optional, TypeVar +from typing import Generic, List, TypeVar -from horizon.enforcer.data_filtering.schemas import CRTerm +from horizon.enforcer.data_filtering.schemas import ( + CRExpression, + CRQuery, + CompileResponse, + CRTerm, + CRSupportModule, +) def indent_string(s: str, indent_char: str = "\t", indent_level: int = 1): indent = indent_char * indent_level - return ["{}{}".format(indent, row) for row in s.splitlines()] + return "\n".join(["{}{}".format(indent, row) for row in s.splitlines()]) class QuerySet: @@ -60,11 +66,12 @@ class QuerySet: Between each query of the queryset - there is a logical OR. """ - def __init__(self, queries: list["Query"]): - self.queries = queries + def __init__(self, queries: list["Query"], support_modules: list): + self._queries = queries + self._support_modules = support_modules @classmethod - def parse(cls, queries: list): + def parse(cls, response: CompileResponse): """ example data: # queryset @@ -82,12 +89,35 @@ def parse(cls, queries: list): ... ] """ - return cls([Query.parse(q) for q in queries]) + queries: List[CRQuery] = ( + [] if response.result.queries is None else response.result.queries.__root__ + ) + # TODO: parse support modules + # modules: List[CRSupportModule] = ( + # [] if response.result.support is None else response.result.support.__root__ + # ) + return cls([Query.parse(q) for q in queries], []) + + @property + def queries(self): + return self._queries def __repr__(self): - queries_str = "\n".join([indent_string(repr(r)) for r in self.queries]) + queries_str = "\n".join([indent_string(repr(r)) for r in self._queries]) return "QuerySet([\n{}\n])\n".format(queries_str) + @property + def always_false(self) -> bool: + return len(self._queries) == 0 + + @property + def always_true(self) -> bool: + return len(self._queries) > 0 and any(q.always_true for q in self._queries) + + @property + def conditional(self) -> bool: + return not self.always_false and not self.always_true + class Query: """ @@ -97,10 +127,10 @@ class Query: """ def __init__(self, expressions: list["Expression"]): - self.expressions = expressions + self._expressions = expressions @classmethod - def parse(cls, expressions: list): + def parse(cls, query: CRQuery): """ example data: # query (an array of expressions) @@ -114,10 +144,21 @@ def parse(cls, expressions: list): } ] """ - return cls([Expression.parse(e) for e in expressions]) + return cls([Expression.parse(e) for e in query.__root__]) + + @property + def expressions(self): + return self._expressions + + @property + def always_true(self) -> bool: + """ + returns true if the query always evaluates to TRUE + """ + return len(self._expressions) == 0 def __repr__(self): - exprs_str = "\n".join([indent_string(repr(e)) for e in self.expressions]) + exprs_str = "\n".join([indent_string(repr(e)) for e in self._expressions]) return "Query([\n{}\n])\n".format(exprs_str) @@ -130,10 +171,10 @@ class Expression: """ def __init__(self, terms: list["Term"]): - self.terms = terms + self._terms = terms @classmethod - def parse(cls, data: dict): + def parse(cls, data: CRExpression): """ example data: # expression @@ -178,27 +219,35 @@ def parse(cls, data: dict): ] } """ - terms = data["terms"] - if isinstance(terms, dict): + terms = data.terms + if isinstance(terms, CRTerm): return cls([TermParser.parse(terms)]) - return cls([TermParser.parse(t) for t in terms]) + else: + return cls([TermParser.parse(t) for t in terms]) @property def operator(self): """ returns the term that is the operator of the expression (typically the first term) """ - return self.terms[0] + return self._terms[0] @property def operands(self): """ returns the terms that are the operands of the expression """ - return self.terms[1:] + return self._terms[1:] + + @property + def terms(self): + """ + returns all the terms of the expression + """ + return self._terms def __repr__(self): - operands_str = ",".join([repr(o) for o in self.operands]) + operands_str = ", ".join([repr(o) for o in self.operands]) return "Expression({}, [{}])".format(repr(self.operator), operands_str) diff --git a/horizon/enforcer/data_filtering/schemas.py b/horizon/enforcer/data_filtering/schemas.py index b08d39fe..8b1c7dec 100644 --- a/horizon/enforcer/data_filtering/schemas.py +++ b/horizon/enforcer/data_filtering/schemas.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import Any, List, Optional +from typing import Any, List, Optional, Union from pydantic import BaseModel, Field @@ -30,7 +30,7 @@ class CRQuery(BaseSchema): class CRExpression(BaseSchema): index: int - terms: List["CRTerm"] + terms: Union["CRTerm", List["CRTerm"]] class CRTerm(BaseSchema): diff --git a/horizon/enforcer/data_filtering/tests/test_ast_parser.py b/horizon/enforcer/data_filtering/tests/test_ast_parser.py index e67a584a..2150ff0f 100644 --- a/horizon/enforcer/data_filtering/tests/test_ast_parser.py +++ b/horizon/enforcer/data_filtering/tests/test_ast_parser.py @@ -2,7 +2,10 @@ BooleanTerm, Call, CallTerm, + Expression, NullTerm, + Query, + QuerySet, Ref, RefTerm, Term, @@ -11,7 +14,13 @@ StringTerm, VarTerm, ) -from horizon.enforcer.data_filtering.schemas import CRTerm +from horizon.enforcer.data_filtering.schemas import ( + CRTerm, + CRExpression, + CRQuery, + CRQuerySet, + CompileResponse, +) def test_parse_null_term(): @@ -114,3 +123,204 @@ def test_parse_call_term(): assert len(term.value.args) == 1 assert isinstance(term.value.args[0].value, Ref) assert term.value.args[0].value.as_string == "data.partial.example.rbac4.allowed" + + +def test_parse_expression_eq(): + expr = { + "index": 0, + "terms": [ + {"type": "ref", "value": [{"type": "var", "value": "eq"}]}, + { + "type": "ref", + "value": [ + {"type": "var", "value": "input"}, + {"type": "string", "value": "resource"}, + {"type": "string", "value": "tenant"}, + ], + }, + {"type": "string", "value": "second"}, + ], + } + parsed_expr = CRExpression(**expr) + assert parsed_expr.index == 0 + assert len(parsed_expr.terms) == 3 + expression = Expression.parse(parsed_expr) + assert isinstance(expression.operator.value, Ref) + assert expression.operator.value.as_string == "eq" + assert len(expression.operands) == 2 + assert isinstance(expression.operands[0], RefTerm) + assert expression.operands[0].value.as_string == "input.resource.tenant" + assert isinstance(expression.operands[1], StringTerm) + assert expression.operands[1].value == "second" + + +def test_parse_trivial_query(): + query = Query.parse(CRQuery(__root__=[])) + assert len(query.expressions) == 0 + assert query.always_true + + +def test_parse_query(): + q = CRQuery( + __root__=[ + { + "index": 0, + "terms": [ + {"type": "ref", "value": [{"type": "var", "value": "gt"}]}, + { + "type": "ref", + "value": [ + {"type": "var", "value": "input"}, + {"type": "string", "value": "resource"}, + {"type": "string", "value": "attributes"}, + {"type": "string", "value": "age"}, + ], + }, + {"type": "number", "value": 7}, + ], + } + ] + ) + query = Query.parse(q) + assert len(query.expressions) == 1 + assert not query.always_true + + expression = query.expressions[0] + + assert isinstance(expression.operator.value, Ref) + assert expression.operator.value.as_string == "gt" + assert len(expression.operands) == 2 + assert isinstance(expression.operands[0], RefTerm) + assert expression.operands[0].value.as_string == "input.resource.attributes.age" + assert isinstance(expression.operands[1], NumberTerm) + assert expression.operands[1].value == 7 + + +def test_parse_queryset_always_false(): + response = CompileResponse(**{"result": {}}) + queryset = QuerySet.parse(response) + assert len(queryset.queries) == 0 + assert queryset.always_false + assert not queryset.always_true + assert not queryset.conditional + + +def test_parse_queryset_always_true(): + response = CompileResponse( + **{ + "result": { + "queries": [ + [ + { + "index": 0, + "terms": [ + { + "type": "ref", + "value": [{"type": "var", "value": "eq"}], + }, + { + "type": "ref", + "value": [ + {"type": "var", "value": "input"}, + {"type": "string", "value": "resource"}, + {"type": "string", "value": "tenant"}, + ], + }, + {"type": "string", "value": "default"}, + ], + } + ], + [], + ] + } + } + ) + queryset = QuerySet.parse(response) + assert queryset.always_true + assert not queryset.always_false + assert not queryset.conditional + + +def test_parse_queryset_conditional(): + response = CompileResponse( + **{ + "result": { + "queries": [ + [ + { + "index": 0, + "terms": [ + { + "type": "ref", + "value": [{"type": "var", "value": "eq"}], + }, + { + "type": "ref", + "value": [ + {"type": "var", "value": "input"}, + {"type": "string", "value": "resource"}, + {"type": "string", "value": "tenant"}, + ], + }, + {"type": "string", "value": "default"}, + ], + } + ], + [ + { + "index": 0, + "terms": [ + { + "type": "ref", + "value": [{"type": "var", "value": "eq"}], + }, + { + "type": "ref", + "value": [ + {"type": "var", "value": "input"}, + {"type": "string", "value": "resource"}, + {"type": "string", "value": "tenant"}, + ], + }, + {"type": "string", "value": "second"}, + ], + } + ], + ] + } + } + ) + queryset = QuerySet.parse(response) + assert queryset.conditional + assert not queryset.always_true + assert not queryset.always_false + + assert len(queryset.queries) == 2 + + q0 = queryset.queries[0] + assert len(q0.expressions) == 1 + assert not q0.always_true + + expression = q0.expressions[0] + + assert isinstance(expression.operator.value, Ref) + assert expression.operator.value.as_string == "eq" + assert len(expression.operands) == 2 + assert isinstance(expression.operands[0], RefTerm) + assert expression.operands[0].value.as_string == "input.resource.tenant" + assert isinstance(expression.operands[1], StringTerm) + assert expression.operands[1].value == "default" + + q1 = queryset.queries[1] + assert len(q1.expressions) == 1 + assert not q1.always_true + + expression = q1.expressions[0] + + assert isinstance(expression.operator.value, Ref) + assert expression.operator.value.as_string == "eq" + assert len(expression.operands) == 2 + assert isinstance(expression.operands[0], RefTerm) + assert expression.operands[0].value.as_string == "input.resource.tenant" + assert isinstance(expression.operands[1], StringTerm) + assert expression.operands[1].value == "second" From b0eff68ac8a9430246408183f4a2d5f5f0f1eb9a Mon Sep 17 00:00:00 2001 From: Asaf Cohen Date: Sun, 28 Jul 2024 16:02:18 +0300 Subject: [PATCH 08/23] added missing docstrings and return type annotations --- .../enforcer/data_filtering/compile_client.py | 2 +- horizon/enforcer/data_filtering/rego_ast.py | 77 ++++++++++++++----- 2 files changed, 57 insertions(+), 22 deletions(-) diff --git a/horizon/enforcer/data_filtering/compile_client.py b/horizon/enforcer/data_filtering/compile_client.py index c2d9d280..bd74a299 100644 --- a/horizon/enforcer/data_filtering/compile_client.py +++ b/horizon/enforcer/data_filtering/compile_client.py @@ -19,10 +19,10 @@ def __init__(self, headers: dict): async def compile_query( self, query: str, input: AuthorizationQuery, unknowns: list[str] ): + # we don't want debug rules when we try to reduce the policy into a partial policy input = {**input.dict(), "use_debugger": False} data = { "query": query, - # we don't want debug rules when we try to reduce the policy into a partial policy "input": input, "unknowns": unknowns, } diff --git a/horizon/enforcer/data_filtering/rego_ast.py b/horizon/enforcer/data_filtering/rego_ast.py index 069f31a2..ae55366a 100644 --- a/horizon/enforcer/data_filtering/rego_ast.py +++ b/horizon/enforcer/data_filtering/rego_ast.py @@ -48,7 +48,14 @@ ) -def indent_string(s: str, indent_char: str = "\t", indent_level: int = 1): +def indent_lines(s: str, indent_char: str = "\t", indent_level: int = 1) -> str: + """ + indents all lines within a multiline string by an indent character repeated by a given indent level. + e.g: indents all lines within a string by 3 tab (\t) characters (indent level 3). + + returns a new indented string. + ) + """ indent = indent_char * indent_level return "\n".join(["{}{}".format(indent, row) for row in s.splitlines()]) @@ -63,7 +70,8 @@ class QuerySet: You can roughly translate the query set into an SQL WHERE statement. - Between each query of the queryset - there is a logical OR. + Between each query of the queryset - there is a logical OR: + Policy is true if query_1 OR query_2 OR ... OR query_n and query_i == (expr_i_1 AND ... AND expr_i_m). """ def __init__(self, queries: list["Query"], support_modules: list): @@ -71,7 +79,7 @@ def __init__(self, queries: list["Query"], support_modules: list): self._support_modules = support_modules @classmethod - def parse(cls, response: CompileResponse): + def parse(cls, response: CompileResponse) -> "QuerySet": """ example data: # queryset @@ -99,11 +107,11 @@ def parse(cls, response: CompileResponse): return cls([Query.parse(q) for q in queries], []) @property - def queries(self): + def queries(self) -> list["Query"]: return self._queries def __repr__(self): - queries_str = "\n".join([indent_string(repr(r)) for r in self._queries]) + queries_str = "\n".join([indent_lines(repr(r)) for r in self._queries]) return "QuerySet([\n{}\n])\n".format(queries_str) @property @@ -130,7 +138,7 @@ def __init__(self, expressions: list["Expression"]): self._expressions = expressions @classmethod - def parse(cls, query: CRQuery): + def parse(cls, query: CRQuery) -> "Query": """ example data: # query (an array of expressions) @@ -147,7 +155,7 @@ def parse(cls, query: CRQuery): return cls([Expression.parse(e) for e in query.__root__]) @property - def expressions(self): + def expressions(self) -> list["Expression"]: return self._expressions @property @@ -158,7 +166,7 @@ def always_true(self) -> bool: return len(self._expressions) == 0 def __repr__(self): - exprs_str = "\n".join([indent_string(repr(e)) for e in self._expressions]) + exprs_str = "\n".join([indent_lines(repr(e)) for e in self._expressions]) return "Query([\n{}\n])\n".format(exprs_str) @@ -174,7 +182,7 @@ def __init__(self, terms: list["Term"]): self._terms = terms @classmethod - def parse(cls, data: CRExpression): + def parse(cls, data: CRExpression) -> "Expression": """ example data: # expression @@ -226,21 +234,21 @@ def parse(cls, data: CRExpression): return cls([TermParser.parse(t) for t in terms]) @property - def operator(self): + def operator(self) -> "Term": """ returns the term that is the operator of the expression (typically the first term) """ return self._terms[0] @property - def operands(self): + def operands(self) -> list["Term"]: """ returns the terms that are the operands of the expression """ return self._terms[1:] @property - def terms(self): + def terms(self) -> list["Term"]: """ returns all the terms of the expression """ @@ -255,11 +263,16 @@ def __repr__(self): class Term(Generic[T]): + """ + a term is an atomic part of an expression (line of code). + it is typically a literal (number, string, etc), a reference (variable) or an operator (e.g: "==" or ">"). + """ + def __init__(self, value: T): self.value = value @classmethod - def parse(cls, data: T): + def parse(cls, data: T) -> "Term": return cls(data) def __repr__(self): @@ -288,24 +301,36 @@ def __repr__(self): class Ref: + """ + represents a reference in OPA, which is a path to a document in the OPA document tree. + when translating this into a boolean expression tree, a reference would typically translate into a variable. + """ + def __init__(self, ref_parts: list[str]): self._parts = ref_parts @property - def parts(self): + def parts(self) -> list[str]: + """ + the parts of the full path to the document, each part is a node in OPA document tree. + """ return self._parts @property - def as_string(self): - return str(self) + def as_string(self) -> str: + return str(self) # calls __str__(self) def __str__(self): return ".".join(self._parts) class RefTerm(Term[Ref]): + """ + A term that represents an OPA reference, holds a Ref object as a value. + """ + @classmethod - def parse(cls, terms: list[dict]): + def parse(cls, terms: list[dict]) -> "Term[Ref]": assert len(terms) > 0 parsed_terms: list[Term] = [TermParser.parse(CRTerm(**t)) for t in terms] var_term = parsed_terms[0] @@ -327,7 +352,7 @@ def __repr__(self): class Call: """ - represents a function call + represents a function call expression inside OPA. """ def __init__(self, func: Term, args: list[Term]): @@ -335,11 +360,17 @@ def __init__(self, func: Term, args: list[Term]): self._args = args @property - def func(self): + def func(self) -> Term: + """ + a (ref)term that holds the reference to the name of the function that is acting on the arguments of the function call + """ return self._func @property - def args(self): + def args(self) -> list[Term]: + """ + the terms representing the arguments of the function call + """ return self._args def __str__(self): @@ -347,8 +378,12 @@ def __str__(self): class CallTerm(Term[Call]): + """ + A term that represents a function call Term, holds a Call object as a value. + """ + @classmethod - def parse(cls, terms: list[dict]): + def parse(cls, terms: list[dict]) -> Term[Call]: assert len(terms) > 0 parsed_terms: list[Term] = [TermParser.parse(CRTerm(**t)) for t in terms] func_term = parsed_terms[0] From d379086c20ff92568a8ef66e9610191858e9fff6 Mon Sep 17 00:00:00 2001 From: Asaf Cohen Date: Sun, 28 Jul 2024 16:24:38 +0300 Subject: [PATCH 09/23] refactor file structure --- horizon/enforcer/api.py | 2 +- horizon/enforcer/data_filtering/compile_api/__init__.py | 0 .../data_filtering/{ => compile_api}/compile_client.py | 0 horizon/enforcer/data_filtering/{ => compile_api}/schemas.py | 0 horizon/enforcer/data_filtering/rego_ast/__init__.py | 0 .../data_filtering/{rego_ast.py => rego_ast/parser.py} | 2 +- horizon/enforcer/data_filtering/tests/test_ast_parser.py | 4 ++-- horizon/enforcer/data_filtering/tests/test_compile_parsing.py | 2 +- 8 files changed, 5 insertions(+), 5 deletions(-) create mode 100644 horizon/enforcer/data_filtering/compile_api/__init__.py rename horizon/enforcer/data_filtering/{ => compile_api}/compile_client.py (100%) rename horizon/enforcer/data_filtering/{ => compile_api}/schemas.py (100%) create mode 100644 horizon/enforcer/data_filtering/rego_ast/__init__.py rename horizon/enforcer/data_filtering/{rego_ast.py => rego_ast/parser.py} (99%) diff --git a/horizon/enforcer/api.py b/horizon/enforcer/api.py index a0f4b394..23d917fd 100644 --- a/horizon/enforcer/api.py +++ b/horizon/enforcer/api.py @@ -19,7 +19,7 @@ from horizon.authentication import enforce_pdp_token from horizon.config import sidecar_config -from horizon.enforcer.data_filtering.compile_client import OpaCompileClient +from horizon.enforcer.data_filtering.compile_api.compile_client import OpaCompileClient from horizon.enforcer.schemas import ( AuthorizationQuery, AuthorizationResult, diff --git a/horizon/enforcer/data_filtering/compile_api/__init__.py b/horizon/enforcer/data_filtering/compile_api/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/horizon/enforcer/data_filtering/compile_client.py b/horizon/enforcer/data_filtering/compile_api/compile_client.py similarity index 100% rename from horizon/enforcer/data_filtering/compile_client.py rename to horizon/enforcer/data_filtering/compile_api/compile_client.py diff --git a/horizon/enforcer/data_filtering/schemas.py b/horizon/enforcer/data_filtering/compile_api/schemas.py similarity index 100% rename from horizon/enforcer/data_filtering/schemas.py rename to horizon/enforcer/data_filtering/compile_api/schemas.py diff --git a/horizon/enforcer/data_filtering/rego_ast/__init__.py b/horizon/enforcer/data_filtering/rego_ast/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/horizon/enforcer/data_filtering/rego_ast.py b/horizon/enforcer/data_filtering/rego_ast/parser.py similarity index 99% rename from horizon/enforcer/data_filtering/rego_ast.py rename to horizon/enforcer/data_filtering/rego_ast/parser.py index ae55366a..a17485d0 100644 --- a/horizon/enforcer/data_filtering/rego_ast.py +++ b/horizon/enforcer/data_filtering/rego_ast/parser.py @@ -39,7 +39,7 @@ from types import NoneType from typing import Generic, List, TypeVar -from horizon.enforcer.data_filtering.schemas import ( +from horizon.enforcer.data_filtering.compile_api.schemas import ( CRExpression, CRQuery, CompileResponse, diff --git a/horizon/enforcer/data_filtering/tests/test_ast_parser.py b/horizon/enforcer/data_filtering/tests/test_ast_parser.py index 2150ff0f..4aa5fa9e 100644 --- a/horizon/enforcer/data_filtering/tests/test_ast_parser.py +++ b/horizon/enforcer/data_filtering/tests/test_ast_parser.py @@ -1,4 +1,4 @@ -from horizon.enforcer.data_filtering.rego_ast import ( +from horizon.enforcer.data_filtering.rego_ast.parser import ( BooleanTerm, Call, CallTerm, @@ -14,7 +14,7 @@ StringTerm, VarTerm, ) -from horizon.enforcer.data_filtering.schemas import ( +from horizon.enforcer.data_filtering.compile_api.schemas import ( CRTerm, CRExpression, CRQuery, diff --git a/horizon/enforcer/data_filtering/tests/test_compile_parsing.py b/horizon/enforcer/data_filtering/tests/test_compile_parsing.py index 19223bb4..8023cc73 100644 --- a/horizon/enforcer/data_filtering/tests/test_compile_parsing.py +++ b/horizon/enforcer/data_filtering/tests/test_compile_parsing.py @@ -1,6 +1,6 @@ import json -from horizon.enforcer.data_filtering.schemas import CRTerm, CompileResponse +from horizon.enforcer.data_filtering.compile_api.schemas import CRTerm, CompileResponse COMPILE_RESPONE_RBAC_NO_SUPPORT_BLOCK = """{ From 34974c80b95080572c162cdd4644991b4def764d Mon Sep 17 00:00:00 2001 From: Asaf Cohen Date: Mon, 29 Jul 2024 13:06:35 +0300 Subject: [PATCH 10/23] boolean expression schema --- .../boolean_expression/__init__.py | 0 .../boolean_expression/schemas.py | 89 ++++++++++++++++++ .../tests/test_boolean_expression_schema.py | 93 +++++++++++++++++++ 3 files changed, 182 insertions(+) create mode 100644 horizon/enforcer/data_filtering/boolean_expression/__init__.py create mode 100644 horizon/enforcer/data_filtering/boolean_expression/schemas.py create mode 100644 horizon/enforcer/data_filtering/tests/test_boolean_expression_schema.py diff --git a/horizon/enforcer/data_filtering/boolean_expression/__init__.py b/horizon/enforcer/data_filtering/boolean_expression/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/horizon/enforcer/data_filtering/boolean_expression/schemas.py b/horizon/enforcer/data_filtering/boolean_expression/schemas.py new file mode 100644 index 00000000..b128eb13 --- /dev/null +++ b/horizon/enforcer/data_filtering/boolean_expression/schemas.py @@ -0,0 +1,89 @@ +from __future__ import annotations + +from enum import Enum +from typing import Any, Optional, Union + +from pydantic import BaseModel, Field, root_validator + + +class BaseSchema(BaseModel): + class Config: + orm_mode = True + allow_population_by_field_name = True + + +class Variable(BaseSchema): + """ + represents a variable in a boolean expression tree. + + if the boolean expression originated in OPA - this was originally a reference in the OPA document tree. + """ + + variable: str = Field(..., description="a path to a variable (reference)") + + +class Value(BaseSchema): + """ + Represents a value (literal) in a boolean expression tree. + Could be of any jsonable type: string, int, boolean, float, list, dict. + """ + + value: Any = Field( + ..., description="a literal value, typically compared to a variable" + ) + + +class Expr(BaseSchema): + operator: str = Field(..., description="the name of the operator") + operands: list["Operand"] = Field(..., description="the operands to the expression") + + +class Expression(BaseSchema): + """ + represents a boolean expression, comparised of logical operators (e.g: and/or/not) and comparison operators (e.g: ==, >, <) + + we translate OPA call terms to expressions, treating the operator as the function name and the operands as the args. + """ + + expression: Expr = Field( + ..., + description="represents a boolean expression, comparised of logical operators (e.g: and/or/not) and comparison operators (e.g: ==, >, <)", + ) + + +Operand = Union[Variable, Value, "Expression"] + + +class ResidualPolicyType(str, Enum): + ALWAYS_ALLOW = "always_allow" + ALWAYS_DENY = "always_deny" + CONDITIONAL = "conditional" + + +class ResidualPolicyResponse(BaseSchema): + type: ResidualPolicyType = Field(..., description="the type of the residual policy") + condition: Optional["Expression"] = Field( + None, + description="an optional condition, exists if the type of the residual policy is CONDITIONAL", + ) + + @root_validator + def check_condition_exists_when_needed(cls, values: dict): + type, condition = values.get("type"), values.get("condition", None) + if ( + type == ResidualPolicyType.ALWAYS_ALLOW + or type == ResidualPolicyType.ALWAYS_DENY + ) and condition is not None: + raise ValueError( + f"invalid residual policy: a condition exists but the type is not CONDITIONAL, instead: {type}" + ) + if type == ResidualPolicyType.CONDITIONAL and condition is None: + raise ValueError( + f"invalid residual policy: type is CONDITIONAL, but no condition is provided" + ) + return values + + +Expr.update_forward_refs() +Expression.update_forward_refs() +ResidualPolicyResponse.update_forward_refs() diff --git a/horizon/enforcer/data_filtering/tests/test_boolean_expression_schema.py b/horizon/enforcer/data_filtering/tests/test_boolean_expression_schema.py new file mode 100644 index 00000000..db115d10 --- /dev/null +++ b/horizon/enforcer/data_filtering/tests/test_boolean_expression_schema.py @@ -0,0 +1,93 @@ +import pytest +import pydantic + +from horizon.enforcer.data_filtering.boolean_expression.schemas import ( + ResidualPolicyResponse, + ResidualPolicyType, + Expression, +) + + +def test_valid_residual_policies(): + d = { + "type": "conditional", + "condition": { + "expression": { + "operator": "or", + "operands": [ + { + "expression": { + "operator": "eq", + "operands": [ + {"variable": "input.resource.tenant"}, + {"value": "default"}, + ], + } + }, + { + "expression": { + "operator": "eq", + "operands": [ + {"variable": "input.resource.tenant"}, + {"value": "second"}, + ], + } + }, + ], + } + }, + } + policy = ResidualPolicyResponse(**d) + assert policy.type == ResidualPolicyType.CONDITIONAL + assert policy.condition.expression.operator == "or" + assert len(policy.condition.expression.operands) == 2 + + d = {"type": "always_allow", "condition": None} + policy = ResidualPolicyResponse(**d) + assert policy.type == ResidualPolicyType.ALWAYS_ALLOW + assert policy.condition == None + + d = {"type": "always_deny", "condition": None} + policy = ResidualPolicyResponse(**d) + assert policy.type == ResidualPolicyType.ALWAYS_DENY + assert policy.condition == None + + +def test_invalid_residual_policies(): + for trival_residual_type in ["always_allow", "always_deny"]: + d = { + "type": trival_residual_type, + "condition": { + "expression": { + "operator": "or", + "operands": [ + { + "expression": { + "operator": "eq", + "operands": [ + {"variable": "input.resource.tenant"}, + {"value": "default"}, + ], + } + }, + { + "expression": { + "operator": "eq", + "operands": [ + {"variable": "input.resource.tenant"}, + {"value": "second"}, + ], + } + }, + ], + } + }, + } + with pytest.raises(pydantic.ValidationError) as e: + policy = ResidualPolicyResponse(**d) + assert "invalid residual policy" in str(e.value) + + d = {"type": "conditional", "condition": None} + with pytest.raises(pydantic.ValidationError) as e: + policy = ResidualPolicyResponse(**d) + assert "invalid residual policy" in str(e.value) From a71fedcea707c76ed14440be203de6d9c7572d11 Mon Sep 17 00:00:00 2001 From: Asaf Cohen Date: Mon, 29 Jul 2024 16:23:59 +0300 Subject: [PATCH 11/23] translate ast into a generic boolean expression --- .../boolean_expression/translator.py | 112 ++++++++++++++++++ .../data_filtering/rego_ast/parser.py | 43 ++++++- 2 files changed, 151 insertions(+), 4 deletions(-) create mode 100644 horizon/enforcer/data_filtering/boolean_expression/translator.py diff --git a/horizon/enforcer/data_filtering/boolean_expression/translator.py b/horizon/enforcer/data_filtering/boolean_expression/translator.py new file mode 100644 index 00000000..588f2412 --- /dev/null +++ b/horizon/enforcer/data_filtering/boolean_expression/translator.py @@ -0,0 +1,112 @@ +from horizon.enforcer.data_filtering.rego_ast import parser as ast +from horizon.enforcer.data_filtering.boolean_expression.schemas import ( + Operand, + ResidualPolicyResponse, + ResidualPolicyType, + Expression, + Expr, + Value, + Variable, +) + + +def translate_opa_queryset(queryset: ast.QuerySet) -> ResidualPolicyResponse: + """ + translates the Rego AST into a generic residual policy constructed as a boolean expression. + + this boolean expression can then be translated by plugins into various SQL and ORMs. + """ + if queryset.always_false: + return ResidualPolicyResponse(type=ResidualPolicyType.ALWAYS_DENY) + + if queryset.always_true: + return ResidualPolicyResponse(type=ResidualPolicyType.ALWAYS_ALLOW) + + if len(queryset.queries) == 1: + return ResidualPolicyResponse( + type=ResidualPolicyType.CONDITIONAL, + condition=translate_query(queryset.queries[0]), + ) + + queries = [query for query in queryset.queries if not query.always_true] + + if len(queries) == 0: + # no not trival queries means always true + return ResidualPolicyResponse(type=ResidualPolicyType.ALWAYS_ALLOW) + + # else, more than one query means there's a logical OR between queries + return ResidualPolicyResponse( + type=ResidualPolicyType.CONDITIONAL, + condition=Expression( + expression=Expr( + operator="or", + operands=[translate_query(query) for query in queries], + ) + ), + ) + + +def translate_query(query: ast.Query) -> Expression: + if len(query.expressions) == 1: + return translate_expression(query.expressions[0]) + + return Expression( + expression=Expr( + operator="and", + operands=[ + translate_expression(expression) for expression in query.expressions + ], + ) + ) + + +def translate_expression(expression: ast.Expression) -> Expression: + if len(expression.terms) == 1 and expression.terms[0].type == ast.TermType.CALL: + # this is a call expression + return translate_call_term(expression.terms[0].value) + + return Expression( + expression=Expr( + operator=expression.operator, + operands=[translate_term(term) for term in expression.operands], + ) + ) + + +def translate_call_term(call: ast.Call) -> Expression: + return Expression( + expression=Expr( + operator="call", + operands=[ + Expression( + expression=Expr( + operator=call.func, + operands=[translate_term(term) for term in call.args], + ) + ) + ], + ) + ) + + +def translate_term(term: ast.Term) -> Operand: + if term.type in ( + ast.TermType.NULL, + ast.TermType.BOOLEAN, + ast.TermType.NUMBER, + ast.TermType.STRING, + ): + return Value(value=term.value) + + if term.type == ast.TermType.VAR: + return Variable(variable=term.value) + + if term.type == ast.TermType.REF and isinstance(term.value, ast.Ref): + return Variable(variable=term.value.as_string) + + if term.type == ast.TermType.CALL and isinstance(term.value, ast.Call): + return translate_call_term(term.value) + + raise ValueError( + f"unable to translate term with type {term.type} and value {term.value}" + ) diff --git a/horizon/enforcer/data_filtering/rego_ast/parser.py b/horizon/enforcer/data_filtering/rego_ast/parser.py index a17485d0..1619196f 100644 --- a/horizon/enforcer/data_filtering/rego_ast/parser.py +++ b/horizon/enforcer/data_filtering/rego_ast/parser.py @@ -35,6 +35,7 @@ # +--- Term +from enum import Enum import json from types import NoneType from typing import Generic, List, TypeVar @@ -262,6 +263,16 @@ def __repr__(self): T = TypeVar("T") +class TermType(str, Enum): + NULL = "null" + BOOLEAN = "boolean" + NUMBER = "number" + STRING = "string" + VAR = "var" + REF = "ref" + CALL = "call" + + class Term(Generic[T]): """ a term is an atomic part of an expression (line of code). @@ -271,6 +282,10 @@ class Term(Generic[T]): def __init__(self, value: T): self.value = value + @property + def type(self) -> str: + return NotImplementedError() + @classmethod def parse(cls, data: T) -> "Term": return cls(data) @@ -280,25 +295,37 @@ def __repr__(self): class NullTerm(Term[NoneType]): - pass + @property + def type(self) -> str: + return TermType.NULL class BooleanTerm(Term[bool]): - pass + @property + def type(self) -> str: + return TermType.BOOLEAN class NumberTerm(Term[int | float]): - pass + @property + def type(self) -> str: + return TermType.NUMBER class StringTerm(Term[str]): - pass + @property + def type(self) -> str: + return TermType.STRING class VarTerm(Term[str]): def __repr__(self): return self.value + @property + def type(self) -> str: + return TermType.VAR + class Ref: """ @@ -329,6 +356,10 @@ class RefTerm(Term[Ref]): A term that represents an OPA reference, holds a Ref object as a value. """ + @property + def type(self) -> str: + return TermType.REF + @classmethod def parse(cls, terms: list[dict]) -> "Term[Ref]": assert len(terms) > 0 @@ -382,6 +413,10 @@ class CallTerm(Term[Call]): A term that represents a function call Term, holds a Call object as a value. """ + @property + def type(self) -> str: + return TermType.CALL + @classmethod def parse(cls, terms: list[dict]) -> Term[Call]: assert len(terms) > 0 From a0152838a24e2fe3dd4a8a142e7b269ef3207094 Mon Sep 17 00:00:00 2001 From: Asaf Cohen Date: Mon, 29 Jul 2024 16:43:20 +0300 Subject: [PATCH 12/23] test translation from ast into boolean expression --- .../boolean_expression/translator.py | 7 +- .../tests/test_ast_translation.py | 77 +++++++++++++++++++ 2 files changed, 83 insertions(+), 1 deletion(-) create mode 100644 horizon/enforcer/data_filtering/tests/test_ast_translation.py diff --git a/horizon/enforcer/data_filtering/boolean_expression/translator.py b/horizon/enforcer/data_filtering/boolean_expression/translator.py index 588f2412..630934ea 100644 --- a/horizon/enforcer/data_filtering/boolean_expression/translator.py +++ b/horizon/enforcer/data_filtering/boolean_expression/translator.py @@ -65,9 +65,14 @@ def translate_expression(expression: ast.Expression) -> Expression: # this is a call expression return translate_call_term(expression.terms[0].value) + if not isinstance(expression.operator, ast.RefTerm): + raise ValueError( + f"The operator in an expression must be a term of type ref, instead got type {expression.operator.type} and value {expression.operator.value}" + ) + return Expression( expression=Expr( - operator=expression.operator, + operator=expression.operator.value.as_string, operands=[translate_term(term) for term in expression.operands], ) ) diff --git a/horizon/enforcer/data_filtering/tests/test_ast_translation.py b/horizon/enforcer/data_filtering/tests/test_ast_translation.py new file mode 100644 index 00000000..40645254 --- /dev/null +++ b/horizon/enforcer/data_filtering/tests/test_ast_translation.py @@ -0,0 +1,77 @@ +from horizon.enforcer.data_filtering.compile_api.schemas import CompileResponse +from horizon.enforcer.data_filtering.rego_ast import parser as ast +from horizon.enforcer.data_filtering.boolean_expression.schemas import ( + ResidualPolicyResponse, + ResidualPolicyType, +) +from horizon.enforcer.data_filtering.boolean_expression.translator import ( + translate_opa_queryset, +) + + +def test_ast_to_boolean_expr(): + response = CompileResponse( + **{ + "result": { + "queries": [ + [ + { + "index": 0, + "terms": [ + { + "type": "ref", + "value": [{"type": "var", "value": "eq"}], + }, + { + "type": "ref", + "value": [ + {"type": "var", "value": "input"}, + {"type": "string", "value": "resource"}, + {"type": "string", "value": "tenant"}, + ], + }, + {"type": "string", "value": "default"}, + ], + } + ], + [ + { + "index": 0, + "terms": [ + { + "type": "ref", + "value": [{"type": "var", "value": "eq"}], + }, + { + "type": "ref", + "value": [ + {"type": "var", "value": "input"}, + {"type": "string", "value": "resource"}, + {"type": "string", "value": "tenant"}, + ], + }, + {"type": "string", "value": "second"}, + ], + } + ], + ] + } + } + ) + queryset = ast.QuerySet.parse(response) + residual_policy: ResidualPolicyResponse = translate_opa_queryset(queryset) + # print(json.dumps(residual_policy.dict(), indent=2)) + + assert residual_policy.type == ResidualPolicyType.CONDITIONAL + assert residual_policy.condition.expression.operator == "or" + assert len(residual_policy.condition.expression.operands) == 2 + + assert residual_policy.condition.expression.operands[0].expression.operator == "eq" + assert ( + residual_policy.condition.expression.operands[0].expression.operands[0].variable + == "input.resource.tenant" + ) + assert ( + residual_policy.condition.expression.operands[0].expression.operands[1].value + == "default" + ) From a68bcc9ca0c38ab429851de1b262d8b94469c071 Mon Sep 17 00:00:00 2001 From: Asaf Cohen Date: Mon, 29 Jul 2024 19:39:26 +0300 Subject: [PATCH 13/23] filter_resources endpoint issues partial eval request and translates back to boolean expression --- horizon/enforcer/api.py | 10 +++- .../boolean_expression/schemas.py | 4 ++ .../compile_api/compile_client.py | 48 +++++++++++++++---- 3 files changed, 52 insertions(+), 10 deletions(-) diff --git a/horizon/enforcer/api.py b/horizon/enforcer/api.py index 23d917fd..6befbef0 100644 --- a/horizon/enforcer/api.py +++ b/horizon/enforcer/api.py @@ -4,7 +4,7 @@ from typing import cast, Optional, Union, Dict, List import aiohttp -from fastapi import APIRouter, Depends, Header +from fastapi import APIRouter, Depends, Header, Query from fastapi import HTTPException from fastapi import Request, Response, status from opal_client.config import opal_client_config @@ -58,6 +58,7 @@ AUTHORIZED_USERS_POLICY_PACKAGE = "permit.authorized_users.authorized_users" USER_TENANTS_POLICY_PACKAGE = USER_PERMISSIONS_POLICY_PACKAGE + ".tenants" KONG_ROUTES_TABLE_FILE = "/config/kong_routes.json" +MAIN_PARTIAL_EVAL_PACKAGE = "permit.partial_eval" stats_manager = StatisticsManager( interval_seconds=sidecar_config.OPA_CLIENT_FAILURE_THRESHOLD_INTERVAL, @@ -688,11 +689,15 @@ async def is_allowed_kong(request: Request, query: KongAuthorizationQuery): async def filter_resources( request: Request, input: AuthorizationQuery, + raw: bool = Query( + False, + description="whether we should include the OPA raw compilation result in the response. this can help us debug the translation of the AST", + ), x_permit_sdk_language: Optional[str] = Depends(notify_seen_sdk), ): headers = transform_headers(request) client = OpaCompileClient(headers=headers) - COMPILE_ROOT_RULE_REFERENCE = "data.example.rbac4.allow" + COMPILE_ROOT_RULE_REFERENCE = f"data.{MAIN_PARTIAL_EVAL_PACKAGE}.allow" query = f"{COMPILE_ROOT_RULE_REFERENCE} == true" response = await client.compile_query( query=query, @@ -702,6 +707,7 @@ async def filter_resources( "input.resource.tenant", "input.resource.attributes", ], + raw=raw, ) return response diff --git a/horizon/enforcer/data_filtering/boolean_expression/schemas.py b/horizon/enforcer/data_filtering/boolean_expression/schemas.py index b128eb13..44c47576 100644 --- a/horizon/enforcer/data_filtering/boolean_expression/schemas.py +++ b/horizon/enforcer/data_filtering/boolean_expression/schemas.py @@ -66,6 +66,10 @@ class ResidualPolicyResponse(BaseSchema): None, description="an optional condition, exists if the type of the residual policy is CONDITIONAL", ) + raw: Optional[dict] = Field( + None, + description="raw OPA compilation result, provided for debugging purposes", + ) @root_validator def check_condition_exists_when_needed(cls, values: dict): diff --git a/horizon/enforcer/data_filtering/compile_api/compile_client.py b/horizon/enforcer/data_filtering/compile_api/compile_client.py index bd74a299..85ee5234 100644 --- a/horizon/enforcer/data_filtering/compile_api/compile_client.py +++ b/horizon/enforcer/data_filtering/compile_api/compile_client.py @@ -6,6 +6,14 @@ from opal_client.logger import logger from horizon.enforcer.schemas import AuthorizationQuery +from horizon.enforcer.data_filtering.compile_api.schemas import CompileResponse +from horizon.enforcer.data_filtering.rego_ast import parser as ast +from horizon.enforcer.data_filtering.boolean_expression.schemas import ( + ResidualPolicyResponse, +) +from horizon.enforcer.data_filtering.boolean_expression.translator import ( + translate_opa_queryset, +) class OpaCompileClient: @@ -17,7 +25,11 @@ def __init__(self, headers: dict): ) async def compile_query( - self, query: str, input: AuthorizationQuery, unknowns: list[str] + self, + query: str, + input: AuthorizationQuery, + unknowns: list[str], + raw: bool = False, ): # we don't want debug rules when we try to reduce the policy into a partial policy input = {**input.dict(), "use_debugger": False} @@ -27,20 +39,35 @@ async def compile_query( "unknowns": unknowns, } try: - logger.debug("Compiling OPA query: {}", data) + logger.info("Compiling OPA query: {}", data) async with self._client as session: async with session.post( "/v1/compile", data=json.dumps(data), raise_for_status=True, ) as response: - content = await response.text() - return Response( - content=content, - status_code=response.status, - headers=dict(response.headers), - media_type="application/json", + opa_compile_result = await response.json() + logger.info( + "OPA compile query result: status={status}, response={response}", + status=response.status, + response=json.dumps(opa_compile_result), ) + try: + residual_policy = self.translate_rego_ast(opa_compile_result) + if raw: + residual_policy.raw = opa_compile_result + return Response( + content=json.dumps(residual_policy.dict()), + status_code=status.HTTP_200_OK, + media_type="application/json", + ) + except Exception as exc: + return HTTPException( + status.HTTP_406_NOT_ACCEPTABLE, + detail="failed to translate compiled OPA query (query: {query}, response: {response}, exc={exc})".format( + query=data, response=opa_compile_result, exc=exc + ), + ) except aiohttp.ClientResponseError as e: exc = HTTPException( status.HTTP_502_BAD_GATEWAY, # 502 indicates server got an error from another server @@ -57,3 +84,8 @@ async def compile_query( ) logger.warning(exc.detail) raise exc + + def translate_rego_ast(self, response: dict) -> ResidualPolicyResponse: + response = CompileResponse(**response) + queryset = ast.QuerySet.parse(response) + return translate_opa_queryset(queryset) From 1d4cd698253a06f8f09f07e793c054245db87bf7 Mon Sep 17 00:00:00 2001 From: Asaf Cohen Date: Mon, 29 Jul 2024 19:49:04 +0300 Subject: [PATCH 14/23] refactor compile client to return the residual policy itself and leave response encoding to the api wrapper --- horizon/enforcer/api.py | 38 +++++++++++++++++-- .../compile_api/compile_client.py | 8 +--- 2 files changed, 36 insertions(+), 10 deletions(-) diff --git a/horizon/enforcer/api.py b/horizon/enforcer/api.py index 6befbef0..013fc34c 100644 --- a/horizon/enforcer/api.py +++ b/horizon/enforcer/api.py @@ -60,6 +60,33 @@ KONG_ROUTES_TABLE_FILE = "/config/kong_routes.json" MAIN_PARTIAL_EVAL_PACKAGE = "permit.partial_eval" +# TODO: a more robust policy needs to be added to permit default managed policy +# this policy is partial-eval friendly, but only supports RBAC at the moment +TEMP_PARTIAL_EVAL_POLICY = """ +package permit.partial_eval + +import future.keywords.contains +import future.keywords.if +import future.keywords.in + +default allow := false + +allow if { + checked_permission := sprintf("%s:%s", [input.resource.type, input.action]) + + some granting_role, role_data in data.roles + some resource_type, actions in role_data.grants + granted_action := actions[_] + granted_permission := sprintf("%s:%s", [resource_type, granted_action]) + + some tenant, roles in data.users[input.user.key].roleAssignments + role := roles[_] + role == granting_role + checked_permission == granted_permission + input.resource.tenant == tenant +} +""" + stats_manager = StatisticsManager( interval_seconds=sidecar_config.OPA_CLIENT_FAILURE_THRESHOLD_INTERVAL, failures_threshold_percentage=sidecar_config.OPA_CLIENT_FAILURE_THRESHOLD_PERCENTAGE, @@ -683,8 +710,7 @@ async def is_allowed_kong(request: Request, query: KongAuthorizationQuery): response_model=AuthorizationResult, status_code=status.HTTP_200_OK, response_model_exclude_none=True, - # TODO: restore authz - # dependencies=[Depends(enforce_pdp_token)], + dependencies=[Depends(enforce_pdp_token)], ) async def filter_resources( request: Request, @@ -699,7 +725,7 @@ async def filter_resources( client = OpaCompileClient(headers=headers) COMPILE_ROOT_RULE_REFERENCE = f"data.{MAIN_PARTIAL_EVAL_PACKAGE}.allow" query = f"{COMPILE_ROOT_RULE_REFERENCE} == true" - response = await client.compile_query( + residual_policy = await client.compile_query( query=query, input=input, unknowns=[ @@ -709,6 +735,10 @@ async def filter_resources( ], raw=raw, ) - return response + return Response( + content=json.dumps(residual_policy.dict()), + status_code=status.HTTP_200_OK, + media_type="application/json", + ) return router diff --git a/horizon/enforcer/data_filtering/compile_api/compile_client.py b/horizon/enforcer/data_filtering/compile_api/compile_client.py index 85ee5234..1bd774f0 100644 --- a/horizon/enforcer/data_filtering/compile_api/compile_client.py +++ b/horizon/enforcer/data_filtering/compile_api/compile_client.py @@ -30,7 +30,7 @@ async def compile_query( input: AuthorizationQuery, unknowns: list[str], raw: bool = False, - ): + ) -> ResidualPolicyResponse: # we don't want debug rules when we try to reduce the policy into a partial policy input = {**input.dict(), "use_debugger": False} data = { @@ -56,11 +56,7 @@ async def compile_query( residual_policy = self.translate_rego_ast(opa_compile_result) if raw: residual_policy.raw = opa_compile_result - return Response( - content=json.dumps(residual_policy.dict()), - status_code=status.HTTP_200_OK, - media_type="application/json", - ) + return residual_policy except Exception as exc: return HTTPException( status.HTTP_406_NOT_ACCEPTABLE, From 07ac5a2ed96f3dacc23323fde4c87c1884e8f52e Mon Sep 17 00:00:00 2001 From: Asaf Cohen Date: Mon, 29 Jul 2024 22:05:52 +0300 Subject: [PATCH 15/23] wip sql mapping --- horizon/enforcer/api.py | 30 ++++ .../boolean_expression/schemas.py | 5 + .../boolean_expression/translator.py | 9 +- .../enforcer/data_filtering/sdk/__init__.py | 0 .../data_filtering/sdk/permit_filter.py | 72 +++++++++ .../data_filtering/sdk/permit_sqlalchemy.py | 143 ++++++++++++++++++ 6 files changed, 256 insertions(+), 3 deletions(-) create mode 100644 horizon/enforcer/data_filtering/sdk/__init__.py create mode 100644 horizon/enforcer/data_filtering/sdk/permit_filter.py create mode 100644 horizon/enforcer/data_filtering/sdk/permit_sqlalchemy.py diff --git a/horizon/enforcer/api.py b/horizon/enforcer/api.py index 013fc34c..3144fe13 100644 --- a/horizon/enforcer/api.py +++ b/horizon/enforcer/api.py @@ -741,4 +741,34 @@ async def filter_resources( media_type="application/json", ) + @router.post( + "/filter_resources_get_sql", + response_model=AuthorizationResult, + status_code=status.HTTP_200_OK, + response_model_exclude_none=True, + dependencies=[Depends(enforce_pdp_token)], + ) + async def filter_resources_get_sql( + request: Request, + input: AuthorizationQuery, + x_permit_sdk_language: Optional[str] = Depends(notify_seen_sdk), + ): + """ + TODO: temp endpoint, instead we should wrap the capability in the SDK + """ + headers = transform_headers(request) + client = OpaCompileClient(headers=headers) + COMPILE_ROOT_RULE_REFERENCE = f"data.{MAIN_PARTIAL_EVAL_PACKAGE}.allow" + query = f"{COMPILE_ROOT_RULE_REFERENCE} == true" + residual_policy = await client.compile_query( + query=query, + input=input, + unknowns=[ + "input.resource.key", + "input.resource.tenant", + "input.resource.attributes", + ], + raw=True, + ) + return router diff --git a/horizon/enforcer/data_filtering/boolean_expression/schemas.py b/horizon/enforcer/data_filtering/boolean_expression/schemas.py index 44c47576..5494be58 100644 --- a/horizon/enforcer/data_filtering/boolean_expression/schemas.py +++ b/horizon/enforcer/data_filtering/boolean_expression/schemas.py @@ -5,6 +5,11 @@ from pydantic import BaseModel, Field, root_validator +LOGICAL_AND = "and" +LOGICAL_OR = "or" +LOGICAL_NOT = "not" +CALL_OPERATOR = "call" + class BaseSchema(BaseModel): class Config: diff --git a/horizon/enforcer/data_filtering/boolean_expression/translator.py b/horizon/enforcer/data_filtering/boolean_expression/translator.py index 630934ea..895ee6bb 100644 --- a/horizon/enforcer/data_filtering/boolean_expression/translator.py +++ b/horizon/enforcer/data_filtering/boolean_expression/translator.py @@ -1,5 +1,6 @@ from horizon.enforcer.data_filtering.rego_ast import parser as ast from horizon.enforcer.data_filtering.boolean_expression.schemas import ( + CALL_OPERATOR, Operand, ResidualPolicyResponse, ResidualPolicyType, @@ -7,6 +8,8 @@ Expr, Value, Variable, + LOGICAL_AND, + LOGICAL_OR, ) @@ -39,7 +42,7 @@ def translate_opa_queryset(queryset: ast.QuerySet) -> ResidualPolicyResponse: type=ResidualPolicyType.CONDITIONAL, condition=Expression( expression=Expr( - operator="or", + operator=LOGICAL_OR, operands=[translate_query(query) for query in queries], ) ), @@ -52,7 +55,7 @@ def translate_query(query: ast.Query) -> Expression: return Expression( expression=Expr( - operator="and", + operator=LOGICAL_AND, operands=[ translate_expression(expression) for expression in query.expressions ], @@ -81,7 +84,7 @@ def translate_expression(expression: ast.Expression) -> Expression: def translate_call_term(call: ast.Call) -> Expression: return Expression( expression=Expr( - operator="call", + operator=CALL_OPERATOR, operands=[ Expression( expression=Expr( diff --git a/horizon/enforcer/data_filtering/sdk/__init__.py b/horizon/enforcer/data_filtering/sdk/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/horizon/enforcer/data_filtering/sdk/permit_filter.py b/horizon/enforcer/data_filtering/sdk/permit_filter.py new file mode 100644 index 00000000..389b10e8 --- /dev/null +++ b/horizon/enforcer/data_filtering/sdk/permit_filter.py @@ -0,0 +1,72 @@ +from typing import Union + +from horizon.enforcer import schemas as enforcer_schemas +from horizon.enforcer.api import MAIN_PARTIAL_EVAL_PACKAGE +from horizon.enforcer.data_filtering.boolean_expression.schemas import ( + ResidualPolicyResponse, +) +from horizon.enforcer.data_filtering.compile_api.compile_client import OpaCompileClient + +User = Union[dict, str] +Action = str +Resource = Union[dict, str] + + +def normalize_user(user: User) -> dict: + if isinstance(user, str): + return dict(key=user) + else: + return user + + +def normalize_resource_type(resource: Resource) -> str: + if isinstance(resource, dict): + t = resource.get("type", None) + if t is not None and isinstance(t, str): + return t + raise ValueError("no resource type provided") + else: + return resource + + +def filter_resource_query( + user: User, action: Action, resource: Resource +) -> enforcer_schemas.AuthorizationQuery: + normalized_user = normalize_user(user) + resource_type: str = normalize_resource_type(resource) + return enforcer_schemas.AuthorizationQuery( + user=normalized_user, + action=action, + resource=enforcer_schemas.Resource(type=resource_type), + ) + + +class Permit: + """ + stub for future SDK code + """ + + def __init__(self, token: str): + self._headers = { + "Authorization": f"Bearer {token}", + "Content-Type": "application/json", + } + + async def filter_resources( + self, user: User, action: Action, resource: Resource + ) -> ResidualPolicyResponse: + """ + stub for future permit.filter_resources() function + """ + client = OpaCompileClient(headers=self._headers) + input = filter_resource_query(user, action, resource) + return await client.compile_query( + query=f"data.{MAIN_PARTIAL_EVAL_PACKAGE}.allow == true", + input=input, + unknowns=[ + "input.resource.key", + "input.resource.tenant", + "input.resource.attributes", + ], + raw=True, + ) diff --git a/horizon/enforcer/data_filtering/sdk/permit_sqlalchemy.py b/horizon/enforcer/data_filtering/sdk/permit_sqlalchemy.py new file mode 100644 index 00000000..65faf930 --- /dev/null +++ b/horizon/enforcer/data_filtering/sdk/permit_sqlalchemy.py @@ -0,0 +1,143 @@ +from typing import Any, Callable, Dict, List, Tuple, Union, cast + +import sqlalchemy as sa + +# import Column, Table, and_, not_, or_, select +from sqlalchemy.orm import DeclarativeMeta, InstrumentedAttribute +from sqlalchemy.sql import Select +from sqlalchemy.sql.expression import BinaryExpression, ColumnOperators + +from horizon.enforcer.data_filtering.boolean_expression.schemas import ( + CALL_OPERATOR, + LOGICAL_AND, + LOGICAL_NOT, + LOGICAL_OR, + Expression, + Operand, + ResidualPolicyResponse, + ResidualPolicyType, + Value, + Variable, +) + +Table = Union[sa.Table, DeclarativeMeta] +Column = Union[sa.Column, InstrumentedAttribute] +Condition = Union[BinaryExpression, ColumnOperators] + + +OperatorMap = Dict[str, Callable[[Column, Any], Condition]] + +SUPPORTED_OPERATORS: OperatorMap = { + "eq": lambda c, v: c == v, + "ne": lambda c, v: c != v, + "lt": lambda c, v: c < v, + "gt": lambda c, v: c > v, + "le": lambda c, v: c <= v, + "ge": lambda c, v: c >= v, +} + + +def operator_to_sql(operator: str, column: Column, value: Any) -> Condition: + if (operator_fn := SUPPORTED_OPERATORS.get(operator)) is not None: + return operator_fn(column, value) + raise ValueError(f"Unrecognised operator: {operator}") + + +def _get_table_name(t: Table) -> str: + try: + return t.__table__.name + except AttributeError: + return t.name + + +def verify_join_conditions( + table: Table, + reference_mapping: Dict[str, Column], + join_conditions: Union[List[Tuple[Table, Condition]], None] = None, +): + def column_table_name(c: Column) -> str: + return cast(sa.Table, c.table).name + + def is_main_table_column(c: Column) -> bool: + return column_table_name(c) == _get_table_name(table) + + required_joins = set( + ( + column_table_name(column) + for column in reference_mapping.values() + if not is_main_table_column(column) + ) + ) + + if len(required_joins): + if join_conditions is None: + raise TypeError(f"to_query() is missing argument 'join_conditions'") + else: + missing_tables = required_joins.difference( + set((t for t, _ in join_conditions)) + ) + if len(missing_tables): + raise TypeError( + f"to_query() argument 'join_conditions' is missing mapping for tables: {repr(missing_tables)}" + ) + + +def to_query( + filter: ResidualPolicyResponse, + table: Table, + reference_mapping: Dict[str, Column], + join_conditions: Union[List[Tuple[Table, Condition]], None] = None, +) -> Select: + select_all = cast(Select, sa.select(table)) + + if filter.type == ResidualPolicyType.ALWAYS_ALLOW: + return select_all + + if filter.type == ResidualPolicyType.ALWAYS_DENY: + return select_all.where(False) + + verify_join_conditions(table, reference_mapping, join_conditions) + + def to_sql(expr: Expression): + operator = expr.expression.operator + operands = expr.expression.operands + + if operator == LOGICAL_AND: + return sa.and_(*[to_sql(o) for o in operands]) + if operator == LOGICAL_OR: + return sa.or_(*[to_sql(o) for o in operands]) + if operator == LOGICAL_NOT: + return sa.not_(*[to_sql(o) for o in operands]) + if operator == CALL_OPERATOR: + raise NotImplementedError("need to implement call() translation to sql") + + # otherwise, operator is a comparison operator + variables = [o for o in operands if isinstance(o, Variable)] + values = [o for o in operands if isinstance(o, Value)] + + if not (len(variables) == 1 and len(values) == 1): + raise NotImplementedError( + "need to implement support in more comparison operators" + ) + + variable_ref: str = variables[0].variable + value: Any = values[0].value + + try: + column = reference_mapping[variable_ref] + except KeyError: + raise KeyError( + f"Residual variable does not exist in the reference mapping: {variable_ref}" + ) + + # the operator handlers here are the leaf nodes of the recursion + return operator_to_sql(operator, column, value) + + query: Select = select_all.where(to_sql(filter.condition)) + + if join_conditions: + query = query.select_from(table) + for join_table, predicate in join_conditions: + query = query.join(join_table, predicate) + + return query From d1d7d253840a2c9890af0c2537e43332ca8623b6 Mon Sep 17 00:00:00 2001 From: Asaf Cohen Date: Mon, 29 Jul 2024 22:16:07 +0300 Subject: [PATCH 16/23] test df wip --- .../tests/test_data_filtering_usage.py | 67 +++++++++++++++++++ 1 file changed, 67 insertions(+) create mode 100644 horizon/enforcer/data_filtering/tests/test_data_filtering_usage.py diff --git a/horizon/enforcer/data_filtering/tests/test_data_filtering_usage.py b/horizon/enforcer/data_filtering/tests/test_data_filtering_usage.py new file mode 100644 index 00000000..8ea4ba8b --- /dev/null +++ b/horizon/enforcer/data_filtering/tests/test_data_filtering_usage.py @@ -0,0 +1,67 @@ +import pytest + +from horizon.enforcer.data_filtering.sdk.permit_filter import Permit + +from datetime import datetime + +from sqlalchemy import Boolean, Column, DateTime, ForeignKey, String +from sqlalchemy.orm import declarative_base, relationship + +from horizon.enforcer.data_filtering.sdk.permit_sqlalchemy import to_query + +Base = declarative_base() + + +# example db model +class User(Base): + __tablename__ = "user" + + id = Column(String, primary_key=True) + username = Column(String(255)) + email = Column(String(255)) + + +class Tenant(Base): + __tablename__ = "tenant" + + id = Column(String, primary_key=True) + name = Column(String(255)) + + +class Task(Base): + __tablename__ = "task" + + id = Column(String, primary_key=True) + created_at = Column(DateTime, default=datetime.utcnow()) + updated_at = Column(DateTime) + description = Column(String(255)) + tenant_id = Column(String(255)) + tenant_id_joined = Column(String, ForeignKey("tenant.id")) + tenant = relationship("Tenant", back_populates="tasks") + + +async def test_data_filtering_e2e(): + """ + tests how df should work e2e with stub sdk + """ + permit = Permit(token="") + filter = await permit.filter_resources("user", "read", "task") + + sa_query = to_query( + filter, + Task, + refs={ + # example how to map a column on the same model + "input.resource.tenant": Task.tenant_id, + }, + ) + + sa_query = to_query( + filter, + Task, + refs={ + # example how to map a column on a related model + "input.resource.tenant_id": Tenant.id, + }, + join_conditions=[(Tenant, Task.tenant_id_joined == Tenant.id)], + ) From 4542323770224a335d624eb58f7af3707645768f Mon Sep 17 00:00:00 2001 From: Asaf Cohen Date: Thu, 8 Aug 2024 16:11:54 +0300 Subject: [PATCH 17/23] fix small bugs with sql conversion + add sql conversion tests --- .../data_filtering/sdk/permit_sqlalchemy.py | 11 +- .../tests/test_data_filtering_usage.py | 140 +++++++++++++----- 2 files changed, 111 insertions(+), 40 deletions(-) diff --git a/horizon/enforcer/data_filtering/sdk/permit_sqlalchemy.py b/horizon/enforcer/data_filtering/sdk/permit_sqlalchemy.py index 65faf930..b05bd4f1 100644 --- a/horizon/enforcer/data_filtering/sdk/permit_sqlalchemy.py +++ b/horizon/enforcer/data_filtering/sdk/permit_sqlalchemy.py @@ -73,9 +73,10 @@ def is_main_table_column(c: Column) -> bool: if join_conditions is None: raise TypeError(f"to_query() is missing argument 'join_conditions'") else: - missing_tables = required_joins.difference( - set((t for t, _ in join_conditions)) + provided_joined_tables = set( + (_get_table_name(t) for t, _ in join_conditions) ) + missing_tables = required_joins.difference(provided_joined_tables) if len(missing_tables): raise TypeError( f"to_query() argument 'join_conditions' is missing mapping for tables: {repr(missing_tables)}" @@ -85,7 +86,7 @@ def is_main_table_column(c: Column) -> bool: def to_query( filter: ResidualPolicyResponse, table: Table, - reference_mapping: Dict[str, Column], + refs: Dict[str, Column], join_conditions: Union[List[Tuple[Table, Condition]], None] = None, ) -> Select: select_all = cast(Select, sa.select(table)) @@ -96,7 +97,7 @@ def to_query( if filter.type == ResidualPolicyType.ALWAYS_DENY: return select_all.where(False) - verify_join_conditions(table, reference_mapping, join_conditions) + verify_join_conditions(table, refs, join_conditions) def to_sql(expr: Expression): operator = expr.expression.operator @@ -124,7 +125,7 @@ def to_sql(expr: Expression): value: Any = values[0].value try: - column = reference_mapping[variable_ref] + column = refs[variable_ref] except KeyError: raise KeyError( f"Residual variable does not exist in the reference mapping: {variable_ref}" diff --git a/horizon/enforcer/data_filtering/tests/test_data_filtering_usage.py b/horizon/enforcer/data_filtering/tests/test_data_filtering_usage.py index 8ea4ba8b..5bfed842 100644 --- a/horizon/enforcer/data_filtering/tests/test_data_filtering_usage.py +++ b/horizon/enforcer/data_filtering/tests/test_data_filtering_usage.py @@ -1,10 +1,12 @@ -import pytest - -from horizon.enforcer.data_filtering.sdk.permit_filter import Permit +from horizon.enforcer.data_filtering.boolean_expression.schemas import ( + ResidualPolicyResponse, +) from datetime import datetime -from sqlalchemy import Boolean, Column, DateTime, ForeignKey, String +from sqlalchemy import Column, DateTime, ForeignKey, String +from sqlalchemy.dialects import postgresql +from sqlalchemy.sql import Select from sqlalchemy.orm import declarative_base, relationship from horizon.enforcer.data_filtering.sdk.permit_sqlalchemy import to_query @@ -12,40 +14,56 @@ Base = declarative_base() -# example db model -class User(Base): - __tablename__ = "user" - - id = Column(String, primary_key=True) - username = Column(String(255)) - email = Column(String(255)) - - -class Tenant(Base): - __tablename__ = "tenant" - - id = Column(String, primary_key=True) - name = Column(String(255)) - +def query_to_string(query: Select) -> str: + """ + utility function to print raw sql statement + """ + return str( + query.compile( + dialect=postgresql.dialect(), compile_kwargs={"literal_binds": True} + ) + ) -class Task(Base): - __tablename__ = "task" - id = Column(String, primary_key=True) - created_at = Column(DateTime, default=datetime.utcnow()) - updated_at = Column(DateTime) - description = Column(String(255)) - tenant_id = Column(String(255)) - tenant_id_joined = Column(String, ForeignKey("tenant.id")) - tenant = relationship("Tenant", back_populates="tasks") +def striplines(s: str) -> str: + return "\n".join([line.strip() for line in s.splitlines()]) -async def test_data_filtering_e2e(): +def test_sql_translation_no_join(): """ - tests how df should work e2e with stub sdk + tests residual policy to sql conversion without joins """ - permit = Permit(token="") - filter = await permit.filter_resources("user", "read", "task") + # this would be an e2e test, but harder to run with pytest + # since the api key is always changing + # --- + # token = os.environ.get("PDP_API_KEY", "") + # permit = Permit(token=token) + # filter = await permit.filter_resources("user", "read", "task") + + # another option is to mock the residual policy + filter = ResidualPolicyResponse( + **{ + "type": "conditional", + "condition": { + "expression": { + "operator": "eq", + "operands": [ + {"variable": "input.resource.tenant"}, + {"value": "082f6978-6424-4e05-a706-1ab6f26c3768"}, + ], + } + }, + } + ) + + class Task(Base): + __tablename__ = "task" + + id = Column(String, primary_key=True) + created_at = Column(DateTime, default=datetime.utcnow()) + updated_at = Column(DateTime) + description = Column(String(255)) + tenant_id = Column(String(255)) sa_query = to_query( filter, @@ -56,12 +74,64 @@ async def test_data_filtering_e2e(): }, ) + str_query = query_to_string(sa_query) + + assert striplines(str_query) == striplines( + """SELECT task.id, task.created_at, task.updated_at, task.description, task.tenant_id + FROM task + WHERE task.tenant_id = '082f6978-6424-4e05-a706-1ab6f26c3768'""" + ) + + +def test_sql_translation_with_join(): + """ + tests residual policy to sql conversion without joins + """ + filter = ResidualPolicyResponse( + **{ + "type": "conditional", + "condition": { + "expression": { + "operator": "eq", + "operands": [ + {"variable": "input.resource.tenant"}, + {"value": "082f6978-6424-4e05-a706-1ab6f26c3768"}, + ], + } + }, + } + ) + + class Tenant(Base): + __tablename__ = "tenant" + + id = Column(String, primary_key=True) + key = Column(String(255)) + + class TaskJoined(Base): + __tablename__ = "task_joined" + + id = Column(String, primary_key=True) + created_at = Column(DateTime, default=datetime.utcnow()) + updated_at = Column(DateTime) + description = Column(String(255)) + tenant_id_joined = Column(String, ForeignKey("tenant.id")) + tenant = relationship("Tenant", backref="tasks") + sa_query = to_query( filter, - Task, + TaskJoined, refs={ # example how to map a column on a related model - "input.resource.tenant_id": Tenant.id, + "input.resource.tenant": Tenant.key, }, - join_conditions=[(Tenant, Task.tenant_id_joined == Tenant.id)], + join_conditions=[(Tenant, TaskJoined.tenant_id_joined == Tenant.id)], + ) + + str_query = query_to_string(sa_query) + + assert striplines(str_query) == striplines( + """SELECT task_joined.id, task_joined.created_at, task_joined.updated_at, task_joined.description, task_joined.tenant_id_joined + FROM task_joined JOIN tenant ON task_joined.tenant_id_joined = tenant.id + WHERE tenant.key = '082f6978-6424-4e05-a706-1ab6f26c3768'""" ) From b3a4ddf009602fd93e80a9e6372efee58480546a Mon Sep 17 00:00:00 2001 From: Asaf Cohen Date: Thu, 8 Aug 2024 16:14:04 +0300 Subject: [PATCH 18/23] add with_only_columns to tests --- .../data_filtering/tests/test_data_filtering_usage.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/horizon/enforcer/data_filtering/tests/test_data_filtering_usage.py b/horizon/enforcer/data_filtering/tests/test_data_filtering_usage.py index 5bfed842..933adc0a 100644 --- a/horizon/enforcer/data_filtering/tests/test_data_filtering_usage.py +++ b/horizon/enforcer/data_filtering/tests/test_data_filtering_usage.py @@ -82,6 +82,14 @@ class Task(Base): WHERE task.tenant_id = '082f6978-6424-4e05-a706-1ab6f26c3768'""" ) + str_query_only_columns = query_to_string(sa_query.with_only_columns(Task.id)) + + assert striplines(str_query_only_columns) == striplines( + """SELECT task.id + FROM task + WHERE task.tenant_id = '082f6978-6424-4e05-a706-1ab6f26c3768'""" + ) + def test_sql_translation_with_join(): """ From 2156b4bce6900968925a214260dcdfe8a8daf5fb Mon Sep 17 00:00:00 2001 From: Asaf Cohen Date: Thu, 8 Aug 2024 16:34:20 +0300 Subject: [PATCH 19/23] more tests --- .../tests/test_data_filtering_usage.py | 122 ++++++++++++++++++ 1 file changed, 122 insertions(+) diff --git a/horizon/enforcer/data_filtering/tests/test_data_filtering_usage.py b/horizon/enforcer/data_filtering/tests/test_data_filtering_usage.py index 933adc0a..ed61240b 100644 --- a/horizon/enforcer/data_filtering/tests/test_data_filtering_usage.py +++ b/horizon/enforcer/data_filtering/tests/test_data_filtering_usage.py @@ -1,3 +1,5 @@ +import pytest + from horizon.enforcer.data_filtering.boolean_expression.schemas import ( ResidualPolicyResponse, ) @@ -143,3 +145,123 @@ class TaskJoined(Base): FROM task_joined JOIN tenant ON task_joined.tenant_id_joined = tenant.id WHERE tenant.key = '082f6978-6424-4e05-a706-1ab6f26c3768'""" ) + + +def test_sql_translation_of_trivial_policies(): + class Tasks(Base): + __tablename__ = "tasks" + + id = Column(String, primary_key=True) + created_at = Column(DateTime, default=datetime.utcnow()) + updated_at = Column(DateTime) + description = Column(String(255)) + tenant_id = Column(String(255)) + + filter = ResidualPolicyResponse(**{"type": "always_allow", "condition": None}) + + sa_query = to_query( + filter, + Tasks, + refs={ + # example how to map a column on the same model + "input.resource.tenant": Tasks.tenant_id, + }, + ) + + str_query = query_to_string(sa_query) + assert striplines(str_query) == striplines( + """SELECT tasks.id, tasks.created_at, tasks.updated_at, tasks.description, tasks.tenant_id + FROM tasks""" + ) # this query would always return all rows from the tasks table + + filter = ResidualPolicyResponse(**{"type": "always_deny", "condition": None}) + + sa_query = to_query( + filter, + Tasks, + refs={ + # example how to map a column on the same model + "input.resource.tenant": Tasks.tenant_id, + }, + ) + + str_query = query_to_string(sa_query) + assert striplines(str_query) == striplines( + """SELECT tasks.id, tasks.created_at, tasks.updated_at, tasks.description, tasks.tenant_id + FROM tasks + WHERE false""" + ) # this query would never have any results + + +def test_missing_joins(): + filter = ResidualPolicyResponse( + **{ + "type": "conditional", + "condition": { + "expression": { + "operator": "eq", + "operands": [ + {"variable": "input.resource.tenant"}, + {"value": "082f6978-6424-4e05-a706-1ab6f26c3768"}, + ], + } + }, + } + ) + + class User2(Base): + __tablename__ = "user2" + + id = Column(String, primary_key=True) + username = Column(String(255)) + + class Tenant2(Base): + __tablename__ = "tenant2" + + id = Column(String, primary_key=True) + key = Column(String(255)) + + class TaskJoined2(Base): + __tablename__ = "task_joined2" + + id = Column(String, primary_key=True) + created_at = Column(DateTime, default=datetime.utcnow()) + updated_at = Column(DateTime) + description = Column(String(255)) + tenant_id_joined = Column(String, ForeignKey("tenant2.id")) + tenant = relationship("Tenant2", backref="tasks") + owner_id = Column(String, ForeignKey("user2.id")) + owner = relationship("User2", backref="tasks") + + with pytest.raises(TypeError) as e: + # Tenant2.key is a column outside the main table (requires a join) + # if we don't provide any join conditions, to_query() will throw a TypeError + sa_query = to_query( + filter, + TaskJoined2, + refs={ + # example how to map a column on a related model + "input.resource.tenant": Tenant2.key, + }, + ) + + assert str(e.value) == "to_query() is missing argument 'join_conditions'" + + with pytest.raises(TypeError) as e: + # Tenant2.key is a column outside the main table (requires a join) + # if we provide join conditions but not to all required tables, + # to_query() will throw a different TypeError + sa_query = to_query( + filter, + TaskJoined2, + refs={ + # example how to map a column on a related model + "input.resource.tenant": Tenant2.key, + }, + join_conditions=[(User2, TaskJoined2.owner_id == User2.id)], + ) + + assert ( + str(e.value) + == "to_query() argument 'join_conditions' is missing mapping for tables: {'tenant2'}" + ) From b593e137bad9d2eee75ac0565eee5a086edee8d9 Mon Sep 17 00:00:00 2001 From: Asaf Cohen Date: Sun, 11 Aug 2024 17:59:53 +0300 Subject: [PATCH 20/23] refactor api to query builder --- .../data_filtering/sdk/permit_sqlalchemy.py | 157 ++++++++++++------ .../tests/test_data_filtering_usage.py | 7 +- 2 files changed, 112 insertions(+), 52 deletions(-) diff --git a/horizon/enforcer/data_filtering/sdk/permit_sqlalchemy.py b/horizon/enforcer/data_filtering/sdk/permit_sqlalchemy.py index b05bd4f1..359bb452 100644 --- a/horizon/enforcer/data_filtering/sdk/permit_sqlalchemy.py +++ b/horizon/enforcer/data_filtering/sdk/permit_sqlalchemy.py @@ -1,4 +1,4 @@ -from typing import Any, Callable, Dict, List, Tuple, Union, cast +from typing import Any, Callable, Dict, List, Optional, Tuple, Union, cast import sqlalchemy as sa @@ -53,7 +53,7 @@ def _get_table_name(t: Table) -> str: def verify_join_conditions( table: Table, reference_mapping: Dict[str, Column], - join_conditions: Union[List[Tuple[Table, Condition]], None] = None, + join_conditions: List[Tuple[Table, Condition]] = [], ): def column_table_name(c: Column) -> str: return cast(sa.Table, c.table).name @@ -70,8 +70,10 @@ def is_main_table_column(c: Column) -> bool: ) if len(required_joins): - if join_conditions is None: - raise TypeError(f"to_query() is missing argument 'join_conditions'") + if len(join_conditions) == 0: + raise TypeError( + f"You must call QueryBuilder.join(table, condition) to map residual references to other SQL tables" + ) else: provided_joined_tables = set( (_get_table_name(t) for t, _ in join_conditions) @@ -79,66 +81,121 @@ def is_main_table_column(c: Column) -> bool: missing_tables = required_joins.difference(provided_joined_tables) if len(missing_tables): raise TypeError( - f"to_query() argument 'join_conditions' is missing mapping for tables: {repr(missing_tables)}" + f"QueryBuilder.join() was not called for these SQL tables: {repr(missing_tables)}" ) -def to_query( - filter: ResidualPolicyResponse, - table: Table, - refs: Dict[str, Column], - join_conditions: Union[List[Tuple[Table, Condition]], None] = None, -) -> Select: - select_all = cast(Select, sa.select(table)) +class QueryBuilder: + def __init__(self): + self._table: Optional[Table] = None + self._residual_policy: Optional[ResidualPolicyResponse] = None + self._refs: Optional[Dict[str, Column]] = None + self._joins: List[Tuple[Table, Condition]] = [] - if filter.type == ResidualPolicyType.ALWAYS_ALLOW: - return select_all + def select(self, table: Table) -> "QueryBuilder": + self._table = table + return self - if filter.type == ResidualPolicyType.ALWAYS_DENY: - return select_all.where(False) + def filter_by(self, residual_policy: ResidualPolicyResponse) -> "QueryBuilder": + self._residual_policy = residual_policy + return self - verify_join_conditions(table, refs, join_conditions) + def map_references(self, refs: Dict[str, Column]) -> "QueryBuilder": + self._refs = refs + return self - def to_sql(expr: Expression): - operator = expr.expression.operator - operands = expr.expression.operands + def join(self, table: Table, condition: Condition) -> "QueryBuilder": + self._joins.append((table, condition)) + return self - if operator == LOGICAL_AND: - return sa.and_(*[to_sql(o) for o in operands]) - if operator == LOGICAL_OR: - return sa.or_(*[to_sql(o) for o in operands]) - if operator == LOGICAL_NOT: - return sa.not_(*[to_sql(o) for o in operands]) - if operator == CALL_OPERATOR: - raise NotImplementedError("need to implement call() translation to sql") + def _verify_args(self): + if self._table is None: + raise ValueError( + f"You must call QueryBuilder.select(table) to specify to main table to filter on" + ) - # otherwise, operator is a comparison operator - variables = [o for o in operands if isinstance(o, Variable)] - values = [o for o in operands if isinstance(o, Value)] + if self._residual_policy is None: + raise ValueError( + f"You must call QueryBuilder.filter_by(residual_policy) to specify the compiled partial policy returned from OPA" + ) - if not (len(variables) == 1 and len(values) == 1): - raise NotImplementedError( - "need to implement support in more comparison operators" + if self._refs is None: + raise ValueError( + f"You must call QueryBuilder.map_references(refs) to specify how to map residual OPA references to SQL tables" ) - variable_ref: str = variables[0].variable - value: Any = values[0].value + def build(self) -> Select: + self._verify_args() - try: - column = refs[variable_ref] - except KeyError: - raise KeyError( - f"Residual variable does not exist in the reference mapping: {variable_ref}" - ) + table = self._table + residual_policy = self._residual_policy + refs = self._refs + join_conditions = self._joins + + select_all = cast(Select, sa.select(table)) + + if residual_policy.type == ResidualPolicyType.ALWAYS_ALLOW: + return select_all + + if residual_policy.type == ResidualPolicyType.ALWAYS_DENY: + return select_all.where(False) + + verify_join_conditions(table, refs, join_conditions) + + def to_sql(expr: Expression): + operator = expr.expression.operator + operands = expr.expression.operands - # the operator handlers here are the leaf nodes of the recursion - return operator_to_sql(operator, column, value) + if operator == LOGICAL_AND: + return sa.and_(*[to_sql(o) for o in operands]) + if operator == LOGICAL_OR: + return sa.or_(*[to_sql(o) for o in operands]) + if operator == LOGICAL_NOT: + return sa.not_(*[to_sql(o) for o in operands]) + if operator == CALL_OPERATOR: + raise NotImplementedError("need to implement call() translation to sql") - query: Select = select_all.where(to_sql(filter.condition)) + # otherwise, operator is a comparison operator + variables = [o for o in operands if isinstance(o, Variable)] + values = [o for o in operands if isinstance(o, Value)] - if join_conditions: - query = query.select_from(table) - for join_table, predicate in join_conditions: - query = query.join(join_table, predicate) + if not (len(variables) == 1 and len(values) == 1): + raise NotImplementedError( + "need to implement support in more comparison operators" + ) + + variable_ref: str = variables[0].variable + value: Any = values[0].value + + try: + column = refs[variable_ref] + except KeyError: + raise KeyError( + f"Residual variable does not exist in the reference mapping: {variable_ref}" + ) + + # the operator handlers here are the leaf nodes of the recursion + return operator_to_sql(operator, column, value) + + query: Select = select_all.where(to_sql(residual_policy.condition)) + + if join_conditions: + query = query.select_from(table) + for join_table, predicate in join_conditions: + query = query.join(join_table, predicate) - return query + return query + + +def to_query( + filters: ResidualPolicyResponse, + table: Table, + *, + refs: Dict[str, Column], + join_conditions: Optional[List[Tuple[Table, Condition]]] = None, +) -> Select: + query_builder = QueryBuilder().select(table).filter_by(filters).map_references(refs) + if join_conditions is not None: + for joined_table, join_condition in join_conditions: + query_builder = query_builder.join(joined_table, join_condition) + return query_builder.build() diff --git a/horizon/enforcer/data_filtering/tests/test_data_filtering_usage.py b/horizon/enforcer/data_filtering/tests/test_data_filtering_usage.py index ed61240b..a6f53001 100644 --- a/horizon/enforcer/data_filtering/tests/test_data_filtering_usage.py +++ b/horizon/enforcer/data_filtering/tests/test_data_filtering_usage.py @@ -245,7 +245,10 @@ class TaskJoined2(Base): }, ) - assert str(e.value) == "to_query() is missing argument 'join_conditions'" + assert ( + str(e.value) + == "You must call QueryBuilder.join(table, condition) to map residual references to other SQL tables" + ) with pytest.raises(TypeError) as e: # Tenant2.key is a column outside the main table (requires a join) @@ -263,5 +266,5 @@ class TaskJoined2(Base): assert ( str(e.value) - == "to_query() argument 'join_conditions' is missing mapping for tables: {'tenant2'}" + == "QueryBuilder.join() was not called for these SQL tables: {'tenant2'}" ) From c97fda11626dc66e0cea5051f6cd9b24983376a5 Mon Sep 17 00:00:00 2001 From: Asaf Cohen Date: Mon, 12 Aug 2024 15:52:35 +0300 Subject: [PATCH 21/23] extract datafiltering code into a separate package: permit-datafilter --- .dockerignore | 1 + .gitignore | 1 - horizon/enforcer/api.py | 38 ++-------- .../data_filtering/sdk/permit_filter.py | 72 ------------------- lib/permit-datafilter/MANIFEST.in | 1 + lib/permit-datafilter/Makefile | 11 +++ lib/permit-datafilter/README.md | 9 +++ .../permit_datafilter}/__init__.py | 0 .../boolean_expression/__init__.py | 0 .../boolean_expression/schemas.py | 0 .../boolean_expression/translator.py | 4 +- .../compile_api/__init__.py | 0 .../compile_api/compile_client.py | 26 ++++--- .../permit_datafilter}/compile_api/schemas.py | 0 .../permit_datafilter}/rego_ast/__init__.py | 0 .../permit_datafilter}/rego_ast/parser.py | 3 +- .../permit_datafilter}/sdk/__init__.py | 0 .../permit_datafilter/sdk/permit_filter.py | 72 +++++++++++++++++++ .../sdk/permit_sqlalchemy.py | 3 +- lib/permit-datafilter/requirements.txt | 7 ++ lib/permit-datafilter/setup.py | 40 +++++++++++ .../permit-datafilter}/tests/__init__.py | 0 .../tests/test_ast_parser.py | 4 +- .../tests/test_ast_translation.py | 8 +-- .../tests/test_boolean_expression_schema.py | 3 +- .../tests/test_compile_parsing.py | 2 +- .../tests/test_data_filtering_usage.py | 4 +- requirements.txt | 1 + setup.py | 2 +- 29 files changed, 174 insertions(+), 138 deletions(-) delete mode 100644 horizon/enforcer/data_filtering/sdk/permit_filter.py create mode 100644 lib/permit-datafilter/MANIFEST.in create mode 100644 lib/permit-datafilter/Makefile create mode 100644 lib/permit-datafilter/README.md rename {horizon/enforcer/data_filtering => lib/permit-datafilter/permit_datafilter}/__init__.py (100%) rename {horizon/enforcer/data_filtering => lib/permit-datafilter/permit_datafilter}/boolean_expression/__init__.py (100%) rename {horizon/enforcer/data_filtering => lib/permit-datafilter/permit_datafilter}/boolean_expression/schemas.py (100%) rename {horizon/enforcer/data_filtering => lib/permit-datafilter/permit_datafilter}/boolean_expression/translator.py (96%) rename {horizon/enforcer/data_filtering => lib/permit-datafilter/permit_datafilter}/compile_api/__init__.py (100%) rename {horizon/enforcer/data_filtering => lib/permit-datafilter/permit_datafilter}/compile_api/compile_client.py (78%) rename {horizon/enforcer/data_filtering => lib/permit-datafilter/permit_datafilter}/compile_api/schemas.py (100%) rename {horizon/enforcer/data_filtering => lib/permit-datafilter/permit_datafilter}/rego_ast/__init__.py (100%) rename {horizon/enforcer/data_filtering => lib/permit-datafilter/permit_datafilter}/rego_ast/parser.py (99%) rename {horizon/enforcer/data_filtering => lib/permit-datafilter/permit_datafilter}/sdk/__init__.py (100%) create mode 100644 lib/permit-datafilter/permit_datafilter/sdk/permit_filter.py rename {horizon/enforcer/data_filtering => lib/permit-datafilter/permit_datafilter}/sdk/permit_sqlalchemy.py (98%) create mode 100644 lib/permit-datafilter/requirements.txt create mode 100644 lib/permit-datafilter/setup.py rename {horizon/enforcer/data_filtering => lib/permit-datafilter}/tests/__init__.py (100%) rename {horizon/enforcer/data_filtering => lib/permit-datafilter}/tests/test_ast_parser.py (98%) rename {horizon/enforcer/data_filtering => lib/permit-datafilter}/tests/test_ast_translation.py (90%) rename {horizon/enforcer/data_filtering => lib/permit-datafilter}/tests/test_boolean_expression_schema.py (97%) rename {horizon/enforcer/data_filtering => lib/permit-datafilter}/tests/test_compile_parsing.py (99%) rename {horizon/enforcer/data_filtering => lib/permit-datafilter}/tests/test_data_filtering_usage.py (98%) diff --git a/.dockerignore b/.dockerignore index 756bcfa9..f40cffa7 100644 --- a/.dockerignore +++ b/.dockerignore @@ -2,3 +2,4 @@ helm/ .venv/ .github/ +lib/ \ No newline at end of file diff --git a/.gitignore b/.gitignore index 7d6da5c0..8c0fbd1a 100644 --- a/.gitignore +++ b/.gitignore @@ -14,7 +14,6 @@ dist/ downloads/ eggs/ .eggs/ -lib/ lib64/ parts/ sdist/ diff --git a/horizon/enforcer/api.py b/horizon/enforcer/api.py index 3144fe13..00ab50a7 100644 --- a/horizon/enforcer/api.py +++ b/horizon/enforcer/api.py @@ -14,12 +14,12 @@ DEFAULT_POLICY_STORE_GETTER, ) from opal_client.utils import proxy_response +from permit_datafilter.compile_api.compile_client import OpaCompileClient from pydantic import parse_obj_as from starlette.responses import JSONResponse from horizon.authentication import enforce_pdp_token from horizon.config import sidecar_config -from horizon.enforcer.data_filtering.compile_api.compile_client import OpaCompileClient from horizon.enforcer.schemas import ( AuthorizationQuery, AuthorizationResult, @@ -722,12 +722,14 @@ async def filter_resources( x_permit_sdk_language: Optional[str] = Depends(notify_seen_sdk), ): headers = transform_headers(request) - client = OpaCompileClient(headers=headers) + client = OpaCompileClient( + base_url=f"{opal_client_config.POLICY_STORE_URL}", headers=headers + ) COMPILE_ROOT_RULE_REFERENCE = f"data.{MAIN_PARTIAL_EVAL_PACKAGE}.allow" query = f"{COMPILE_ROOT_RULE_REFERENCE} == true" residual_policy = await client.compile_query( query=query, - input=input, + input=input.dict(), unknowns=[ "input.resource.key", "input.resource.tenant", @@ -741,34 +743,4 @@ async def filter_resources( media_type="application/json", ) - @router.post( - "/filter_resources_get_sql", - response_model=AuthorizationResult, - status_code=status.HTTP_200_OK, - response_model_exclude_none=True, - dependencies=[Depends(enforce_pdp_token)], - ) - async def filter_resources_get_sql( - request: Request, - input: AuthorizationQuery, - x_permit_sdk_language: Optional[str] = Depends(notify_seen_sdk), - ): - """ - TODO: temp endpoint, instead we should wrap the capability in the SDK - """ - headers = transform_headers(request) - client = OpaCompileClient(headers=headers) - COMPILE_ROOT_RULE_REFERENCE = f"data.{MAIN_PARTIAL_EVAL_PACKAGE}.allow" - query = f"{COMPILE_ROOT_RULE_REFERENCE} == true" - residual_policy = await client.compile_query( - query=query, - input=input, - unknowns=[ - "input.resource.key", - "input.resource.tenant", - "input.resource.attributes", - ], - raw=True, - ) - return router diff --git a/horizon/enforcer/data_filtering/sdk/permit_filter.py b/horizon/enforcer/data_filtering/sdk/permit_filter.py deleted file mode 100644 index 389b10e8..00000000 --- a/horizon/enforcer/data_filtering/sdk/permit_filter.py +++ /dev/null @@ -1,72 +0,0 @@ -from typing import Union - -from horizon.enforcer import schemas as enforcer_schemas -from horizon.enforcer.api import MAIN_PARTIAL_EVAL_PACKAGE -from horizon.enforcer.data_filtering.boolean_expression.schemas import ( - ResidualPolicyResponse, -) -from horizon.enforcer.data_filtering.compile_api.compile_client import OpaCompileClient - -User = Union[dict, str] -Action = str -Resource = Union[dict, str] - - -def normalize_user(user: User) -> dict: - if isinstance(user, str): - return dict(key=user) - else: - return user - - -def normalize_resource_type(resource: Resource) -> str: - if isinstance(resource, dict): - t = resource.get("type", None) - if t is not None and isinstance(t, str): - return t - raise ValueError("no resource type provided") - else: - return resource - - -def filter_resource_query( - user: User, action: Action, resource: Resource -) -> enforcer_schemas.AuthorizationQuery: - normalized_user = normalize_user(user) - resource_type: str = normalize_resource_type(resource) - return enforcer_schemas.AuthorizationQuery( - user=normalized_user, - action=action, - resource=enforcer_schemas.Resource(type=resource_type), - ) - - -class Permit: - """ - stub for future SDK code - """ - - def __init__(self, token: str): - self._headers = { - "Authorization": f"Bearer {token}", - "Content-Type": "application/json", - } - - async def filter_resources( - self, user: User, action: Action, resource: Resource - ) -> ResidualPolicyResponse: - """ - stub for future permit.filter_resources() function - """ - client = OpaCompileClient(headers=self._headers) - input = filter_resource_query(user, action, resource) - return await client.compile_query( - query=f"data.{MAIN_PARTIAL_EVAL_PACKAGE}.allow == true", - input=input, - unknowns=[ - "input.resource.key", - "input.resource.tenant", - "input.resource.attributes", - ], - raw=True, - ) diff --git a/lib/permit-datafilter/MANIFEST.in b/lib/permit-datafilter/MANIFEST.in new file mode 100644 index 00000000..698cd966 --- /dev/null +++ b/lib/permit-datafilter/MANIFEST.in @@ -0,0 +1 @@ +include *.md requirements.txt \ No newline at end of file diff --git a/lib/permit-datafilter/Makefile b/lib/permit-datafilter/Makefile new file mode 100644 index 00000000..3ff80647 --- /dev/null +++ b/lib/permit-datafilter/Makefile @@ -0,0 +1,11 @@ +.PHONY: help + +.DEFAULT_GOAL := help + +clean: + rm -rf *.egg-info build/ dist/ + +publish: + $(MAKE) clean + python setup.py sdist bdist_wheel + python -m twine upload dist/* \ No newline at end of file diff --git a/lib/permit-datafilter/README.md b/lib/permit-datafilter/README.md new file mode 100644 index 00000000..8da0d365 --- /dev/null +++ b/lib/permit-datafilter/README.md @@ -0,0 +1,9 @@ +# Permit.io Data Filtering SDK (EAP) + +Initial SDK to enable data filtering scenarios based on compiling OPA policies from Rego AST into SQL-like expressions. + +## Installation + +```py +pip install permit-datafilter +``` diff --git a/horizon/enforcer/data_filtering/__init__.py b/lib/permit-datafilter/permit_datafilter/__init__.py similarity index 100% rename from horizon/enforcer/data_filtering/__init__.py rename to lib/permit-datafilter/permit_datafilter/__init__.py diff --git a/horizon/enforcer/data_filtering/boolean_expression/__init__.py b/lib/permit-datafilter/permit_datafilter/boolean_expression/__init__.py similarity index 100% rename from horizon/enforcer/data_filtering/boolean_expression/__init__.py rename to lib/permit-datafilter/permit_datafilter/boolean_expression/__init__.py diff --git a/horizon/enforcer/data_filtering/boolean_expression/schemas.py b/lib/permit-datafilter/permit_datafilter/boolean_expression/schemas.py similarity index 100% rename from horizon/enforcer/data_filtering/boolean_expression/schemas.py rename to lib/permit-datafilter/permit_datafilter/boolean_expression/schemas.py diff --git a/horizon/enforcer/data_filtering/boolean_expression/translator.py b/lib/permit-datafilter/permit_datafilter/boolean_expression/translator.py similarity index 96% rename from horizon/enforcer/data_filtering/boolean_expression/translator.py rename to lib/permit-datafilter/permit_datafilter/boolean_expression/translator.py index 895ee6bb..6ae2c477 100644 --- a/horizon/enforcer/data_filtering/boolean_expression/translator.py +++ b/lib/permit-datafilter/permit_datafilter/boolean_expression/translator.py @@ -1,5 +1,5 @@ -from horizon.enforcer.data_filtering.rego_ast import parser as ast -from horizon.enforcer.data_filtering.boolean_expression.schemas import ( +from permit_datafilter.rego_ast import parser as ast +from permit_datafilter.boolean_expression.schemas import ( CALL_OPERATOR, Operand, ResidualPolicyResponse, diff --git a/horizon/enforcer/data_filtering/compile_api/__init__.py b/lib/permit-datafilter/permit_datafilter/compile_api/__init__.py similarity index 100% rename from horizon/enforcer/data_filtering/compile_api/__init__.py rename to lib/permit-datafilter/permit_datafilter/compile_api/__init__.py diff --git a/horizon/enforcer/data_filtering/compile_api/compile_client.py b/lib/permit-datafilter/permit_datafilter/compile_api/compile_client.py similarity index 78% rename from horizon/enforcer/data_filtering/compile_api/compile_client.py rename to lib/permit-datafilter/permit_datafilter/compile_api/compile_client.py index 1bd774f0..cfbc72f5 100644 --- a/horizon/enforcer/data_filtering/compile_api/compile_client.py +++ b/lib/permit-datafilter/permit_datafilter/compile_api/compile_client.py @@ -1,24 +1,22 @@ import json import aiohttp -from fastapi import HTTPException, Response, status -from opal_client.config import opal_client_config -from opal_client.logger import logger +from fastapi import HTTPException, status +from loguru import logger -from horizon.enforcer.schemas import AuthorizationQuery -from horizon.enforcer.data_filtering.compile_api.schemas import CompileResponse -from horizon.enforcer.data_filtering.rego_ast import parser as ast -from horizon.enforcer.data_filtering.boolean_expression.schemas import ( +from permit_datafilter.compile_api.schemas import CompileResponse +from permit_datafilter.rego_ast import parser as ast +from permit_datafilter.boolean_expression.schemas import ( ResidualPolicyResponse, ) -from horizon.enforcer.data_filtering.boolean_expression.translator import ( +from permit_datafilter.boolean_expression.translator import ( translate_opa_queryset, ) class OpaCompileClient: - def __init__(self, headers: dict): - self._base_url = f"{opal_client_config.POLICY_STORE_URL}" + def __init__(self, base_url: str, headers: dict): + self._base_url = base_url # f"{opal_client_config.POLICY_STORE_URL}" self._headers = headers self._client = aiohttp.ClientSession( base_url=self._base_url, headers=self._headers @@ -27,19 +25,19 @@ def __init__(self, headers: dict): async def compile_query( self, query: str, - input: AuthorizationQuery, + input: dict, unknowns: list[str], raw: bool = False, ) -> ResidualPolicyResponse: # we don't want debug rules when we try to reduce the policy into a partial policy - input = {**input.dict(), "use_debugger": False} + input = {**input, "use_debugger": False} data = { "query": query, "input": input, "unknowns": unknowns, } try: - logger.info("Compiling OPA query: {}", data) + logger.debug("Compiling OPA query: {}", data) async with self._client as session: async with session.post( "/v1/compile", @@ -47,7 +45,7 @@ async def compile_query( raise_for_status=True, ) as response: opa_compile_result = await response.json() - logger.info( + logger.debug( "OPA compile query result: status={status}, response={response}", status=response.status, response=json.dumps(opa_compile_result), diff --git a/horizon/enforcer/data_filtering/compile_api/schemas.py b/lib/permit-datafilter/permit_datafilter/compile_api/schemas.py similarity index 100% rename from horizon/enforcer/data_filtering/compile_api/schemas.py rename to lib/permit-datafilter/permit_datafilter/compile_api/schemas.py diff --git a/horizon/enforcer/data_filtering/rego_ast/__init__.py b/lib/permit-datafilter/permit_datafilter/rego_ast/__init__.py similarity index 100% rename from horizon/enforcer/data_filtering/rego_ast/__init__.py rename to lib/permit-datafilter/permit_datafilter/rego_ast/__init__.py diff --git a/horizon/enforcer/data_filtering/rego_ast/parser.py b/lib/permit-datafilter/permit_datafilter/rego_ast/parser.py similarity index 99% rename from horizon/enforcer/data_filtering/rego_ast/parser.py rename to lib/permit-datafilter/permit_datafilter/rego_ast/parser.py index 1619196f..8979c739 100644 --- a/horizon/enforcer/data_filtering/rego_ast/parser.py +++ b/lib/permit-datafilter/permit_datafilter/rego_ast/parser.py @@ -40,12 +40,11 @@ from types import NoneType from typing import Generic, List, TypeVar -from horizon.enforcer.data_filtering.compile_api.schemas import ( +from permit_datafilter.compile_api.schemas import ( CRExpression, CRQuery, CompileResponse, CRTerm, - CRSupportModule, ) diff --git a/horizon/enforcer/data_filtering/sdk/__init__.py b/lib/permit-datafilter/permit_datafilter/sdk/__init__.py similarity index 100% rename from horizon/enforcer/data_filtering/sdk/__init__.py rename to lib/permit-datafilter/permit_datafilter/sdk/__init__.py diff --git a/lib/permit-datafilter/permit_datafilter/sdk/permit_filter.py b/lib/permit-datafilter/permit_datafilter/sdk/permit_filter.py new file mode 100644 index 00000000..17a46d8d --- /dev/null +++ b/lib/permit-datafilter/permit_datafilter/sdk/permit_filter.py @@ -0,0 +1,72 @@ +# from typing import Union + +# from horizon.enforcer import schemas as enforcer_schemas +# from horizon.enforcer.api import MAIN_PARTIAL_EVAL_PACKAGE +# from permit_datafilter.boolean_expression.schemas import ( +# ResidualPolicyResponse, +# ) +# from permit_datafilter.compile_api.compile_client import OpaCompileClient + +# User = Union[dict, str] +# Action = str +# Resource = Union[dict, str] + + +# def normalize_user(user: User) -> dict: +# if isinstance(user, str): +# return dict(key=user) +# else: +# return user + + +# def normalize_resource_type(resource: Resource) -> str: +# if isinstance(resource, dict): +# t = resource.get("type", None) +# if t is not None and isinstance(t, str): +# return t +# raise ValueError("no resource type provided") +# else: +# return resource + + +# def filter_resource_query( +# user: User, action: Action, resource: Resource +# ) -> enforcer_schemas.AuthorizationQuery: +# normalized_user = normalize_user(user) +# resource_type: str = normalize_resource_type(resource) +# return enforcer_schemas.AuthorizationQuery( +# user=normalized_user, +# action=action, +# resource=enforcer_schemas.Resource(type=resource_type), +# ) + + +# class Permit: +# """ +# stub for future SDK code +# """ + +# def __init__(self, token: str): +# self._headers = { +# "Authorization": f"Bearer {token}", +# "Content-Type": "application/json", +# } + +# async def filter_resources( +# self, user: User, action: Action, resource: Resource +# ) -> ResidualPolicyResponse: +# """ +# stub for future permit.filter_resources() function +# """ +# client = OpaCompileClient(headers=self._headers) +# input = filter_resource_query(user, action, resource) +# return await client.compile_query( +# query=f"data.{MAIN_PARTIAL_EVAL_PACKAGE}.allow == true", +# input=input, +# unknowns=[ +# "input.resource.key", +# "input.resource.tenant", +# "input.resource.attributes", +# ], +# raw=True, +# ) diff --git a/horizon/enforcer/data_filtering/sdk/permit_sqlalchemy.py b/lib/permit-datafilter/permit_datafilter/sdk/permit_sqlalchemy.py similarity index 98% rename from horizon/enforcer/data_filtering/sdk/permit_sqlalchemy.py rename to lib/permit-datafilter/permit_datafilter/sdk/permit_sqlalchemy.py index 359bb452..515013c1 100644 --- a/horizon/enforcer/data_filtering/sdk/permit_sqlalchemy.py +++ b/lib/permit-datafilter/permit_datafilter/sdk/permit_sqlalchemy.py @@ -2,12 +2,11 @@ import sqlalchemy as sa -# import Column, Table, and_, not_, or_, select from sqlalchemy.orm import DeclarativeMeta, InstrumentedAttribute from sqlalchemy.sql import Select from sqlalchemy.sql.expression import BinaryExpression, ColumnOperators -from horizon.enforcer.data_filtering.boolean_expression.schemas import ( +from permit_datafilter.boolean_expression.schemas import ( CALL_OPERATOR, LOGICAL_AND, LOGICAL_NOT, diff --git a/lib/permit-datafilter/requirements.txt b/lib/permit-datafilter/requirements.txt new file mode 100644 index 00000000..ac3bcaf2 --- /dev/null +++ b/lib/permit-datafilter/requirements.txt @@ -0,0 +1,7 @@ +aiohttp>=3.9.4,<4 +ddtrace +fastapi>=0.109.1,<1 +loguru +pydantic[email]>=1.9.1,<2 +pytest +SQLAlchemy==1.4.46 \ No newline at end of file diff --git a/lib/permit-datafilter/setup.py b/lib/permit-datafilter/setup.py new file mode 100644 index 00000000..93b117d3 --- /dev/null +++ b/lib/permit-datafilter/setup.py @@ -0,0 +1,40 @@ +from pathlib import Path + +from setuptools import find_packages, setup + + +def get_requirements(env=""): + if env: + env = "-{}".format(env) + with open("requirements{}.txt".format(env)) as fp: + return [x.strip() for x in fp.read().split("\n") if not x.startswith("#")] + + +def get_readme() -> str: + this_directory = Path(__file__).parent + long_description = (this_directory / "README.md").read_text() + return long_description + + +setup( + name="permit-datafilter", + version="0.0.2", + packages=find_packages(), + author="Asaf Cohen", + author_email="asaf@permit.io", + license="Apache 2.0", + python_requires=">=3.8", + description="Permit.io python sdk", + install_requires=get_requirements(), + long_description=get_readme(), + long_description_content_type="text/markdown", + classifiers=[ + "Operating System :: OS Independent", + "Programming Language :: Python", + "Programming Language :: Python :: 3", + "Programming Language :: Python :: 3.8", + "Programming Language :: Python :: 3.9", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", + ], +) diff --git a/horizon/enforcer/data_filtering/tests/__init__.py b/lib/permit-datafilter/tests/__init__.py similarity index 100% rename from horizon/enforcer/data_filtering/tests/__init__.py rename to lib/permit-datafilter/tests/__init__.py diff --git a/horizon/enforcer/data_filtering/tests/test_ast_parser.py b/lib/permit-datafilter/tests/test_ast_parser.py similarity index 98% rename from horizon/enforcer/data_filtering/tests/test_ast_parser.py rename to lib/permit-datafilter/tests/test_ast_parser.py index 4aa5fa9e..2af4ba26 100644 --- a/horizon/enforcer/data_filtering/tests/test_ast_parser.py +++ b/lib/permit-datafilter/tests/test_ast_parser.py @@ -1,4 +1,4 @@ -from horizon.enforcer.data_filtering.rego_ast.parser import ( +from permit_datafilter.rego_ast.parser import ( BooleanTerm, Call, CallTerm, @@ -14,7 +14,7 @@ StringTerm, VarTerm, ) -from horizon.enforcer.data_filtering.compile_api.schemas import ( +from permit_datafilter.compile_api.schemas import ( CRTerm, CRExpression, CRQuery, diff --git a/horizon/enforcer/data_filtering/tests/test_ast_translation.py b/lib/permit-datafilter/tests/test_ast_translation.py similarity index 90% rename from horizon/enforcer/data_filtering/tests/test_ast_translation.py rename to lib/permit-datafilter/tests/test_ast_translation.py index 40645254..2b475d93 100644 --- a/horizon/enforcer/data_filtering/tests/test_ast_translation.py +++ b/lib/permit-datafilter/tests/test_ast_translation.py @@ -1,10 +1,10 @@ -from horizon.enforcer.data_filtering.compile_api.schemas import CompileResponse -from horizon.enforcer.data_filtering.rego_ast import parser as ast -from horizon.enforcer.data_filtering.boolean_expression.schemas import ( +from permit_datafilter.compile_api.schemas import CompileResponse +from permit_datafilter.rego_ast import parser as ast +from permit_datafilter.boolean_expression.schemas import ( ResidualPolicyResponse, ResidualPolicyType, ) -from horizon.enforcer.data_filtering.boolean_expression.translator import ( +from permit_datafilter.boolean_expression.translator import ( translate_opa_queryset, ) diff --git a/horizon/enforcer/data_filtering/tests/test_boolean_expression_schema.py b/lib/permit-datafilter/tests/test_boolean_expression_schema.py similarity index 97% rename from horizon/enforcer/data_filtering/tests/test_boolean_expression_schema.py rename to lib/permit-datafilter/tests/test_boolean_expression_schema.py index db115d10..ce169141 100644 --- a/horizon/enforcer/data_filtering/tests/test_boolean_expression_schema.py +++ b/lib/permit-datafilter/tests/test_boolean_expression_schema.py @@ -1,10 +1,9 @@ import pytest import pydantic -from horizon.enforcer.data_filtering.boolean_expression.schemas import ( +from permit_datafilter.boolean_expression.schemas import ( ResidualPolicyResponse, ResidualPolicyType, - Expression, ) diff --git a/horizon/enforcer/data_filtering/tests/test_compile_parsing.py b/lib/permit-datafilter/tests/test_compile_parsing.py similarity index 99% rename from horizon/enforcer/data_filtering/tests/test_compile_parsing.py rename to lib/permit-datafilter/tests/test_compile_parsing.py index 8023cc73..7811c095 100644 --- a/horizon/enforcer/data_filtering/tests/test_compile_parsing.py +++ b/lib/permit-datafilter/tests/test_compile_parsing.py @@ -1,6 +1,6 @@ import json -from horizon.enforcer.data_filtering.compile_api.schemas import CRTerm, CompileResponse +from permit_datafilter.compile_api.schemas import CRTerm, CompileResponse COMPILE_RESPONE_RBAC_NO_SUPPORT_BLOCK = """{ diff --git a/horizon/enforcer/data_filtering/tests/test_data_filtering_usage.py b/lib/permit-datafilter/tests/test_data_filtering_usage.py similarity index 98% rename from horizon/enforcer/data_filtering/tests/test_data_filtering_usage.py rename to lib/permit-datafilter/tests/test_data_filtering_usage.py index a6f53001..969c047b 100644 --- a/horizon/enforcer/data_filtering/tests/test_data_filtering_usage.py +++ b/lib/permit-datafilter/tests/test_data_filtering_usage.py @@ -1,6 +1,6 @@ import pytest -from horizon.enforcer.data_filtering.boolean_expression.schemas import ( +from permit_datafilter.boolean_expression.schemas import ( ResidualPolicyResponse, ) @@ -11,7 +11,7 @@ from sqlalchemy.sql import Select from sqlalchemy.orm import declarative_base, relationship -from horizon.enforcer.data_filtering.sdk.permit_sqlalchemy import to_query +from permit_datafilter.sdk.permit_sqlalchemy import to_query Base = declarative_base() diff --git a/requirements.txt b/requirements.txt index b5cccf2c..21531d0f 100644 --- a/requirements.txt +++ b/requirements.txt @@ -14,3 +14,4 @@ httpx>=0.27.0,<1 protobuf>=3.20.2 # not directly required, pinned by Snyk to avoid a vulnerability opal-common==0.7.6 opal-client==0.7.6 +permit-datafilter>=0.0.2,<1 diff --git a/setup.py b/setup.py index de7da89e..eb0ed177 100644 --- a/setup.py +++ b/setup.py @@ -28,7 +28,7 @@ def get_data_files(root_directory: str): setup( name="horizon", version="0.2.0", - packages=find_packages(), + packages=find_packages(exclude=("lib/*")), python_requires=">=3.8", include_package_data=True, data_files=get_data_files("horizon/static"), From 1b4bcc82470e0bcc394cb69bec2ed509852f36d4 Mon Sep 17 00:00:00 2001 From: Asaf Cohen Date: Mon, 12 Aug 2024 16:13:53 +0300 Subject: [PATCH 22/23] fix precommit errors --- .dockerignore | 2 +- lib/permit-datafilter/MANIFEST.in | 2 +- lib/permit-datafilter/Makefile | 2 +- lib/permit-datafilter/requirements.txt | 2 +- .../tests/test_data_filtering_usage.py | 16 ++++++++-------- 5 files changed, 12 insertions(+), 12 deletions(-) diff --git a/.dockerignore b/.dockerignore index f40cffa7..ddf60e0c 100644 --- a/.dockerignore +++ b/.dockerignore @@ -2,4 +2,4 @@ helm/ .venv/ .github/ -lib/ \ No newline at end of file +lib/ diff --git a/lib/permit-datafilter/MANIFEST.in b/lib/permit-datafilter/MANIFEST.in index 698cd966..479b8867 100644 --- a/lib/permit-datafilter/MANIFEST.in +++ b/lib/permit-datafilter/MANIFEST.in @@ -1 +1 @@ -include *.md requirements.txt \ No newline at end of file +include *.md requirements.txt diff --git a/lib/permit-datafilter/Makefile b/lib/permit-datafilter/Makefile index 3ff80647..b411d380 100644 --- a/lib/permit-datafilter/Makefile +++ b/lib/permit-datafilter/Makefile @@ -8,4 +8,4 @@ clean: publish: $(MAKE) clean python setup.py sdist bdist_wheel - python -m twine upload dist/* \ No newline at end of file + python -m twine upload dist/* diff --git a/lib/permit-datafilter/requirements.txt b/lib/permit-datafilter/requirements.txt index ac3bcaf2..27c9ec1f 100644 --- a/lib/permit-datafilter/requirements.txt +++ b/lib/permit-datafilter/requirements.txt @@ -4,4 +4,4 @@ fastapi>=0.109.1,<1 loguru pydantic[email]>=1.9.1,<2 pytest -SQLAlchemy==1.4.46 \ No newline at end of file +SQLAlchemy==1.4.46 diff --git a/lib/permit-datafilter/tests/test_data_filtering_usage.py b/lib/permit-datafilter/tests/test_data_filtering_usage.py index 969c047b..2de4c8eb 100644 --- a/lib/permit-datafilter/tests/test_data_filtering_usage.py +++ b/lib/permit-datafilter/tests/test_data_filtering_usage.py @@ -79,16 +79,16 @@ class Task(Base): str_query = query_to_string(sa_query) assert striplines(str_query) == striplines( - """SELECT task.id, task.created_at, task.updated_at, task.description, task.tenant_id - FROM task + """SELECT task.id, task.created_at, task.updated_at, task.description, task.tenant_id + FROM task WHERE task.tenant_id = '082f6978-6424-4e05-a706-1ab6f26c3768'""" ) str_query_only_columns = query_to_string(sa_query.with_only_columns(Task.id)) assert striplines(str_query_only_columns) == striplines( - """SELECT task.id - FROM task + """SELECT task.id + FROM task WHERE task.tenant_id = '082f6978-6424-4e05-a706-1ab6f26c3768'""" ) @@ -141,8 +141,8 @@ class TaskJoined(Base): str_query = query_to_string(sa_query) assert striplines(str_query) == striplines( - """SELECT task_joined.id, task_joined.created_at, task_joined.updated_at, task_joined.description, task_joined.tenant_id_joined - FROM task_joined JOIN tenant ON task_joined.tenant_id_joined = tenant.id + """SELECT task_joined.id, task_joined.created_at, task_joined.updated_at, task_joined.description, task_joined.tenant_id_joined + FROM task_joined JOIN tenant ON task_joined.tenant_id_joined = tenant.id WHERE tenant.key = '082f6978-6424-4e05-a706-1ab6f26c3768'""" ) @@ -170,7 +170,7 @@ class Tasks(Base): str_query = query_to_string(sa_query) assert striplines(str_query) == striplines( - """SELECT tasks.id, tasks.created_at, tasks.updated_at, tasks.description, tasks.tenant_id + """SELECT tasks.id, tasks.created_at, tasks.updated_at, tasks.description, tasks.tenant_id FROM tasks""" ) # this query would always return all rows from the tasks table @@ -187,7 +187,7 @@ class Tasks(Base): str_query = query_to_string(sa_query) assert striplines(str_query) == striplines( - """SELECT tasks.id, tasks.created_at, tasks.updated_at, tasks.description, tasks.tenant_id + """SELECT tasks.id, tasks.created_at, tasks.updated_at, tasks.description, tasks.tenant_id FROM tasks WHERE false""" ) # this query would never have any results From 91fdfd6366fb67c5a014af0cf25e00bdd981567f Mon Sep 17 00:00:00 2001 From: Asaf Cohen Date: Mon, 12 Aug 2024 19:15:41 +0300 Subject: [PATCH 23/23] change sqlalchemy module name --- .../{sdk => plugins}/__init__.py | 0 .../sqlalchemy.py} | 0 .../permit_datafilter/sdk/permit_filter.py | 72 ------------------- lib/permit-datafilter/setup.py | 2 +- .../tests/test_data_filtering_usage.py | 2 +- requirements.txt | 2 +- 6 files changed, 3 insertions(+), 75 deletions(-) rename lib/permit-datafilter/permit_datafilter/{sdk => plugins}/__init__.py (100%) rename lib/permit-datafilter/permit_datafilter/{sdk/permit_sqlalchemy.py => plugins/sqlalchemy.py} (100%) delete mode 100644 lib/permit-datafilter/permit_datafilter/sdk/permit_filter.py diff --git a/lib/permit-datafilter/permit_datafilter/sdk/__init__.py b/lib/permit-datafilter/permit_datafilter/plugins/__init__.py similarity index 100% rename from lib/permit-datafilter/permit_datafilter/sdk/__init__.py rename to lib/permit-datafilter/permit_datafilter/plugins/__init__.py diff --git a/lib/permit-datafilter/permit_datafilter/sdk/permit_sqlalchemy.py b/lib/permit-datafilter/permit_datafilter/plugins/sqlalchemy.py similarity index 100% rename from lib/permit-datafilter/permit_datafilter/sdk/permit_sqlalchemy.py rename to lib/permit-datafilter/permit_datafilter/plugins/sqlalchemy.py diff --git a/lib/permit-datafilter/permit_datafilter/sdk/permit_filter.py b/lib/permit-datafilter/permit_datafilter/sdk/permit_filter.py deleted file mode 100644 index 17a46d8d..00000000 --- a/lib/permit-datafilter/permit_datafilter/sdk/permit_filter.py +++ /dev/null @@ -1,72 +0,0 @@ -# from typing import Union - -# from horizon.enforcer import schemas as enforcer_schemas -# from horizon.enforcer.api import MAIN_PARTIAL_EVAL_PACKAGE -# from permit_datafilter.boolean_expression.schemas import ( -# ResidualPolicyResponse, -# ) -# from permit_datafilter.compile_api.compile_client import OpaCompileClient - -# User = Union[dict, str] -# Action = str -# Resource = Union[dict, str] - - -# def normalize_user(user: User) -> dict: -# if isinstance(user, str): -# return dict(key=user) -# else: -# return user - - -# def normalize_resource_type(resource: Resource) -> str: -# if isinstance(resource, dict): -# t = resource.get("type", None) -# if t is not None and isinstance(t, str): -# return t -# raise ValueError("no resource type provided") -# else: -# return resource - - -# def filter_resource_query( -# user: User, action: Action, resource: Resource -# ) -> enforcer_schemas.AuthorizationQuery: -# normalized_user = normalize_user(user) -# resource_type: str = normalize_resource_type(resource) -# return enforcer_schemas.AuthorizationQuery( -# user=normalized_user, -# action=action, -# resource=enforcer_schemas.Resource(type=resource_type), -# ) - - -# class Permit: -# """ -# stub for future SDK code -# """ - -# def __init__(self, token: str): -# self._headers = { -# "Authorization": f"Bearer {token}", -# "Content-Type": "application/json", -# } - -# async def filter_resources( -# self, user: User, action: Action, resource: Resource -# ) -> ResidualPolicyResponse: -# """ -# stub for future permit.filter_resources() function -# """ -# client = OpaCompileClient(headers=self._headers) -# input = filter_resource_query(user, action, resource) -# return await client.compile_query( -# query=f"data.{MAIN_PARTIAL_EVAL_PACKAGE}.allow == true", -# input=input, -# unknowns=[ -# "input.resource.key", -# "input.resource.tenant", -# "input.resource.attributes", -# ], -# raw=True, -# ) diff --git a/lib/permit-datafilter/setup.py b/lib/permit-datafilter/setup.py index 93b117d3..c448579b 100644 --- a/lib/permit-datafilter/setup.py +++ b/lib/permit-datafilter/setup.py @@ -18,7 +18,7 @@ def get_readme() -> str: setup( name="permit-datafilter", - version="0.0.2", + version="0.0.3", packages=find_packages(), author="Asaf Cohen", author_email="asaf@permit.io", diff --git a/lib/permit-datafilter/tests/test_data_filtering_usage.py b/lib/permit-datafilter/tests/test_data_filtering_usage.py index 2de4c8eb..caa2d1e9 100644 --- a/lib/permit-datafilter/tests/test_data_filtering_usage.py +++ b/lib/permit-datafilter/tests/test_data_filtering_usage.py @@ -11,7 +11,7 @@ from sqlalchemy.sql import Select from sqlalchemy.orm import declarative_base, relationship -from permit_datafilter.sdk.permit_sqlalchemy import to_query +from permit_datafilter.plugins.sqlalchemy import to_query Base = declarative_base() diff --git a/requirements.txt b/requirements.txt index 21531d0f..ffcb6b80 100644 --- a/requirements.txt +++ b/requirements.txt @@ -14,4 +14,4 @@ httpx>=0.27.0,<1 protobuf>=3.20.2 # not directly required, pinned by Snyk to avoid a vulnerability opal-common==0.7.6 opal-client==0.7.6 -permit-datafilter>=0.0.2,<1 +permit-datafilter>=0.0.3,<1