diff --git a/scripts/create_registry_schema.py b/scripts/create_registry_schema.py index cef27f0d..53519c56 100644 --- a/scripts/create_registry_schema.py +++ b/scripts/create_registry_schema.py @@ -318,23 +318,25 @@ def _BuildTable(schema, table_name, has_production, production): # Loop over each schema for schema in schema_list: - if schema == prod_schema: + + # Connect to database to find out what the backend is + db_connection = DbConnection(args.config, schema) + print(f"Database dialect is '{db_connection.dialect}'") + + if db_connection.dialect == "sqlite": + print(f"Creating sqlite database...") + schema = None + elif schema == prod_schema: print(f"Creating production schema {prod_schema}...") else: print( f"Creating schema '{schema}', linking to production schema '{prod_schema}'..." ) - # Connect to database to find out what the backend is - db_connection = DbConnection(args.config, schema) - # Make sure the linked production schema exists / is allowed if db_connection.dialect == "sqlite": if schema == prod_schema: raise ValueError("Production not available for sqlite databases") - # In fact we don't use schemas at all for sqlite - schema = None - prod_schema = None else: if schema != prod_schema: # production schema, tables must already exists and schema @@ -356,48 +358,54 @@ def _BuildTable(schema, table_name, has_production, production): raise RuntimeError("production schema version incompatible") # Create the schema - stmt = f"CREATE SCHEMA IF NOT EXISTS {schema}" - with db_connection.engine.connect() as conn: - conn.execute(text(stmt)) - conn.commit() + if db_connection.dialect != "sqlite": + stmt = f"CREATE SCHEMA IF NOT EXISTS {schema}" + with db_connection.engine.connect() as conn: + conn.execute(text(stmt)) + conn.commit() # Create the tables for table_name in schema_data.keys(): _BuildTable(schema, table_name, db_connection.dialect != "sqlite", prod_schema) - print(f"Built table {table_name} in {schema}") + if db_connection.dialect != "sqlite": + print(f"Built table {table_name} in {schema}") + else: + print(f"Built table {table_name}") # Generate the database - if schema != prod_schema: - Base.metadata.reflect(db_connection.engine, prod_schema) + if db_connection.dialect != "sqlite": + if schema != prod_schema: + Base.metadata.reflect(db_connection.engine, prod_schema) Base.metadata.create_all(db_connection.engine) - # Grant access to other accounts. Can only grant access to objects + # Grant access to other accounts. Can only grant access to objects # after they've been created - try: - with db_connection.engine.connect() as conn: - # Grant reg_reader access. - acct = "reg_reader" - usage_prv = f"GRANT USAGE ON SCHEMA {schema} to {acct}" - select_prv = f"GRANT SELECT ON ALL TABLES IN SCHEMA {schema} to {acct}" - conn.execute(text(usage_prv)) - conn.execute(text(select_prv)) - conn.commit() - except Exception: - print(f"Could not grant access to {acct} on schema {schema}") - - if schema == prod_schema: + if db_connection.dialect != "sqlite": try: with db_connection.engine.connect() as conn: - # Grant reg_writer access. - acct = "reg_writer" - usage_priv = f"GRANT USAGE ON SCHEMA {schema} to {acct}" - select_priv = f"GRANT SELECT ON ALL TABLES IN SCHEMA {schema} to {acct}" - conn.execute(text(usage_priv)) - conn.execute(text(select_priv)) + # Grant reg_reader access. + acct = "reg_reader" + usage_prv = f"GRANT USAGE ON SCHEMA {schema} to {acct}" + select_prv = f"GRANT SELECT ON ALL TABLES IN SCHEMA {schema} to {acct}" + conn.execute(text(usage_prv)) + conn.execute(text(select_prv)) conn.commit() except Exception: print(f"Could not grant access to {acct} on schema {schema}") + if schema == prod_schema: + try: + with db_connection.engine.connect() as conn: + # Grant reg_writer access. + acct = "reg_writer" + usage_priv = f"GRANT USAGE ON SCHEMA {schema} to {acct}" + select_priv = f"GRANT SELECT ON ALL TABLES IN SCHEMA {schema} to {acct}" + conn.execute(text(usage_priv)) + conn.execute(text(select_priv)) + conn.commit() + except Exception: + print(f"Could not grant access to {acct} on schema {schema}") + # Add initial provenance information prov_id = _insert_provenance( db_connection,