Skip to content

Commit

Permalink
wip
Browse files Browse the repository at this point in the history
  • Loading branch information
nick-bisonai committed Jul 10, 2024
1 parent f8bed4e commit bf1e185
Show file tree
Hide file tree
Showing 3 changed files with 267 additions and 2 deletions.
4 changes: 2 additions & 2 deletions node/pkg/dal/tests/main_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ import (
"bisonai.com/orakl/node/pkg/common/types"
"bisonai.com/orakl/node/pkg/dal/api"
"bisonai.com/orakl/node/pkg/dal/collector"
"bisonai.com/orakl/node/pkg/dal/utils"
"bisonai.com/orakl/node/pkg/dal/utils/initializer"
"bisonai.com/orakl/node/pkg/db"
"github.com/gofiber/fiber/v2"
"github.com/rs/zerolog"
Expand Down Expand Up @@ -82,7 +82,7 @@ func setup(ctx context.Context) (func() error, *TestItems, error) {
}
testItems.TmpConfig = tmpConfig

app, err := utils.Setup(ctx)
app, err := initializer.Setup(ctx)
if err != nil {
return nil, nil, err
}
Expand Down
148 changes: 148 additions & 0 deletions node/pkg/dal/tests/multipgs_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,148 @@
package test

import (
"context"
"testing"

"bisonai.com/orakl/node/pkg/dal/utils/multipgs"
)

func TestNewDatabase(t *testing.T) {
ctx := context.Background()

// Test with a valid connection string
_, err := multipgs.NewDatabase(ctx, "DATABASE_URL")
if err != nil {
t.Errorf("NewDatabase with valid connection failed: %v", err)
}

// Test with an invalid connection string
_, err = multipgs.NewDatabase(ctx, "INVALID")
if err == nil {
t.Error("NewDatabase with invalid connection did not return error")
}
}

func TestQueryWithoutResult(t *testing.T) {
ctx := context.Background()

pool, err := multipgs.NewDatabase(ctx, "DATABASE_URL")
if err != nil {
t.Errorf("NewDatabase failed: %v", err)
}

// Create a temporary table
_, err = pool.Exec(ctx, `CREATE TEMPORARY TABLE test (id SERIAL PRIMARY KEY, name TEXT)`)
if err != nil {
t.Fatalf("Failed to create temporary table: %v", err)
}
defer func() {
// Clean up the temporary table
_, err = pool.Exec(ctx, "DROP TABLE test")
if err != nil {
t.Fatalf("Failed to drop table: %v", err)
}
}()

// Insert some test data
_, err = pool.Exec(ctx, `INSERT INTO test (name) VALUES ('Alice'), ('Bob')`)
if err != nil {
t.Fatalf("Failed to insert test data: %v", err)
}

// Test with a valid query
err = multipgs.QueryWithoutResult(ctx, "DATABASE_URL", "SELECT FROM test WHERE id = @name", map[string]interface{}{"name": "Alice"})
if err != nil {
t.Errorf("QueryWithoutResult with valid query failed: %v", err)
}

// Test with an invalid connection
err = multipgs.QueryWithoutResult(ctx, "INVALID", "SELECT FROM test WHERE id = @name", map[string]interface{}{"name": "Alice"})
if err == nil {
t.Error("QueryWithoutResult with invalid connection did not return error")
}
}

func TestQueryRow(t *testing.T) {
ctx := context.Background()

pool, err := multipgs.NewDatabase(ctx, "DATABASE_URL")
if err != nil {
t.Errorf("NewDatabase failed: %v", err)
}

// Create a temporary table
_, err = pool.Exec(ctx, `CREATE TEMPORARY TABLE test (id SERIAL PRIMARY KEY, name TEXT)`)
if err != nil {
t.Fatalf("Failed to create temporary table: %v", err)
}
defer func() {
// Clean up the temporary table
_, err = pool.Exec(ctx, "DROP TABLE test")
if err != nil {
t.Fatalf("Failed to drop table: %v", err)
}
}()

// Insert some test data
_, err = pool.Exec(ctx, `INSERT INTO test (name) VALUES ('Alice'), ('Bob')`)
if err != nil {
t.Fatalf("Failed to insert test data: %v", err)
}

// Call the function being tested
result, err := multipgs.QueryRow[struct {
Name string `db:"name"`
}](ctx, "DATABASE_URL", `SELECT name FROM test WHERE id = 1`, nil)
if err != nil {
t.Fatalf("QueryRow failed: %v", err)
}

// Check the result
if result.Name != "Alice" {
t.Errorf("Unexpected result: got %s, want Alice", result)
}

}

func TestQueryRows(t *testing.T) {
ctx := context.Background()

pool, err := multipgs.NewDatabase(ctx, "DATABASE_URL")
if err != nil {
t.Errorf("NewDatabase failed: %v", err)
}

// Create a temporary table
_, err = pool.Exec(ctx, `CREATE TEMPORARY TABLE test (id SERIAL PRIMARY KEY, name TEXT)`)
if err != nil {
t.Fatalf("Failed to create temporary table: %v", err)
}
defer func() {
// Clean up the temporary table
_, err = pool.Exec(ctx, "DROP TABLE test")
if err != nil {
t.Fatalf("Failed to drop table: %v", err)
}
}()

// Insert some test data
_, err = pool.Exec(ctx, `INSERT INTO test (name) VALUES ('Alice'), ('Bob')`)
if err != nil {
t.Fatalf("Failed to insert test data: %v", err)
}

// Call the function being tested
results, err := multipgs.QueryRows[struct {
ID int `db:"id"`
Name string `db:"name"`
}](ctx, "DATABASE_URL", `SELECT * FROM test`, nil)
if err != nil {
t.Fatalf("QueryRows failed: %v", err)
}

// Check the results
if len(results) != 2 {
t.Errorf("Unexpected number of results: got %d, want 2", len(results))
}
}
117 changes: 117 additions & 0 deletions node/pkg/dal/utils/multipgs/multipgs.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
package multipgs

import (
"context"
"errors"
"sync"

errorSentinel "bisonai.com/orakl/node/pkg/error"
"bisonai.com/orakl/node/pkg/secrets"
"github.com/jackc/pgx/v5"
"github.com/jackc/pgx/v5/pgxpool"
"github.com/rs/zerolog/log"
)

var (
dbInstances sync.Map
)

func NewDatabase(ctx context.Context, connectionEnv string) (*pgxpool.Pool, error) {
if db, ok := dbInstances.Load(connectionEnv); ok {
return db.(*pgxpool.Pool), nil
}

connectionString := loadPgsqlConnectionString(connectionEnv)
if connectionString == "" {
log.Error().Msg("DATABASE_URL is not set for " + connectionEnv)
return nil, errorSentinel.ErrDbDatabaseUrlNotFound
}
pool, err := connectToPgsql(ctx, connectionString)
if err != nil {
return nil, err
}

db := pool
dbInstances.Store(connectionEnv, db)

return db, nil
}

func CloseAll() {
dbInstances.Range(func(key, value any) bool {
value.(*pgxpool.Pool).Close()
dbInstances.Delete(key)
return true
})
}

func loadPgsqlConnectionString(connectionEnv string) string {
return secrets.GetSecret(connectionEnv)
}

func connectToPgsql(ctx context.Context, connectionString string) (*pgxpool.Pool, error) {
return pgxpool.New(ctx, connectionString)
}

func query(ctx context.Context, pool *pgxpool.Pool, queryString string, args map[string]any) (pgx.Rows, error) {
return pool.Query(ctx, queryString, pgx.NamedArgs(args))
}

func QueryWithoutResult(ctx context.Context, dbEnv string, queryString string, args map[string]any) error {
pool, err := NewDatabase(ctx, dbEnv)
if err != nil {
return err
}

rows, err := query(ctx, pool, queryString, args)
if err != nil {
log.Error().Err(err).Str("query", queryString).Msg("Error querying")
return err
}
defer rows.Close()
return nil
}

func QueryRow[T any](ctx context.Context, dbEnv string, queryString string, args map[string]any) (T, error) {
var result T

pool, err := NewDatabase(ctx, dbEnv)
if err != nil {
return result, err
}

rows, err := query(ctx, pool, queryString, args)
if err != nil {
log.Error().Err(err).Str("query", queryString).Msg("Error querying")
return result, err
}

result, err = pgx.CollectOneRow(rows, pgx.RowToStructByName[T])
if errors.Is(err, pgx.ErrNoRows) {
return result, nil
}
defer rows.Close()
return result, err
}

func QueryRows[T any](ctx context.Context, dbEnv string, queryString string, args map[string]any) ([]T, error) {
results := []T{}

pool, err := NewDatabase(ctx, dbEnv)
if err != nil {
return results, err
}

rows, err := query(ctx, pool, queryString, args)
if err != nil {
log.Error().Err(err).Str("query", queryString).Msg("Error querying")
return results, err
}

results, err = pgx.CollectRows(rows, pgx.RowToStructByName[T])
if errors.Is(err, pgx.ErrNoRows) {
return results, nil
}
defer rows.Close()
return results, err
}

0 comments on commit bf1e185

Please sign in to comment.