diff --git a/ann_benchmarks/algorithms/base/module.py b/ann_benchmarks/algorithms/base/module.py index 785b800e3..4c96034f5 100644 --- a/ann_benchmarks/algorithms/base/module.py +++ b/ann_benchmarks/algorithms/base/module.py @@ -76,4 +76,10 @@ def get_additional(self) -> Dict[str, Any]: return {} def __str__(self) -> str: - return self.name \ No newline at end of file + return self.name + + def set_conn_params(self, conn_params: Dict[str, str]) -> None: + """Set connection parameters that might be required for connecting to + the system under test, such as a database server. + """ + pass diff --git a/ann_benchmarks/algorithms/pgvector/module.py b/ann_benchmarks/algorithms/pgvector/module.py index 443ac1ffd..3f345daa8 100644 --- a/ann_benchmarks/algorithms/pgvector/module.py +++ b/ann_benchmarks/algorithms/pgvector/module.py @@ -1,18 +1,28 @@ import subprocess import sys +from typing import Dict + import pgvector.psycopg import psycopg from ..base.module import BaseANN +DEFAULT_POSTGRES_USER = 'ann' +DEFAULT_POSTGRES_PASSWORD = 'ann' +DEFAULT_POSTGRES_DB_NAME = 'ann' + + class PGVector(BaseANN): + _conn_params: Dict[str, str] + def __init__(self, metric, method_param): self._metric = metric self._m = method_param['M'] self._ef_construction = method_param['efConstruction'] self._cur = None + self._conn_params = {} if metric == "angular": self._query = "SELECT id FROM items ORDER BY embedding <=> %s LIMIT %s" @@ -21,9 +31,29 @@ def __init__(self, metric, method_param): else: raise RuntimeError(f"unknown metric {metric}") + def set_conn_params(self, conn_params: Dict[str, str]) -> None: + self._conn_params = conn_params + + def get_conn_param(self, key: str, default_value: str) -> str: + value = self._conn_params.get(key) + if value is None: + return default_value + return value + def fit(self, X): - subprocess.run("service postgresql start", shell=True, check=True, stdout=sys.stdout, stderr=sys.stderr) - conn = psycopg.connect(user="ann", password="ann", dbname="ann", autocommit=True) + psycopg_connect_kwargs: Dict[str, Any] = dict( + autocommit=True, + user=self.get_conn_param('user', DEFAULT_POSTGRES_USER), + password=self.get_conn_param('password', DEFAULT_POSTGRES_PASSWORD), + dbname=self.get_conn_param('dbname', DEFAULT_POSTGRES_DB_NAME) + ) + for arg_name in ['host', 'port']: + # For these arguments, if they are not specified, leave the default + # choice to the psycopg driver. + if self._conn_params.get(arg_name) is not None: + psycopg_connect_kwargs[arg_name] = self._conn_params[arg_name] + + conn = psycopg.connect(**psycopg_connect_kwargs) pgvector.psycopg.register_vector(conn) cur = conn.cursor() cur.execute("DROP TABLE IF EXISTS items") diff --git a/ann_benchmarks/main.py b/ann_benchmarks/main.py index 07539775d..4f6a203a8 100644 --- a/ann_benchmarks/main.py +++ b/ann_benchmarks/main.py @@ -18,7 +18,7 @@ from .constants import INDEX_DIR from .datasets import DATASETS, get_dataset from .results import build_result_filepath -from .runner import run, run_docker +from .runner import run, run_docker, get_conn_params_from_args logging.config.fileConfig("logging.conf") @@ -68,7 +68,7 @@ def run_worker(cpu: int, mem_limit: int, args: argparse.Namespace, queue: multip while not queue.empty(): definition = queue.get() if args.local: - run(definition, args.dataset, args.count, args.runs, args.batch) + run(definition, args.dataset, args.count, args.runs, args.batch, get_conn_params_from_args(args)) else: cpu_limit = str(cpu) if not args.batch else f"0-{multiprocessing.cpu_count() - 1}" @@ -122,6 +122,11 @@ def parse_arguments() -> argparse.Namespace: ) parser.add_argument("--run-disabled", help="run algorithms that are disabled in algos.yml", action="store_true") parser.add_argument("--parallelism", type=positive_int, help="Number of Docker containers in parallel", default=1) + parser.add_argument("--user", help="Username to connect to server") + parser.add_argument("--password", help="Password to connect to server") + parser.add_argument("--dbname", help="Database name to use when connecting to server") + parser.add_argument("--host", help="Server to which to connect") + parser.add_argument("--port", type=int, help="Port to use to connect to server") args = parser.parse_args() if args.timeout == -1: diff --git a/ann_benchmarks/runner.py b/ann_benchmarks/runner.py index 81428114c..1c7830609 100644 --- a/ann_benchmarks/runner.py +++ b/ann_benchmarks/runner.py @@ -194,7 +194,20 @@ def build_index(algo: BaseANN, X_train: numpy.ndarray) -> Tuple: return build_time, index_size -def run(definition: Definition, dataset_name: str, count: int, run_count: int, batch: bool) -> None: +def get_conn_params_from_args(args: argparse.Namespace) -> Dict[str, str]: + """Extracts server connection parameters from the given arguments object.""" + return { + key: getattr(args, key) + for key in ('user', 'password', 'dbname', 'host', 'port') + if getattr(args, key) is not None + } + +def run(definition: Definition, + dataset_name: str, + count: int, + run_count: int, + batch: bool, + conn_params: Dict[str, str]) -> None: """Run the algorithm benchmarking. Args: @@ -203,6 +216,7 @@ def run(definition: Definition, dataset_name: str, count: int, run_count: int, b count (int): The number of results to return. run_count (int): The number of runs. batch (bool): If true, runs in batch mode. + conn_params (dict): Parameters for connecting to the server. """ algo = instantiate_algorithm(definition) assert not definition.query_argument_groups or hasattr( @@ -211,6 +225,7 @@ def run(definition: Definition, dataset_name: str, count: int, run_count: int, b error: query argument groups have been specified for {definition.module}.{definition.constructor}({definition.arguments}), but the \ algorithm instantiated from it does not implement the set_query_arguments \ function""" + algo.set_conn_params(conn_params) X_train, X_test, distance = load_and_transform_dataset(dataset_name) @@ -288,7 +303,8 @@ def run_from_cmdline(): query_argument_groups=query_args, disabled=False, ) - run(definition, args.dataset, args.count, args.runs, args.batch) + run(definition, args.dataset, args.count, args.runs, args.batch, + get_conn_params_from_args(args)) def run_docker(