From b5faae6a7fbbbbbe1aee70a0552044179a7b9bff Mon Sep 17 00:00:00 2001 From: JulienDavat Date: Wed, 19 Jan 2022 17:13:26 +0100 Subject: [PATCH] (Values/Filter)-push-down optimization --- sage/cli/debug.py | 2 +- sage/cli/explain.py | 7 +- sage/database/core/dataset.py | 14 +- sage/database/core/yaml_config.py | 23 +- sage/http_server/server.py | 16 +- sage/query_engine/iterators/filter.py | 10 +- sage/query_engine/iterators/nlj.py | 6 +- .../iterators/preemptable_iterator.py | 2 +- sage/query_engine/iterators/projection.py | 7 +- sage/query_engine/iterators/scan.py | 2 +- sage/query_engine/iterators/union.py | 6 +- sage/query_engine/iterators/values.py | 9 +- .../optimizer/logical/optimizer.py | 3 +- .../optimizer/logical/plan_visitor.py | 5 + .../visitors/filter_variables_extractor.py | 47 +++ .../logical/visitors/pipeline_builder.py | 6 + sage/query_engine/optimizer/optimizer.py | 6 +- .../optimizer/physical/optimizer.py | 9 +- .../optimizer/physical/plan_visitor.py | 18 +- .../physical/visitors/cost_estimator.py | 4 +- .../physical/visitors/filter_push_down.py | 288 ++++++------------ .../physical/visitors/values_push_down.py | 214 +++++++++++++ 22 files changed, 467 insertions(+), 237 deletions(-) create mode 100644 sage/query_engine/optimizer/logical/visitors/filter_variables_extractor.py create mode 100644 sage/query_engine/optimizer/physical/visitors/values_push_down.py diff --git a/sage/cli/debug.py b/sage/cli/debug.py index c7b51db..7908fca 100644 --- a/sage/cli/debug.py +++ b/sage/cli/debug.py @@ -82,7 +82,7 @@ def sage_query_debug(config_file, default_graph_uri, query, file, limit): exit(1) logical_plan = Parser.parse(query) - iterator, cardinalities = Optimizer.get_default().optimize( + iterator, cardinalities = Optimizer.get_default(dataset).optimize( logical_plan, dataset, default_graph_uri ) # iterator, cards = parse_query(query, dataset, default_graph_uri) diff --git a/sage/cli/explain.py b/sage/cli/explain.py index 5ed7ecf..b4337b8 100644 --- a/sage/cli/explain.py +++ b/sage/cli/explain.py @@ -119,7 +119,7 @@ def explain( print(pprintAlgebra(tq)) logical_plan = Parser.parse(query) - iterator, cardinalities = Optimizer.get_default().optimize( + iterator, cardinalities = Optimizer.get_default(dataset).optimize( logical_plan, dataset, graph_uri ) # iterator, cards = parse_query(query, dataset, graph_uri) @@ -132,6 +132,11 @@ def explain( with open(output, 'w') as outfile: outfile.write(QueryPlanStringifier().visit(iterator)) + print("-----------------") + print("Optimized query") + print("-----------------") + print(QueryPlanStringifier().visit(iterator)) + print("-----------------") print("Cardinalities") print("-----------------") diff --git a/sage/database/core/dataset.py b/sage/database/core/dataset.py index 46daea9..259fbce 100644 --- a/sage/database/core/dataset.py +++ b/sage/database/core/dataset.py @@ -23,7 +23,9 @@ def __init__( public_url: Optional[str] = None, default_query: Optional[str] = None, analytics=None, - stateless=True + stateless=True, + filter_push_down=True, + values_push_down=True ): super(Dataset, self).__init__() self._name = name @@ -33,6 +35,8 @@ def __init__( self._default_query = default_query self._analytics = analytics self._stateless = stateless + self._filter_push_down = filter_push_down + self._values_push_down = values_push_down self._force_order = False @property @@ -43,6 +47,14 @@ def name(self) -> str: def is_stateless(self) -> bool: return self._stateless + @property + def do_filter_push_down(self) -> bool: + return self._filter_push_down + + @property + def do_values_push_down(self) -> bool: + return self._values_push_down + @property def force_order(self) -> bool: return self._force_order diff --git a/sage/database/core/yaml_config.py b/sage/database/core/yaml_config.py index 159a15d..a69367d 100644 --- a/sage/database/core/yaml_config.py +++ b/sage/database/core/yaml_config.py @@ -49,13 +49,6 @@ def load_config(config_file: str) -> Dataset: else: stateless = True - # # if statefull, load the saved plan storage backend to use - # statefull_manager = None - # if not is_stateless: - # # TODO allow use of custom backend for saved plans - # # same kind of usage than custom DB backends - # statefull_manager = HashMapManager() - # get default time quantum & maximum number of results per page if 'quota' in config: if config['quota'] == 'inf': @@ -71,7 +64,17 @@ def load_config(config_file: str) -> Dataset: logging.warning("You are using SaGe without limitations on the number of results sent per page. This is fine, but be carefull as very large page of results can have unexpected serialization time.") max_results = inf - # build all RDF graphs found in the configuration file + # load debug parameters + if 'filter_push_down' in config: + filter_push_down = config['filter_push_down'] + else: + filter_push_down = True + if 'values_push_down' in config: + values_push_down = config['values_push_down'] + else: + values_push_down = True + + # build all RDF graphs found in the configuration file graphs = dict() if "graphs" not in config: raise SyntaxError("Np RDF graphs found in the configuration file. Please refers to the documentation to see how to declare RDF graphs in a SaGe YAML configuration file.") @@ -102,5 +105,7 @@ def load_config(config_file: str) -> Dataset: public_url=public_url, default_query=default_query, analytics=analytics, - stateless=stateless + stateless=stateless, + filter_push_down=filter_push_down, + values_push_down=values_push_down ) diff --git a/sage/http_server/server.py b/sage/http_server/server.py index 98d60c6..d24a55b 100644 --- a/sage/http_server/server.py +++ b/sage/http_server/server.py @@ -25,7 +25,6 @@ from sage.query_engine.optimizer.parser import Parser from sage.query_engine.optimizer.optimizer import Optimizer from sage.query_engine.optimizer.physical.visitors.query_plan_stringifier import QueryPlanStringifier -# from sage.query_engine.optimizer.query_parser import parse_query from sage.database.saved_plan.saved_plan_manager import SavedPlanManager from sage.database.saved_plan.stateless_manager import StatelessManager from sage.database.saved_plan.statefull_manager import StatefullManager @@ -83,6 +82,8 @@ async def execute_query( raise HTTPException(status_code=404, detail=f"RDF Graph {default_graph_uri} not found on the server.") graph = dataset.get_graph(default_graph_uri) + optimizer = Optimizer.get_default(dataset) + # decode next_link or build query execution plan cardinalities = dict() loadin_start = time() @@ -92,7 +93,7 @@ async def execute_query( else: start_timestamp = datetime.now() logical_plan = Parser.parse(query) - plan, cardinalities = Optimizer.get_default().optimize( + plan, cardinalities = optimizer.optimize( logical_plan, dataset, default_graph_uri, as_of=start_timestamp ) # plan, cardinalities = parse_query(query, dataset, default_graph_uri) @@ -131,8 +132,8 @@ async def execute_query( "metrics": { "progression": coverage_after, "coverage": coverage_after - coverage_before, - "cost": Optimizer.get_default().cost(plan), - "cardinality": Optimizer.get_default().cardinality(plan) + "cost": optimizer.cost(plan), + "cardinality": optimizer.cardinality(plan) } } print(stats['metrics']) @@ -149,17 +150,18 @@ async def explain_query( query: str, default_graph_uri: str, next_link: Optional[str], dataset: Dataset ) -> str: + optimizer = Optimizer.get_default(dataset) if next_link is not None: plan = StatelessManager().get_plan(next_link, dataset) else: logical_plan = Parser.parse(query) - plan, cardinalities = Optimizer.get_default().optimize( + plan, cardinalities = optimizer.optimize( logical_plan, dataset, default_graph_uri ) return JSONResponse({ "query": QueryPlanStringifier().visit(plan), - "cost": Optimizer.get_default().cost(plan), - "cardinality": Optimizer.get_default().cardinality(plan) + "cost": optimizer.cost(plan), + "cardinality": optimizer.cardinality(plan) }) diff --git a/sage/query_engine/iterators/filter.py b/sage/query_engine/iterators/filter.py index 41e5fa1..c6c9218 100644 --- a/sage/query_engine/iterators/filter.py +++ b/sage/query_engine/iterators/filter.py @@ -1,6 +1,6 @@ # filter.py # Author: Thomas MINIER - MIT License 2017-2020 -from typing import Dict, Optional, Union, Set, Any +from typing import Dict, Optional, Union, Set, Any, List from rdflib.term import Literal, URIRef, Variable from rdflib.plugins.sparql.parserutils import Expr from rdflib.plugins.sparql.sparql import Bindings, QueryContext @@ -9,6 +9,7 @@ from sage.query_engine.iterators.preemptable_iterator import PreemptableIterator from sage.query_engine.protobuf.iterators_pb2 import SavedFilterIterator from sage.query_engine.protobuf.utils import pyDict_to_protoDict +from sage.query_engine.optimizer.logical.visitors.filter_variables_extractor import FilterVariablesExtractor def to_rdflib_term(value: str) -> Union[Literal, URIRef, Variable]: @@ -63,8 +64,11 @@ def explain(self, height: int = 0, step: int = 3) -> None: print(f'{prefix}FilterIterator <{str(self._expression.vars)}>') self._source.explain(height=(height + step), step=step) - def variables(self) -> Set[str]: - return self._source.variables() + def constrained_variables(self) -> List[str]: + return FilterVariablesExtractor().visit(self._expression) + + def variables(self, include_values: bool = False) -> Set[str]: + return self._source.variables(include_values=include_values) def __evaluate__(self, mappings: Dict[str, str]) -> bool: """Evaluate the FILTER expression with a set mappings. diff --git a/sage/query_engine/iterators/nlj.py b/sage/query_engine/iterators/nlj.py index a7ef9c8..88fec99 100644 --- a/sage/query_engine/iterators/nlj.py +++ b/sage/query_engine/iterators/nlj.py @@ -41,8 +41,10 @@ def explain(self, height: int = 0, step: int = 3) -> None: self._left.explain(height=(height + step), step=step) self._right.explain(height=(height + step), step=step) - def variables(self) -> Set[str]: - return self._left.variables().union(self._right.variables()) + def variables(self, include_values: bool = False) -> Set[str]: + return self._left.variables(include_values=include_values).union( + self._right.variables(include_values=include_values) + ) def next_stage(self, mappings: Dict[str, str]): """Propagate mappings to the bottom of the pipeline in order to compute nested loop joins""" diff --git a/sage/query_engine/iterators/preemptable_iterator.py b/sage/query_engine/iterators/preemptable_iterator.py index d719c2a..5c5c32a 100644 --- a/sage/query_engine/iterators/preemptable_iterator.py +++ b/sage/query_engine/iterators/preemptable_iterator.py @@ -18,7 +18,7 @@ def explain(self, height: int = 0, step: int = 3) -> None: pass @abstractmethod - def variables(self) -> Set[str]: + def variables(self, include_values: bool = False) -> Set[str]: """Return the domain of the iterator""" pass diff --git a/sage/query_engine/iterators/projection.py b/sage/query_engine/iterators/projection.py index 2cb2291..a1734df 100644 --- a/sage/query_engine/iterators/projection.py +++ b/sage/query_engine/iterators/projection.py @@ -34,8 +34,11 @@ def explain(self, height: int = 0, step: int = 3) -> None: print(f'{prefix}ProjectionIterator SELECT {self._projection}') self._source.explain(height=(height + step), step=step) - def variables(self) -> Set[str]: - return set(self._projection) + def variables(self, include_values: bool = False) -> Set[str]: + if self._projection is None: + return self._source.variables(include_values=include_values) + else: + return set(self._projection) def next_stage(self, mappings: Dict[str, str]): """Propagate mappings to the bottom of the pipeline in order to compute nested loop joins""" diff --git a/sage/query_engine/iterators/scan.py b/sage/query_engine/iterators/scan.py index 9dac652..51747da 100644 --- a/sage/query_engine/iterators/scan.py +++ b/sage/query_engine/iterators/scan.py @@ -88,7 +88,7 @@ def explain(self, height: int = 0, step: int = 3) -> None: object = self._pattern['object'] print(f'{prefix}ScanIterator <({subject} {predicate} {object})>') - def variables(self) -> Set[str]: + def variables(self, include_values: bool = False) -> Set[str]: vars = set() if self._pattern['subject'].startswith('?'): vars.add(self._pattern['subject']) diff --git a/sage/query_engine/iterators/union.py b/sage/query_engine/iterators/union.py index 71cea97..09aee00 100644 --- a/sage/query_engine/iterators/union.py +++ b/sage/query_engine/iterators/union.py @@ -40,8 +40,10 @@ def explain(self, height: int = 0, step: int = 3) -> None: self._left.explain(height=(height + step), step=step) self._right.explain(height=(height + step), step=step) - def variables(self) -> Set[str]: - return self._left.variables().union(self._right.variables()) + def variables(self, include_values: bool = False) -> Set[str]: + return self._left.variables(include_values=include_values).union( + self._right.variables(include_values=include_values) + ) def next_stage(self, mappings: Dict[str, str]): """Propagate mappings to the bottom of the pipeline in order to compute nested loop joins""" diff --git a/sage/query_engine/iterators/values.py b/sage/query_engine/iterators/values.py index 03eb563..68ab25b 100644 --- a/sage/query_engine/iterators/values.py +++ b/sage/query_engine/iterators/values.py @@ -22,7 +22,7 @@ def __len__(self) -> int: return len(self._values) def __repr__(self) -> str: - return f"" + return f"" def serialized_name(self): """Get the name of the iterator, as used in the plan serialization protocol""" @@ -33,10 +33,10 @@ def explain(self, height: int = 0, step: int = 3) -> None: if height > step: prefix = ('|' + (' ' * (step - 1))) * (int(height / step) - 1) prefix += ('|' + ('-' * (step - 1))) - print(f'{prefix}ValuesIterator <{self._values}>') + print(f'{prefix}ValuesIterator <{self.variables()}>') - def variables(self) -> Set[str]: - return set(self._values[0].keys()) + def variables(self, include_values: bool = True) -> Set[str]: + return set(self._values[0].keys()) if include_values else set() def next_stage(self, mappings: Dict[str, str]): self._current_mappings = mappings @@ -55,7 +55,6 @@ async def next(self, context: Dict[str, Any] = {}) -> Optional[Dict[str, str]]: mappings = {**self._current_mappings, **mu} else: mappings = mu - print(mappings) return mappings def save(self) -> SavedValuesIterator: diff --git a/sage/query_engine/optimizer/logical/optimizer.py b/sage/query_engine/optimizer/logical/optimizer.py index b31e407..8fb1948 100644 --- a/sage/query_engine/optimizer/logical/optimizer.py +++ b/sage/query_engine/optimizer/logical/optimizer.py @@ -1,5 +1,6 @@ from __future__ import annotations +from sage.database.core.dataset import Dataset from sage.query_engine.optimizer.logical.plan_visitor import LogicalPlanVisitor, Node from sage.query_engine.optimizer.logical.visitors.filter_splitter import FilterSplitter @@ -10,7 +11,7 @@ def __init__(self): self._visitors = [] @staticmethod - def get_default() -> LogicalPlanOptimizer: + def get_default(dataset: Dataset) -> LogicalPlanOptimizer: optimizer = LogicalPlanOptimizer() optimizer.add_visitor(FilterSplitter()) return optimizer diff --git a/sage/query_engine/optimizer/logical/plan_visitor.py b/sage/query_engine/optimizer/logical/plan_visitor.py index 715728f..79726f0 100644 --- a/sage/query_engine/optimizer/logical/plan_visitor.py +++ b/sage/query_engine/optimizer/logical/plan_visitor.py @@ -77,6 +77,8 @@ def visit_expression(self, node: Expr) -> Any: return self.visit_conditional_or_expression(node) elif node.name == 'RelationalExpression': return self.visit_relational_expression(node) + elif node.name == 'AdditiveExpression': + return self.visit_additive_expression(node) elif node.name == 'Builtin_REGEX': return self.visit_regex_expression(node) elif node.name == 'Builtin_NOTEXISTS': @@ -136,6 +138,9 @@ def visit_conditional_or_expression(self, node: Expr) -> Any: def visit_relational_expression(self, node: Expr) -> Any: raise UnsupportedSPARQL(f'The {node.name} expressions are not implemented') + def visit_additive_expression(self, node: Expr) -> Any: + raise UnsupportedSPARQL(f'The {node.name} expressions are not implemented') + def visit_regex_expression(self, node: Expr) -> Any: raise UnsupportedSPARQL(f'The {node.name} expressions are not implemented') diff --git a/sage/query_engine/optimizer/logical/visitors/filter_variables_extractor.py b/sage/query_engine/optimizer/logical/visitors/filter_variables_extractor.py new file mode 100644 index 0000000..26161a6 --- /dev/null +++ b/sage/query_engine/optimizer/logical/visitors/filter_variables_extractor.py @@ -0,0 +1,47 @@ +from typing import Set +from rdflib.term import Variable +from rdflib.plugins.sparql.parserutils import Expr + +from sage.query_engine.optimizer.logical.plan_visitor import LogicalPlanVisitor, RDFTerm + + +class FilterVariablesExtractor(LogicalPlanVisitor): + + def visit_rdfterm(self, node: RDFTerm) -> Set[str]: + if isinstance(node, Variable): + return set([node.n3()]) + else: + return set() + + def visit_conditional_and_expression(self, node: Expr) -> Set[str]: + variables = self.visit(node.expr) + for other in node.other: + variables.update(self.visit(other)) + return variables + + def visit_conditional_or_expression(self, node: Expr) -> Set[str]: + variables = self.visit(node.expr) + for other in node.other: + variables.update(self.visit(other)) + return variables + + def visit_relational_expression(self, node: Expr) -> Set[str]: + return self.visit(node.expr) + + def visit_additive_expression(self, node: Expr) -> Set[str]: + variables = self.visit(node.expr) + for other in node.other: + variables.update(self.visit(other)) + return variables + + def visit_regex_expression(self, node: Expr) -> Set[str]: + return self.visit(node.text) + + def visit_not_exists_expression(self, node: Expr) -> Set[str]: + return self.visit(node.expr) + + def visit_str_expression(self, node: Expr) -> Set[str]: + return self.visit(node.arg) + + def visit_unary_not_expression(self, node: Expr) -> Set[str]: + return self.visit(node.expr) diff --git a/sage/query_engine/optimizer/logical/visitors/pipeline_builder.py b/sage/query_engine/optimizer/logical/visitors/pipeline_builder.py index 4a3f1d5..aa51577 100644 --- a/sage/query_engine/optimizer/logical/visitors/pipeline_builder.py +++ b/sage/query_engine/optimizer/logical/visitors/pipeline_builder.py @@ -54,6 +54,12 @@ def visit_unary_not_expression(self, node: Expr) -> str: def visit_str_expression(self, node: Expr) -> str: return f'str({self.visit(node.arg)})' + def visit_additive_expression(self, node: Expr) -> str: + expression = self.visit(node.expr) + for index, operator in enumerate(node.op): + expression += f' {operator} {self.visit(node.other[index])}' + return f'({expression})' + class PipelineBuilder(LogicalPlanVisitor): diff --git a/sage/query_engine/optimizer/optimizer.py b/sage/query_engine/optimizer/optimizer.py index 1cd710d..5353c7c 100644 --- a/sage/query_engine/optimizer/optimizer.py +++ b/sage/query_engine/optimizer/optimizer.py @@ -19,10 +19,10 @@ def __init__(self): self._physical_optimizer = None @staticmethod - def get_default() -> Optimizer: + def get_default(dataset: Dataset) -> Optimizer: optimizer = Optimizer() - optimizer.set_logical_optimizer(LogicalPlanOptimizer.get_default()) - optimizer.set_physical_optimizer(PhysicalPlanOptimizer.get_default()) + optimizer.set_logical_optimizer(LogicalPlanOptimizer.get_default(dataset)) + optimizer.set_physical_optimizer(PhysicalPlanOptimizer.get_default(dataset)) return optimizer def set_logical_optimizer(self, optimizer: LogicalPlanOptimizer) -> None: diff --git a/sage/query_engine/optimizer/physical/optimizer.py b/sage/query_engine/optimizer/physical/optimizer.py index 58ed5b4..7759c46 100644 --- a/sage/query_engine/optimizer/physical/optimizer.py +++ b/sage/query_engine/optimizer/physical/optimizer.py @@ -1,8 +1,10 @@ from __future__ import annotations +from sage.database.core.dataset import Dataset from sage.query_engine.iterators.preemptable_iterator import PreemptableIterator from sage.query_engine.optimizer.physical.plan_visitor import PhysicalPlanVisitor from sage.query_engine.optimizer.physical.visitors.filter_push_down import FilterPushDown +from sage.query_engine.optimizer.physical.visitors.values_push_down import ValuesPushDown class PhysicalPlanOptimizer(): @@ -11,9 +13,12 @@ def __init__(self): self._visitors = [] @staticmethod - def get_default() -> PhysicalPlanOptimizer: + def get_default(dataset: Dataset) -> PhysicalPlanOptimizer: optimizer = PhysicalPlanOptimizer() - optimizer.add_visitor(FilterPushDown()) + if dataset.do_filter_push_down: + optimizer.add_visitor(FilterPushDown()) + if dataset.do_values_push_down: + optimizer.add_visitor(ValuesPushDown()) return optimizer def add_visitor(self, visitor: PhysicalPlanVisitor) -> None: diff --git a/sage/query_engine/optimizer/physical/plan_visitor.py b/sage/query_engine/optimizer/physical/plan_visitor.py index 1249dec..a733833 100644 --- a/sage/query_engine/optimizer/physical/plan_visitor.py +++ b/sage/query_engine/optimizer/physical/plan_visitor.py @@ -3,6 +3,12 @@ from sage.query_engine.exceptions import UnsupportedSPARQL from sage.query_engine.iterators.preemptable_iterator import PreemptableIterator +from sage.query_engine.iterators.projection import ProjectionIterator +from sage.query_engine.iterators.values import ValuesIterator +from sage.query_engine.iterators.filter import FilterIterator +from sage.query_engine.iterators.nlj import IndexJoinIterator +from sage.query_engine.iterators.union import BagUnionIterator +from sage.query_engine.iterators.scan import ScanIterator class PhysicalPlanVisitor(ABC): @@ -26,20 +32,20 @@ def visit(self, node: PreemptableIterator, context: Dict[str, Any] = {}) -> Any: else: raise UnsupportedSPARQL(f'Unsupported SPARQL iterator: {node.serialized_name()}') - def visit_projection(self, node: PreemptableIterator, context: Dict[str, Any] = {}) -> Any: + def visit_projection(self, node: ProjectionIterator, context: Dict[str, Any] = {}) -> Any: raise UnsupportedSPARQL(f'The {node.serialized_name()} iterator is not implemented') - def visit_values(self, node: PreemptableIterator, context: Dict[str, Any] = {}) -> Any: + def visit_values(self, node: ValuesIterator, context: Dict[str, Any] = {}) -> Any: raise UnsupportedSPARQL(f'The {node.serialized_name()} iterator is not implemented') - def visit_filter(self, node: PreemptableIterator, context: Dict[str, Any] = {}) -> Any: + def visit_filter(self, node: FilterIterator, context: Dict[str, Any] = {}) -> Any: raise UnsupportedSPARQL(f'The {node.serialized_name()} iterator is not implemented') - def visit_join(self, node: PreemptableIterator, context: Dict[str, Any] = {}) -> Any: + def visit_join(self, node: IndexJoinIterator, context: Dict[str, Any] = {}) -> Any: raise UnsupportedSPARQL(f'The {node.serialized_name()} iterator is not implemented') - def visit_union(self, node: PreemptableIterator, context: Dict[str, Any] = {}) -> Any: + def visit_union(self, node: BagUnionIterator, context: Dict[str, Any] = {}) -> Any: raise UnsupportedSPARQL(f'The {node.serialized_name()} iterator is not implemented') - def visit_scan(self, node: PreemptableIterator, context: Dict[str, Any] = {}) -> Any: + def visit_scan(self, node: ScanIterator, context: Dict[str, Any] = {}) -> Any: raise UnsupportedSPARQL(f'The {node.serialized_name()} iterator is not implemented') diff --git a/sage/query_engine/optimizer/physical/visitors/cost_estimator.py b/sage/query_engine/optimizer/physical/visitors/cost_estimator.py index e69d495..3921efc 100644 --- a/sage/query_engine/optimizer/physical/visitors/cost_estimator.py +++ b/sage/query_engine/optimizer/physical/visitors/cost_estimator.py @@ -40,7 +40,7 @@ def visit_values( output_size = input_size * len(node._values) context['input-size'] = output_size print( - f'Card({node._values}) = {input_size} x {len(node._values)} = ' + + f'Cout({list(node._values[0].keys())}) = {input_size} x {len(node._values)} = ' + f'{output_size}' ) return output_size @@ -58,7 +58,7 @@ def visit_filter( output_size = input_size * selectivity context['input-size'] = output_size print( - f'Card({node._raw_expression}) = ' + + f'Cout({node._raw_expression}) = ' + f'{input_size} x {selectivity} = {output_size}' ) return input_size diff --git a/sage/query_engine/optimizer/physical/visitors/filter_push_down.py b/sage/query_engine/optimizer/physical/visitors/filter_push_down.py index af26996..f399c7e 100644 --- a/sage/query_engine/optimizer/physical/visitors/filter_push_down.py +++ b/sage/query_engine/optimizer/physical/visitors/filter_push_down.py @@ -1,118 +1,89 @@ from hashlib import md5 -from typing import Set, Tuple, List, Dict, Any -from rdflib.term import Variable -from rdflib.plugins.sparql.parserutils import Expr +from typing import Tuple, List, Dict, Any, Union from sage.query_engine.optimizer.physical.plan_visitor import PhysicalPlanVisitor -from sage.query_engine.optimizer.logical.plan_visitor import LogicalPlanVisitor, RDFTerm from sage.query_engine.iterators.preemptable_iterator import PreemptableIterator +from sage.query_engine.iterators.projection import ProjectionIterator +from sage.query_engine.iterators.values import ValuesIterator from sage.query_engine.iterators.filter import FilterIterator +from sage.query_engine.iterators.nlj import IndexJoinIterator +from sage.query_engine.iterators.union import BagUnionIterator +from sage.query_engine.iterators.scan import ScanIterator SOURCE = 0 LEFT = 1 RIGHT = 2 -class FilterVariablesExtractor(LogicalPlanVisitor): - - def visit_rdfterm(self, node: RDFTerm) -> Set[str]: - if isinstance(node, Variable): - return set([node.n3()]) - else: - return set() - - def visit_conditional_and_expression(self, node: Expr) -> Set[str]: - variables = self.visit(node.expr) - for other in node.other: - variables.update(self.visit(other)) - return variables - - def visit_conditional_or_expression(self, node: Expr) -> Set[str]: - variables = self.visit(node.expr) - for other in node.other: - variables.update(self.visit(other)) - return variables - - def visit_relational_expression(self, node: Expr) -> Set[str]: - return self.visit(node.expr) - - def visit_regex_expression(self, node: Expr) -> Set[str]: - return self.visit(node.text) - - def visit_not_exists_expression(self, node: Expr) -> Set[str]: - return self.visit(node.expr) - - def visit_str_expression(self, node: Expr) -> Set[str]: - return self.visit(node.arg) - - def visit_unary_not_expression(self, node: Expr) -> Set[str]: - return self.visit(node.expr) - - class FilterTargets(PhysicalPlanVisitor): - def __init__(self, filter: FilterIterator): - super().__init__() - if filter._expression.vars is None: - variables = FilterVariablesExtractor().visit(filter._expression) - filter._expression.vars = variables - self._constrained_variables = filter._expression.vars - - def visit_projection(self, node: PreemptableIterator, context: Dict[str, Any] = {}) -> List[Tuple[PreemptableIterator, int]]: - # can be moved at a lower level than the current iterator - targets = self.visit(node._source) + def visit_projection( + self, node: ProjectionIterator, context: Dict[str, Any] = {} + ) -> List[Tuple[PreemptableIterator, int]]: + # can be moved at a lower level than the current iterator's children + targets = self.visit(node._source, context=context) if len(targets) > 0: return targets # can be a child of the current iterator - if node._source.variables().issuperset(self._constrained_variables): + if node._source.variables().issuperset(context['variables']): return [(node, SOURCE)] # projection is the top iterator, something wrong... raise Exception('Malformed FILTER clause') - def visit_filter(self, node: PreemptableIterator, context: Dict[str, Any] = {}) -> List[Tuple[PreemptableIterator, int]]: - # can be moved at a lower level than the current iterator - targets = self.visit(node._source) + def visit_filter( + self, node: FilterIterator, context: Dict[str, Any] = {} + ) -> List[Tuple[PreemptableIterator, int]]: + # can be moved at a lower level than the current iterator's children + targets = self.visit(node._source, context=context) if len(targets) > 0: return targets # can be a child of the current iterator - if node._source.variables().issuperset(self._constrained_variables): + if node._source.variables().issuperset(context['variables']): return [(node, SOURCE)] # cannot be moved after this iterator return [] - def visit_join(self, node: PreemptableIterator, context: Dict[str, Any] = {}) -> List[Tuple[PreemptableIterator, int]]: - # can be moved at a lower level than the current iterator - targets = self.visit(node._left) + self.visit(node._right) + def visit_join( + self, node: IndexJoinIterator, context: Dict[str, Any] = {} + ) -> List[Tuple[PreemptableIterator, int]]: + # can be moved at a lower level than the current iterator's children + targets = self.visit(node._left, context=context) + if len(targets) > 0: + return targets + targets = self.visit(node._right, context=context) if len(targets) > 0: return targets # can be a child of the current iterator - if node._left.variables().issuperset(self._constrained_variables): + if node._left.variables().issuperset(context['variables']): return [(node, LEFT)] - if node._right.variables().issuperset(self._constrained_variables): + if node._right.variables().issuperset(context['variables']): return [(node, RIGHT)] # cannot be moved after this iterator return [] - def visit_union(self, node: PreemptableIterator, context: Dict[str, Any] = {}) -> List[Tuple[PreemptableIterator, int]]: - # can be moved at a lower level than the current iterator - targets = self.visit(node._left) + self.visit(node._right) + def visit_union( + self, node: BagUnionIterator, context: Dict[str, Any] = {} + ) -> List[Tuple[PreemptableIterator, int]]: + # can be moved at a lower level than the current iterator's children + targets = self.visit(node._left, context=context) + self.visit(node._right, context=context) if len(targets) > 0: return targets # can be a child of the current iterator targets = [] - if node._left.variables().issuperset(self._constrained_variables): + if node._left.variables().issuperset(context['variables']): targets.append((node, LEFT)) - if node._right.variables().issuperset(self._constrained_variables): + if node._right.variables().issuperset(context['variables']): targets.append((node, RIGHT)) - if len(targets) > 0: - return targets - # cannot be moved after this iterator return targets - def visit_values(self, node: PreemptableIterator, context: Dict[str, Any] = {}) -> List[Tuple[PreemptableIterator, int]]: + def visit_values( + self, node: ValuesIterator, context: Dict[str, Any] = {} + ) -> List[Tuple[PreemptableIterator, int]]: return [] # cannot be moved after a leaf iterator - def visit_scan(self, node: PreemptableIterator, context: Dict[str, Any] = {}) -> List[Tuple[PreemptableIterator, int]]: + def visit_scan( + self, node: ScanIterator, context: Dict[str, Any] = {} + ) -> List[Tuple[PreemptableIterator, int]]: return [] # cannot be moved after a leaf iterator @@ -122,7 +93,7 @@ def __init__(self): super().__init__() self._moved = dict() - def __has_already_been_moved__(self, filter: PreemptableIterator) -> bool: + def __has_already_been_moved__(self, filter: FilterIterator) -> bool: key = md5(filter._raw_expression.encode()).hexdigest() if key not in self._moved: self._moved[key] = None @@ -130,7 +101,10 @@ def __has_already_been_moved__(self, filter: PreemptableIterator) -> bool: else: return True - def __push_filter__(self, filter, targets) -> bool: + def __push_filter__( + self, filter: FilterIterator, + targets: List[Tuple[PreemptableIterator, int]] + ) -> bool: if self.__has_already_been_moved__(filter): return False for (iterator, position) in targets: @@ -151,131 +125,69 @@ def __push_filter__(self, filter, targets) -> bool: raise Exception(f'FilterPushDownError: {message}') return True - def __process_unary_iterator__(self, node: PreemptableIterator) -> PreemptableIterator: - updated = False + def __process_unary_iterator__( + self, node: Union[ProjectionIterator, FilterIterator], + context: Dict[str, Any] = {} + ) -> PreemptableIterator: + node._source = self.visit(node._source, context=context) if node._source.serialized_name() == 'filter': - targets = FilterTargets(node._source).visit(node) - updated = self.__push_filter__(node._source, targets) - if updated: - node._source = node._source._source - return self.visit(node) - node._source = self.visit(node._source) + targets = FilterTargets().visit( + context['root'], + {'variables': node._source.constrained_variables()} + ) + if self.__push_filter__(node._source, targets): + node._source = node._source._source return node - def visit_projection(self, node: PreemptableIterator, context: Dict[str, Any] = {}) -> PreemptableIterator: - return self.__process_unary_iterator__(node) - - def visit_filter(self, node: PreemptableIterator, context: Dict[str, Any] = {}) -> PreemptableIterator: - return self.__process_unary_iterator__(node) - - def __process_binary_iterator__(self, node: PreemptableIterator) -> PreemptableIterator: - updated = False - # remove the left filter if it has been moved + def visit_projection( + self, node: ProjectionIterator, context: Dict[str, Any] = {} + ) -> PreemptableIterator: + context['root'] = node + return self.__process_unary_iterator__(node, context=context) + + def visit_filter( + self, node: FilterIterator, context: Dict[str, Any] = {} + ) -> PreemptableIterator: + return self.__process_unary_iterator__(node, context=context) + + def __process_binary_iterator__( + self, node: Union[IndexJoinIterator, BagUnionIterator], + context: Dict[str, Any] = {} + ) -> PreemptableIterator: + node._left = self.visit(node._left, context=context) if node._left.serialized_name() == 'filter': - targets = FilterTargets(node._left).visit(node) - updated = self.__push_filter__(node._left, targets) - if updated: - node._left = node._left._source - return self.visit(node) - # remove the right filter if it has been moved + targets = FilterTargets().visit( + context['root'], + {'variables': node._left.constrained_variables()} + ) + if self.__push_filter__(node._left, targets): + node._left = node._left._source + node._right = self.visit(node._right, context=context) if node._right.serialized_name() == 'filter': - targets = FilterTargets(node._right).visit(node) - updated = self.__push_filter__(node._right, targets) - if updated: - node._right = node._right._source - return self.visit(node) - # continue the exploration of the tree - node._left = self.visit(node._left) - node._right = self.visit(node._right) + targets = FilterTargets().visit( + context['root'], + {'variables': node._right.constrained_variables()} + ) + if self.__push_filter__(node._right, targets): + node._right = node._right._source return node - def visit_join(self, node: PreemptableIterator, context: Dict[str, Any] = {}) -> PreemptableIterator: - return self.__process_binary_iterator__(node) + def visit_join( + self, node: IndexJoinIterator, context: Dict[str, Any] = {} + ) -> PreemptableIterator: + return self.__process_binary_iterator__(node, context=context) - def visit_union(self, node: PreemptableIterator, context: Dict[str, Any] = {}) -> PreemptableIterator: - return self.__process_binary_iterator__(node) + def visit_union( + self, node: BagUnionIterator, context: Dict[str, Any] = {} + ) -> PreemptableIterator: + return self.__process_binary_iterator__(node, context=context) - def visit_values(self, node: PreemptableIterator, context: Dict[str, Any] = {}) -> PreemptableIterator: + def visit_values( + self, node: ValuesIterator, context: Dict[str, Any] = {} + ) -> PreemptableIterator: return node - def visit_scan(self, node: PreemptableIterator, context: Dict[str, Any] = {}) -> PreemptableIterator: + def visit_scan( + self, node: ScanIterator, context: Dict[str, Any] = {} + ) -> PreemptableIterator: return node - - -# class FilterPushDown(PhysicalPlanVisitor): -# -# def __init__(self): -# super().__init__() -# -# def __push_filter__(self, filter, targets) -> bool: -# updated = False -# for (iterator, position) in targets: -# if iterator.variables() == filter.variables(): -# continue -# updated = True -# if position == SOURCE: -# iterator._source = FilterIterator( -# iterator._source, filter._raw_expression, filter._expression -# ) -# elif position == LEFT: -# iterator._left = FilterIterator( -# iterator._left, filter._raw_expression, filter._expression -# ) -# elif position == RIGHT: -# iterator._right = FilterIterator( -# iterator._right, filter._raw_expression, filter._expression -# ) -# else: -# message = f'Unexpected relative position {position}' -# raise Exception(f'FilterPushDownError: {message}') -# return updated -# -# def __process_unary_iterator__(self, node: PreemptableIterator) -> PreemptableIterator: -# updated = False -# if node._source.serialized_name() == 'filter': -# targets = FilterTargets(node._source).visit(node) -# updated = self.__push_filter__(node._source, targets) -# if updated: -# node._source = node._source._source -# return self.visit(node) -# node._source = self.visit(node._source) -# return node -# -# def visit_projection(self, node: PreemptableIterator) -> PreemptableIterator: -# return self.__process_unary_iterator__(node) -# -# def visit_filter(self, node: PreemptableIterator) -> PreemptableIterator: -# return self.__process_unary_iterator__(node) -# -# def __process_binary_iterator__(self, node: PreemptableIterator) -> PreemptableIterator: -# updated = False -# # remove the left filter if it has been moved -# if node._left.serialized_name() == 'filter': -# targets = FilterTargets(node._left).visit(node) -# updated = self.__push_filter__(node._left, targets) -# if updated: -# node._left = node._left._source -# return self.visit(node) -# # remove the right filter if it has been moved -# if node._right.serialized_name() == 'filter': -# targets = FilterTargets(node._right).visit(node) -# updated = self.__push_filter__(node._right, targets) -# if updated: -# node._right = node._right._source -# return self.visit(node) -# # continue the exploration of the tree -# node._left = self.visit(node._left) -# node._right = self.visit(node._right) -# return node -# -# def visit_join(self, node: PreemptableIterator) -> PreemptableIterator: -# return self.__process_binary_iterator__(node) -# -# def visit_union(self, node: PreemptableIterator) -> PreemptableIterator: -# return self.__process_binary_iterator__(node) -# -# def visit_values(self, node: PreemptableIterator) -> PreemptableIterator: -# return node -# -# def visit_scan(self, node: PreemptableIterator) -> PreemptableIterator: -# return node diff --git a/sage/query_engine/optimizer/physical/visitors/values_push_down.py b/sage/query_engine/optimizer/physical/visitors/values_push_down.py new file mode 100644 index 0000000..825f1df --- /dev/null +++ b/sage/query_engine/optimizer/physical/visitors/values_push_down.py @@ -0,0 +1,214 @@ +from hashlib import md5 +from typing import Tuple, List, Dict, Any, Union + +from sage.query_engine.optimizer.physical.plan_visitor import PhysicalPlanVisitor +from sage.query_engine.iterators.preemptable_iterator import PreemptableIterator +from sage.query_engine.iterators.projection import ProjectionIterator +from sage.query_engine.iterators.values import ValuesIterator +from sage.query_engine.iterators.filter import FilterIterator +from sage.query_engine.iterators.nlj import IndexJoinIterator +from sage.query_engine.iterators.union import BagUnionIterator +from sage.query_engine.iterators.scan import ScanIterator + +SOURCE = 0 +LEFT = 1 +RIGHT = 2 + + +class ValuesTargets(PhysicalPlanVisitor): + + def visit_projection( + self, node: ProjectionIterator, context: Dict[str, Any] = {} + ) -> List[Tuple[PreemptableIterator, int]]: + # can be moved at a lower level than the current iterator's children + targets = self.visit(node._source, context=context) + if len(targets) > 0: + return targets + # cannot be moved elsewhere (not really efficient, can be improved...) + return [(node, SOURCE)] + + def visit_filter( + self, node: FilterIterator, context: Dict[str, Any] = {} + ) -> List[Tuple[PreemptableIterator, int]]: + # can be moved at a lower level than the current iterator's children + targets = self.visit(node._source, context=context) + if len(targets) > 0: + return targets + # can be a child of the current iterator + if node._source.variables(include_values=False).issuperset(context['variables']): + return [(node, SOURCE)] + # cannot be moved after this iterator + return [] + + def visit_join( + self, node: IndexJoinIterator, context: Dict[str, Any] = {} + ) -> List[Tuple[PreemptableIterator, int]]: + # can be moved at a lower level than the current iterator's children + targets = self.visit(node._left, context=context) + if len(targets) > 0: + return targets + targets = self.visit(node._right, context=context) + if len(targets) > 0: + return targets + # can be a child of the current iterator + if node._left.variables(include_values=False).issuperset(context['variables']): + return [(node, LEFT)] + if node._right.variables(include_values=False).issuperset(context['variables']): + return [(node, RIGHT)] + # cannot be moved after this iterator + return [] + + def visit_union( + self, node: BagUnionIterator, context: Dict[str, Any] = {} + ) -> List[Tuple[PreemptableIterator, int]]: + # can be moved at a lower level than the current iterator's children + targets = self.visit(node._left, context=context) + self.visit(node._right, context=context) + if len(targets) > 0: + return targets + # can be a child of the current iterator + targets = [] + if node._left.variables(include_values=False).issuperset(context['variables']): + targets.append((node, LEFT)) + if node._right.variables(include_values=False).issuperset(context['variables']): + targets.append((node, RIGHT)) + return targets + + def visit_values( + self, node: ValuesIterator, context: Dict[str, Any] = {} + ) -> List[Tuple[PreemptableIterator, int]]: + return [] # cannot be moved after a leaf iterator + + def visit_scan( + self, node: ScanIterator, context: Dict[str, Any] = {} + ) -> List[Tuple[PreemptableIterator, int]]: + return [] # cannot be moved after a leaf iterator + + +class ValuesPushDown(PhysicalPlanVisitor): + + def __init__(self): + super().__init__() + self._moved = dict() + + def __has_already_been_moved__(self, values: ValuesIterator) -> bool: + key = md5(''.join(values.variables()).encode()).hexdigest() + if key not in self._moved: + self._moved[key] = None + return False + else: + return True + + def __push_values__( + self, values: ValuesIterator, + targets: List[Tuple[PreemptableIterator, int]] + ) -> bool: + if self.__has_already_been_moved__(values): + return False + for (iterator, position) in targets: + if position == SOURCE: + iterator._source = IndexJoinIterator( + ValuesIterator(values._values), iterator._source + ) + elif position == LEFT: + iterator._left = IndexJoinIterator( + ValuesIterator(values._values), iterator._left + ) + elif position == RIGHT: + iterator._right = IndexJoinIterator( + ValuesIterator(values._values), iterator._right + ) + else: + message = f'Unexpected relative position {position}' + raise Exception(f'ValuesPushDownError: {message}') + return True + + def __process_unary_iterator__( + self, node: Union[ProjectionIterator, FilterIterator], + context: Dict[str, Any] = {} + ) -> PreemptableIterator: + node._source = self.visit(node._source, context=context) + if node._source.serialized_name() == 'join': + if node._source._left.serialized_name() == 'values': + targets = ValuesTargets().visit( + context['root'], + {'variables': node._source._left.variables()} + ) + if self.__push_values__(node._source._left, targets): + node._source = node._source._right + elif node._source._right.serialized_name() == 'values': + targets = ValuesTargets().visit( + context['root'], + {'variables': node._source._right.variables()} + ) + if self.__push_values__(node._source._right, targets): + node._source = node._source._left + return node + + def visit_projection( + self, node: ProjectionIterator, context: Dict[str, Any] = {} + ) -> PreemptableIterator: + context['root'] = node + return self.__process_unary_iterator__(node, context=context) + + def visit_filter( + self, node: FilterIterator, context: Dict[str, Any] = {} + ) -> PreemptableIterator: + return self.__process_unary_iterator__(node, context=context) + + def __process_binary_iterator__( + self, node: Union[IndexJoinIterator, BagUnionIterator], + context: Dict[str, Any] = {} + ) -> PreemptableIterator: + node._left = self.visit(node._left, context=context) + if node._left.serialized_name() == 'join': + if node._left._left.serialized_name() == 'values': + targets = ValuesTargets().visit( + context['root'], + {'variables': node._left._left.variables()} + ) + if self.__push_values__(node._left._left, targets): + node._left = node._left._right + elif node._left._right.serialized_name() == 'values': + targets = ValuesTargets().visit( + context['root'], + {'variables': node._left._right.variables()} + ) + if self.__push_values__(node._left._right, targets): + node._left = node._left._left + node._right = self.visit(node._right, context=context) + if node._right.serialized_name() == 'join': + if node._right._left.serialized_name() == 'values': + targets = ValuesTargets().visit( + context['root'], + {'variables': node._right._left.variables()} + ) + if self.__push_values__(node._right._left, targets): + node._right = node._right._right + elif node._right._right.serialized_name() == 'values': + targets = ValuesTargets().visit( + context['root'], + {'variables': node._right._right.variables()} + ) + if self.__push_values__(node._right._right, targets): + node._right = node._right._left + return node + + def visit_join( + self, node: IndexJoinIterator, context: Dict[str, Any] = {} + ) -> PreemptableIterator: + return self.__process_binary_iterator__(node, context=context) + + def visit_union( + self, node: BagUnionIterator, context: Dict[str, Any] = {} + ) -> PreemptableIterator: + return self.__process_binary_iterator__(node, context=context) + + def visit_values( + self, node: ValuesIterator, context: Dict[str, Any] = {} + ) -> PreemptableIterator: + return node + + def visit_scan( + self, node: ScanIterator, context: Dict[str, Any] = {} + ) -> PreemptableIterator: + return node