Skip to content

Commit

Permalink
feat: update with testcode
Browse files Browse the repository at this point in the history
  • Loading branch information
nick-bisonai committed Aug 22, 2024
1 parent 5bd05cb commit 6e2d4f5
Show file tree
Hide file tree
Showing 2 changed files with 69 additions and 2 deletions.
8 changes: 7 additions & 1 deletion node/pkg/db/pgsql.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"fmt"
"strings"
"sync"
"time"

errorSentinel "bisonai.com/miko/node/pkg/error"
"bisonai.com/miko/node/pkg/secrets"
Expand All @@ -22,6 +23,8 @@ var (
poolErr error
)

const DefaultDBTimeout = 15 * time.Second

func GetPool(ctx context.Context) (*pgxpool.Pool, error) {
return getPool(ctx, &initPgxOnce)
}
Expand Down Expand Up @@ -100,7 +103,10 @@ func QueryRows[T any](ctx context.Context, queryString string, args map[string]a
}

func query(ctx context.Context, pool *pgxpool.Pool, query string, args map[string]any) (pgx.Rows, error) {
return pool.Query(ctx, query, pgx.NamedArgs(args))
ctxWithTimeout, cancel := context.WithTimeout(ctx, DefaultDBTimeout)
defer cancel()

return pool.Query(ctxWithTimeout, query, pgx.NamedArgs(args))
}

func queryRow[T any](ctx context.Context, pool *pgxpool.Pool, queryString string, args map[string]any) (T, error) {
Expand Down
63 changes: 62 additions & 1 deletion node/pkg/db/pgsql_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@ package db

import (
"context"
"errors"
"os"
"reflect"
"testing"
)
Expand Down Expand Up @@ -109,8 +111,8 @@ func TestQueryRow(t *testing.T) {
if result.Name != "Alice" {
t.Errorf("Unexpected result: got %s, want Alice", result)
}

}

func TestQueryRows(t *testing.T) {
ctx := context.Background()
pool, err := GetPool(ctx)
Expand Down Expand Up @@ -500,3 +502,62 @@ func TestBulkSelect(t *testing.T) {
})

}

func TestQueryTimeout(t *testing.T) {
// Setting up the context with a short timeout
ctx := context.Background()

pool, err := GetPool(ctx)
if err != nil {
t.Fatalf("GetPool failed: %v", err)
}

// Create a temporary table (optional, depending on your test needs)
_, err = pool.Exec(ctx, `CREATE TEMPORARY TABLE test_timeout (id SERIAL PRIMARY KEY, name TEXT)`)
if err != nil {
t.Fatalf("Failed to create temporary table: %v", err)
}

// Simulate a long-running query using pg_sleep (2 seconds)
_, err = QueryRow[struct {
Name string `db:"name"`
}](ctx, `SELECT pg_sleep(16)`, nil)

// Check for context.DeadlineExceeded error
if err == nil {
t.Fatalf("Expected timeout error but got none")
}
if !errors.Is(err, context.DeadlineExceeded) {
t.Fatalf("Expected context.DeadlineExceeded error, but got: %v", err)
}
}

func TestQueryTimeoutTransient(t *testing.T) {
ctx := context.Background()

dbUrl := os.Getenv("DATABASE_URL")
pool, err := GetTransientPool(ctx, dbUrl)
if err != nil {
t.Fatalf("GetPool failed: %v", err)
}
defer pool.Close()

// Create a temporary table (optional, depending on your test needs)
_, err = pool.Exec(ctx, `CREATE TEMPORARY TABLE test_timeout (id SERIAL PRIMARY KEY, name TEXT)`)
if err != nil {
t.Fatalf("Failed to create temporary table: %v", err)
}

// Simulate a long-running query using pg_sleep (2 seconds)
_, err = QueryRow[struct {
Name string `db:"name"`
}](ctx, `SELECT pg_sleep(16)`, nil)

// Check for context.DeadlineExceeded error
if err == nil {
t.Fatalf("Expected timeout error but got none")
}
if !errors.Is(err, context.DeadlineExceeded) {
t.Fatalf("Expected context.DeadlineExceeded error, but got: %v", err)
}
}

0 comments on commit 6e2d4f5

Please sign in to comment.