diff --git a/.dockerignore b/.dockerignore index 756bcfa9..ddf60e0c 100644 --- a/.dockerignore +++ b/.dockerignore @@ -2,3 +2,4 @@ helm/ .venv/ .github/ +lib/ diff --git a/.gitignore b/.gitignore index 7d6da5c0..8c0fbd1a 100644 --- a/.gitignore +++ b/.gitignore @@ -14,7 +14,6 @@ dist/ downloads/ eggs/ .eggs/ -lib/ lib64/ parts/ sdist/ diff --git a/Makefile b/Makefile index 4cf8e62b..52b1c8f5 100644 --- a/Makefile +++ b/Makefile @@ -24,7 +24,7 @@ build-arm64: prepare @docker buildx build --platform linux/arm64 -t permitio/pdp-v2:$(VERSION) . --load run: run-prepare - @docker run -p 7766:7000 --env PDP_API_KEY=$(API_KEY) --env PDP_DEBUG=true permitio/pdp-v2:$(VERSION) + @docker run -it -p 7766:7000 -p 8181:8181 --env PDP_API_KEY=$(API_KEY) --env PDP_DEBUG=true permitio/pdp-v2:$(VERSION) run-on-background: run-prepare - @docker run -d -p 7766:7000 --env PDP_API_KEY=$(API_KEY) --env PDP_DEBUG=true permitio/pdp-v2:$(VERSION) + @docker run -d -p 7766:7000 -p 8181:8181 --env PDP_API_KEY=$(API_KEY) --env PDP_DEBUG=true permitio/pdp-v2:$(VERSION) diff --git a/horizon/enforcer/api.py b/horizon/enforcer/api.py index 24b1dc8e..00ab50a7 100644 --- a/horizon/enforcer/api.py +++ b/horizon/enforcer/api.py @@ -4,7 +4,7 @@ from typing import cast, Optional, Union, Dict, List import aiohttp -from fastapi import APIRouter, Depends, Header +from fastapi import APIRouter, Depends, Header, Query from fastapi import HTTPException from fastapi import Request, Response, status from opal_client.config import opal_client_config @@ -14,6 +14,7 @@ DEFAULT_POLICY_STORE_GETTER, ) from opal_client.utils import proxy_response +from permit_datafilter.compile_api.compile_client import OpaCompileClient from pydantic import parse_obj_as from starlette.responses import JSONResponse @@ -57,6 +58,34 @@ AUTHORIZED_USERS_POLICY_PACKAGE = "permit.authorized_users.authorized_users" USER_TENANTS_POLICY_PACKAGE = USER_PERMISSIONS_POLICY_PACKAGE + ".tenants" KONG_ROUTES_TABLE_FILE = "/config/kong_routes.json" +MAIN_PARTIAL_EVAL_PACKAGE = "permit.partial_eval" + +# TODO: a more robust policy needs to be added to permit default managed policy +# this policy is partial-eval friendly, but only supports RBAC at the moment +TEMP_PARTIAL_EVAL_POLICY = """ +package permit.partial_eval + +import future.keywords.contains +import future.keywords.if +import future.keywords.in + +default allow := false + +allow if { + checked_permission := sprintf("%s:%s", [input.resource.type, input.action]) + + some granting_role, role_data in data.roles + some resource_type, actions in role_data.grants + granted_action := actions[_] + granted_permission := sprintf("%s:%s", [resource_type, granted_action]) + + some tenant, roles in data.users[input.user.key].roleAssignments + role := roles[_] + role == granting_role + checked_permission == granted_permission + input.resource.tenant == tenant +} +""" stats_manager = StatisticsManager( interval_seconds=sidecar_config.OPA_CLIENT_FAILURE_THRESHOLD_INTERVAL, @@ -676,4 +705,42 @@ async def is_allowed_kong(request: Request, query: KongAuthorizationQuery): ) return result + @router.post( + "/filter_resources", + response_model=AuthorizationResult, + status_code=status.HTTP_200_OK, + response_model_exclude_none=True, + dependencies=[Depends(enforce_pdp_token)], + ) + async def filter_resources( + request: Request, + input: AuthorizationQuery, + raw: bool = Query( + False, + description="whether we should include the OPA raw compilation result in the response. this can help us debug the translation of the AST", + ), + x_permit_sdk_language: Optional[str] = Depends(notify_seen_sdk), + ): + headers = transform_headers(request) + client = OpaCompileClient( + base_url=f"{opal_client_config.POLICY_STORE_URL}", headers=headers + ) + COMPILE_ROOT_RULE_REFERENCE = f"data.{MAIN_PARTIAL_EVAL_PACKAGE}.allow" + query = f"{COMPILE_ROOT_RULE_REFERENCE} == true" + residual_policy = await client.compile_query( + query=query, + input=input.dict(), + unknowns=[ + "input.resource.key", + "input.resource.tenant", + "input.resource.attributes", + ], + raw=raw, + ) + return Response( + content=json.dumps(residual_policy.dict()), + status_code=status.HTTP_200_OK, + media_type="application/json", + ) + return router diff --git a/horizon/state.py b/horizon/state.py index 2ba862f5..364ab0be 100644 --- a/horizon/state.py +++ b/horizon/state.py @@ -125,7 +125,7 @@ def _get_pdp_version(cls) -> Optional[str]: if os.path.exists(PDP_VERSION_FILENAME): with open(PDP_VERSION_FILENAME) as f: return f.read().strip() - return None + return "0.0.0" @classmethod def _get_pdp_runtime(cls) -> dict: diff --git a/lib/permit-datafilter/MANIFEST.in b/lib/permit-datafilter/MANIFEST.in new file mode 100644 index 00000000..479b8867 --- /dev/null +++ b/lib/permit-datafilter/MANIFEST.in @@ -0,0 +1 @@ +include *.md requirements.txt diff --git a/lib/permit-datafilter/Makefile b/lib/permit-datafilter/Makefile new file mode 100644 index 00000000..b411d380 --- /dev/null +++ b/lib/permit-datafilter/Makefile @@ -0,0 +1,11 @@ +.PHONY: help + +.DEFAULT_GOAL := help + +clean: + rm -rf *.egg-info build/ dist/ + +publish: + $(MAKE) clean + python setup.py sdist bdist_wheel + python -m twine upload dist/* diff --git a/lib/permit-datafilter/README.md b/lib/permit-datafilter/README.md new file mode 100644 index 00000000..8da0d365 --- /dev/null +++ b/lib/permit-datafilter/README.md @@ -0,0 +1,9 @@ +# Permit.io Data Filtering SDK (EAP) + +Initial SDK to enable data filtering scenarios based on compiling OPA policies from Rego AST into SQL-like expressions. + +## Installation + +```py +pip install permit-datafilter +``` diff --git a/lib/permit-datafilter/permit_datafilter/__init__.py b/lib/permit-datafilter/permit_datafilter/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/lib/permit-datafilter/permit_datafilter/boolean_expression/__init__.py b/lib/permit-datafilter/permit_datafilter/boolean_expression/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/lib/permit-datafilter/permit_datafilter/boolean_expression/schemas.py b/lib/permit-datafilter/permit_datafilter/boolean_expression/schemas.py new file mode 100644 index 00000000..5494be58 --- /dev/null +++ b/lib/permit-datafilter/permit_datafilter/boolean_expression/schemas.py @@ -0,0 +1,98 @@ +from __future__ import annotations + +from enum import Enum +from typing import Any, Optional, Union + +from pydantic import BaseModel, Field, root_validator + +LOGICAL_AND = "and" +LOGICAL_OR = "or" +LOGICAL_NOT = "not" +CALL_OPERATOR = "call" + + +class BaseSchema(BaseModel): + class Config: + orm_mode = True + allow_population_by_field_name = True + + +class Variable(BaseSchema): + """ + represents a variable in a boolean expression tree. + + if the boolean expression originated in OPA - this was originally a reference in the OPA document tree. + """ + + variable: str = Field(..., description="a path to a variable (reference)") + + +class Value(BaseSchema): + """ + Represents a value (literal) in a boolean expression tree. + Could be of any jsonable type: string, int, boolean, float, list, dict. + """ + + value: Any = Field( + ..., description="a literal value, typically compared to a variable" + ) + + +class Expr(BaseSchema): + operator: str = Field(..., description="the name of the operator") + operands: list["Operand"] = Field(..., description="the operands to the expression") + + +class Expression(BaseSchema): + """ + represents a boolean expression, comparised of logical operators (e.g: and/or/not) and comparison operators (e.g: ==, >, <) + + we translate OPA call terms to expressions, treating the operator as the function name and the operands as the args. + """ + + expression: Expr = Field( + ..., + description="represents a boolean expression, comparised of logical operators (e.g: and/or/not) and comparison operators (e.g: ==, >, <)", + ) + + +Operand = Union[Variable, Value, "Expression"] + + +class ResidualPolicyType(str, Enum): + ALWAYS_ALLOW = "always_allow" + ALWAYS_DENY = "always_deny" + CONDITIONAL = "conditional" + + +class ResidualPolicyResponse(BaseSchema): + type: ResidualPolicyType = Field(..., description="the type of the residual policy") + condition: Optional["Expression"] = Field( + None, + description="an optional condition, exists if the type of the residual policy is CONDITIONAL", + ) + raw: Optional[dict] = Field( + None, + description="raw OPA compilation result, provided for debugging purposes", + ) + + @root_validator + def check_condition_exists_when_needed(cls, values: dict): + type, condition = values.get("type"), values.get("condition", None) + if ( + type == ResidualPolicyType.ALWAYS_ALLOW + or type == ResidualPolicyType.ALWAYS_DENY + ) and condition is not None: + raise ValueError( + f"invalid residual policy: a condition exists but the type is not CONDITIONAL, instead: {type}" + ) + if type == ResidualPolicyType.CONDITIONAL and condition is None: + raise ValueError( + f"invalid residual policy: type is CONDITIONAL, but no condition is provided" + ) + return values + + +Expr.update_forward_refs() +Expression.update_forward_refs() +ResidualPolicyResponse.update_forward_refs() diff --git a/lib/permit-datafilter/permit_datafilter/boolean_expression/translator.py b/lib/permit-datafilter/permit_datafilter/boolean_expression/translator.py new file mode 100644 index 00000000..6ae2c477 --- /dev/null +++ b/lib/permit-datafilter/permit_datafilter/boolean_expression/translator.py @@ -0,0 +1,120 @@ +from permit_datafilter.rego_ast import parser as ast +from permit_datafilter.boolean_expression.schemas import ( + CALL_OPERATOR, + Operand, + ResidualPolicyResponse, + ResidualPolicyType, + Expression, + Expr, + Value, + Variable, + LOGICAL_AND, + LOGICAL_OR, +) + + +def translate_opa_queryset(queryset: ast.QuerySet) -> ResidualPolicyResponse: + """ + translates the Rego AST into a generic residual policy constructed as a boolean expression. + + this boolean expression can then be translated by plugins into various SQL and ORMs. + """ + if queryset.always_false: + return ResidualPolicyResponse(type=ResidualPolicyType.ALWAYS_DENY) + + if queryset.always_true: + return ResidualPolicyResponse(type=ResidualPolicyType.ALWAYS_ALLOW) + + if len(queryset.queries) == 1: + return ResidualPolicyResponse( + type=ResidualPolicyType.CONDITIONAL, + condition=translate_query(queryset.queries[0]), + ) + + queries = [query for query in queryset.queries if not query.always_true] + + if len(queries) == 0: + # no not trival queries means always true + return ResidualPolicyResponse(type=ResidualPolicyType.ALWAYS_ALLOW) + + # else, more than one query means there's a logical OR between queries + return ResidualPolicyResponse( + type=ResidualPolicyType.CONDITIONAL, + condition=Expression( + expression=Expr( + operator=LOGICAL_OR, + operands=[translate_query(query) for query in queries], + ) + ), + ) + + +def translate_query(query: ast.Query) -> Expression: + if len(query.expressions) == 1: + return translate_expression(query.expressions[0]) + + return Expression( + expression=Expr( + operator=LOGICAL_AND, + operands=[ + translate_expression(expression) for expression in query.expressions + ], + ) + ) + + +def translate_expression(expression: ast.Expression) -> Expression: + if len(expression.terms) == 1 and expression.terms[0].type == ast.TermType.CALL: + # this is a call expression + return translate_call_term(expression.terms[0].value) + + if not isinstance(expression.operator, ast.RefTerm): + raise ValueError( + f"The operator in an expression must be a term of type ref, instead got type {expression.operator.type} and value {expression.operator.value}" + ) + + return Expression( + expression=Expr( + operator=expression.operator.value.as_string, + operands=[translate_term(term) for term in expression.operands], + ) + ) + + +def translate_call_term(call: ast.Call) -> Expression: + return Expression( + expression=Expr( + operator=CALL_OPERATOR, + operands=[ + Expression( + expression=Expr( + operator=call.func, + operands=[translate_term(term) for term in call.args], + ) + ) + ], + ) + ) + + +def translate_term(term: ast.Term) -> Operand: + if term.type in ( + ast.TermType.NULL, + ast.TermType.BOOLEAN, + ast.TermType.NUMBER, + ast.TermType.STRING, + ): + return Value(value=term.value) + + if term.type == ast.TermType.VAR: + return Variable(variable=term.value) + + if term.type == ast.TermType.REF and isinstance(term.value, ast.Ref): + return Variable(variable=term.value.as_string) + + if term.type == ast.TermType.CALL and isinstance(term.value, ast.Call): + return translate_call_term(term.value) + + raise ValueError( + f"unable to translate term with type {term.type} and value {term.value}" + ) diff --git a/lib/permit-datafilter/permit_datafilter/compile_api/__init__.py b/lib/permit-datafilter/permit_datafilter/compile_api/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/lib/permit-datafilter/permit_datafilter/compile_api/compile_client.py b/lib/permit-datafilter/permit_datafilter/compile_api/compile_client.py new file mode 100644 index 00000000..cfbc72f5 --- /dev/null +++ b/lib/permit-datafilter/permit_datafilter/compile_api/compile_client.py @@ -0,0 +1,85 @@ +import json + +import aiohttp +from fastapi import HTTPException, status +from loguru import logger + +from permit_datafilter.compile_api.schemas import CompileResponse +from permit_datafilter.rego_ast import parser as ast +from permit_datafilter.boolean_expression.schemas import ( + ResidualPolicyResponse, +) +from permit_datafilter.boolean_expression.translator import ( + translate_opa_queryset, +) + + +class OpaCompileClient: + def __init__(self, base_url: str, headers: dict): + self._base_url = base_url # f"{opal_client_config.POLICY_STORE_URL}" + self._headers = headers + self._client = aiohttp.ClientSession( + base_url=self._base_url, headers=self._headers + ) + + async def compile_query( + self, + query: str, + input: dict, + unknowns: list[str], + raw: bool = False, + ) -> ResidualPolicyResponse: + # we don't want debug rules when we try to reduce the policy into a partial policy + input = {**input, "use_debugger": False} + data = { + "query": query, + "input": input, + "unknowns": unknowns, + } + try: + logger.debug("Compiling OPA query: {}", data) + async with self._client as session: + async with session.post( + "/v1/compile", + data=json.dumps(data), + raise_for_status=True, + ) as response: + opa_compile_result = await response.json() + logger.debug( + "OPA compile query result: status={status}, response={response}", + status=response.status, + response=json.dumps(opa_compile_result), + ) + try: + residual_policy = self.translate_rego_ast(opa_compile_result) + if raw: + residual_policy.raw = opa_compile_result + return residual_policy + except Exception as exc: + return HTTPException( + status.HTTP_406_NOT_ACCEPTABLE, + detail="failed to translate compiled OPA query (query: {query}, response: {response}, exc={exc})".format( + query=data, response=opa_compile_result, exc=exc + ), + ) + except aiohttp.ClientResponseError as e: + exc = HTTPException( + status.HTTP_502_BAD_GATEWAY, # 502 indicates server got an error from another server + detail="OPA request failed (url: {url}, status: {status}, message: {message})".format( + url=self._base_url, status=e.status, message=e.message + ), + ) + except aiohttp.ClientError as e: + exc = HTTPException( + status.HTTP_502_BAD_GATEWAY, + detail="OPA request failed (url: {url}, error: {error}".format( + url=self._base_url, error=str(e) + ), + ) + logger.warning(exc.detail) + raise exc + + def translate_rego_ast(self, response: dict) -> ResidualPolicyResponse: + response = CompileResponse(**response) + queryset = ast.QuerySet.parse(response) + return translate_opa_queryset(queryset) diff --git a/lib/permit-datafilter/permit_datafilter/compile_api/schemas.py b/lib/permit-datafilter/permit_datafilter/compile_api/schemas.py new file mode 100644 index 00000000..8b1c7dec --- /dev/null +++ b/lib/permit-datafilter/permit_datafilter/compile_api/schemas.py @@ -0,0 +1,75 @@ +from __future__ import annotations + +from typing import Any, List, Optional, Union + +from pydantic import BaseModel, Field + + +class BaseSchema(BaseModel): + class Config: + orm_mode = True + allow_population_by_field_name = True + + +class CompileResponse(BaseSchema): + result: "CompileResponseComponents" + + +class CompileResponseComponents(BaseSchema): + queries: Optional["CRQuerySet"] = None + support: Optional["CRSupportBlock"] = None + + +class CRQuerySet(BaseSchema): + __root__: List["CRQuery"] + + +class CRQuery(BaseSchema): + __root__: List["CRExpression"] + + +class CRExpression(BaseSchema): + index: int + terms: Union["CRTerm", List["CRTerm"]] + + +class CRTerm(BaseSchema): + type: str + value: Any + + +class CRSupportBlock(BaseSchema): + __root__: List["CRSupportModule"] + + +class CRSupportModule(BaseSchema): + package: "CRSupportModulePackage" + rules: List["CRRegoRule"] + + +class CRSupportModulePackage(BaseSchema): + path: List["CRTerm"] + + +class CRRegoRule(BaseSchema): + body: List["CRExpression"] + head: "CRRuleHead" + + +class CRRuleHead(BaseSchema): + name: str + key: "CRTerm" + ref: List["CRTerm"] + + +CompileResponse.update_forward_refs() +CompileResponseComponents.update_forward_refs() +CRQuerySet.update_forward_refs() +CRQuery.update_forward_refs() +CRExpression.update_forward_refs() +CRTerm.update_forward_refs() +CRSupportBlock.update_forward_refs() +CRSupportModule.update_forward_refs() +CRSupportModulePackage.update_forward_refs() +CRRegoRule.update_forward_refs() +CRRuleHead.update_forward_refs() diff --git a/lib/permit-datafilter/permit_datafilter/plugins/__init__.py b/lib/permit-datafilter/permit_datafilter/plugins/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/lib/permit-datafilter/permit_datafilter/plugins/sqlalchemy.py b/lib/permit-datafilter/permit_datafilter/plugins/sqlalchemy.py new file mode 100644 index 00000000..515013c1 --- /dev/null +++ b/lib/permit-datafilter/permit_datafilter/plugins/sqlalchemy.py @@ -0,0 +1,200 @@ +from typing import Any, Callable, Dict, List, Optional, Tuple, Union, cast + +import sqlalchemy as sa + +from sqlalchemy.orm import DeclarativeMeta, InstrumentedAttribute +from sqlalchemy.sql import Select +from sqlalchemy.sql.expression import BinaryExpression, ColumnOperators + +from permit_datafilter.boolean_expression.schemas import ( + CALL_OPERATOR, + LOGICAL_AND, + LOGICAL_NOT, + LOGICAL_OR, + Expression, + Operand, + ResidualPolicyResponse, + ResidualPolicyType, + Value, + Variable, +) + +Table = Union[sa.Table, DeclarativeMeta] +Column = Union[sa.Column, InstrumentedAttribute] +Condition = Union[BinaryExpression, ColumnOperators] + + +OperatorMap = Dict[str, Callable[[Column, Any], Condition]] + +SUPPORTED_OPERATORS: OperatorMap = { + "eq": lambda c, v: c == v, + "ne": lambda c, v: c != v, + "lt": lambda c, v: c < v, + "gt": lambda c, v: c > v, + "le": lambda c, v: c <= v, + "ge": lambda c, v: c >= v, +} + + +def operator_to_sql(operator: str, column: Column, value: Any) -> Condition: + if (operator_fn := SUPPORTED_OPERATORS.get(operator)) is not None: + return operator_fn(column, value) + raise ValueError(f"Unrecognised operator: {operator}") + + +def _get_table_name(t: Table) -> str: + try: + return t.__table__.name + except AttributeError: + return t.name + + +def verify_join_conditions( + table: Table, + reference_mapping: Dict[str, Column], + join_conditions: List[Tuple[Table, Condition]] = [], +): + def column_table_name(c: Column) -> str: + return cast(sa.Table, c.table).name + + def is_main_table_column(c: Column) -> bool: + return column_table_name(c) == _get_table_name(table) + + required_joins = set( + ( + column_table_name(column) + for column in reference_mapping.values() + if not is_main_table_column(column) + ) + ) + + if len(required_joins): + if len(join_conditions) == 0: + raise TypeError( + f"You must call QueryBuilder.join(table, condition) to map residual references to other SQL tables" + ) + else: + provided_joined_tables = set( + (_get_table_name(t) for t, _ in join_conditions) + ) + missing_tables = required_joins.difference(provided_joined_tables) + if len(missing_tables): + raise TypeError( + f"QueryBuilder.join() was not called for these SQL tables: {repr(missing_tables)}" + ) + + +class QueryBuilder: + def __init__(self): + self._table: Optional[Table] = None + self._residual_policy: Optional[ResidualPolicyResponse] = None + self._refs: Optional[Dict[str, Column]] = None + self._joins: List[Tuple[Table, Condition]] = [] + + def select(self, table: Table) -> "QueryBuilder": + self._table = table + return self + + def filter_by(self, residual_policy: ResidualPolicyResponse) -> "QueryBuilder": + self._residual_policy = residual_policy + return self + + def map_references(self, refs: Dict[str, Column]) -> "QueryBuilder": + self._refs = refs + return self + + def join(self, table: Table, condition: Condition) -> "QueryBuilder": + self._joins.append((table, condition)) + return self + + def _verify_args(self): + if self._table is None: + raise ValueError( + f"You must call QueryBuilder.select(table) to specify to main table to filter on" + ) + + if self._residual_policy is None: + raise ValueError( + f"You must call QueryBuilder.filter_by(residual_policy) to specify the compiled partial policy returned from OPA" + ) + + if self._refs is None: + raise ValueError( + f"You must call QueryBuilder.map_references(refs) to specify how to map residual OPA references to SQL tables" + ) + + def build(self) -> Select: + self._verify_args() + + table = self._table + residual_policy = self._residual_policy + refs = self._refs + join_conditions = self._joins + + select_all = cast(Select, sa.select(table)) + + if residual_policy.type == ResidualPolicyType.ALWAYS_ALLOW: + return select_all + + if residual_policy.type == ResidualPolicyType.ALWAYS_DENY: + return select_all.where(False) + + verify_join_conditions(table, refs, join_conditions) + + def to_sql(expr: Expression): + operator = expr.expression.operator + operands = expr.expression.operands + + if operator == LOGICAL_AND: + return sa.and_(*[to_sql(o) for o in operands]) + if operator == LOGICAL_OR: + return sa.or_(*[to_sql(o) for o in operands]) + if operator == LOGICAL_NOT: + return sa.not_(*[to_sql(o) for o in operands]) + if operator == CALL_OPERATOR: + raise NotImplementedError("need to implement call() translation to sql") + + # otherwise, operator is a comparison operator + variables = [o for o in operands if isinstance(o, Variable)] + values = [o for o in operands if isinstance(o, Value)] + + if not (len(variables) == 1 and len(values) == 1): + raise NotImplementedError( + "need to implement support in more comparison operators" + ) + + variable_ref: str = variables[0].variable + value: Any = values[0].value + + try: + column = refs[variable_ref] + except KeyError: + raise KeyError( + f"Residual variable does not exist in the reference mapping: {variable_ref}" + ) + + # the operator handlers here are the leaf nodes of the recursion + return operator_to_sql(operator, column, value) + + query: Select = select_all.where(to_sql(residual_policy.condition)) + + if join_conditions: + query = query.select_from(table) + for join_table, predicate in join_conditions: + query = query.join(join_table, predicate) + + return query + + +def to_query( + filters: ResidualPolicyResponse, + table: Table, + *, + refs: Dict[str, Column], + join_conditions: Optional[List[Tuple[Table, Condition]]] = None, +) -> Select: + query_builder = QueryBuilder().select(table).filter_by(filters).map_references(refs) + if join_conditions is not None: + for joined_table, join_condition in join_conditions: + query_builder = query_builder.join(joined_table, join_condition) + return query_builder.build() diff --git a/lib/permit-datafilter/permit_datafilter/rego_ast/__init__.py b/lib/permit-datafilter/permit_datafilter/rego_ast/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/lib/permit-datafilter/permit_datafilter/rego_ast/parser.py b/lib/permit-datafilter/permit_datafilter/rego_ast/parser.py new file mode 100644 index 00000000..8979c739 --- /dev/null +++ b/lib/permit-datafilter/permit_datafilter/rego_ast/parser.py @@ -0,0 +1,455 @@ +# Rego policy structure +# +# Rego policies are defined using a relatively small set of types: +# modules, package and import declarations, rules, expressions, and terms. +# At their core, policies consist of rules that are defined by one or more expressions over documents available to the policy engine. +# The expressions are defined by intrinsic values (terms) such as strings, objects, variables, etc. +# Rego policies are typically defined in text files and then parsed and compiled by the policy engine at runtime. +# +# The parsing stage takes the text or string representation of the policy and converts it into an abstract syntax tree (AST) +# that consists of the types mentioned above. The AST is organized as follows: +# Module +# | +# +--- Package (Reference) +# | +# +--- Imports +# | | +# | +--- Import (Term) +# | +# +--- Rules +# | +# +--- Rule +# | +# +--- Head +# | | +# | +--- Name (Variable) +# | | +# | +--- Key (Term) +# | | +# | +--- Value (Term) +# | +# +--- Body +# | +# +--- Expression (Term | Terms | Variable Declaration) +# | +# +--- Term + + +from enum import Enum +import json +from types import NoneType +from typing import Generic, List, TypeVar + +from permit_datafilter.compile_api.schemas import ( + CRExpression, + CRQuery, + CompileResponse, + CRTerm, +) + + +def indent_lines(s: str, indent_char: str = "\t", indent_level: int = 1) -> str: + """ + indents all lines within a multiline string by an indent character repeated by a given indent level. + e.g: indents all lines within a string by 3 tab (\t) characters (indent level 3). + + returns a new indented string. + ) + """ + indent = indent_char * indent_level + return "\n".join(["{}{}".format(indent, row) for row in s.splitlines()]) + + +class QuerySet: + """ + A queryset is a result of partial evaluation, creating a residual policy consisting + of multiple queries (each query consists of multiple rego expressions). + + The query essentially outlines a set of conditions for the residual policy to be true. + All the expressions of the query must be true (logical AND) in order for the query to evaluate to TRUE. + + You can roughly translate the query set into an SQL WHERE statement. + + Between each query of the queryset - there is a logical OR: + Policy is true if query_1 OR query_2 OR ... OR query_n and query_i == (expr_i_1 AND ... AND expr_i_m). + """ + + def __init__(self, queries: list["Query"], support_modules: list): + self._queries = queries + self._support_modules = support_modules + + @classmethod + def parse(cls, response: CompileResponse) -> "QuerySet": + """ + example data: + # queryset + [ + # query (an array of expressions) + [ + # expression (an array of terms) + { + "index": 0, + "terms": [ + ... + ] + } + ], + ... + ] + """ + queries: List[CRQuery] = ( + [] if response.result.queries is None else response.result.queries.__root__ + ) + # TODO: parse support modules + # modules: List[CRSupportModule] = ( + # [] if response.result.support is None else response.result.support.__root__ + # ) + return cls([Query.parse(q) for q in queries], []) + + @property + def queries(self) -> list["Query"]: + return self._queries + + def __repr__(self): + queries_str = "\n".join([indent_lines(repr(r)) for r in self._queries]) + return "QuerySet([\n{}\n])\n".format(queries_str) + + @property + def always_false(self) -> bool: + return len(self._queries) == 0 + + @property + def always_true(self) -> bool: + return len(self._queries) > 0 and any(q.always_true for q in self._queries) + + @property + def conditional(self) -> bool: + return not self.always_false and not self.always_true + + +class Query: + """ + A residual query is a result of partial evaluation. + The query essentially outlines a set of conditions for the residual policy to be true. + All the expressions of the query must be true (logical AND) in order for the query to evaluate to TRUE. + """ + + def __init__(self, expressions: list["Expression"]): + self._expressions = expressions + + @classmethod + def parse(cls, query: CRQuery) -> "Query": + """ + example data: + # query (an array of expressions) + [ + # expression (an array of terms) + { + "index": 0, + "terms": [ + ... + ] + } + ] + """ + return cls([Expression.parse(e) for e in query.__root__]) + + @property + def expressions(self) -> list["Expression"]: + return self._expressions + + @property + def always_true(self) -> bool: + """ + returns true if the query always evaluates to TRUE + """ + return len(self._expressions) == 0 + + def __repr__(self): + exprs_str = "\n".join([indent_lines(repr(e)) for e in self._expressions]) + return "Query([\n{}\n])\n".format(exprs_str) + + +class Expression: + """ + An expression roughly translate into one line of rego code. + Typically a rego rule consists of multiple expressions. + + An expression is comprised of multiple terms (typically 3), where the first is an operator and the rest are operands. + """ + + def __init__(self, terms: list["Term"]): + self._terms = terms + + @classmethod + def parse(cls, data: CRExpression) -> "Expression": + """ + example data: + # expression + { + "index": 0, + # terms + "terms": [ + # first term is typically an operator (e.g: ==, !=, >, <, etc) + # the operator will typically be a *reference* to a built in function. + # for example the "equals" (or "==") operator (within OPA operators are called builtins) is actually the builtin function "eq()". + { + "type": "ref", + "value": [ + { + "type": "var", + "value": "eq" + } + ] + }, + # rest of terms are typically operands + { + "type": "ref", + "value": [ + { + "type": "var", + "value": "input" + }, + { + "type": "string", + "value": "resource" + }, + { + "type": "string", + "value": "tenant" + } + ] + }, + { + "type": "string", + "value": "default" + } + ] + } + """ + terms = data.terms + if isinstance(terms, CRTerm): + return cls([TermParser.parse(terms)]) + else: + return cls([TermParser.parse(t) for t in terms]) + + @property + def operator(self) -> "Term": + """ + returns the term that is the operator of the expression (typically the first term) + """ + return self._terms[0] + + @property + def operands(self) -> list["Term"]: + """ + returns the terms that are the operands of the expression + """ + return self._terms[1:] + + @property + def terms(self) -> list["Term"]: + """ + returns all the terms of the expression + """ + return self._terms + + def __repr__(self): + operands_str = ", ".join([repr(o) for o in self.operands]) + return "Expression({}, [{}])".format(repr(self.operator), operands_str) + + +T = TypeVar("T") + + +class TermType(str, Enum): + NULL = "null" + BOOLEAN = "boolean" + NUMBER = "number" + STRING = "string" + VAR = "var" + REF = "ref" + CALL = "call" + + +class Term(Generic[T]): + """ + a term is an atomic part of an expression (line of code). + it is typically a literal (number, string, etc), a reference (variable) or an operator (e.g: "==" or ">"). + """ + + def __init__(self, value: T): + self.value = value + + @property + def type(self) -> str: + return NotImplementedError() + + @classmethod + def parse(cls, data: T) -> "Term": + return cls(data) + + def __repr__(self): + return json.dumps(self.value) + + +class NullTerm(Term[NoneType]): + @property + def type(self) -> str: + return TermType.NULL + + +class BooleanTerm(Term[bool]): + @property + def type(self) -> str: + return TermType.BOOLEAN + + +class NumberTerm(Term[int | float]): + @property + def type(self) -> str: + return TermType.NUMBER + + +class StringTerm(Term[str]): + @property + def type(self) -> str: + return TermType.STRING + + +class VarTerm(Term[str]): + def __repr__(self): + return self.value + + @property + def type(self) -> str: + return TermType.VAR + + +class Ref: + """ + represents a reference in OPA, which is a path to a document in the OPA document tree. + when translating this into a boolean expression tree, a reference would typically translate into a variable. + """ + + def __init__(self, ref_parts: list[str]): + self._parts = ref_parts + + @property + def parts(self) -> list[str]: + """ + the parts of the full path to the document, each part is a node in OPA document tree. + """ + return self._parts + + @property + def as_string(self) -> str: + return str(self) # calls __str__(self) + + def __str__(self): + return ".".join(self._parts) + + +class RefTerm(Term[Ref]): + """ + A term that represents an OPA reference, holds a Ref object as a value. + """ + + @property + def type(self) -> str: + return TermType.REF + + @classmethod + def parse(cls, terms: list[dict]) -> "Term[Ref]": + assert len(terms) > 0 + parsed_terms: list[Term] = [TermParser.parse(CRTerm(**t)) for t in terms] + var_term = parsed_terms[0] + # TODO: support more types of refs + assert isinstance( + var_term, VarTerm + ), "first sub-term inside ref is not a variable" + string_terms = parsed_terms[1:] # might be empty + # TODO: support more types of refs + assert all( + isinstance(t, StringTerm) for t in string_terms + ), "ref parts are not string terms" + ref_parts = [t.value for t in parsed_terms] + return cls(Ref(ref_parts)) + + def __repr__(self): + return "Ref({})".format(self.value.as_string) + + +class Call: + """ + represents a function call expression inside OPA. + """ + + def __init__(self, func: Term, args: list[Term]): + self._func = func + self._args = args + + @property + def func(self) -> Term: + """ + a (ref)term that holds the reference to the name of the function that is acting on the arguments of the function call + """ + return self._func + + @property + def args(self) -> list[Term]: + """ + the terms representing the arguments of the function call + """ + return self._args + + def __str__(self): + return "{}({})".format(self.func, ", ".join([str(arg) for arg in self.args])) + + +class CallTerm(Term[Call]): + """ + A term that represents a function call Term, holds a Call object as a value. + """ + + @property + def type(self) -> str: + return TermType.CALL + + @classmethod + def parse(cls, terms: list[dict]) -> Term[Call]: + assert len(terms) > 0 + parsed_terms: list[Term] = [TermParser.parse(CRTerm(**t)) for t in terms] + func_term = parsed_terms[0] + # TODO: support more types of refs + assert isinstance(func_term, RefTerm), "first sub-term inside call is not a ref" + arg_terms = parsed_terms[1:] # might be empty + return cls(Call(func_term, arg_terms)) + + def __repr__(self): + return "call:{}".format(str(self.value)) + + +class TermParser: + TERMS_BY_TYPE: dict[str, Term] = { + "null": NullTerm, + "boolean": BooleanTerm, + "number": NumberTerm, + "string": StringTerm, + "var": VarTerm, + "ref": RefTerm, + "call": CallTerm, + # "array": ArrayTerm, + # "set": SetTerm, + # "object": ObjectTerm, + # "arraycomprehension": ArrayComprehensionTerm, + # "setcomprehension": SetComprehensionTerm, + # "objectcomprehension": ObjectComprehensionTerm, + } + + @classmethod + def parse(cls, data: CRTerm) -> Term: + if data.type == "null": + data.value = None + klass = cls.TERMS_BY_TYPE[data.type] + return klass.parse(data.value) diff --git a/lib/permit-datafilter/requirements.txt b/lib/permit-datafilter/requirements.txt new file mode 100644 index 00000000..27c9ec1f --- /dev/null +++ b/lib/permit-datafilter/requirements.txt @@ -0,0 +1,7 @@ +aiohttp>=3.9.4,<4 +ddtrace +fastapi>=0.109.1,<1 +loguru +pydantic[email]>=1.9.1,<2 +pytest +SQLAlchemy==1.4.46 diff --git a/lib/permit-datafilter/setup.py b/lib/permit-datafilter/setup.py new file mode 100644 index 00000000..c448579b --- /dev/null +++ b/lib/permit-datafilter/setup.py @@ -0,0 +1,40 @@ +from pathlib import Path + +from setuptools import find_packages, setup + + +def get_requirements(env=""): + if env: + env = "-{}".format(env) + with open("requirements{}.txt".format(env)) as fp: + return [x.strip() for x in fp.read().split("\n") if not x.startswith("#")] + + +def get_readme() -> str: + this_directory = Path(__file__).parent + long_description = (this_directory / "README.md").read_text() + return long_description + + +setup( + name="permit-datafilter", + version="0.0.3", + packages=find_packages(), + author="Asaf Cohen", + author_email="asaf@permit.io", + license="Apache 2.0", + python_requires=">=3.8", + description="Permit.io python sdk", + install_requires=get_requirements(), + long_description=get_readme(), + long_description_content_type="text/markdown", + classifiers=[ + "Operating System :: OS Independent", + "Programming Language :: Python", + "Programming Language :: Python :: 3", + "Programming Language :: Python :: 3.8", + "Programming Language :: Python :: 3.9", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", + ], +) diff --git a/lib/permit-datafilter/tests/__init__.py b/lib/permit-datafilter/tests/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/lib/permit-datafilter/tests/test_ast_parser.py b/lib/permit-datafilter/tests/test_ast_parser.py new file mode 100644 index 00000000..2af4ba26 --- /dev/null +++ b/lib/permit-datafilter/tests/test_ast_parser.py @@ -0,0 +1,326 @@ +from permit_datafilter.rego_ast.parser import ( + BooleanTerm, + Call, + CallTerm, + Expression, + NullTerm, + Query, + QuerySet, + Ref, + RefTerm, + Term, + TermParser, + NumberTerm, + StringTerm, + VarTerm, +) +from permit_datafilter.compile_api.schemas import ( + CRTerm, + CRExpression, + CRQuery, + CRQuerySet, + CompileResponse, +) + + +def test_parse_null_term(): + t = CRTerm(**{"type": "null"}) + term: Term = TermParser.parse(t) + assert isinstance(term, NullTerm) + assert term.value == None + + +def test_parse_boolean_term(): + for val in [True, False]: + t = CRTerm(**{"type": "boolean", "value": val}) + term: Term = TermParser.parse(t) + assert isinstance(term, BooleanTerm) + assert term.value == val + isinstance(term.value, bool) + + +def test_parse_number_term(): + for val in [0, 2, 3.14]: + t = CRTerm(**{"type": "number", "value": val}) + term: Term = TermParser.parse(t) + assert isinstance(term, NumberTerm) + assert term.value == val + assert isinstance(term.value, int) or isinstance(term.value, float) + + +def test_parse_string_term(): + for val in ["hello", "world", ""]: + t = CRTerm(**{"type": "string", "value": val}) + term: Term = TermParser.parse(t) + assert isinstance(term, StringTerm) + assert term.value == val + assert isinstance(term.value, str) + + +def test_parse_var_term(): + t = CRTerm(**{"type": "var", "value": "eq"}) + term: Term = TermParser.parse(t) + assert isinstance(term, VarTerm) + assert term.value == "eq" + assert isinstance(term.value, str) + + +def test_parse_simple_ref_term(): + simple_ref_term = { + "type": "ref", + "value": [ + {"type": "var", "value": "allowed"}, + ], + } + t = CRTerm(**simple_ref_term) + term: Term = TermParser.parse(t) + assert isinstance(term, RefTerm) + assert isinstance(term.value, Ref) + assert len(term.value.parts) == 1 + assert term.value.as_string == "allowed" + + +def test_parse_complex_ref_term(): + complex_ref_term = { + "type": "ref", + "value": [ + {"type": "var", "value": "input"}, + {"type": "string", "value": "resource"}, + {"type": "string", "value": "tenant"}, + ], + } + t = CRTerm(**complex_ref_term) + term: Term = TermParser.parse(t) + assert isinstance(term, RefTerm) + assert isinstance(term.value, Ref) + assert len(term.value.parts) == 3 + assert term.value.as_string == "input.resource.tenant" + + +def test_parse_call_term(): + call_term = { + "type": "call", + "value": [ + {"type": "ref", "value": [{"type": "var", "value": "count"}]}, + { + "type": "ref", + "value": [ + {"type": "var", "value": "data"}, + {"type": "string", "value": "partial"}, + {"type": "string", "value": "example"}, + {"type": "string", "value": "rbac4"}, + {"type": "string", "value": "allowed"}, + ], + }, + ], + } + t = CRTerm(**call_term) + term: Term = TermParser.parse(t) + assert isinstance(term, CallTerm) + assert isinstance(term.value, Call) + assert isinstance(term.value.func.value, Ref) + assert term.value.func.value.as_string == "count" + assert len(term.value.args) == 1 + assert isinstance(term.value.args[0].value, Ref) + assert term.value.args[0].value.as_string == "data.partial.example.rbac4.allowed" + + +def test_parse_expression_eq(): + expr = { + "index": 0, + "terms": [ + {"type": "ref", "value": [{"type": "var", "value": "eq"}]}, + { + "type": "ref", + "value": [ + {"type": "var", "value": "input"}, + {"type": "string", "value": "resource"}, + {"type": "string", "value": "tenant"}, + ], + }, + {"type": "string", "value": "second"}, + ], + } + parsed_expr = CRExpression(**expr) + assert parsed_expr.index == 0 + assert len(parsed_expr.terms) == 3 + expression = Expression.parse(parsed_expr) + assert isinstance(expression.operator.value, Ref) + assert expression.operator.value.as_string == "eq" + assert len(expression.operands) == 2 + assert isinstance(expression.operands[0], RefTerm) + assert expression.operands[0].value.as_string == "input.resource.tenant" + assert isinstance(expression.operands[1], StringTerm) + assert expression.operands[1].value == "second" + + +def test_parse_trivial_query(): + query = Query.parse(CRQuery(__root__=[])) + assert len(query.expressions) == 0 + assert query.always_true + + +def test_parse_query(): + q = CRQuery( + __root__=[ + { + "index": 0, + "terms": [ + {"type": "ref", "value": [{"type": "var", "value": "gt"}]}, + { + "type": "ref", + "value": [ + {"type": "var", "value": "input"}, + {"type": "string", "value": "resource"}, + {"type": "string", "value": "attributes"}, + {"type": "string", "value": "age"}, + ], + }, + {"type": "number", "value": 7}, + ], + } + ] + ) + query = Query.parse(q) + assert len(query.expressions) == 1 + assert not query.always_true + + expression = query.expressions[0] + + assert isinstance(expression.operator.value, Ref) + assert expression.operator.value.as_string == "gt" + assert len(expression.operands) == 2 + assert isinstance(expression.operands[0], RefTerm) + assert expression.operands[0].value.as_string == "input.resource.attributes.age" + assert isinstance(expression.operands[1], NumberTerm) + assert expression.operands[1].value == 7 + + +def test_parse_queryset_always_false(): + response = CompileResponse(**{"result": {}}) + queryset = QuerySet.parse(response) + assert len(queryset.queries) == 0 + assert queryset.always_false + assert not queryset.always_true + assert not queryset.conditional + + +def test_parse_queryset_always_true(): + response = CompileResponse( + **{ + "result": { + "queries": [ + [ + { + "index": 0, + "terms": [ + { + "type": "ref", + "value": [{"type": "var", "value": "eq"}], + }, + { + "type": "ref", + "value": [ + {"type": "var", "value": "input"}, + {"type": "string", "value": "resource"}, + {"type": "string", "value": "tenant"}, + ], + }, + {"type": "string", "value": "default"}, + ], + } + ], + [], + ] + } + } + ) + queryset = QuerySet.parse(response) + assert queryset.always_true + assert not queryset.always_false + assert not queryset.conditional + + +def test_parse_queryset_conditional(): + response = CompileResponse( + **{ + "result": { + "queries": [ + [ + { + "index": 0, + "terms": [ + { + "type": "ref", + "value": [{"type": "var", "value": "eq"}], + }, + { + "type": "ref", + "value": [ + {"type": "var", "value": "input"}, + {"type": "string", "value": "resource"}, + {"type": "string", "value": "tenant"}, + ], + }, + {"type": "string", "value": "default"}, + ], + } + ], + [ + { + "index": 0, + "terms": [ + { + "type": "ref", + "value": [{"type": "var", "value": "eq"}], + }, + { + "type": "ref", + "value": [ + {"type": "var", "value": "input"}, + {"type": "string", "value": "resource"}, + {"type": "string", "value": "tenant"}, + ], + }, + {"type": "string", "value": "second"}, + ], + } + ], + ] + } + } + ) + queryset = QuerySet.parse(response) + assert queryset.conditional + assert not queryset.always_true + assert not queryset.always_false + + assert len(queryset.queries) == 2 + + q0 = queryset.queries[0] + assert len(q0.expressions) == 1 + assert not q0.always_true + + expression = q0.expressions[0] + + assert isinstance(expression.operator.value, Ref) + assert expression.operator.value.as_string == "eq" + assert len(expression.operands) == 2 + assert isinstance(expression.operands[0], RefTerm) + assert expression.operands[0].value.as_string == "input.resource.tenant" + assert isinstance(expression.operands[1], StringTerm) + assert expression.operands[1].value == "default" + + q1 = queryset.queries[1] + assert len(q1.expressions) == 1 + assert not q1.always_true + + expression = q1.expressions[0] + + assert isinstance(expression.operator.value, Ref) + assert expression.operator.value.as_string == "eq" + assert len(expression.operands) == 2 + assert isinstance(expression.operands[0], RefTerm) + assert expression.operands[0].value.as_string == "input.resource.tenant" + assert isinstance(expression.operands[1], StringTerm) + assert expression.operands[1].value == "second" diff --git a/lib/permit-datafilter/tests/test_ast_translation.py b/lib/permit-datafilter/tests/test_ast_translation.py new file mode 100644 index 00000000..2b475d93 --- /dev/null +++ b/lib/permit-datafilter/tests/test_ast_translation.py @@ -0,0 +1,77 @@ +from permit_datafilter.compile_api.schemas import CompileResponse +from permit_datafilter.rego_ast import parser as ast +from permit_datafilter.boolean_expression.schemas import ( + ResidualPolicyResponse, + ResidualPolicyType, +) +from permit_datafilter.boolean_expression.translator import ( + translate_opa_queryset, +) + + +def test_ast_to_boolean_expr(): + response = CompileResponse( + **{ + "result": { + "queries": [ + [ + { + "index": 0, + "terms": [ + { + "type": "ref", + "value": [{"type": "var", "value": "eq"}], + }, + { + "type": "ref", + "value": [ + {"type": "var", "value": "input"}, + {"type": "string", "value": "resource"}, + {"type": "string", "value": "tenant"}, + ], + }, + {"type": "string", "value": "default"}, + ], + } + ], + [ + { + "index": 0, + "terms": [ + { + "type": "ref", + "value": [{"type": "var", "value": "eq"}], + }, + { + "type": "ref", + "value": [ + {"type": "var", "value": "input"}, + {"type": "string", "value": "resource"}, + {"type": "string", "value": "tenant"}, + ], + }, + {"type": "string", "value": "second"}, + ], + } + ], + ] + } + } + ) + queryset = ast.QuerySet.parse(response) + residual_policy: ResidualPolicyResponse = translate_opa_queryset(queryset) + # print(json.dumps(residual_policy.dict(), indent=2)) + + assert residual_policy.type == ResidualPolicyType.CONDITIONAL + assert residual_policy.condition.expression.operator == "or" + assert len(residual_policy.condition.expression.operands) == 2 + + assert residual_policy.condition.expression.operands[0].expression.operator == "eq" + assert ( + residual_policy.condition.expression.operands[0].expression.operands[0].variable + == "input.resource.tenant" + ) + assert ( + residual_policy.condition.expression.operands[0].expression.operands[1].value + == "default" + ) diff --git a/lib/permit-datafilter/tests/test_boolean_expression_schema.py b/lib/permit-datafilter/tests/test_boolean_expression_schema.py new file mode 100644 index 00000000..ce169141 --- /dev/null +++ b/lib/permit-datafilter/tests/test_boolean_expression_schema.py @@ -0,0 +1,92 @@ +import pytest +import pydantic + +from permit_datafilter.boolean_expression.schemas import ( + ResidualPolicyResponse, + ResidualPolicyType, +) + + +def test_valid_residual_policies(): + d = { + "type": "conditional", + "condition": { + "expression": { + "operator": "or", + "operands": [ + { + "expression": { + "operator": "eq", + "operands": [ + {"variable": "input.resource.tenant"}, + {"value": "default"}, + ], + } + }, + { + "expression": { + "operator": "eq", + "operands": [ + {"variable": "input.resource.tenant"}, + {"value": "second"}, + ], + } + }, + ], + } + }, + } + policy = ResidualPolicyResponse(**d) + assert policy.type == ResidualPolicyType.CONDITIONAL + assert policy.condition.expression.operator == "or" + assert len(policy.condition.expression.operands) == 2 + + d = {"type": "always_allow", "condition": None} + policy = ResidualPolicyResponse(**d) + assert policy.type == ResidualPolicyType.ALWAYS_ALLOW + assert policy.condition == None + + d = {"type": "always_deny", "condition": None} + policy = ResidualPolicyResponse(**d) + assert policy.type == ResidualPolicyType.ALWAYS_DENY + assert policy.condition == None + + +def test_invalid_residual_policies(): + for trival_residual_type in ["always_allow", "always_deny"]: + d = { + "type": trival_residual_type, + "condition": { + "expression": { + "operator": "or", + "operands": [ + { + "expression": { + "operator": "eq", + "operands": [ + {"variable": "input.resource.tenant"}, + {"value": "default"}, + ], + } + }, + { + "expression": { + "operator": "eq", + "operands": [ + {"variable": "input.resource.tenant"}, + {"value": "second"}, + ], + } + }, + ], + } + }, + } + with pytest.raises(pydantic.ValidationError) as e: + policy = ResidualPolicyResponse(**d) + assert "invalid residual policy" in str(e.value) + + d = {"type": "conditional", "condition": None} + with pytest.raises(pydantic.ValidationError) as e: + policy = ResidualPolicyResponse(**d) + assert "invalid residual policy" in str(e.value) diff --git a/lib/permit-datafilter/tests/test_compile_parsing.py b/lib/permit-datafilter/tests/test_compile_parsing.py new file mode 100644 index 00000000..7811c095 --- /dev/null +++ b/lib/permit-datafilter/tests/test_compile_parsing.py @@ -0,0 +1,344 @@ +import json + +from permit_datafilter.compile_api.schemas import CRTerm, CompileResponse + + +COMPILE_RESPONE_RBAC_NO_SUPPORT_BLOCK = """{ + "result": { + "queries": [ + [ + { + "index": 0, + "terms": [ + { + "type": "ref", + "value": [ + { + "type": "var", + "value": "eq" + } + ] + }, + { + "type": "ref", + "value": [ + { + "type": "var", + "value": "input" + }, + { + "type": "string", + "value": "resource" + }, + { + "type": "string", + "value": "tenant" + } + ] + }, + { + "type": "string", + "value": "default" + } + ] + } + ], + [ + { + "index": 0, + "terms": [ + { + "type": "ref", + "value": [ + { + "type": "var", + "value": "eq" + } + ] + }, + { + "type": "ref", + "value": [ + { + "type": "var", + "value": "input" + }, + { + "type": "string", + "value": "resource" + }, + { + "type": "string", + "value": "tenant" + } + ] + }, + { + "type": "string", + "value": "second" + } + ] + } + ] + ] + } +} +""" + + +COMPILE_RESPONE_RBAC_WITH_SUPPORT_BLOCK = """{ + "result": { + "queries": [ + [ + { + "index": 0, + "terms": [ + { + "type": "ref", + "value": [ + { + "type": "var", + "value": "gt" + } + ] + }, + { + "type": "call", + "value": [ + { + "type": "ref", + "value": [ + { + "type": "var", + "value": "count" + } + ] + }, + { + "type": "ref", + "value": [ + { + "type": "var", + "value": "data" + }, + { + "type": "string", + "value": "partial" + }, + { + "type": "string", + "value": "example" + }, + { + "type": "string", + "value": "rbac4" + }, + { + "type": "string", + "value": "allowed" + } + ] + } + ] + }, + { + "type": "number", + "value": 0 + } + ] + } + ] + ], + "support": [ + { + "package": { + "path": [ + { + "type": "var", + "value": "data" + }, + { + "type": "string", + "value": "partial" + }, + { + "type": "string", + "value": "example" + }, + { + "type": "string", + "value": "rbac4" + } + ] + }, + "rules": [ + { + "body": [ + { + "index": 0, + "terms": [ + { + "type": "ref", + "value": [ + { + "type": "var", + "value": "eq" + } + ] + }, + { + "type": "ref", + "value": [ + { + "type": "var", + "value": "input" + }, + { + "type": "string", + "value": "resource" + }, + { + "type": "string", + "value": "tenant" + } + ] + }, + { + "type": "string", + "value": "default" + } + ] + } + ], + "head": { + "name": "allowed", + "key": { + "type": "string", + "value": "user 'asaf' has role 'editor' in tenant 'default' which grants permission 'task:read'" + }, + "ref": [ + { + "type": "var", + "value": "allowed" + } + ] + } + }, + { + "body": [ + { + "index": 0, + "terms": [ + { + "type": "ref", + "value": [ + { + "type": "var", + "value": "eq" + } + ] + }, + { + "type": "ref", + "value": [ + { + "type": "var", + "value": "input" + }, + { + "type": "string", + "value": "resource" + }, + { + "type": "string", + "value": "tenant" + } + ] + }, + { + "type": "string", + "value": "second" + } + ] + } + ], + "head": { + "name": "allowed", + "key": { + "type": "string", + "value": "user 'asaf' has role 'viewer' in tenant 'second' which grants permission 'task:read'" + }, + "ref": [ + { + "type": "var", + "value": "allowed" + } + ] + } + } + ] + } + ] + } +} +""" + + +def test_parse_compile_response_rbac_no_support_block(): + response = json.loads(COMPILE_RESPONE_RBAC_NO_SUPPORT_BLOCK) + res = CompileResponse(**response) + assert res.result.queries is not None + assert len(res.result.queries.__root__) == 2 + + first_query = res.result.queries.__root__[0] + len(first_query.__root__) == 1 + + first_query_expression = first_query.__root__[0] + assert first_query_expression.index == 0 + assert len(first_query_expression.terms) == 3 + + assert first_query_expression.terms[0].type == "ref" + assert first_query_expression.terms[1].type == "ref" + assert first_query_expression.terms[2].type == "string" + assert first_query_expression.terms[2].value == "default" + + second_query = res.result.queries.__root__[1] + len(second_query.__root__) == 1 + + second_query_expression = second_query.__root__[0] + assert second_query_expression.index == 0 + assert len(second_query_expression.terms) == 3 + + assert second_query_expression.terms[0].type == "ref" + assert second_query_expression.terms[1].type == "ref" + assert second_query_expression.terms[2].type == "string" + assert second_query_expression.terms[2].value == "second" + + +def test_parse_compile_response_rbac_with_support_block(): + response = json.loads(COMPILE_RESPONE_RBAC_WITH_SUPPORT_BLOCK) + res = CompileResponse(**response) + assert res.result.queries is not None + assert len(res.result.queries.__root__) == 1 # 1 query + + first_query = res.result.queries.__root__[0] + len(first_query.__root__) == 1 # 1 expression + + first_query_expression = first_query.__root__[0] + assert first_query_expression.index == 0 + assert len(first_query_expression.terms) == 3 + + assert first_query_expression.terms[0].type == "ref" # operator + len(first_query_expression.terms[0].value) == 1 + op = CRTerm(**first_query_expression.terms[0].value[0]) + assert op.type == "var" + assert op.value == "gt" + + assert first_query_expression.terms[1].type == "call" + len(first_query_expression.terms[1].value) == 2 + + assert first_query_expression.terms[2].type == "number" + assert first_query_expression.terms[2].value == 0 + + assert res.result.support is not None + assert len(res.result.support.__root__) == 1 # 1 support module diff --git a/lib/permit-datafilter/tests/test_data_filtering_usage.py b/lib/permit-datafilter/tests/test_data_filtering_usage.py new file mode 100644 index 00000000..caa2d1e9 --- /dev/null +++ b/lib/permit-datafilter/tests/test_data_filtering_usage.py @@ -0,0 +1,270 @@ +import pytest + +from permit_datafilter.boolean_expression.schemas import ( + ResidualPolicyResponse, +) + +from datetime import datetime + +from sqlalchemy import Column, DateTime, ForeignKey, String +from sqlalchemy.dialects import postgresql +from sqlalchemy.sql import Select +from sqlalchemy.orm import declarative_base, relationship + +from permit_datafilter.plugins.sqlalchemy import to_query + +Base = declarative_base() + + +def query_to_string(query: Select) -> str: + """ + utility function to print raw sql statement + """ + return str( + query.compile( + dialect=postgresql.dialect(), compile_kwargs={"literal_binds": True} + ) + ) + + +def striplines(s: str) -> str: + return "\n".join([line.strip() for line in s.splitlines()]) + + +def test_sql_translation_no_join(): + """ + tests residual policy to sql conversion without joins + """ + # this would be an e2e test, but harder to run with pytest + # since the api key is always changing + # --- + # token = os.environ.get("PDP_API_KEY", "") + # permit = Permit(token=token) + # filter = await permit.filter_resources("user", "read", "task") + + # another option is to mock the residual policy + filter = ResidualPolicyResponse( + **{ + "type": "conditional", + "condition": { + "expression": { + "operator": "eq", + "operands": [ + {"variable": "input.resource.tenant"}, + {"value": "082f6978-6424-4e05-a706-1ab6f26c3768"}, + ], + } + }, + } + ) + + class Task(Base): + __tablename__ = "task" + + id = Column(String, primary_key=True) + created_at = Column(DateTime, default=datetime.utcnow()) + updated_at = Column(DateTime) + description = Column(String(255)) + tenant_id = Column(String(255)) + + sa_query = to_query( + filter, + Task, + refs={ + # example how to map a column on the same model + "input.resource.tenant": Task.tenant_id, + }, + ) + + str_query = query_to_string(sa_query) + + assert striplines(str_query) == striplines( + """SELECT task.id, task.created_at, task.updated_at, task.description, task.tenant_id + FROM task + WHERE task.tenant_id = '082f6978-6424-4e05-a706-1ab6f26c3768'""" + ) + + str_query_only_columns = query_to_string(sa_query.with_only_columns(Task.id)) + + assert striplines(str_query_only_columns) == striplines( + """SELECT task.id + FROM task + WHERE task.tenant_id = '082f6978-6424-4e05-a706-1ab6f26c3768'""" + ) + + +def test_sql_translation_with_join(): + """ + tests residual policy to sql conversion without joins + """ + filter = ResidualPolicyResponse( + **{ + "type": "conditional", + "condition": { + "expression": { + "operator": "eq", + "operands": [ + {"variable": "input.resource.tenant"}, + {"value": "082f6978-6424-4e05-a706-1ab6f26c3768"}, + ], + } + }, + } + ) + + class Tenant(Base): + __tablename__ = "tenant" + + id = Column(String, primary_key=True) + key = Column(String(255)) + + class TaskJoined(Base): + __tablename__ = "task_joined" + + id = Column(String, primary_key=True) + created_at = Column(DateTime, default=datetime.utcnow()) + updated_at = Column(DateTime) + description = Column(String(255)) + tenant_id_joined = Column(String, ForeignKey("tenant.id")) + tenant = relationship("Tenant", backref="tasks") + + sa_query = to_query( + filter, + TaskJoined, + refs={ + # example how to map a column on a related model + "input.resource.tenant": Tenant.key, + }, + join_conditions=[(Tenant, TaskJoined.tenant_id_joined == Tenant.id)], + ) + + str_query = query_to_string(sa_query) + + assert striplines(str_query) == striplines( + """SELECT task_joined.id, task_joined.created_at, task_joined.updated_at, task_joined.description, task_joined.tenant_id_joined + FROM task_joined JOIN tenant ON task_joined.tenant_id_joined = tenant.id + WHERE tenant.key = '082f6978-6424-4e05-a706-1ab6f26c3768'""" + ) + + +def test_sql_translation_of_trivial_policies(): + class Tasks(Base): + __tablename__ = "tasks" + + id = Column(String, primary_key=True) + created_at = Column(DateTime, default=datetime.utcnow()) + updated_at = Column(DateTime) + description = Column(String(255)) + tenant_id = Column(String(255)) + + filter = ResidualPolicyResponse(**{"type": "always_allow", "condition": None}) + + sa_query = to_query( + filter, + Tasks, + refs={ + # example how to map a column on the same model + "input.resource.tenant": Tasks.tenant_id, + }, + ) + + str_query = query_to_string(sa_query) + assert striplines(str_query) == striplines( + """SELECT tasks.id, tasks.created_at, tasks.updated_at, tasks.description, tasks.tenant_id + FROM tasks""" + ) # this query would always return all rows from the tasks table + + filter = ResidualPolicyResponse(**{"type": "always_deny", "condition": None}) + + sa_query = to_query( + filter, + Tasks, + refs={ + # example how to map a column on the same model + "input.resource.tenant": Tasks.tenant_id, + }, + ) + + str_query = query_to_string(sa_query) + assert striplines(str_query) == striplines( + """SELECT tasks.id, tasks.created_at, tasks.updated_at, tasks.description, tasks.tenant_id + FROM tasks + WHERE false""" + ) # this query would never have any results + + +def test_missing_joins(): + filter = ResidualPolicyResponse( + **{ + "type": "conditional", + "condition": { + "expression": { + "operator": "eq", + "operands": [ + {"variable": "input.resource.tenant"}, + {"value": "082f6978-6424-4e05-a706-1ab6f26c3768"}, + ], + } + }, + } + ) + + class User2(Base): + __tablename__ = "user2" + + id = Column(String, primary_key=True) + username = Column(String(255)) + + class Tenant2(Base): + __tablename__ = "tenant2" + + id = Column(String, primary_key=True) + key = Column(String(255)) + + class TaskJoined2(Base): + __tablename__ = "task_joined2" + + id = Column(String, primary_key=True) + created_at = Column(DateTime, default=datetime.utcnow()) + updated_at = Column(DateTime) + description = Column(String(255)) + tenant_id_joined = Column(String, ForeignKey("tenant2.id")) + tenant = relationship("Tenant2", backref="tasks") + owner_id = Column(String, ForeignKey("user2.id")) + owner = relationship("User2", backref="tasks") + + with pytest.raises(TypeError) as e: + # Tenant2.key is a column outside the main table (requires a join) + # if we don't provide any join conditions, to_query() will throw a TypeError + sa_query = to_query( + filter, + TaskJoined2, + refs={ + # example how to map a column on a related model + "input.resource.tenant": Tenant2.key, + }, + ) + + assert ( + str(e.value) + == "You must call QueryBuilder.join(table, condition) to map residual references to other SQL tables" + ) + + with pytest.raises(TypeError) as e: + # Tenant2.key is a column outside the main table (requires a join) + # if we provide join conditions but not to all required tables, + # to_query() will throw a different TypeError + sa_query = to_query( + filter, + TaskJoined2, + refs={ + # example how to map a column on a related model + "input.resource.tenant": Tenant2.key, + }, + join_conditions=[(User2, TaskJoined2.owner_id == User2.id)], + ) + + assert ( + str(e.value) + == "QueryBuilder.join() was not called for these SQL tables: {'tenant2'}" + ) diff --git a/requirements.txt b/requirements.txt index b5cccf2c..ffcb6b80 100644 --- a/requirements.txt +++ b/requirements.txt @@ -14,3 +14,4 @@ httpx>=0.27.0,<1 protobuf>=3.20.2 # not directly required, pinned by Snyk to avoid a vulnerability opal-common==0.7.6 opal-client==0.7.6 +permit-datafilter>=0.0.3,<1 diff --git a/setup.py b/setup.py index de7da89e..eb0ed177 100644 --- a/setup.py +++ b/setup.py @@ -28,7 +28,7 @@ def get_data_files(root_directory: str): setup( name="horizon", version="0.2.0", - packages=find_packages(), + packages=find_packages(exclude=("lib/*")), python_requires=">=3.8", include_package_data=True, data_files=get_data_files("horizon/static"),