Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Commit

Permalink
enabling connection pooling (#89)
Browse files Browse the repository at this point in the history
Signed-off-by: Daniel Rammer <[email protected]>

Signed-off-by: Daniel Rammer <[email protected]>
  • Loading branch information
hamersaw committed Jan 25, 2023
1 parent faa86db commit 24d8aba
Show file tree
Hide file tree
Showing 2 changed files with 72 additions and 6 deletions.
24 changes: 18 additions & 6 deletions pkg/repositories/config/postgres.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,13 @@ import (
"context"
"fmt"

"github.com/flyteorg/flytestdlib/database"
stdlibLogger "github.com/flyteorg/flytestdlib/logger"

"gorm.io/gorm/logger"

"github.com/flyteorg/flytestdlib/promutils"

"github.com/flyteorg/flytestdlib/database"
"gorm.io/driver/postgres"
"gorm.io/gorm"
"gorm.io/gorm/logger"
)

const (
Expand Down Expand Up @@ -70,12 +68,26 @@ func (p *PostgresConfigProvider) GetDBConfig() database.DbConfig {
// Opens a connection to the database specified in the config.
// You must call CloseDbConnection at the end of your session!
func OpenDbConnection(ctx context.Context, config DbConnectionConfigProvider) (*gorm.DB, error) {
dbConfig := config.GetDBConfig()

db, err := gorm.Open(config.GetDialector(), &gorm.Config{
Logger: database.GetGormLogger(ctx, stdlibLogger.GetConfig()),
DisableForeignKeyConstraintWhenMigrating: !config.GetDBConfig().EnableForeignKeyConstraintWhenMigrating,
DisableForeignKeyConstraintWhenMigrating: !dbConfig.EnableForeignKeyConstraintWhenMigrating,
})
if err != nil {
return nil, err
}
return db, nil

return db, setupDbConnectionPool(db, &dbConfig)
}

func setupDbConnectionPool(gormDb *gorm.DB, dbConfig *database.DbConfig) error {
genericDb, err := gormDb.DB()
if err != nil {
return err
}
genericDb.SetConnMaxLifetime(dbConfig.ConnMaxLifeTime.Duration)
genericDb.SetMaxIdleConns(dbConfig.MaxIdleConnections)
genericDb.SetMaxOpenConns(dbConfig.MaxOpenConnections)
return nil
}
54 changes: 54 additions & 0 deletions pkg/repositories/config/postgres_test.go
Original file line number Diff line number Diff line change
@@ -1,12 +1,19 @@
package config

import (
"os"
"path/filepath"
"testing"
"time"

"github.com/flyteorg/flytestdlib/config"
"github.com/flyteorg/flytestdlib/database"
mockScope "github.com/flyteorg/flytestdlib/promutils"

"github.com/stretchr/testify/assert"

"gorm.io/driver/sqlite"
"gorm.io/gorm"
)

func TestConstructGormArgs(t *testing.T) {
Expand Down Expand Up @@ -50,3 +57,50 @@ func TestConstructGormArgsWithPasswordNoExtra(t *testing.T) {

assert.Equal(t, "host=localhost port=5432 dbname=postgres user=postgres password=pass ", postgresConfigProvider.GetDSN())
}

func TestSetupDbConnectionPool(t *testing.T) {
t.Run("successful", func(t *testing.T) {
gormDb, err := gorm.Open(sqlite.Open(filepath.Join(os.TempDir(), "gorm.db")), &gorm.Config{})
assert.Nil(t, err)
dbConfig := &database.DbConfig{
DeprecatedPort: 5432,
MaxIdleConnections: 10,
MaxOpenConnections: 1000,
ConnMaxLifeTime: config.Duration{Duration: time.Hour},
}
err = setupDbConnectionPool(gormDb, dbConfig)
assert.Nil(t, err)
genericDb, err := gormDb.DB()
assert.Nil(t, err)
assert.Equal(t, genericDb.Stats().MaxOpenConnections, 1000)
})
t.Run("unset duration", func(t *testing.T) {
gormDb, err := gorm.Open(sqlite.Open(filepath.Join(os.TempDir(), "gorm.db")), &gorm.Config{})
assert.Nil(t, err)
dbConfig := &database.DbConfig{
DeprecatedPort: 5432,
MaxIdleConnections: 10,
MaxOpenConnections: 1000,
}
err = setupDbConnectionPool(gormDb, dbConfig)
assert.Nil(t, err)
genericDb, err := gormDb.DB()
assert.Nil(t, err)
assert.Equal(t, genericDb.Stats().MaxOpenConnections, 1000)
})
t.Run("failed to get DB", func(t *testing.T) {
gormDb := &gorm.DB{
Config: &gorm.Config{
ConnPool: &gorm.PreparedStmtDB{},
},
}
dbConfig := &database.DbConfig{
DeprecatedPort: 5432,
MaxIdleConnections: 10,
MaxOpenConnections: 1000,
ConnMaxLifeTime: config.Duration{Duration: time.Hour},
}
err := setupDbConnectionPool(gormDb, dbConfig)
assert.NotNil(t, err)
})
}

0 comments on commit 24d8aba

Please sign in to comment.