diff --git a/.github/workflows/code-check.yml b/.github/workflows/code-check.yml index 1431262..6c92d67 100644 --- a/.github/workflows/code-check.yml +++ b/.github/workflows/code-check.yml @@ -43,3 +43,7 @@ jobs: - name: Run pre-commit checks run: pre-commit run --all-files + + - name: Run race condition tests + run: | + go test ./... -race -timeout=10s -v --tags=race diff --git a/driver.go b/driver.go index 3cd3dff..2a253d1 100644 --- a/driver.go +++ b/driver.go @@ -4,6 +4,7 @@ import ( "context" "database/sql" "database/sql/driver" + "sync" "github.com/firebolt-db/firebolt-go-sdk/client" @@ -12,6 +13,7 @@ import ( ) type FireboltDriver struct { + mutex sync.RWMutex engineUrl string cachedParams map[string]string client client.Client @@ -38,29 +40,39 @@ func copyMap(original map[string]string) map[string]string { func (d *FireboltDriver) OpenConnector(dsn string) (driver.Connector, error) { logging.Infolog.Println("Opening firebolt connector") - if d.lastUsedDsn != dsn || d.lastUsedDsn == "" { - - d.lastUsedDsn = "" //nolintd - logging.Infolog.Println("constructing new client") - // parsing dsn string to get configuration settings - settings, err := ParseDSNString(dsn) - if err != nil { - return nil, errors.Wrap(errors.DSNParseError, err) - } - - // authenticating and getting access token - logging.Infolog.Println("dsn parsed correctly, trying to authenticate") - d.client, err = client.ClientFactory(settings, client.GetHostNameURL()) - if err != nil { - return nil, errors.ConstructNestedError("error during initializing client", err) - } - - d.engineUrl, d.cachedParams, err = d.client.GetConnectionParameters(context.TODO(), settings.EngineName, settings.Database) - if err != nil { - return nil, errors.ConstructNestedError("error during getting connection parameters", err) - } - d.lastUsedDsn = dsn //nolint + d.mutex.RLock() + if d.lastUsedDsn == dsn && d.lastUsedDsn != "" { + connector := &FireboltConnector{d.engineUrl, d.client, copyMap(d.cachedParams), d} + d.mutex.RUnlock() + return connector, nil } + d.mutex.RUnlock() + + d.mutex.Lock() + defer d.mutex.Unlock() + + if d.lastUsedDsn == dsn && d.lastUsedDsn != "" { + return &FireboltConnector{d.engineUrl, d.client, copyMap(d.cachedParams), d}, nil + } + + d.lastUsedDsn = "" + logging.Infolog.Println("constructing new client") + settings, err := ParseDSNString(dsn) + if err != nil { + return nil, errors.Wrap(errors.DSNParseError, err) + } + + logging.Infolog.Println("dsn parsed correctly, trying to authenticate") + d.client, err = client.ClientFactory(settings, client.GetHostNameURL()) + if err != nil { + return nil, errors.ConstructNestedError("error during initializing client", err) + } + + d.engineUrl, d.cachedParams, err = d.client.GetConnectionParameters(context.TODO(), settings.EngineName, settings.Database) + if err != nil { + return nil, errors.ConstructNestedError("error during getting connection parameters", err) + } + d.lastUsedDsn = dsn return &FireboltConnector{d.engineUrl, d.client, copyMap(d.cachedParams), d}, nil } diff --git a/driver_options.go b/driver_options.go index 1cda433..9c14ed6 100644 --- a/driver_options.go +++ b/driver_options.go @@ -20,6 +20,8 @@ func NoError(option driverOption) driverOptionWithError { // WithEngineUrl defines engine url for the driver func WithEngineUrl(engineUrl string) driverOption { return func(d *FireboltDriver) { + d.mutex.Lock() + defer d.mutex.Unlock() d.engineUrl = engineUrl } } @@ -27,6 +29,8 @@ func WithEngineUrl(engineUrl string) driverOption { // WithDatabaseName defines database name for the driver func WithDatabaseName(databaseName string) driverOption { return func(d *FireboltDriver) { + d.mutex.Lock() + defer d.mutex.Unlock() if d.cachedParams == nil { d.cachedParams = map[string]string{} } @@ -37,6 +41,8 @@ func WithDatabaseName(databaseName string) driverOption { // WithAccountID defines account ID for the driver func WithAccountID(accountID string) driverOption { return func(d *FireboltDriver) { + d.mutex.Lock() + defer d.mutex.Unlock() if d.cachedParams == nil { d.cachedParams = map[string]string{} } @@ -48,6 +54,8 @@ func WithAccountID(accountID string) driverOption { func withClientOption(setter func(baseClient *client.BaseClient)) driverOption { return func(d *FireboltDriver) { + d.mutex.Lock() + defer d.mutex.Unlock() if d.client != nil { if clientImpl, ok := d.client.(*client.ClientImpl); ok { setter(&clientImpl.BaseClient) @@ -95,6 +103,8 @@ func WithClientParams(accountID string, token string, userAgent string) driverOp // WithAccountName defines account name for the driver func WithAccountName(accountName string) driverOptionWithError { return func(d *FireboltDriver) error { + d.mutex.Lock() + defer d.mutex.Unlock() if d.client != nil { if clientImpl, ok := d.client.(*client.ClientImpl); ok { clientImpl.AccountName = accountName @@ -118,6 +128,8 @@ func WithAccountName(accountName string) driverOptionWithError { // WithDatabaseAndEngineName defines database name and engine name for the driver func WithDatabaseAndEngineName(databaseName, engineName string) driverOptionWithError { return func(d *FireboltDriver) error { + d.mutex.Lock() + defer d.mutex.Unlock() if d.client == nil { return errors.New("client must be initialized before setting database and engine name") } diff --git a/driver_race_test.go b/driver_race_test.go new file mode 100644 index 0000000..be3a7fb --- /dev/null +++ b/driver_race_test.go @@ -0,0 +1,49 @@ +//go:build race +// +build race + +package fireboltgosdk + +import ( + "database/sql" + "net/http" + "net/http/httptest" + "os" + "testing" + + "github.com/firebolt-db/firebolt-go-sdk/client" + "github.com/firebolt-db/firebolt-go-sdk/utils" +) + +func TestDriverOpenRace(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch r.URL.Path { + case client.UsernamePasswordURLSuffix: + _, _ = w.Write(utils.GetAuthResponse(10000)) + case client.DefaultAccountURL: + _, _ = w.Write(getDefaultAccountResponse()) + } + })) + currentEndpoint := os.Getenv("FIREBOLT_ENDPOINT") + utils.Must(os.Setenv("FIREBOLT_ENDPOINT", server.URL)) + defer func() { utils.Must(os.Setenv("FIREBOLT_ENDPOINT", currentEndpoint)) }() + defer server.Close() + numGoroutines := 10 + done := make(chan bool, numGoroutines) + + for i := 0; i < numGoroutines; i++ { + go func() { + db, err := sql.Open("firebolt", "firebolt://user@fb:pass@db_name/eng.firebolt.io") + if err != nil { + t.Errorf("connection failed unexpectedly: %v", err) + } + if _, ok := db.Driver().(*FireboltDriver); !ok { + t.Errorf("returned connector is not a firebolt connector") + } + done <- true + }() + } + + for i := 0; i < numGoroutines; i++ { + <-done + } +}