-
Notifications
You must be signed in to change notification settings - Fork 19
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
f8bed4e
commit bf1e185
Showing
3 changed files
with
267 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,148 @@ | ||
package test | ||
|
||
import ( | ||
"context" | ||
"testing" | ||
|
||
"bisonai.com/orakl/node/pkg/dal/utils/multipgs" | ||
) | ||
|
||
func TestNewDatabase(t *testing.T) { | ||
ctx := context.Background() | ||
|
||
// Test with a valid connection string | ||
_, err := multipgs.NewDatabase(ctx, "DATABASE_URL") | ||
if err != nil { | ||
t.Errorf("NewDatabase with valid connection failed: %v", err) | ||
} | ||
|
||
// Test with an invalid connection string | ||
_, err = multipgs.NewDatabase(ctx, "INVALID") | ||
if err == nil { | ||
t.Error("NewDatabase with invalid connection did not return error") | ||
} | ||
} | ||
|
||
func TestQueryWithoutResult(t *testing.T) { | ||
ctx := context.Background() | ||
|
||
pool, err := multipgs.NewDatabase(ctx, "DATABASE_URL") | ||
if err != nil { | ||
t.Errorf("NewDatabase failed: %v", err) | ||
} | ||
|
||
// Create a temporary table | ||
_, err = pool.Exec(ctx, `CREATE TEMPORARY TABLE test (id SERIAL PRIMARY KEY, name TEXT)`) | ||
if err != nil { | ||
t.Fatalf("Failed to create temporary table: %v", err) | ||
} | ||
defer func() { | ||
// Clean up the temporary table | ||
_, err = pool.Exec(ctx, "DROP TABLE test") | ||
if err != nil { | ||
t.Fatalf("Failed to drop table: %v", err) | ||
} | ||
}() | ||
|
||
// Insert some test data | ||
_, err = pool.Exec(ctx, `INSERT INTO test (name) VALUES ('Alice'), ('Bob')`) | ||
if err != nil { | ||
t.Fatalf("Failed to insert test data: %v", err) | ||
} | ||
|
||
// Test with a valid query | ||
err = multipgs.QueryWithoutResult(ctx, "DATABASE_URL", "SELECT FROM test WHERE id = @name", map[string]interface{}{"name": "Alice"}) | ||
if err != nil { | ||
t.Errorf("QueryWithoutResult with valid query failed: %v", err) | ||
} | ||
|
||
// Test with an invalid connection | ||
err = multipgs.QueryWithoutResult(ctx, "INVALID", "SELECT FROM test WHERE id = @name", map[string]interface{}{"name": "Alice"}) | ||
if err == nil { | ||
t.Error("QueryWithoutResult with invalid connection did not return error") | ||
} | ||
} | ||
|
||
func TestQueryRow(t *testing.T) { | ||
ctx := context.Background() | ||
|
||
pool, err := multipgs.NewDatabase(ctx, "DATABASE_URL") | ||
if err != nil { | ||
t.Errorf("NewDatabase failed: %v", err) | ||
} | ||
|
||
// Create a temporary table | ||
_, err = pool.Exec(ctx, `CREATE TEMPORARY TABLE test (id SERIAL PRIMARY KEY, name TEXT)`) | ||
if err != nil { | ||
t.Fatalf("Failed to create temporary table: %v", err) | ||
} | ||
defer func() { | ||
// Clean up the temporary table | ||
_, err = pool.Exec(ctx, "DROP TABLE test") | ||
if err != nil { | ||
t.Fatalf("Failed to drop table: %v", err) | ||
} | ||
}() | ||
|
||
// Insert some test data | ||
_, err = pool.Exec(ctx, `INSERT INTO test (name) VALUES ('Alice'), ('Bob')`) | ||
if err != nil { | ||
t.Fatalf("Failed to insert test data: %v", err) | ||
} | ||
|
||
// Call the function being tested | ||
result, err := multipgs.QueryRow[struct { | ||
Name string `db:"name"` | ||
}](ctx, "DATABASE_URL", `SELECT name FROM test WHERE id = 1`, nil) | ||
if err != nil { | ||
t.Fatalf("QueryRow failed: %v", err) | ||
} | ||
|
||
// Check the result | ||
if result.Name != "Alice" { | ||
t.Errorf("Unexpected result: got %s, want Alice", result) | ||
} | ||
|
||
} | ||
|
||
func TestQueryRows(t *testing.T) { | ||
ctx := context.Background() | ||
|
||
pool, err := multipgs.NewDatabase(ctx, "DATABASE_URL") | ||
if err != nil { | ||
t.Errorf("NewDatabase failed: %v", err) | ||
} | ||
|
||
// Create a temporary table | ||
_, err = pool.Exec(ctx, `CREATE TEMPORARY TABLE test (id SERIAL PRIMARY KEY, name TEXT)`) | ||
if err != nil { | ||
t.Fatalf("Failed to create temporary table: %v", err) | ||
} | ||
defer func() { | ||
// Clean up the temporary table | ||
_, err = pool.Exec(ctx, "DROP TABLE test") | ||
if err != nil { | ||
t.Fatalf("Failed to drop table: %v", err) | ||
} | ||
}() | ||
|
||
// Insert some test data | ||
_, err = pool.Exec(ctx, `INSERT INTO test (name) VALUES ('Alice'), ('Bob')`) | ||
if err != nil { | ||
t.Fatalf("Failed to insert test data: %v", err) | ||
} | ||
|
||
// Call the function being tested | ||
results, err := multipgs.QueryRows[struct { | ||
ID int `db:"id"` | ||
Name string `db:"name"` | ||
}](ctx, "DATABASE_URL", `SELECT * FROM test`, nil) | ||
if err != nil { | ||
t.Fatalf("QueryRows failed: %v", err) | ||
} | ||
|
||
// Check the results | ||
if len(results) != 2 { | ||
t.Errorf("Unexpected number of results: got %d, want 2", len(results)) | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,117 @@ | ||
package multipgs | ||
|
||
import ( | ||
"context" | ||
"errors" | ||
"sync" | ||
|
||
errorSentinel "bisonai.com/orakl/node/pkg/error" | ||
"bisonai.com/orakl/node/pkg/secrets" | ||
"github.com/jackc/pgx/v5" | ||
"github.com/jackc/pgx/v5/pgxpool" | ||
"github.com/rs/zerolog/log" | ||
) | ||
|
||
var ( | ||
dbInstances sync.Map | ||
) | ||
|
||
func NewDatabase(ctx context.Context, connectionEnv string) (*pgxpool.Pool, error) { | ||
if db, ok := dbInstances.Load(connectionEnv); ok { | ||
return db.(*pgxpool.Pool), nil | ||
} | ||
|
||
connectionString := loadPgsqlConnectionString(connectionEnv) | ||
if connectionString == "" { | ||
log.Error().Msg("DATABASE_URL is not set for " + connectionEnv) | ||
return nil, errorSentinel.ErrDbDatabaseUrlNotFound | ||
} | ||
pool, err := connectToPgsql(ctx, connectionString) | ||
if err != nil { | ||
return nil, err | ||
} | ||
|
||
db := pool | ||
dbInstances.Store(connectionEnv, db) | ||
|
||
return db, nil | ||
} | ||
|
||
func CloseAll() { | ||
dbInstances.Range(func(key, value any) bool { | ||
value.(*pgxpool.Pool).Close() | ||
dbInstances.Delete(key) | ||
return true | ||
}) | ||
} | ||
|
||
func loadPgsqlConnectionString(connectionEnv string) string { | ||
return secrets.GetSecret(connectionEnv) | ||
} | ||
|
||
func connectToPgsql(ctx context.Context, connectionString string) (*pgxpool.Pool, error) { | ||
return pgxpool.New(ctx, connectionString) | ||
} | ||
|
||
func query(ctx context.Context, pool *pgxpool.Pool, queryString string, args map[string]any) (pgx.Rows, error) { | ||
return pool.Query(ctx, queryString, pgx.NamedArgs(args)) | ||
} | ||
|
||
func QueryWithoutResult(ctx context.Context, dbEnv string, queryString string, args map[string]any) error { | ||
pool, err := NewDatabase(ctx, dbEnv) | ||
if err != nil { | ||
return err | ||
} | ||
|
||
rows, err := query(ctx, pool, queryString, args) | ||
if err != nil { | ||
log.Error().Err(err).Str("query", queryString).Msg("Error querying") | ||
return err | ||
} | ||
defer rows.Close() | ||
return nil | ||
} | ||
|
||
func QueryRow[T any](ctx context.Context, dbEnv string, queryString string, args map[string]any) (T, error) { | ||
var result T | ||
|
||
pool, err := NewDatabase(ctx, dbEnv) | ||
if err != nil { | ||
return result, err | ||
} | ||
|
||
rows, err := query(ctx, pool, queryString, args) | ||
if err != nil { | ||
log.Error().Err(err).Str("query", queryString).Msg("Error querying") | ||
return result, err | ||
} | ||
|
||
result, err = pgx.CollectOneRow(rows, pgx.RowToStructByName[T]) | ||
if errors.Is(err, pgx.ErrNoRows) { | ||
return result, nil | ||
} | ||
defer rows.Close() | ||
return result, err | ||
} | ||
|
||
func QueryRows[T any](ctx context.Context, dbEnv string, queryString string, args map[string]any) ([]T, error) { | ||
results := []T{} | ||
|
||
pool, err := NewDatabase(ctx, dbEnv) | ||
if err != nil { | ||
return results, err | ||
} | ||
|
||
rows, err := query(ctx, pool, queryString, args) | ||
if err != nil { | ||
log.Error().Err(err).Str("query", queryString).Msg("Error querying") | ||
return results, err | ||
} | ||
|
||
results, err = pgx.CollectRows(rows, pgx.RowToStructByName[T]) | ||
if errors.Is(err, pgx.ErrNoRows) { | ||
return results, nil | ||
} | ||
defer rows.Close() | ||
return results, err | ||
} |