Skip to content

Commit

Permalink
Add ability to specify database connection credentials
Browse files Browse the repository at this point in the history
  • Loading branch information
michaelbautin committed Dec 26, 2024
1 parent 0e32628 commit 179a8f7
Show file tree
Hide file tree
Showing 4 changed files with 64 additions and 7 deletions.
8 changes: 7 additions & 1 deletion ann_benchmarks/algorithms/base/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,4 +76,10 @@ def get_additional(self) -> Dict[str, Any]:
return {}

def __str__(self) -> str:
return self.name
return self.name

def set_conn_params(self, conn_params: Dict[str, str]) -> None:
"""Set connection parameters that might be required for connecting to
the system under test, such as a database server.
"""
pass
34 changes: 32 additions & 2 deletions ann_benchmarks/algorithms/pgvector/module.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,28 @@
import subprocess
import sys

from typing import Dict

import pgvector.psycopg
import psycopg

from ..base.module import BaseANN


DEFAULT_POSTGRES_USER = 'ann'
DEFAULT_POSTGRES_PASSWORD = 'ann'
DEFAULT_POSTGRES_DB_NAME = 'ann'


class PGVector(BaseANN):
_conn_params: Dict[str, str]

def __init__(self, metric, method_param):
self._metric = metric
self._m = method_param['M']
self._ef_construction = method_param['efConstruction']
self._cur = None
self._conn_params = {}

if metric == "angular":
self._query = "SELECT id FROM items ORDER BY embedding <=> %s LIMIT %s"
Expand All @@ -21,9 +31,29 @@ def __init__(self, metric, method_param):
else:
raise RuntimeError(f"unknown metric {metric}")

def set_conn_params(self, conn_params: Dict[str, str]) -> None:
self._conn_params = conn_params

def get_conn_param(self, key: str, default_value: str) -> str:
value = self._conn_params.get(key)
if value is None:
return default_value
return value

def fit(self, X):
subprocess.run("service postgresql start", shell=True, check=True, stdout=sys.stdout, stderr=sys.stderr)
conn = psycopg.connect(user="ann", password="ann", dbname="ann", autocommit=True)
psycopg_connect_kwargs: Dict[str, Any] = dict(
autocommit=True,
user=self.get_conn_param('user', DEFAULT_POSTGRES_USER),
password=self.get_conn_param('password', DEFAULT_POSTGRES_PASSWORD),
dbname=self.get_conn_param('dbname', DEFAULT_POSTGRES_DB_NAME)
)
for arg_name in ['host', 'port']:
# For these arguments, if they are not specified, leave the default
# choice to the psycopg driver.
if self._conn_params.get(arg_name) is not None:
psycopg_connect_kwargs[arg_name] = self._conn_params[arg_name]

conn = psycopg.connect(**psycopg_connect_kwargs)
pgvector.psycopg.register_vector(conn)
cur = conn.cursor()
cur.execute("DROP TABLE IF EXISTS items")
Expand Down
9 changes: 7 additions & 2 deletions ann_benchmarks/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from .constants import INDEX_DIR
from .datasets import DATASETS, get_dataset
from .results import build_result_filepath
from .runner import run, run_docker
from .runner import run, run_docker, get_conn_params_from_args


logging.config.fileConfig("logging.conf")
Expand Down Expand Up @@ -68,7 +68,7 @@ def run_worker(cpu: int, mem_limit: int, args: argparse.Namespace, queue: multip
while not queue.empty():
definition = queue.get()
if args.local:
run(definition, args.dataset, args.count, args.runs, args.batch)
run(definition, args.dataset, args.count, args.runs, args.batch, get_conn_params_from_args(args))
else:
cpu_limit = str(cpu) if not args.batch else f"0-{multiprocessing.cpu_count() - 1}"

Expand Down Expand Up @@ -122,6 +122,11 @@ def parse_arguments() -> argparse.Namespace:
)
parser.add_argument("--run-disabled", help="run algorithms that are disabled in algos.yml", action="store_true")
parser.add_argument("--parallelism", type=positive_int, help="Number of Docker containers in parallel", default=1)
parser.add_argument("--user", help="Username to connect to server")
parser.add_argument("--password", help="Password to connect to server")
parser.add_argument("--dbname", help="Database name to use when connecting to server")
parser.add_argument("--host", help="Server to which to connect")
parser.add_argument("--port", type=int, help="Port to use to connect to server")

args = parser.parse_args()
if args.timeout == -1:
Expand Down
20 changes: 18 additions & 2 deletions ann_benchmarks/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,20 @@ def build_index(algo: BaseANN, X_train: numpy.ndarray) -> Tuple:
return build_time, index_size


def run(definition: Definition, dataset_name: str, count: int, run_count: int, batch: bool) -> None:
def get_conn_params_from_args(args: argparse.Namespace) -> Dict[str, str]:
"""Extracts server connection parameters from the given arguments object."""
return {
key: getattr(args, key)
for key in ('user', 'password', 'dbname', 'host', 'port')
if getattr(args, key) is not None
}

def run(definition: Definition,
dataset_name: str,
count: int,
run_count: int,
batch: bool,
conn_params: Dict[str, str]) -> None:
"""Run the algorithm benchmarking.
Args:
Expand All @@ -203,6 +216,7 @@ def run(definition: Definition, dataset_name: str, count: int, run_count: int, b
count (int): The number of results to return.
run_count (int): The number of runs.
batch (bool): If true, runs in batch mode.
conn_params (dict): Parameters for connecting to the server.
"""
algo = instantiate_algorithm(definition)
assert not definition.query_argument_groups or hasattr(
Expand All @@ -211,6 +225,7 @@ def run(definition: Definition, dataset_name: str, count: int, run_count: int, b
error: query argument groups have been specified for {definition.module}.{definition.constructor}({definition.arguments}), but the \
algorithm instantiated from it does not implement the set_query_arguments \
function"""
algo.set_conn_params(conn_params)

X_train, X_test, distance = load_and_transform_dataset(dataset_name)

Expand Down Expand Up @@ -288,7 +303,8 @@ def run_from_cmdline():
query_argument_groups=query_args,
disabled=False,
)
run(definition, args.dataset, args.count, args.runs, args.batch)
run(definition, args.dataset, args.count, args.runs, args.batch,
get_conn_params_from_args(args))


def run_docker(
Expand Down

0 comments on commit 179a8f7

Please sign in to comment.