diff --git a/grate.unittests/Generic/GenericMigrationTables.cs b/grate.unittests/Generic/GenericMigrationTables.cs index 8cf0736e..b66cc018 100644 --- a/grate.unittests/Generic/GenericMigrationTables.cs +++ b/grate.unittests/Generic/GenericMigrationTables.cs @@ -98,7 +98,90 @@ public async Task Migration_does_not_fail_if_table_already_exists(string tableNa Assert.DoesNotThrowAsync(() => migrator.Migrate()); } } + + [TestCase("version")] + [TestCase("vErSiON")] + public async Task Does_not_create_Version_table_if_it_exists_with_another_casing(string existingTable) + { + await CheckTableCasing("Version", existingTable, (config, name) => config.VersionTableName = name); + } + + [TestCase("scriptsrun")] + [TestCase("SCRiptSrUN")] + public async Task Does_not_create_ScriptsRun_table_if_it_exists_with_another_casing(string existingTable) + { + await CheckTableCasing("ScriptsRun", existingTable, (config, name) => config.ScriptsRunTableName = name); + } + + [TestCase("scriptsrunerrors")] + [TestCase("ScripTSRunErrors")] + public async Task Does_not_create_ScriptsRunErrors_table_if_it_exists_with_another_casing(string existingTable) + { + await CheckTableCasing("ScriptsRunErrors", existingTable, (config, name) => config.ScriptsRunErrorsTableName = name); + } + + protected virtual async Task CheckTableCasing(string tableName, string funnyCasing, Action setTableName) + { + var db = TestConfig.RandomDatabase(); + + var parent = TestConfig.CreateRandomTempDirectory(); + var knownFolders = FoldersConfiguration.Default(); + + // Set the version table name to be lower-case first, and run one migration. + var config = Context.GetConfiguration(db, parent, knownFolders); + + setTableName(config, funnyCasing); + + await using (var migrator = Context.GetMigrator(config)) + { + await migrator.Migrate(); + } + + // Check that the table is indeed created with lower-case + var errorCaseCountAfterFirstMigration = await TableCountIn(db, funnyCasing); + var normalCountAfterFirstMigration = await TableCountIn(db, tableName); + Assert.Multiple(() => + { + errorCaseCountAfterFirstMigration.Should().Be(1); + normalCountAfterFirstMigration.Should().Be(0); + }); + + // Run migration again - make sure it does not create the table with different casing too + setTableName(config, tableName); + await using (var migrator = Context.GetMigrator(config)) + { + await migrator.Migrate(); + } + + var errorCaseCountAfterSecondMigration = await TableCountIn(db, funnyCasing); + var normalCountAfterSecondMigration = await TableCountIn(db, tableName); + Assert.Multiple(() => + { + errorCaseCountAfterSecondMigration.Should().Be(1); + normalCountAfterSecondMigration.Should().Be(0); + }); + + } + + private async Task TableCountIn(string db, string tableName) + { + var schemaName = Context.DefaultConfiguration.SchemaName; + var supportsSchemas = Context.DatabaseMigrator.SupportsSchemas; + var fullTableName = supportsSchemas ? tableName : Context.Syntax.TableWithSchema(schemaName, tableName); + var tableSchema = supportsSchemas ? schemaName : db; + + int count; + string countSql = CountTableSql(tableSchema, fullTableName); + + await using (var conn = Context.GetDbConnection(Context.ConnectionString(db))) + { + count = await conn.ExecuteScalarAsync(countSql); + } + + return count; + } + [Test()] public async Task Inserts_version_in_version_table() { @@ -156,4 +239,14 @@ protected static DirectoryInfo MakeSurePathExists(DirectoryInfo? path) private static DirectoryInfo Wrap(DirectoryInfo root, string? relativePath) => new(Path.Combine(root.ToString(), relativePath ?? "")); + protected virtual string CountTableSql(string schemaName, string tableName) + { + return $@" +SELECT count(table_name) FROM INFORMATION_SCHEMA.TABLES +WHERE +table_schema = '{schemaName}' AND +table_name = '{tableName}' +"; + } + } diff --git a/grate.unittests/Oracle/MigrationTables.cs b/grate.unittests/Oracle/MigrationTables.cs index 0e10bb88..5902728b 100644 --- a/grate.unittests/Oracle/MigrationTables.cs +++ b/grate.unittests/Oracle/MigrationTables.cs @@ -1,3 +1,6 @@ +using System; +using System.Threading.Tasks; +using grate.Configuration; using grate.unittests.TestInfrastructure; using NUnit.Framework; @@ -5,7 +8,21 @@ namespace grate.unittests.Oracle; [TestFixture] [Category("Oracle")] -public class MigrationTables: Generic.GenericMigrationTables +public class MigrationTables : Generic.GenericMigrationTables { protected override IGrateTestContext Context => GrateTestContext.Oracle; -} \ No newline at end of file + + protected override Task CheckTableCasing(string tableName, string funnyCasing, Action setTableName) + { + Assert.Ignore("Oracle has never been case-sensitive for grate. No need to introduce that now."); + return Task.CompletedTask; + } + + protected override string CountTableSql(string schemaName, string tableName) + { + return $@" +SELECT COUNT(table_name) FROM user_tables +WHERE +lower(table_name) = '{tableName.ToLowerInvariant()}'"; + } +} diff --git a/grate.unittests/SqLite/MigrationTables.cs b/grate.unittests/SqLite/MigrationTables.cs index e9a9495b..5cb536d2 100644 --- a/grate.unittests/SqLite/MigrationTables.cs +++ b/grate.unittests/SqLite/MigrationTables.cs @@ -8,4 +8,13 @@ namespace grate.unittests.Sqlite; public class MigrationTables: Generic.GenericMigrationTables { protected override IGrateTestContext Context => GrateTestContext.Sqlite; -} \ No newline at end of file + + protected override string CountTableSql(string schemaName, string tableName) + { + return $@" +SELECT COUNT(name) FROM sqlite_master +WHERE type ='table' AND +name = '{tableName}'; +"; + } +} diff --git a/grate.unittests/SqlServer/MigrationTables.cs b/grate.unittests/SqlServer/MigrationTables.cs index 28db250a..aaf55daf 100644 --- a/grate.unittests/SqlServer/MigrationTables.cs +++ b/grate.unittests/SqlServer/MigrationTables.cs @@ -8,4 +8,14 @@ namespace grate.unittests.SqlServer; public class MigrationTables: Generic.GenericMigrationTables { protected override IGrateTestContext Context => GrateTestContext.SqlServer; -} \ No newline at end of file + + protected override string CountTableSql(string schemaName, string tableName) + { + return $@" +SELECT count(table_name) FROM INFORMATION_SCHEMA.TABLES +WHERE +TABLE_SCHEMA = '{schemaName}' AND +TABLE_NAME = '{tableName}' COLLATE Latin1_General_CS_AS +"; + } +} diff --git a/grate.unittests/TestContext.cs b/grate.unittests/TestContext.cs index 8e0b04b0..5e75cb7b 100644 --- a/grate.unittests/TestContext.cs +++ b/grate.unittests/TestContext.cs @@ -1,4 +1,8 @@ using grate.unittests.TestInfrastructure; +using NUnit.Framework; + +// There are some parallelism issues, but this does not solve it +//[assembly:LevelOfParallelism(1)] namespace grate.unittests; diff --git a/grate.unittests/TestInfrastructure/IGrateTestContext.cs b/grate.unittests/TestInfrastructure/IGrateTestContext.cs index 1756a8c0..dd8d72b1 100644 --- a/grate.unittests/TestInfrastructure/IGrateTestContext.cs +++ b/grate.unittests/TestInfrastructure/IGrateTestContext.cs @@ -56,6 +56,17 @@ DefaultConfiguration with SqlFilesDirectory = sqlFilesDirectory }; + public GrateConfiguration GetConfiguration(string databaseName, DirectoryInfo sqlFilesDirectory, + IFoldersConfiguration knownFolders, string? env, bool runInTransaction) => + DefaultConfiguration with + { + ConnectionString = ConnectionString(databaseName), + Folders = knownFolders, + Environment = env != null ? new GrateEnvironment(env) : null, + Transaction = runInTransaction, + SqlFilesDirectory = sqlFilesDirectory + }; + public GrateMigrator GetMigrator(GrateConfiguration config) { var factory = Substitute.For(); diff --git a/grate/Configuration/GrateConfiguration.cs b/grate/Configuration/GrateConfiguration.cs index 65547f38..fd646cc4 100644 --- a/grate/Configuration/GrateConfiguration.cs +++ b/grate/Configuration/GrateConfiguration.cs @@ -26,6 +26,10 @@ public record GrateConfiguration public string? ConnectionString { get; init; } = null; public string SchemaName { get; init; } = "grate"; + + public string ScriptsRunTableName { get; set; } = "ScriptsRun"; + public string ScriptsRunErrorsTableName { get; set; } = "ScriptsRunErrors"; + public string VersionTableName { get; set; } = "Version"; public string? AdminConnectionString { diff --git a/grate/Migration/AnsiSqlDatabase.cs b/grate/Migration/AnsiSqlDatabase.cs index bcb29b4a..72f38ec1 100644 --- a/grate/Migration/AnsiSqlDatabase.cs +++ b/grate/Migration/AnsiSqlDatabase.cs @@ -46,14 +46,18 @@ protected AnsiSqlDatabase(ILogger logger, ISyntax syntax) .Split("=", TrimEntries | RemoveEmptyEntries).Last(); public abstract bool SupportsDdlTransactions { get; } - protected abstract bool SupportsSchemas { get; } + public abstract bool SupportsSchemas { get; } public bool SplitBatchStatements => true; public string StatementSeparatorRegex => _syntax.StatementSeparatorRegex; - public string ScriptsRunTable => _syntax.TableWithSchema(SchemaName, "ScriptsRun"); - public string ScriptsRunErrorsTable => _syntax.TableWithSchema(SchemaName, "ScriptsRunErrors"); - public string VersionTable => _syntax.TableWithSchema(SchemaName, "Version"); + public string ScriptsRunTable => _syntax.TableWithSchema(SchemaName, ScriptsRunTableName); + public string ScriptsRunErrorsTable => _syntax.TableWithSchema(SchemaName, ScriptsRunErrorsTableName); + public string VersionTable => _syntax.TableWithSchema(SchemaName, VersionTableName); + + private string ScriptsRunTableName { get; set; } + private string ScriptsRunErrorsTableName { get; set; } + private string VersionTableName { get; set; } public virtual Task InitializeConnections(GrateConfiguration configuration) { @@ -61,11 +65,23 @@ public virtual Task InitializeConnections(GrateConfiguration configuration) ConnectionString = configuration.ConnectionString; AdminConnectionString = configuration.AdminConnectionString; + SchemaName = configuration.SchemaName; + + VersionTableName = configuration.VersionTableName; + ScriptsRunTableName = configuration.ScriptsRunTableName; + ScriptsRunErrorsTableName = configuration.ScriptsRunErrorsTableName; + Config = configuration; + return Task.CompletedTask; } + private async Task ExistingOrDefault(string schemaName, string tableName) => + await ExistingTable(schemaName, tableName) ?? tableName; + + + private string? AdminConnectionString { get; set; } protected string? ConnectionString { get; set; } @@ -266,6 +282,10 @@ private async Task RunSchemaExists() protected virtual async Task CreateScriptsRunTable() { + // Update scripts run table name with the correct casing, should it differ from the standard + + ScriptsRunTableName = await ExistingOrDefault(SchemaName, ScriptsRunTableName); + string createSql = $@" CREATE TABLE {ScriptsRunTable}( {_syntax.PrimaryKeyColumn("id")}, @@ -288,6 +308,9 @@ protected virtual async Task CreateScriptsRunTable() protected virtual async Task CreateScriptsRunErrorsTable() { + // Update scripts run errors table name with the correct casing, should it differ from the standard + ScriptsRunErrorsTableName = await ExistingOrDefault(SchemaName, ScriptsRunErrorsTableName); + string createSql = $@" CREATE TABLE {ScriptsRunErrorsTable}( {_syntax.PrimaryKeyColumn("id")}, @@ -310,6 +333,9 @@ protected virtual async Task CreateScriptsRunErrorsTable() protected virtual async Task CreateVersionTable() { + // Update version table name with the correct casing, should it differ from the standard + VersionTableName = await ExistingOrDefault(SchemaName, VersionTableName); + string createSql = $@" CREATE TABLE {VersionTable}( {_syntax.PrimaryKeyColumn("id")}, @@ -320,6 +346,7 @@ protected virtual async Task CreateVersionTable() entered_by {_syntax.VarcharType}(50) NULL {_syntax.PrimaryKeyConstraint("Version", "id")} )"; + if (!await VersionTableExists()) { await ExecuteNonQuery(ActiveConnection, createSql, Config?.CommandTimeout); @@ -338,21 +365,25 @@ ALTER TABLE {VersionTable} } } - protected async Task ScriptsRunTableExists() => await TableExists(SchemaName, "ScriptsRun"); - protected async Task ScriptsRunErrorsTableExists() => await TableExists(SchemaName, "ScriptsRunErrors"); - public async Task VersionTableExists() => await TableExists(SchemaName, "Version"); - protected async Task StatusColumnInVersionTableExists() => await ColumnExists(SchemaName, "Version", "status"); + protected async Task ScriptsRunTableExists() => (await ExistingTable(SchemaName, ScriptsRunTableName) is not null) ; + protected async Task ScriptsRunErrorsTableExists() => (await ExistingTable(SchemaName, ScriptsRunErrorsTableName) is not null); + public async Task VersionTableExists() => (await ExistingTable(SchemaName, VersionTableName) is not null); + + protected async Task StatusColumnInVersionTableExists() => await ColumnExists(SchemaName, VersionTableName, "status"); - public async Task TableExists(string schemaName, string tableName) + public async Task ExistingTable(string schemaName, string tableName) { var fullTableName = SupportsSchemas ? tableName : _syntax.TableWithSchema(schemaName, tableName); var tableSchema = SupportsSchemas ? schemaName : DatabaseName; - + string existsSql = ExistsSql(tableSchema, fullTableName); var res = await ExecuteScalarAsync(ActiveConnection, existsSql); - return !DBNull.Value.Equals(res) && res is not null; + var name = (!DBNull.Value.Equals(res) && res is not null) ? (string) res : null; + + var prefix = SupportsSchemas ? string.Empty : _syntax.TableWithSchema(schemaName, string.Empty); + return name?[prefix.Length..] ; } private async Task ColumnExists(string schemaName, string tableName, string columnName) @@ -369,10 +400,10 @@ private async Task ColumnExists(string schemaName, string tableName, strin protected virtual string ExistsSql(string tableSchema, string fullTableName) { return $@" -SELECT * FROM INFORMATION_SCHEMA.TABLES +SELECT table_name FROM INFORMATION_SCHEMA.TABLES WHERE -TABLE_SCHEMA = '{tableSchema}' AND -TABLE_NAME = '{fullTableName}' +LOWER(TABLE_SCHEMA) = LOWER('{tableSchema}') AND +LOWER(TABLE_NAME) = LOWER('{fullTableName}') "; } @@ -381,9 +412,9 @@ protected virtual string ExistsSql(string tableSchema, string fullTableName, str return $@" SELECT * FROM INFORMATION_SCHEMA.COLUMNS WHERE -TABLE_SCHEMA = '{tableSchema}' AND -TABLE_NAME = '{fullTableName}' AND -COLUMN_NAME = '{columnName}' +LOWER(TABLE_SCHEMA) = LOWER('{tableSchema}') AND +LOWER(TABLE_NAME) = LOWER('{fullTableName}') AND +LOWER(COLUMN_NAME) = LOWER('{columnName}') "; } diff --git a/grate/Migration/IDatabase.cs b/grate/Migration/IDatabase.cs index 708b13d7..26755cef 100644 --- a/grate/Migration/IDatabase.cs +++ b/grate/Migration/IDatabase.cs @@ -17,6 +17,7 @@ public interface IDatabase : IAsyncDisposable public string ScriptsRunErrorsTable { get; } public string VersionTable { get; } DbConnection ActiveConnection { set; } + bool SupportsSchemas { get; } Task InitializeConnections(GrateConfiguration configuration); Task OpenConnection(); @@ -48,4 +49,5 @@ Task InsertScriptRun(string scriptName, string? sql, string hash, bool runOnce, void SetDefaultConnectionActive(); Task OpenNewActiveConnection(); Task OpenActiveConnection(); + Task ExistingTable(string schemaName, string tableName); } diff --git a/grate/Migration/MariaDbDatabase.cs b/grate/Migration/MariaDbDatabase.cs index 527e2e06..93cda516 100644 --- a/grate/Migration/MariaDbDatabase.cs +++ b/grate/Migration/MariaDbDatabase.cs @@ -15,7 +15,7 @@ public MariaDbDatabase(ILogger logger) { } public override bool SupportsDdlTransactions => false; - protected override bool SupportsSchemas => false; + public override bool SupportsSchemas => false; protected override DbConnection GetSqlConnection(string? connectionString) => new MySqlConnection(connectionString); public override Task RestoreDatabase(string backupPath) diff --git a/grate/Migration/OracleDatabase.cs b/grate/Migration/OracleDatabase.cs index 940f28fa..576a1d1d 100644 --- a/grate/Migration/OracleDatabase.cs +++ b/grate/Migration/OracleDatabase.cs @@ -23,13 +23,13 @@ public OracleDatabase(ILogger logger) } public override bool SupportsDdlTransactions => false; - protected override bool SupportsSchemas => false; + public override bool SupportsSchemas => false; protected override DbConnection GetSqlConnection(string? connectionString) => new OracleConnection(connectionString); protected override string ExistsSql(string tableSchema, string fullTableName) => $@" -SELECT * FROM user_tables +SELECT table_name FROM user_tables WHERE lower(table_name) = '{fullTableName.ToLowerInvariant()}' "; diff --git a/grate/Migration/PostgreSqlDatabase.cs b/grate/Migration/PostgreSqlDatabase.cs index 2d7e7377..1dfb5afc 100644 --- a/grate/Migration/PostgreSqlDatabase.cs +++ b/grate/Migration/PostgreSqlDatabase.cs @@ -13,11 +13,11 @@ public PostgreSqlDatabase(ILogger logger) { } public override bool SupportsDdlTransactions => true; - protected override bool SupportsSchemas => true; + public override bool SupportsSchemas => true; protected override DbConnection GetSqlConnection(string? connectionString) => new NpgsqlConnection(connectionString); public override Task RestoreDatabase(string backupPath) { throw new System.NotImplementedException("Restoring a database from file is not currently supported for Postgresql."); } -} \ No newline at end of file +} diff --git a/grate/Migration/SqLiteDatabase.cs b/grate/Migration/SqLiteDatabase.cs index c087af24..25f1d607 100644 --- a/grate/Migration/SqLiteDatabase.cs +++ b/grate/Migration/SqLiteDatabase.cs @@ -17,14 +17,14 @@ public SqliteDatabase(ILogger logger) { } public override bool SupportsDdlTransactions => false; - protected override bool SupportsSchemas => false; + public override bool SupportsSchemas => false; protected override DbConnection GetSqlConnection(string? connectionString) => new SqliteConnection(connectionString); protected override string ExistsSql(string tableSchema, string fullTableName) => $@" SELECT name FROM sqlite_master WHERE type ='table' AND -name = '{fullTableName}'; +LOWER(name) = LOWER('{fullTableName}'); "; protected override string ExistsSql(string tableSchema, string fullTableName, string columnName) => diff --git a/grate/Migration/SqlServerDatabase.cs b/grate/Migration/SqlServerDatabase.cs index 51c39f6a..bf98f57c 100644 --- a/grate/Migration/SqlServerDatabase.cs +++ b/grate/Migration/SqlServerDatabase.cs @@ -16,7 +16,7 @@ public SqlServerDatabase(ILogger logger) { } public override bool SupportsDdlTransactions => true; - protected override bool SupportsSchemas => true; + public override bool SupportsSchemas => true; protected override DbConnection GetSqlConnection(string? connectionString) { // If pooling is not explicitly mentioned in the connection string, turn it off, as enabling it