diff --git a/src/dataregistry/db_basic.py b/src/dataregistry/db_basic.py index eb8132d..0c78119 100644 --- a/src/dataregistry/db_basic.py +++ b/src/dataregistry/db_basic.py @@ -8,6 +8,7 @@ from dataregistry import __version__ from dataregistry.exceptions import DataRegistryException from dataregistry.schema import DEFAULT_SCHEMA_WORKING +from functools import cached_property """ Low-level utility routines and classes for accessing the registry @@ -284,6 +285,36 @@ def _get_db_info(prov_table, get_associated_production=False): # Store metadata self.metadata["tables"] = metadata.tables + @cached_property + def duplicate_column_names(self): + """ + Probe the database for tables which share column names. This is used + later for querying. + + Returns + ------- + duplicates : list + List of column names that are duplicated across tables + """ + + # Database hasn't been reflected yet + if len(self.metadata) == 0: + self._reflect() + + # Find duplicate column names + duplicates = set() + all_columns = [] + for table in self.metadata["tables"]: + for column in self.metadata["tables"][table].c: + if self.metadata["tables"][table].schema != self.active_schema: + continue + + if column.name in all_columns: + duplicates.add(column.name) + all_columns.append(column.name) + + return list(duplicates) + def get_table(self, tbl, schema=None): """ Get metadata for a specific table in the database. diff --git a/src/dataregistry/query.py b/src/dataregistry/query.py index e8b08ae..63e1070 100644 --- a/src/dataregistry/query.py +++ b/src/dataregistry/query.py @@ -199,9 +199,20 @@ def _parse_selected_columns(self, column_names): input_parts = col_name.split(".") num_parts = len(input_parts) + # Make sure column name is value if num_parts > 2: raise ValueError(f"{col_name} is not a valid column") + if num_parts == 1: + if col_name in self.db_connection.duplicate_column_names: + raise DataRegistryException( + ( + f"Column name '{col_name}' is not unique to one table " + f"in the database, use . " + f"format instead" + ) + ) + # Loop over each column in the database and find matches for table in self.db_connection.metadata["tables"]: for column in self.db_connection.metadata["tables"][table].c: @@ -216,6 +227,7 @@ def _parse_selected_columns(self, column_names): # Input is in format if input_parts[0] == table_parts[-1]: tmp_column_list[column.table.schema].append(column) + tables_required.add(column.table.name) elif num_parts == 2: # Input is in . format if ( @@ -223,23 +235,7 @@ def _parse_selected_columns(self, column_names): and input_parts[1] == table_parts[-1] ): tmp_column_list[column.table.schema].append(column) - - # Make sure we don't find multiple matches - for s in tmp_column_list.keys(): # Each schema - chk = [] - for x in tmp_column_list[s]: # Each column in schema - if x.name in chk: - raise DataRegistryException( - ( - f"Column name '{col_name}' is not unique to one table " - f"in the database, use . " - f"format instead" - ) - ) - chk.append(x.name) - - # Add this table to the list - tables_required.add(x.table.name) + tables_required.add(column.table.name) # Store results for att in tmp_column_list.keys():