Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

database: Implement connection pooling and context cacellation #18

Merged
merged 5 commits into from
Jan 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 19 additions & 7 deletions database/db_operations.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,19 +7,19 @@ import (
"reflect"

"github.com/Masterminds/squirrel"
"github.com/jackc/pgx/v5"
"github.com/jackc/pgx/v5/pgxpool"
)

var psql = squirrel.StatementBuilder.PlaceholderFormat(squirrel.Dollar)
var db *pgx.Conn
var pool *pgxpool.Pool

func InitDB() (*pgx.Conn, error) {
conn, err := pgx.Connect(context.Background(), os.Getenv("DATABASE_URL"))
func InitDB() (*pgxpool.Pool, error) {
conn, err := pgxpool.New(context.Background(), os.Getenv("DATABASE_URL"))
if err != nil {
return nil, fmt.Errorf("unable to connect to database: %v", err)
}

db = conn // Store the connection in a package-level variable
pool = conn // Store the connection in a package-level variable

return conn, nil
}
Expand All @@ -36,7 +36,13 @@ func SelectOne[T any](table string, filter map[string]string, result interface{}
return err
}

row := db.QueryRow(context.Background(), sql, args...)
conn, err := pool.Acquire(context.Background())
defer conn.Release()
if err != nil {
return err
}

row := conn.QueryRow(context.Background(), sql, args...)

// Get type information of the result
valueType := reflect.TypeOf(result).Elem()
Expand Down Expand Up @@ -75,7 +81,13 @@ func SelectMany[T any](table string, filter map[string]string, result *[]T) erro
return err
}

rows, err := db.Query(context.Background(), sql, args...)
conn, err := pool.Acquire(context.Background())
defer conn.Release()
if err != nil {
return err
}

rows, err := conn.Query(context.Background(), sql, args...)
if err != nil {
return err
}
Expand Down
152 changes: 41 additions & 111 deletions database/guilds.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ import (
"github.com/Masterminds/squirrel"
)

func SelectGuild(f map[string]string) (models.Guild, error) {
func SelectGuild(ctx context.Context, f map[string]string) (models.Guild, error) {
query := psql.Select("*").From("guilds")

for key, value := range f {
Expand All @@ -21,7 +21,13 @@ func SelectGuild(f map[string]string) (models.Guild, error) {
return models.Guild{}, err
}

row := db.QueryRow(context.Background(), sql, args...)
conn, err := pool.Acquire(ctx)
if err != nil {
return models.Guild{}, err
}
defer conn.Release()

row := conn.QueryRow(ctx, sql, args...)

var guild models.Guild

Expand All @@ -33,47 +39,57 @@ func SelectGuild(f map[string]string) (models.Guild, error) {
return guild, nil
}

func InsertGuild(g string) error {
tx, err := db.Begin(context.Background())
func InsertGuild(ctx context.Context, g string) error {
conn, err := pool.Acquire(ctx)
if err != nil {
return err
}
defer tx.Rollback(context.Background()) // Rollback the transaction if it hasn't been committed
defer conn.Release()

query := psql.Insert("guilds").Columns("guild_id").Values(g)
sql, args, err := query.ToSql()
if err != nil {
return err
}

commandTag, err := db.Exec(context.Background(), sql, args...)
tx, err := conn.Begin(ctx)
if err != nil {
return err
}
defer tx.Rollback(ctx)

if commandTag.RowsAffected() != 1 {
return fmt.Errorf("expected 1 row to be affected, got %d", commandTag.RowsAffected())
commandTagGuild, err := tx.Exec(ctx, sql, args...)
if err != nil {
return err
}

err = InitializeGuildCategories(g)
sql = `INSERT INTO guild_categories(guild_id, category, message_id)
SELECT $1, name, '' FROM categories`
_, err = tx.Exec(ctx, sql, g)
if err != nil {
return err
}

err = InitializeGuildBosses(g)
sql = `INSERT INTO guild_bosses(guild_id, boss, pb_id)
SELECT $1, name, NULL FROM bosses`
_, err = tx.Exec(ctx, sql, g)
if err != nil {
return err
}

err = tx.Commit(context.Background())
err = tx.Commit(ctx)
if err != nil {
return err
}

if rows := commandTagGuild.RowsAffected(); rows != 1 {
return fmt.Errorf("expected 1 row to be affected, got %d", rows)
}

return nil
}

func DeleteGuild(f map[string]string) error {
func DeleteGuild(ctx context.Context, f map[string]string) error {
query := psql.Delete("guilds")

for key, value := range f {
Expand All @@ -85,7 +101,13 @@ func DeleteGuild(f map[string]string) error {
return err
}

commandTag, err := db.Exec(context.Background(), sql, args...)
conn, err := pool.Acquire(ctx)
if err != nil {
return err
}
defer conn.Release()

commandTag, err := conn.Exec(ctx, sql, args...)
if err != nil {
return err
}
Expand All @@ -97,7 +119,7 @@ func DeleteGuild(f map[string]string) error {
return nil
}

func UpdateGuild(g string, f map[string]string) error {
func UpdateGuild(ctx context.Context, g string, f map[string]string) error {
// .SetMap requires an empty interface map so we convert here
// As f passes parameters caught from HTTP requests, thus always string
filter := utils.StringMapToInterfaceMap(f)
Expand All @@ -108,111 +130,19 @@ func UpdateGuild(g string, f map[string]string) error {
return err
}

commandTag, err := db.Exec(context.Background(), sql, args...)
if err != nil {
return err
}

if commandTag.RowsAffected() != 1 {
return fmt.Errorf("expected 1 row to be affected, got %d", commandTag.RowsAffected())
}

return nil
}

func InitializeGuildBosses(g string) error {
sql, args, err := psql.Select("name").From("bosses").ToSql()
if err != nil {
return err
}

rows, err := db.Query(context.Background(), sql, args...)
if err != nil {
return err
}

bosses := make([]map[string]interface{}, 0)

for rows.Next() {
var b string
// Assuming the columns are named "column1", "column2", etc.
if err := rows.Scan(&b); err != nil {
return err
}

boss := map[string]interface{}{
"boss": b,
"guild_id": g,
"pb_id": nil,
}

bosses = append(bosses, boss)
}

// Insert data into guild_bosses
query := psql.Insert("guild_bosses").Columns("boss", "guild_id", "pb_id")

for _, v := range bosses {
query = query.Values(v["boss"], v["guild_id"], v["pb_id"])
}

sql, args, err = query.ToSql()
if err != nil {
return err
}

_, err = db.Exec(context.Background(), sql, args...)
if err != nil {
return err
}

return nil
}

func InitializeGuildCategories(g string) error {
sql, args, err := psql.Select("name").From("categories").ToSql()
if err != nil {
return err
}

rows, err := db.Query(context.Background(), sql, args...)
conn, err := pool.Acquire(ctx)
if err != nil {
return err
}
defer conn.Release()

categories := make([]map[string]interface{}, 0)

for rows.Next() {
var c string
// Assuming the columns are named "column1", "column2", etc.
if err := rows.Scan(&c); err != nil {
return err
}

category := map[string]interface{}{
"guild_id": g,
"category": c,
"message_id": "",
}

categories = append(categories, category)
}

// Insert data into guild_bosses
query := psql.Insert("guild_categories").Columns("guild_id", "category", "message_id")

for _, v := range categories {
query = query.Values(v["guild_id"], v["category"], v["message_id"])
}

sql, args, err = query.ToSql()
commandTag, err := conn.Exec(ctx, sql, args...)
if err != nil {
return err
}

_, err = db.Exec(context.Background(), sql, args...)
if err != nil {
return err
if commandTag.RowsAffected() != 1 {
return fmt.Errorf("expected 1 row to be affected, got %d", commandTag.RowsAffected())
}

return nil
Expand Down
10 changes: 8 additions & 2 deletions database/leaderboard.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ import (
"github.com/Masterminds/squirrel"
)

func SelectLeaderboard(filter map[string]string) (models.Users, error) {
func SelectLeaderboard(ctx context.Context, filter map[string]string) (models.Users, error) {
query := psql.Select("*").From("users").OrderBy("points DESC").Limit(50)

for key, value := range filter {
Expand All @@ -19,8 +19,14 @@ func SelectLeaderboard(filter map[string]string) (models.Users, error) {
return models.Users{}, err
}

conn, err := pool.Acquire(ctx)
if err != nil {
return models.Users{}, err
}
defer conn.Release()

// Executing the query
rows, err := db.Query(context.Background(), sql, args...)
rows, err := conn.Query(ctx, sql, args...)
if err != nil {
return models.Users{}, err
}
Expand Down
Loading