diff --git a/node/pkg/db/pgsql.go b/node/pkg/db/pgsql.go index 4213f0290..caf723143 100644 --- a/node/pkg/db/pgsql.go +++ b/node/pkg/db/pgsql.go @@ -6,6 +6,7 @@ import ( "fmt" "strings" "sync" + "time" errorSentinel "bisonai.com/miko/node/pkg/error" "bisonai.com/miko/node/pkg/secrets" @@ -22,6 +23,8 @@ var ( poolErr error ) +const DefaultDBTimeout = 15 * time.Second + func GetPool(ctx context.Context) (*pgxpool.Pool, error) { return getPool(ctx, &initPgxOnce) } @@ -100,7 +103,10 @@ func QueryRows[T any](ctx context.Context, queryString string, args map[string]a } func query(ctx context.Context, pool *pgxpool.Pool, query string, args map[string]any) (pgx.Rows, error) { - return pool.Query(ctx, query, pgx.NamedArgs(args)) + ctxWithTimeout, cancel := context.WithTimeout(ctx, DefaultDBTimeout) + defer cancel() + + return pool.Query(ctxWithTimeout, query, pgx.NamedArgs(args)) } func queryRow[T any](ctx context.Context, pool *pgxpool.Pool, queryString string, args map[string]any) (T, error) { diff --git a/node/pkg/db/pgsql_test.go b/node/pkg/db/pgsql_test.go index 2b42f98d8..98a723dfb 100644 --- a/node/pkg/db/pgsql_test.go +++ b/node/pkg/db/pgsql_test.go @@ -2,6 +2,8 @@ package db import ( "context" + "errors" + "os" "reflect" "testing" ) @@ -109,8 +111,8 @@ func TestQueryRow(t *testing.T) { if result.Name != "Alice" { t.Errorf("Unexpected result: got %s, want Alice", result) } - } + func TestQueryRows(t *testing.T) { ctx := context.Background() pool, err := GetPool(ctx) @@ -500,3 +502,62 @@ func TestBulkSelect(t *testing.T) { }) } + +func TestQueryTimeout(t *testing.T) { + // Setting up the context with a short timeout + ctx := context.Background() + + pool, err := GetPool(ctx) + if err != nil { + t.Fatalf("GetPool failed: %v", err) + } + + // Create a temporary table (optional, depending on your test needs) + _, err = pool.Exec(ctx, `CREATE TEMPORARY TABLE test_timeout (id SERIAL PRIMARY KEY, name TEXT)`) + if err != nil { + t.Fatalf("Failed to create temporary table: %v", err) + } + + // Simulate a long-running query using pg_sleep (2 seconds) + _, err = QueryRow[struct { + Name string `db:"name"` + }](ctx, `SELECT pg_sleep(16)`, nil) + + // Check for context.DeadlineExceeded error + if err == nil { + t.Fatalf("Expected timeout error but got none") + } + if !errors.Is(err, context.DeadlineExceeded) { + t.Fatalf("Expected context.DeadlineExceeded error, but got: %v", err) + } +} + +func TestQueryTimeoutTransient(t *testing.T) { + ctx := context.Background() + + dbUrl := os.Getenv("DATABASE_URL") + pool, err := GetTransientPool(ctx, dbUrl) + if err != nil { + t.Fatalf("GetPool failed: %v", err) + } + defer pool.Close() + + // Create a temporary table (optional, depending on your test needs) + _, err = pool.Exec(ctx, `CREATE TEMPORARY TABLE test_timeout (id SERIAL PRIMARY KEY, name TEXT)`) + if err != nil { + t.Fatalf("Failed to create temporary table: %v", err) + } + + // Simulate a long-running query using pg_sleep (2 seconds) + _, err = QueryRow[struct { + Name string `db:"name"` + }](ctx, `SELECT pg_sleep(16)`, nil) + + // Check for context.DeadlineExceeded error + if err == nil { + t.Fatalf("Expected timeout error but got none") + } + if !errors.Is(err, context.DeadlineExceeded) { + t.Fatalf("Expected context.DeadlineExceeded error, but got: %v", err) + } +}