From d91e009b2e5be695e41d7c2cff9ed32e17fbaf06 Mon Sep 17 00:00:00 2001 From: "Stathi C." Date: Tue, 10 Dec 2024 08:00:27 -0600 Subject: [PATCH] [SNOW-1669514] Honor the Valuer/Stringer methods to resolve #1209 (#1211) Co-authored-by: Piotr Fus --- .github/workflows/build-test.yml | 34 ++++---- .gitignore | 3 + README.md | 22 ++--- bind_uploader.go | 2 +- bindings_test.go | 114 +++++++++++++++++++++----- converter.go | 134 ++++++++++++++++++++++--------- converter_test.go | 48 +++++++++-- driver_test.go | 119 ++++++++++++++++++++++++++- htap_test.go | 12 +-- parameters.json.local | 3 +- parameters.json.tmpl | 3 +- put_get_user_stage_test.go | 4 +- rows_test.go | 4 +- structured_type.go | 5 ++ structured_type_read_test.go | 93 +++++++++++++++++---- structured_type_write_test.go | 10 ++- 16 files changed, 484 insertions(+), 126 deletions(-) diff --git a/.github/workflows/build-test.yml b/.github/workflows/build-test.yml index f91f80cf3..6d6216576 100644 --- a/.github/workflows/build-test.yml +++ b/.github/workflows/build-test.yml @@ -1,22 +1,22 @@ name: Build and Test on: - push: - braches: - - master - tags: - - v* - pull_request: - branches: - - master - - SNOW-* - schedule: - - cron: '7 3 * * *' - workflow_dispatch: - inputs: - goTestParams: - default: - description: "Parameters passed to go test" + push: + branches: + - master + tags: + - v* + pull_request: + branches: + - master + - SNOW-* + schedule: + - cron: '7 3 * * *' + workflow_dispatch: + inputs: + goTestParams: + default: + description: 'Parameters passed to go test' concurrency: # older builds for the same pull request numer or branch should be cancelled @@ -34,7 +34,7 @@ jobs: - name: Setup go uses: actions/setup-go@v5 with: - go-version: '1.21' + go-version-file: './go.mod' - name: golangci-lint uses: golangci/golangci-lint-action@v6 with: diff --git a/.gitignore b/.gitignore index c8fad7b27..04e2c639d 100644 --- a/.gitignore +++ b/.gitignore @@ -1,4 +1,5 @@ .idea/ +.vscode/ parameters*.json parameters*.bat *.p8 @@ -11,6 +12,8 @@ wss-golang-agent.config wss-unified-agent.jar whitesource/ *.swp +cp.out +__debug_bin* # exclude vendor vendor diff --git a/README.md b/README.md index 719a2fdf8..290bbe7c3 100644 --- a/README.md +++ b/README.md @@ -56,13 +56,13 @@ This driver currently does not support GCP regional endpoints. Please ensure tha Snowflake provides a set of sample programs to test with. Set the environment variable ``$GOPATH`` to the top directory of your workspace, e.g., ``~/go`` and make certain to include ``$GOPATH/bin`` in the environment variable ``$PATH``. Run the ``make`` command to build all sample programs. -``` +```sh make install ``` In the following example, the program ``select1.go`` is built and installed in ``$GOPATH/bin`` and can be run from the command line: -``` +```sh SNOWFLAKE_TEST_ACCOUNT= \ SNOWFLAKE_TEST_USER= \ SNOWFLAKE_TEST_PASSWORD= \ @@ -79,7 +79,7 @@ The developer notes are hosted with the source code on [GitHub](https://github.c Set the Snowflake connection info in ``parameters.json``: -``` +```json { "testconnection": { "SNOWFLAKE_TEST_USER": "", @@ -88,21 +88,25 @@ Set the Snowflake connection info in ``parameters.json``: "SNOWFLAKE_TEST_WAREHOUSE": "", "SNOWFLAKE_TEST_DATABASE": "", "SNOWFLAKE_TEST_SCHEMA": "", - "SNOWFLAKE_TEST_ROLE": "" + "SNOWFLAKE_TEST_ROLE": "", + "SNOWFLAKE_TEST_DEBUG": "false" } } ``` Install [jq](https://stedolan.github.io/jq) so that the parameters can get parsed correctly, and run ``make test`` in your Go development environment: -``` +```sh make test ``` +### Setting debug mode during tests +This is for debugging Large SQL statements (greater than 300 characters). If you want to enable debug mode, set `SNOWFLAKE_TEST_DEBUG` to `true` in `parameters.json`, or export it in your shell instance. + ## customizing Logging Tags If you would like to ensure that certain tags are always present in the logs, `RegisterClientLogContextHook` can be used in your init function. See example below. -``` +```go import "github.com/snowflakedb/gosnowflake" func init() { @@ -116,7 +120,7 @@ func init() { ## Setting Log Level If you want to change the log level, `SetLogLevel` can be used in your init function like this: -``` +```go import "github.com/snowflakedb/gosnowflake" func init() { @@ -138,7 +142,7 @@ The following is a list of options you can pass in to set the level from least t Configure your testing environment as described above and run ``make cov``. The coverage percentage will be printed on the console when the testing completes. -``` +```sh make cov ``` @@ -146,7 +150,7 @@ For more detailed analysis, results are printed to ``coverage.txt`` in the proje To read the coverage report, run: -``` +```sh go tool cover -html=coverage.txt ``` diff --git a/bind_uploader.go b/bind_uploader.go index 04a266a8e..553648d01 100644 --- a/bind_uploader.go +++ b/bind_uploader.go @@ -118,7 +118,7 @@ func (bu *bindUploader) createStageIfNeeded() error { return (&SnowflakeError{ Number: code, SQLState: data.Data.SQLState, - Message: err.Error(), + Message: data.Message, QueryID: data.Data.QueryID, }).exceptionTelemetry(bu.sc) } diff --git a/bindings_test.go b/bindings_test.go index 91530dc5e..c12c15bad 100644 --- a/bindings_test.go +++ b/bindings_test.go @@ -8,9 +8,11 @@ import ( "database/sql" "fmt" "log" + "math" "math/big" "math/rand" "reflect" + "slices" "strconv" "strings" "testing" @@ -70,7 +72,7 @@ func TestBindingFloat64(t *testing.T) { dbt.mustExec("INSERT INTO test VALUES (1, ?)", expected) rows = dbt.mustQuery("SELECT value FROM test WHERE id = ?", 1) defer func() { - assertNilF(t, rows.Close()) + assertNilF(t, rows.Close()) }() if rows.Next() { assertNilF(t, rows.Scan(&out)) @@ -203,14 +205,14 @@ func TestBindingTimestampTZ(t *testing.T) { dbt.Fatal(err.Error()) } defer func() { - assertNilF(t, stmt.Close()) + assertNilF(t, stmt.Close()) }() if _, err = stmt.Exec(DataTypeTimestampTz, expected); err != nil { dbt.Fatal(err) } rows := dbt.mustQuery("SELECT tz FROM tztest WHERE id=?", 1) defer func() { - assertNilF(t, rows.Close()) + assertNilF(t, rows.Close()) }() var v time.Time if rows.Next() { @@ -258,7 +260,7 @@ func TestBindingTimePtrInStruct(t *testing.T) { rows := dbt.mustQuery("SELECT tz FROM timeStructTest WHERE id=?", &expectedID) defer func() { - assertNilF(t, rows.Close()) + assertNilF(t, rows.Close()) }() var v time.Time if rows.Next() { @@ -307,7 +309,7 @@ func TestBindingTimeInStruct(t *testing.T) { rows := dbt.mustQuery("SELECT tz FROM timeStructTest WHERE id=?", &expectedID) defer func() { - assertNilF(t, rows.Close()) + assertNilF(t, rows.Close()) }() var v time.Time if rows.Next() { @@ -329,7 +331,7 @@ func TestBindingInterface(t *testing.T) { rows := dbt.mustQueryContext( WithHigherPrecision(context.Background()), selectVariousTypes) defer func() { - assertNilF(t, rows.Close()) + assertNilF(t, rows.Close()) }() if !rows.Next() { dbt.Error("failed to query") @@ -357,7 +359,7 @@ func TestBindingInterfaceString(t *testing.T) { runDBTest(t, func(dbt *DBTest) { rows := dbt.mustQuery(selectVariousTypes) defer func() { - assertNilF(t, rows.Close()) + assertNilF(t, rows.Close()) }() if !rows.Next() { dbt.Error("failed to query") @@ -382,6 +384,74 @@ func TestBindingInterfaceString(t *testing.T) { }) } +func TestBulkArrayBindingUUID(t *testing.T) { + max := math.Pow10(5) // 100K because my power is maximum + expectedUuids := make([]any, int(max)) + + createTable := "CREATE OR REPLACE TABLE TEST_PREP_STATEMENT (uuid VARCHAR)" + insert := "INSERT INTO TEST_PREP_STATEMENT (uuid) VALUES (?)" + + for i := range expectedUuids { + expectedUuids[i] = newTestUUID() + } + + slices.SortStableFunc(expectedUuids, func(i, j any) int { + return strings.Compare(i.(testUUID).String(), j.(testUUID).String()) + }) + + runDBTest(t, func(dbt *DBTest) { + var rows *RowsExtended + t.Cleanup(func() { + if rows != nil { + assertNilF(t, rows.Close()) + } + + _, err := dbt.exec(deleteTableSQL) + if err != nil { + t.Logf("failed to drop table. err: %s", err) + } + }) + + dbt.mustExec(createTable) + + res := dbt.mustExec(insert, Array(&expectedUuids)) + + affected, err := res.RowsAffected() + if err != nil { + t.Fatalf("failed to get affected rows. err: %s", err) + } else if affected != int64(max) { + t.Fatalf("failed to insert all rows. expected: %f.0, got: %v", max, affected) + } + + rows = dbt.mustQuery("SELECT * FROM TEST_PREP_STATEMENT ORDER BY uuid") + if rows == nil { + t.Fatal("failed to query") + } + + if rows.Err() != nil { + t.Fatalf("failed to query. err: %s", rows.Err()) + } + + var actual = make([]testUUID, len(expectedUuids)) + + for i := 0; rows.Next(); i++ { + var ( + out testUUID + ) + if err := rows.Scan(&out); err != nil { + t.Fatal(err) + } + + actual[i] = out + } + + for i := range expectedUuids { + assertEqualE(t, actual[i], expectedUuids[i]) + } + }) + +} + func TestBulkArrayBindingInterfaceNil(t *testing.T) { nilArray := make([]any, 1) @@ -396,7 +466,7 @@ func TestBulkArrayBindingInterfaceNil(t *testing.T) { Array(&nilArray, TimeType)) rows := dbt.mustQuery(selectAllSQL) defer func() { - assertNilF(t, rows.Close()) + assertNilF(t, rows.Close()) }() var v0 sql.NullInt32 @@ -481,7 +551,7 @@ func TestBulkArrayBindingInterface(t *testing.T) { Array(&boolArray), Array(&strArray), Array(&byteArray), Array(&int64Array)) rows := dbt.mustQuery(selectAllSQLBulkArray) defer func() { - assertNilF(t, rows.Close()) + assertNilF(t, rows.Close()) }() var v0 sql.NullInt32 @@ -586,7 +656,7 @@ func TestBulkArrayBindingInterfaceDateTimeTimestamp(t *testing.T) { rows := dbt.mustQuery(selectAllSQLBulkArrayDateTimeTimestamp) defer func() { - assertNilF(t, rows.Close()) + assertNilF(t, rows.Close()) }() var v0, v1, v2, v3, v4 sql.NullTime @@ -695,7 +765,7 @@ func testBindingArray(t *testing.T, bulk bool) { Array(&tmArray, TimeType)) rows := dbt.mustQuery(selectAllSQL) defer func() { - assertNilF(t, rows.Close()) + assertNilF(t, rows.Close()) }() var v0 int @@ -777,7 +847,7 @@ func TestBulkArrayBinding(t *testing.T) { dbt.mustExec(fmt.Sprintf("insert into %v values (?, ?, ?, ?, ?, ?, ?, ?)", dbname), Array(&intArr), Array(&strArr), Array(<zArr, TimestampLTZType), Array(&tzArr, TimestampTZType), Array(&ntzArr, TimestampNTZType), Array(&dateArr, DateType), Array(&timeArr, TimeType), Array(&binArr)) rows := dbt.mustQuery("select * from " + dbname + " order by c1") defer func() { - assertNilF(t, rows.Close()) + assertNilF(t, rows.Close()) }() cnt := 0 var i int @@ -825,7 +895,7 @@ func TestBulkArrayBindingTimeWithPrecision(t *testing.T) { dbt.mustExec(fmt.Sprintf("insert into %v values (?, ?, ?, ?)", dbname), Array(&secondsArr, TimeType), Array(&millisecondsArr, TimeType), Array(µsecondsArr, TimeType), Array(&nanosecondsArr, TimeType)) rows := dbt.mustQuery("select * from " + dbname) defer func() { - assertNilF(t, rows.Close()) + assertNilF(t, rows.Close()) }() cnt := 0 var s, ms, us, ns time.Time @@ -866,7 +936,7 @@ func TestBulkArrayMultiPartBinding(t *testing.T) { Array(&randomStrings)) rows := dbt.mustQuery("select count(*) from " + tempTableName) defer func() { - assertNilF(t, rows.Close()) + assertNilF(t, rows.Close()) }() if rows.Next() { var count int @@ -878,7 +948,7 @@ func TestBulkArrayMultiPartBinding(t *testing.T) { rows := dbt.mustQuery("select count(*) from " + tempTableName) defer func() { - assertNilF(t, rows.Close()) + assertNilF(t, rows.Close()) }() if rows.Next() { var count int @@ -909,7 +979,7 @@ func TestBulkArrayMultiPartBindingInt(t *testing.T) { rows := dbt.mustQuery("select * from binding_test order by c1") defer func() { - assertNilF(t, rows.Close()) + assertNilF(t, rows.Close()) }() cnt := startNum var i int @@ -959,7 +1029,7 @@ func TestBulkArrayMultiPartBindingWithNull(t *testing.T) { rows := dbt.mustQuery("select * from binding_test order by c1,c2") defer func() { - assertNilF(t, rows.Close()) + assertNilF(t, rows.Close()) }() cnt := startNum var i sql.NullInt32 @@ -1042,7 +1112,7 @@ func TestFunctionParameters(t *testing.T) { t.Fatal(err) } defer func() { - assertNilF(t, rows.Close()) + assertNilF(t, rows.Close()) }() if rows.Err() != nil { t.Fatal(err) @@ -1144,7 +1214,7 @@ func TestVariousBindingModes(t *testing.T) { t.Fatal(err) } defer func() { - assertNilF(t, rows.Close()) + assertNilF(t, rows.Close()) }() if !rows.Next() { t.Fatal("Expected to return a row") @@ -1194,7 +1264,7 @@ func testLOBRetrieval(t *testing.T, useArrowFormat bool) { rows, err := dbt.query(fmt.Sprintf("SELECT randstr(%v, 124)", testSize)) assertNilF(t, err) defer func() { - assertNilF(t, rows.Close()) + assertNilF(t, rows.Close()) }() assertTrueF(t, rows.Next(), fmt.Sprintf("no rows returned for the LOB size %v", testSize)) @@ -1227,7 +1297,7 @@ func TestMaxLobSize(t *testing.T) { rows, err := dbt.query("select randstr(20000000, random())") assertNilF(t, err) defer func() { - assertNilF(t, rows.Close()) + assertNilF(t, rows.Close()) }() }) }) @@ -1308,7 +1378,7 @@ func testInsertLOBData(t *testing.T, useArrowFormat bool, isLiteral bool) { rows, err := dbt.query("SELECT * FROM lob_test_table") assertNilF(t, err) defer func() { - assertNilF(t, rows.Close()) + assertNilF(t, rows.Close()) }() assertTrueF(t, rows.Next(), fmt.Sprintf("%s: no rows returned", tc.testDesc)) diff --git a/converter.go b/converter.go index 57ccd9293..48570f284 100644 --- a/converter.go +++ b/converter.go @@ -14,6 +14,7 @@ import ( "math" "math/big" "reflect" + "regexp" "strconv" "strings" "time" @@ -28,6 +29,7 @@ import ( const format = "2006-01-02 15:04:05.999999999" const numberDefaultPrecision = 38 +const jsonFormatStr = "json" type timezoneType int @@ -78,9 +80,13 @@ func isInterfaceArrayBinding(t interface{}) bool { } } +func isJSONFormatType(tsmode snowflakeType) bool { + return tsmode == objectType || tsmode == arrayType || tsmode == sliceType +} + // goTypeToSnowflake translates Go data type to Snowflake data type. func goTypeToSnowflake(v driver.Value, tsmode snowflakeType) snowflakeType { - if tsmode == objectType || tsmode == arrayType || tsmode == sliceType { + if isJSONFormatType(tsmode) { return tsmode } if v == nil { @@ -237,13 +243,27 @@ func snowflakeTypeToGoForMaps[K comparable](ctx context.Context, valueMetadata f // in queries. func valueToString(v driver.Value, tsmode snowflakeType, params map[string]*string) (bindingValue, error) { logger.Debugf("TYPE: %v, %v", reflect.TypeOf(v), reflect.ValueOf(v)) + isJSONFormat := isJSONFormatType(tsmode) if v == nil { - if tsmode == objectType || tsmode == arrayType || tsmode == sliceType { - return bindingValue{nil, "json", nil}, nil + if isJSONFormat { + return bindingValue{nil, jsonFormatStr, nil}, nil } return bindingValue{nil, "", nil}, nil } v1 := reflect.Indirect(reflect.ValueOf(v)) + + if valuer, ok := v.(driver.Valuer); ok { // check for driver.Valuer satisfaction and honor that first + if value, err := valuer.Value(); err == nil && value != nil { + // if the output value is a valid string, return that + if strVal, ok := value.(string); ok { + if isJSONFormat { + return bindingValue{&strVal, jsonFormatStr, nil}, nil + } + return bindingValue{&strVal, "", nil}, nil + } + } + } + switch v1.Kind() { case reflect.Bool: s := strconv.FormatBool(v1.Bool()) @@ -256,8 +276,8 @@ func valueToString(v driver.Value, tsmode snowflakeType, params map[string]*stri return bindingValue{&s, "", nil}, nil case reflect.String: s := v1.String() - if tsmode == objectType || tsmode == arrayType || tsmode == sliceType { - return bindingValue{&s, "json", nil}, nil + if isJSONFormat { + return bindingValue{&s, jsonFormatStr, nil}, nil } return bindingValue{&s, "", nil}, nil case reflect.Slice, reflect.Array: @@ -271,10 +291,37 @@ func valueToString(v driver.Value, tsmode snowflakeType, params map[string]*stri return bindingValue{}, fmt.Errorf("unsupported type: %v", v1.Kind()) } +// isUUIDImplementer checks if a value is a UUID that satisfies RFC 4122 +func isUUIDImplementer(v reflect.Value) bool { + rt := v.Type() + + // Check if the type is an array of 16 bytes + if v.Kind() == reflect.Array && rt.Elem().Kind() == reflect.Uint8 && rt.Len() == 16 { + // Check if the type implements fmt.Stringer + vInt := v.Interface() + if stringer, ok := vInt.(fmt.Stringer); ok { + uuidStr := stringer.String() + + rfc4122Regex := `^[0-9a-fA-F]{8}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{12}$` + matched, err := regexp.MatchString(rfc4122Regex, uuidStr) + if err != nil { + return false + } + + if matched { + // parse the UUID and ensure it is the same as the original string + u := ParseUUID(uuidStr) + return u.String() == uuidStr + } + } + } + return false +} + func arrayToString(v driver.Value, tsmode snowflakeType, params map[string]*string) (bindingValue, error) { v1 := reflect.Indirect(reflect.ValueOf(v)) if v1.Kind() == reflect.Slice && v1.IsNil() { - return bindingValue{nil, "json", nil}, nil + return bindingValue{nil, jsonFormatStr, nil}, nil } if bd, ok := v.([][]byte); ok && tsmode == binaryType { schema := bindingSchema{ @@ -289,14 +336,14 @@ func arrayToString(v driver.Value, tsmode snowflakeType, params map[string]*stri } if len(bd) == 0 { res := "[]" - return bindingValue{value: &res, format: "json", schema: &schema}, nil + return bindingValue{value: &res, format: jsonFormatStr, schema: &schema}, nil } s := "" for _, b := range bd { s += "\"" + hex.EncodeToString(b) + "\"," } s = "[" + s[:len(s)-1] + "]" - return bindingValue{&s, "json", &schema}, nil + return bindingValue{&s, jsonFormatStr, &schema}, nil } else if times, ok := v.([]time.Time); ok { typ := driverTypeToSnowflake[tsmode] sfFormat, err := dateTimeInputFormatByType(typ, params) @@ -313,7 +360,7 @@ func arrayToString(v driver.Value, tsmode snowflakeType, params map[string]*stri } res, err := json.Marshal(v) if err != nil { - return bindingValue{nil, "json", &bindingSchema{ + return bindingValue{nil, jsonFormatStr, &bindingSchema{ Typ: "array", Nullable: true, Fields: []fieldMetadata{ @@ -325,7 +372,7 @@ func arrayToString(v driver.Value, tsmode snowflakeType, params map[string]*stri }}, err } resString := string(res) - return bindingValue{&resString, "json", nil}, nil + return bindingValue{&resString, jsonFormatStr, nil}, nil } else if isArrayOfStructs(v) { stringEntries := make([]string, v1.Len()) sowcForSingleElement, err := buildSowcFromType(params, reflect.TypeOf(v).Elem()) @@ -337,7 +384,7 @@ func arrayToString(v driver.Value, tsmode snowflakeType, params map[string]*stri if sow, ok := potentialSow.Interface().(StructuredObjectWriter); ok { bv, err := structValueToString(sow, tsmode, params) if err != nil { - return bindingValue{nil, "json", nil}, err + return bindingValue{nil, jsonFormatStr, nil}, err } stringEntries[i] = *bv.value } @@ -354,14 +401,14 @@ func arrayToString(v driver.Value, tsmode snowflakeType, params map[string]*stri }, }, } - return bindingValue{&value, "json", arraySchema}, nil + return bindingValue{&value, jsonFormatStr, arraySchema}, nil } else if reflect.ValueOf(v).Len() == 0 { value := "[]" - return bindingValue{&value, "json", nil}, nil + return bindingValue{&value, jsonFormatStr, nil}, nil } else if barr, ok := v.([]byte); ok { if tsmode == binaryType { res := hex.EncodeToString(barr) - return bindingValue{&res, "json", nil}, nil + return bindingValue{&res, jsonFormatStr, nil}, nil } schemaForBytes := bindingSchema{ Typ: "array", @@ -375,23 +422,27 @@ func arrayToString(v driver.Value, tsmode snowflakeType, params map[string]*stri } if len(barr) == 0 { res := "[]" - return bindingValue{&res, "json", &schemaForBytes}, nil + return bindingValue{&res, jsonFormatStr, &schemaForBytes}, nil } res := "[" for _, b := range barr { res += fmt.Sprint(b) + "," } res = res[0:len(res)-1] + "]" - return bindingValue{&res, "json", &schemaForBytes}, nil + return bindingValue{&res, jsonFormatStr, &schemaForBytes}, nil + } else if isUUIDImplementer(v1) { // special case for UUIDs (snowflake type and other implementers) + stringer := v.(fmt.Stringer) // we don't need to validate if it's a fmt.Stringer because we already checked if it's a UUID type with a stringer + value := stringer.String() + return bindingValue{&value, "", nil}, nil } else if isSliceOfSlices(v) { return bindingValue{}, errors.New("array of arrays is not supported") } res, err := json.Marshal(v) if err != nil { - return bindingValue{nil, "json", nil}, err + return bindingValue{nil, jsonFormatStr, nil}, err } resString := string(res) - return bindingValue{&resString, "json", nil}, nil + return bindingValue{&resString, jsonFormatStr, nil}, nil } func mapToString(v driver.Value, tsmode snowflakeType, params map[string]*string) (bindingValue, error) { @@ -549,7 +600,7 @@ func mapToString(v driver.Value, tsmode snowflakeType, params map[string]*string Typ: "MAP", Fields: []fieldMetadata{keyMetadata, *valueMetadata}, } - return bindingValue{&jsonString, "json", &schema}, nil + return bindingValue{&jsonString, jsonFormatStr, &schema}, nil } else { jsonBytes, err = json.Marshal(v) if err != nil { @@ -569,7 +620,7 @@ func mapToString(v driver.Value, tsmode snowflakeType, params map[string]*string Typ: "MAP", Fields: []fieldMetadata{keyMetadata, valueMetadata}, } - return bindingValue{&jsonString, "json", &schema}, nil + return bindingValue{&jsonString, jsonFormatStr, &schema}, nil } func toNullableInt64(val any) (int64, bool) { @@ -770,8 +821,8 @@ func structValueToString(v driver.Value, tsmode snowflakeType, params map[string return bindingValue{&s, "", nil}, nil case sql.NullString: fmt := "" - if tsmode == objectType || tsmode == arrayType || tsmode == sliceType { - fmt = "json" + if isJSONFormatType(tsmode) { + fmt = jsonFormatStr } if !typedVal.Valid { return bindingValue{nil, fmt, nil}, nil @@ -795,7 +846,7 @@ func structValueToString(v driver.Value, tsmode snowflakeType, params map[string Nullable: true, Fields: sowc.toFields(), } - return bindingValue{&jsonString, "json", &schema}, nil + return bindingValue{&jsonString, jsonFormatStr, &schema}, nil } else if typ, ok := v.(reflect.Type); ok && tsmode == nilArrayType { metadata, err := goTypeToFieldMetadata(typ, tsmode, params) if err != nil { @@ -808,7 +859,7 @@ func structValueToString(v driver.Value, tsmode snowflakeType, params map[string metadata, }, } - return bindingValue{nil, "json", &schema}, nil + return bindingValue{nil, jsonFormatStr, &schema}, nil } else if types, ok := v.(NilMapTypes); ok && tsmode == nilMapType { keyMetadata, err := goTypeToFieldMetadata(types.Key, tsmode, params) if err != nil { @@ -823,7 +874,7 @@ func structValueToString(v driver.Value, tsmode snowflakeType, params map[string Nullable: true, Fields: []fieldMetadata{keyMetadata, valueMetadata}, } - return bindingValue{nil, "json", &schema}, nil + return bindingValue{nil, jsonFormatStr, &schema}, nil } else if typ, ok := v.(reflect.Type); ok && tsmode == nilObjectType { metadata, err := goTypeToFieldMetadata(typ, tsmode, params) if err != nil { @@ -834,7 +885,7 @@ func structValueToString(v driver.Value, tsmode snowflakeType, params map[string Nullable: true, Fields: metadata.Fields, } - return bindingValue{nil, "json", &schema}, nil + return bindingValue{nil, jsonFormatStr, &schema}, nil } return bindingValue{}, fmt.Errorf("unknown binding for type %T and mode %v", v, tsmode) } @@ -2653,44 +2704,38 @@ func interfaceSliceToString(interfaceSlice reflect.Value, stream bool, tzType .. for i := 0; i < interfaceSlice.Len(); i++ { val := interfaceSlice.Index(i) if val.CanInterface() { - switch val.Interface().(type) { + v := val.Interface() + + switch x := v.(type) { case int: t = fixedType - x := val.Interface().(int) v := strconv.Itoa(x) arr = append(arr, &v) case int32: t = fixedType - x := val.Interface().(int32) v := strconv.Itoa(int(x)) arr = append(arr, &v) case int64: t = fixedType - x := val.Interface().(int64) v := strconv.FormatInt(x, 10) arr = append(arr, &v) case float32: t = realType - x := val.Interface().(float32) v := fmt.Sprintf("%g", x) arr = append(arr, &v) case float64: t = realType - x := val.Interface().(float64) v := fmt.Sprintf("%g", x) arr = append(arr, &v) case bool: t = booleanType - x := val.Interface().(bool) v := strconv.FormatBool(x) arr = append(arr, &v) case string: t = textType - x := val.Interface().(string) arr = append(arr, &x) case []byte: t = binaryType - x := val.Interface().([]byte) v := hex.EncodeToString(x) arr = append(arr, &v) case time.Time: @@ -2698,7 +2743,6 @@ func interfaceSliceToString(interfaceSlice reflect.Value, stream bool, tzType .. return unSupportedType, nil } - x := val.Interface().(time.Time) switch tzType[0] { case TimestampNTZType: t = timestampNtzType @@ -2738,8 +2782,26 @@ func interfaceSliceToString(interfaceSlice reflect.Value, stream bool, tzType .. default: return unSupportedType, nil } + case driver.Valuer: // honor each driver's Valuer interface + if value, err := x.Value(); err == nil && value != nil { + // if the output value is a valid string, return that + if strVal, ok := value.(string); ok { + t = textType + arr = append(arr, &strVal) + } + } else if v != nil { + return unSupportedType, nil + } else { + arr = append(arr, nil) + } default: if val.Interface() != nil { + if isUUIDImplementer(val) { + t = textType + x := v.(fmt.Stringer).String() + arr = append(arr, &x) + continue + } return unSupportedType, nil } diff --git a/converter_test.go b/converter_test.go index 375359465..5cb12741c 100644 --- a/converter_test.go +++ b/converter_test.go @@ -269,22 +269,56 @@ func TestValueToString(t *testing.T) { assertNilE(t, bv.schema) assertEqualE(t, *bv.value, expectedString) + t.Run("SQL Time", func(t *testing.T) { + bv, err := valueToString(sql.NullTime{Time: localTime, Valid: true}, timestampLtzType, nil) + assertNilF(t, err) + assertEmptyStringE(t, bv.format) + assertNilE(t, bv.schema) + assertEqualE(t, *bv.value, expectedUnixTime) + }) + t.Run("arrays", func(t *testing.T) { bv, err := valueToString([2]int{1, 2}, objectType, nil) assertNilF(t, err) - assertEqualE(t, bv.format, "json") + assertEqualE(t, bv.format, jsonFormatStr) assertEqualE(t, *bv.value, "[1,2]") }) t.Run("slices", func(t *testing.T) { bv, err := valueToString([]int{1, 2}, objectType, nil) assertNilF(t, err) - assertEqualE(t, bv.format, "json") + assertEqualE(t, bv.format, jsonFormatStr) assertEqualE(t, *bv.value, "[1,2]") }) + t.Run("UUID - should return string", func(t *testing.T) { + u := NewUUID() + bv, err := valueToString(u, textType, nil) + assertNilF(t, err) + assertEmptyStringE(t, bv.format) + assertEqualE(t, *bv.value, u.String()) + }) + + t.Run("database/sql/driver - Valuer interface", func(t *testing.T) { + u := newTestUUID() + bv, err := valueToString(u, textType, nil) + assertNilF(t, err) + assertEmptyStringE(t, bv.format) + assertEqualE(t, *bv.value, u.String()) + }) + + t.Run("testUUID", func(t *testing.T) { + u := newTestUUID() + assertEqualE(t, u.String(), parseTestUUID(u.String()).String()) + + bv, err := valueToString(u, textType, nil) + assertNilF(t, err) + assertEmptyStringE(t, bv.format) + assertEqualE(t, *bv.value, u.String()) + }) + bv, err = valueToString(&testValueToStringStructuredObject{s: "some string", i: 123, date: time.Date(2024, time.May, 24, 0, 0, 0, 0, time.UTC)}, timestampLtzType, params) assertNilF(t, err) - assertEqualE(t, bv.format, "json") + assertEqualE(t, bv.format, jsonFormatStr) assertDeepEqualE(t, *bv.schema, bindingSchema{ Typ: "object", Nullable: true, @@ -2175,7 +2209,7 @@ func TestSmallTimestampBinding(t *testing.T) { rows := sct.mustQueryContext(ctx, "SELECT ?", parameters) defer func() { - assertNilF(t, rows.Close()) + assertNilF(t, rows.Close()) }() scanValues := make([]driver.Value, 1) @@ -2213,7 +2247,7 @@ func TestTimestampConversionWithoutArrowBatches(t *testing.T) { query := fmt.Sprintf("SELECT '%s'::%s(%v)", tsStr, tp, scale) rows := sct.mustQueryContext(ctx, query, nil) defer func() { - assertNilF(t, rows.Close()) + assertNilF(t, rows.Close()) }() if rows.Next() { @@ -2295,7 +2329,7 @@ func TestTimestampConversionWithArrowBatchesMicrosecondPassesForDistantDates(t * t.Fatalf("failed to query: %v", err) } defer func() { - assertNilF(t, rows.Close()) + assertNilF(t, rows.Close()) }() // getting result batches @@ -2356,7 +2390,7 @@ func TestTimestampConversionWithArrowBatchesAndWithOriginalTimestamp(t *testing. query := fmt.Sprintf("SELECT '%s'::%s(%v)", tsStr, tp, scale) rows := sct.mustQueryContext(ctx, query, []driver.NamedValue{}) defer func() { - assertNilF(t, rows.Close()) + assertNilF(t, rows.Close()) }() // getting result batches diff --git a/driver_test.go b/driver_test.go index cfb8501f1..d6245653c 100644 --- a/driver_test.go +++ b/driver_test.go @@ -34,6 +34,7 @@ var ( protocol string customPrivateKey bool // Whether user has specified the private key path testPrivKey *rsa.PrivateKey // Valid private key used for all test cases + debugMode bool ) const ( @@ -76,6 +77,8 @@ func init() { setupPrivateKey() createDSN("UTC") + + debugMode, _ = strconv.ParseBool(os.Getenv("SNOWFLAKE_TEST_DEBUG")) } func createDSN(timezone string) { @@ -270,7 +273,7 @@ func (dbt *DBTest) prepare(query string) (*sql.Stmt, error) { } func (dbt *DBTest) fail(method, query string, err error) { - if len(query) > 300 { + if !debugMode && len(query) > 300 { query = "[query too large to print]" } dbt.Fatalf("error on %s [%s]: %s", method, query, err.Error()) @@ -398,7 +401,7 @@ type SCTest struct { } func (sct *SCTest) fail(method, query string, err error) { - if len(query) > 300 { + if !debugMode && len(query) > 300 { query = "[query too large to print]" } sct.Fatalf("error on %s [%s]: %s", method, query, err.Error()) @@ -959,6 +962,118 @@ func testString(t *testing.T, json bool) { }) } +/** TESTING TYPES **/ +// testUUID is a wrapper around UUID for unit testing purposes and should not be used in production +type testUUID struct { + UUID +} + +func newTestUUID() testUUID { + return testUUID{NewUUID()} +} + +func parseTestUUID(str string) testUUID { + if str == "" { + return testUUID{} + } + return testUUID{ParseUUID(str)} +} + +// Scan implements sql.Scanner so UUIDs can be read from databases transparently. +// Currently, database types that map to string and []byte are supported. Please +// consult database-specific driver documentation for matching types. +func (uuid *testUUID) Scan(src interface{}) error { + switch src := src.(type) { + case nil: + return nil + + case string: + // if an empty UUID comes from a table, we return a null UUID + if src == "" { + return nil + } + + // see Parse for required string format + u := ParseUUID(src) + + *uuid = testUUID{u} + + case []byte: + // if an empty UUID comes from a table, we return a null UUID + if len(src) == 0 { + return nil + } + + // assumes a simple slice of bytes if 16 bytes + // otherwise attempts to parse + if len(src) != 16 { + return uuid.Scan(string(src)) + } + copy((uuid.UUID)[:], src) + + default: + return fmt.Errorf("Scan: unable to scan type %T into UUID", src) + } + + return nil +} + +// Value implements sql.Valuer so that UUIDs can be written to databases +// transparently. Currently, UUIDs map to strings. Please consult +// database-specific driver documentation for matching types. +func (uuid testUUID) Value() (driver.Value, error) { + return uuid.String(), nil +} + +func TestUUID(t *testing.T) { + t.Run("JSON", func(t *testing.T) { + testUUIDWithFormat(t, true, false) + }) + t.Run("Arrow", func(t *testing.T) { + testUUIDWithFormat(t, false, true) + }) +} + +func testUUIDWithFormat(t *testing.T, json, arrow bool) { + runDBTest(t, func(dbt *DBTest) { + if json { + dbt.mustExec(forceJSON) + } else if arrow { + dbt.mustExec(forceARROW) + } + + types := []string{"CHAR(255)", "VARCHAR(255)", "TEXT", "STRING"} + + in := make([]testUUID, len(types)) + + for i := range types { + in[i] = newTestUUID() + } + + for i, v := range types { + t.Run(v, func(t *testing.T) { + dbt.mustExec("CREATE OR REPLACE TABLE test (value " + v + ")") + dbt.mustExec("INSERT INTO test VALUES (?)", in[i]) + + rows := dbt.mustQuery("SELECT value FROM test") + defer func() { + assertNilF(t, rows.Close()) + }() + if rows.Next() { + var out testUUID + assertNilF(t, rows.Scan(&out)) + if in[i] != out { + dbt.Errorf("%s: %s != %s", v, in, out) + } + } else { + dbt.Errorf("%s: no data", v) + } + }) + } + dbt.mustExec("DROP TABLE IF EXISTS test") + }) +} + type tcDateTimeTimestamp struct { dbtype string tlayout string diff --git a/htap_test.go b/htap_test.go index 8a8f82ccd..3dc95191e 100644 --- a/htap_test.go +++ b/htap_test.go @@ -348,7 +348,7 @@ func TestHybridTablesE2E(t *testing.T) { runSnowflakeConnTest(t, func(sct *SCTest) { dbQuery := sct.mustQuery("SELECT CURRENT_DATABASE()", nil) defer func() { - assertNilF(t, dbQuery.Close()) + assertNilF(t, dbQuery.Close()) }() currentDb := make([]driver.Value, 1) assertNilF(t, dbQuery.Next(currentDb)) @@ -365,7 +365,7 @@ func TestHybridTablesE2E(t *testing.T) { sct.mustExec("INSERT INTO test_hybrid_table VALUES (1, 'a')", nil) rows := sct.mustQuery("SELECT * FROM test_hybrid_table", nil) defer func() { - assertNilF(t, rows.Close()) + assertNilF(t, rows.Close()) }() row := make([]driver.Value, 2) assertNilF(t, rows.Next(row)) @@ -376,7 +376,7 @@ func TestHybridTablesE2E(t *testing.T) { sct.mustExec("INSERT INTO test_hybrid_table VALUES (2, 'b')", nil) rows2 := sct.mustQuery("SELECT * FROM test_hybrid_table", nil) defer func() { - assertNilF(t, rows2.Close()) + assertNilF(t, rows2.Close()) }() assertNilF(t, rows2.Next(row)) if row[0] != "1" || row[1] != "a" { @@ -397,7 +397,7 @@ func TestHybridTablesE2E(t *testing.T) { rows := sct.mustQuery("SELECT * FROM test_hybrid_table_2", nil) defer func() { - assertNilF(t, rows.Close()) + assertNilF(t, rows.Close()) }() row := make([]driver.Value, 2) assertNilF(t, rows.Next(row)) @@ -415,7 +415,7 @@ func TestHybridTablesE2E(t *testing.T) { rows := sct.mustQuery("SELECT * FROM test_hybrid_table", nil) defer func() { - assertNilF(t, rows.Close()) + assertNilF(t, rows.Close()) }() if len(sct.sc.queryContextCache.entries) != 3 { t.Errorf("expected three entries in query context cache, got: %v", sct.sc.queryContextCache.entries) @@ -578,7 +578,7 @@ func TestConnIsCleanAfterClose(t *testing.T) { var dbName2 string rows2 := dbt2.mustQuery("SELECT CURRENT_DATABASE()") defer func() { - assertNilF(t, rows2.Close()) + assertNilF(t, rows2.Close()) }() rows2.Next() assertNilF(t, rows2.Scan(&dbName2)) diff --git a/parameters.json.local b/parameters.json.local index 8b526b5a1..6e4635374 100644 --- a/parameters.json.local +++ b/parameters.json.local @@ -9,6 +9,7 @@ "SNOWFLAKE_TEST_WAREHOUSE": "regress", "SNOWFLAKE_TEST_DATABASE": "testdb", "SNOWFLAKE_TEST_SCHEMA": "testschema", - "SNOWFLAKE_TEST_ROLE": "sysadmin" + "SNOWFLAKE_TEST_ROLE": "sysadmin", + "SNOWFLAKE_TEST_DEBUG": "false" } } diff --git a/parameters.json.tmpl b/parameters.json.tmpl index 8448f605d..19ec33bd5 100644 --- a/parameters.json.tmpl +++ b/parameters.json.tmpl @@ -6,6 +6,7 @@ "SNOWFLAKE_TEST_WAREHOUSE": "testwarehouse", "SNOWFLAKE_TEST_DATABASE": "testdatabase", "SNOWFLAKE_TEST_SCHEMA": "testschema", - "SNOWFLAKE_TEST_ROLE": "testrole" + "SNOWFLAKE_TEST_ROLE": "testrole", + "SNOWFLAKE_TEST_DEBUG": "false", } } diff --git a/put_get_user_stage_test.go b/put_get_user_stage_test.go index 7fe927d4e..94a308f09 100644 --- a/put_get_user_stage_test.go +++ b/put_get_user_stage_test.go @@ -83,7 +83,7 @@ func putGetUserStage(t *testing.T, numberOfFiles int, numberOfLines int, isStrea rows := dbt.mustQuery("select count(*) from " + dbname) defer func() { - assertNilF(t, rows.Close()) + assertNilF(t, rows.Close()) }() var cnt string if rows.Next() { @@ -133,7 +133,7 @@ func TestPutLoadFromUserStage(t *testing.T) { file_format = (field_delimiter = '|' error_on_column_count_mismatch =false) purge=true`, data.stage)) defer func() { - assertNilF(t, rows.Close()) + assertNilF(t, rows.Close()) }() var s0, s1, s2, s3, s4, s5 string var s6, s7, s8, s9 interface{} diff --git a/rows_test.go b/rows_test.go index d365abdf5..bd576162a 100644 --- a/rows_test.go +++ b/rows_test.go @@ -506,7 +506,7 @@ func TestLocationChangesAfterAlterSession(t *testing.T) { dbt.mustExec("INSERT INTO location_timestamp_ltz VALUES('2023-08-09 10:00:00')") rows1 := dbt.mustQuery("SELECT * FROM location_timestamp_ltz") defer func() { - assertNilF(t, rows1.Close()) + assertNilF(t, rows1.Close()) }() if !rows1.Next() { t.Fatalf("cannot read a record") @@ -519,7 +519,7 @@ func TestLocationChangesAfterAlterSession(t *testing.T) { dbt.mustExec("ALTER SESSION SET TIMEZONE = 'Pacific/Honolulu'") rows2 := dbt.mustQuery("SELECT * FROM location_timestamp_ltz") defer func() { - assertNilF(t, rows2.Close()) + assertNilF(t, rows2.Close()) }() if !rows2.Next() { t.Fatalf("cannot read a record") diff --git a/structured_type.go b/structured_type.go index 7fdf174fa..9df2f8cc4 100644 --- a/structured_type.go +++ b/structured_type.go @@ -3,6 +3,7 @@ package gosnowflake import ( "context" "database/sql" + "database/sql/driver" "encoding/hex" "encoding/json" "errors" @@ -450,6 +451,10 @@ func buildSowcFromType(params map[string]*string, typ reflect.Type) (*structured if err := childSowc.WriteNullableStruct(fieldName, nil, field.Type); err != nil { return nil, err } + } else if t.Implements(reflect.TypeOf((*driver.Valuer)(nil)).Elem()) { + if err := childSowc.WriteNullString(fieldName, sql.NullString{}); err != nil { + return nil, err + } } else { return nil, fmt.Errorf("field %s has unsupported type", field.Name) } diff --git a/structured_type_read_test.go b/structured_type_read_test.go index 362568182..69406bb56 100644 --- a/structured_type_read_test.go +++ b/structured_type_read_test.go @@ -36,10 +36,15 @@ type objectWithAllTypes struct { sArr []string f64Arr []float64 someMap map[string]bool + uuid testUUID } func (o *objectWithAllTypes) Scan(val any) error { - st := val.(StructuredObject) + st, ok := val.(StructuredObject) + if !ok { + return fmt.Errorf("expected StructuredObject, got %T", val) + } + var err error if o.s, err = st.GetString("s"); err != nil { return err @@ -112,6 +117,13 @@ func (o *objectWithAllTypes) Scan(val any) error { if someMap != nil { o.someMap = someMap.(map[string]bool) } + uuidStr, err := st.GetString("uuid") + if err != nil { + return err + } + + o.uuid = parseTestUUID(uuidStr) + return nil } @@ -173,6 +185,9 @@ func (o objectWithAllTypes) Write(sowc StructuredObjectWriterContext) error { if err := sowc.WriteRaw("someMap", o.someMap); err != nil { return err } + if err := sowc.WriteString("uuid", o.uuid.String()); err != nil { + return err + } return nil } @@ -182,7 +197,11 @@ type simpleObject struct { } func (so *simpleObject) Scan(val any) error { - st := val.(StructuredObject) + st, ok := val.(StructuredObject) + if !ok { + return fmt.Errorf("expected StructuredObject, got %T", val) + } + var err error if so.s, err = st.GetString("s"); err != nil { return err @@ -225,7 +244,8 @@ func TestObjectWithAllTypesAsObject(t *testing.T) { runDBTest(t, func(dbt *DBTest) { dbt.mustExec("ALTER SESSION SET TIMEZONE = 'Europe/Warsaw'") forAllStructureTypeFormats(dbt, func(t *testing.T, format string) { - rows := dbt.mustQueryContextT(ctx, t, "SELECT 1, {'s': 'some string', 'b': 1, 'i16': 2, 'i32': 3, 'i64': 9223372036854775807, 'f32': '1.1', 'f64': 2.2, 'nfraction': 3.3, 'bo': true, 'bi': TO_BINARY('616263', 'HEX'), 'date': '2024-03-21'::DATE, 'time': '13:03:02'::TIME, 'ltz': '2021-07-21 11:22:33'::TIMESTAMP_LTZ, 'tz': '2022-08-31 13:43:22 +0200'::TIMESTAMP_TZ, 'ntz': '2023-05-22 01:17:19'::TIMESTAMP_NTZ, 'so': {'s': 'child', 'i': 9}, 'sArr': ARRAY_CONSTRUCT('x', 'y', 'z'), 'f64Arr': ARRAY_CONSTRUCT(1.1, 2.2, 3.3), 'someMap': {'x': true, 'y': false}}::OBJECT(s VARCHAR, b TINYINT, i16 SMALLINT, i32 INTEGER, i64 BIGINT, f32 FLOAT, f64 DOUBLE, nfraction NUMBER(38, 19), bo BOOLEAN, bi BINARY, date DATE, time TIME, ltz TIMESTAMP_LTZ, tz TIMESTAMP_TZ, ntz TIMESTAMP_NTZ, so OBJECT(s VARCHAR, i INTEGER), sArr ARRAY(VARCHAR), f64Arr ARRAY(DOUBLE), someMap MAP(VARCHAR, BOOLEAN))") + uid := newTestUUID() + rows := dbt.mustQueryContextT(ctx, t, fmt.Sprintf("SELECT 1, {'s': 'some string', 'b': 1, 'i16': 2, 'i32': 3, 'i64': 9223372036854775807, 'f32': '1.1', 'f64': 2.2, 'nfraction': 3.3, 'bo': true, 'bi': TO_BINARY('616263', 'HEX'), 'date': '2024-03-21'::DATE, 'time': '13:03:02'::TIME, 'ltz': '2021-07-21 11:22:33'::TIMESTAMP_LTZ, 'tz': '2022-08-31 13:43:22 +0200'::TIMESTAMP_TZ, 'ntz': '2023-05-22 01:17:19'::TIMESTAMP_NTZ, 'so': {'s': 'child', 'i': 9}, 'sArr': ARRAY_CONSTRUCT('x', 'y', 'z'), 'f64Arr': ARRAY_CONSTRUCT(1.1, 2.2, 3.3), 'someMap': {'x': true, 'y': false}, 'uuid': '%s'}::OBJECT(s VARCHAR, b TINYINT, i16 SMALLINT, i32 INTEGER, i64 BIGINT, f32 FLOAT, f64 DOUBLE, nfraction NUMBER(38, 19), bo BOOLEAN, bi BINARY, date DATE, time TIME, ltz TIMESTAMP_LTZ, tz TIMESTAMP_TZ, ntz TIMESTAMP_NTZ, so OBJECT(s VARCHAR, i INTEGER), sArr ARRAY(VARCHAR), f64Arr ARRAY(DOUBLE), someMap MAP(VARCHAR, BOOLEAN), uuid VARCHAR)", uid)) defer rows.Close() rows.Next() var ignore int @@ -253,6 +273,7 @@ func TestObjectWithAllTypesAsObject(t *testing.T) { assertDeepEqualE(t, res.sArr, []string{"x", "y", "z"}) assertDeepEqualE(t, res.f64Arr, []float64{1.1, 2.2, 3.3}) assertDeepEqualE(t, res.someMap, map[string]bool{"x": true, "y": false}) + assertEqualE(t, res.uuid.String(), uid.String()) }) }) } @@ -262,7 +283,7 @@ func TestNullObject(t *testing.T) { runDBTest(t, func(dbt *DBTest) { forAllStructureTypeFormats(dbt, func(t *testing.T, format string) { t.Run("null", func(t *testing.T) { - rows := dbt.mustQueryContextT(ctx, t, "SELECT null::OBJECT(s VARCHAR, b TINYINT, i16 SMALLINT, i32 INTEGER, i64 BIGINT, f32 FLOAT, f64 DOUBLE, nfraction NUMBER(38, 19), bo BOOLEAN, bi BINARY, date DATE, time TIME, ltz TIMESTAMP_LTZ, tz TIMESTAMP_TZ, ntz TIMESTAMP_NTZ, so OBJECT(s VARCHAR, i INTEGER), sArr ARRAY(VARCHAR), f64Arr ARRAY(DOUBLE), someMap MAP(VARCHAR, BOOLEAN))") + rows := dbt.mustQueryContextT(ctx, t, "SELECT null::OBJECT(s VARCHAR, b TINYINT, i16 SMALLINT, i32 INTEGER, i64 BIGINT, f32 FLOAT, f64 DOUBLE, nfraction NUMBER(38, 19), bo BOOLEAN, bi BINARY, date DATE, time TIME, ltz TIMESTAMP_LTZ, tz TIMESTAMP_TZ, ntz TIMESTAMP_NTZ, so OBJECT(s VARCHAR, i INTEGER), sArr ARRAY(VARCHAR), f64Arr ARRAY(DOUBLE), someMap MAP(VARCHAR, BOOLEAN), uuid VARCHAR)") defer rows.Close() assertTrueF(t, rows.Next()) var res *objectWithAllTypes @@ -271,7 +292,8 @@ func TestNullObject(t *testing.T) { assertNilE(t, res) }) t.Run("not null", func(t *testing.T) { - rows := dbt.mustQueryContextT(ctx, t, "SELECT {'s': 'some string', 'b': 1, 'i16': 2, 'i32': 3, 'i64': 9223372036854775807, 'f32': '1.1', 'f64': 2.2, 'nfraction': 3.3, 'bo': true, 'bi': TO_BINARY('616263', 'HEX'), 'date': '2024-03-21'::DATE, 'time': '13:03:02'::TIME, 'ltz': '2021-07-21 11:22:33'::TIMESTAMP_LTZ, 'tz': '2022-08-31 13:43:22 +0200'::TIMESTAMP_TZ, 'ntz': '2023-05-22 01:17:19'::TIMESTAMP_NTZ, 'so': {'s': 'child', 'i': 9}, 'sArr': ARRAY_CONSTRUCT('x', 'y', 'z'), 'f64Arr': ARRAY_CONSTRUCT(1.1, 2.2, 3.3), 'someMap': {'x': true, 'y': false}}::OBJECT(s VARCHAR, b TINYINT, i16 SMALLINT, i32 INTEGER, i64 BIGINT, f32 FLOAT, f64 DOUBLE, nfraction NUMBER(38, 19), bo BOOLEAN, bi BINARY, date DATE, time TIME, ltz TIMESTAMP_LTZ, tz TIMESTAMP_TZ, ntz TIMESTAMP_NTZ, so OBJECT(s VARCHAR, i INTEGER), sArr ARRAY(VARCHAR), f64Arr ARRAY(DOUBLE), someMap MAP(VARCHAR, BOOLEAN))") + uid := newTestUUID() + rows := dbt.mustQueryContextT(ctx, t, fmt.Sprintf("SELECT {'s': 'some string', 'b': 1, 'i16': 2, 'i32': 3, 'i64': 9223372036854775807, 'f32': '1.1', 'f64': 2.2, 'nfraction': 3.3, 'bo': true, 'bi': TO_BINARY('616263', 'HEX'), 'date': '2024-03-21'::DATE, 'time': '13:03:02'::TIME, 'ltz': '2021-07-21 11:22:33'::TIMESTAMP_LTZ, 'tz': '2022-08-31 13:43:22 +0200'::TIMESTAMP_TZ, 'ntz': '2023-05-22 01:17:19'::TIMESTAMP_NTZ, 'so': {'s': 'child', 'i': 9}, 'sArr': ARRAY_CONSTRUCT('x', 'y', 'z'), 'f64Arr': ARRAY_CONSTRUCT(1.1, 2.2, 3.3), 'someMap': {'x': true, 'y': false}, 'uuid': '%s'}::OBJECT(s VARCHAR, b TINYINT, i16 SMALLINT, i32 INTEGER, i64 BIGINT, f32 FLOAT, f64 DOUBLE, nfraction NUMBER(38, 19), bo BOOLEAN, bi BINARY, date DATE, time TIME, ltz TIMESTAMP_LTZ, tz TIMESTAMP_TZ, ntz TIMESTAMP_NTZ, so OBJECT(s VARCHAR, i INTEGER), sArr ARRAY(VARCHAR), f64Arr ARRAY(DOUBLE), someMap MAP(VARCHAR, BOOLEAN), uuid VARCHAR)", uid)) defer rows.Close() assertTrueF(t, rows.Next()) var res *objectWithAllTypes @@ -301,10 +323,15 @@ type objectWithAllTypesNullable struct { sArr []string f64Arr []float64 someMap map[string]bool + uuid testUUID } func (o *objectWithAllTypesNullable) Scan(val any) error { - st := val.(StructuredObject) + st, ok := val.(StructuredObject) + if !ok { + return fmt.Errorf("expected StructuredObject, got %T", val) + } + var err error if o.s, err = st.GetNullString("s"); err != nil { return err @@ -375,6 +402,13 @@ func (o *objectWithAllTypesNullable) Scan(val any) error { if someMap != nil { o.someMap = someMap.(map[string]bool) } + uuidStr, err := st.GetNullString("uuid") + if err != nil { + return err + } + + o.uuid = parseTestUUID(uuidStr.String) + return nil } @@ -430,6 +464,9 @@ func (o *objectWithAllTypesNullable) Write(sowc StructuredObjectWriterContext) e if err := sowc.WriteRaw("someMap", o.someMap); err != nil { return err } + if err := sowc.WriteNullString("uuid", sql.NullString{String: o.uuid.String(), Valid: true}); err != nil { + return err + } return nil } @@ -441,9 +478,9 @@ func TestObjectWithAllTypesNullable(t *testing.T) { dbt.mustExec("ALTER SESSION SET TIMEZONE = 'Europe/Warsaw'") forAllStructureTypeFormats(dbt, func(t *testing.T, format string) { t.Run("null", func(t *testing.T) { - rows := dbt.mustQueryContextT(ctx, t, "select null, object_construct_keep_null('s', null, 'b', null, 'i16', null, 'i32', null, 'i64', null, 'f64', null, 'bo', null, 'bi', null, 'date', null, 'time', null, 'ltz', null, 'tz', null, 'ntz', null, 'so', null, 'sArr', null, 'f64Arr', null, 'someMap', null)::OBJECT(s VARCHAR, b TINYINT, i16 SMALLINT, i32 INTEGER, i64 BIGINT, f64 DOUBLE, bo BOOLEAN, bi BINARY, date DATE, time TIME, ltz TIMESTAMP_LTZ, tz TIMESTAMP_TZ, ntz TIMESTAMP_NTZ, so OBJECT(s VARCHAR, i INTEGER), sArr ARRAY(VARCHAR), f64Arr ARRAY(DOUBLE), someMap MAP(VARCHAR, BOOLEAN))") + rows := dbt.mustQueryContextT(ctx, t, "select null, object_construct_keep_null('s', null, 'b', null, 'i16', null, 'i32', null, 'i64', null, 'f64', null, 'bo', null, 'bi', null, 'date', null, 'time', null, 'ltz', null, 'tz', null, 'ntz', null, 'so', null, 'sArr', null, 'f64Arr', null, 'someMap', null, 'uuid', null)::OBJECT(s VARCHAR, b TINYINT, i16 SMALLINT, i32 INTEGER, i64 BIGINT, f64 DOUBLE, bo BOOLEAN, bi BINARY, date DATE, time TIME, ltz TIMESTAMP_LTZ, tz TIMESTAMP_TZ, ntz TIMESTAMP_NTZ, so OBJECT(s VARCHAR, i INTEGER), sArr ARRAY(VARCHAR), f64Arr ARRAY(DOUBLE), someMap MAP(VARCHAR, BOOLEAN), uuid VARCHAR)") defer rows.Close() - rows.Next() + assertTrueF(t, rows.Next()) var ignore sql.NullInt32 var res objectWithAllTypesNullable err := rows.Scan(&ignore, &res) @@ -464,9 +501,11 @@ func TestObjectWithAllTypesNullable(t *testing.T) { assertEqualE(t, res.ntz, sql.NullTime{Valid: false}) var so *simpleObject assertDeepEqualE(t, res.so, so) + assertEqualE(t, res.uuid, testUUID{}) }) t.Run("not null", func(t *testing.T) { - rows := dbt.mustQueryContextT(ctx, t, "select 1, object_construct_keep_null('s', 'abc', 'b', 1, 'i16', 2, 'i32', 3, 'i64', 9223372036854775807, 'f64', 2.2, 'bo', true, 'bi', TO_BINARY('616263', 'HEX'), 'date', '2024-03-21'::DATE, 'time', '13:03:02'::TIME, 'ltz', '2021-07-21 11:22:33'::TIMESTAMP_LTZ, 'tz', '2022-08-31 13:43:22 +0200'::TIMESTAMP_TZ, 'ntz', '2023-05-22 01:17:19'::TIMESTAMP_NTZ, 'so', {'s': 'child', 'i': 9}::OBJECT, 'sArr', ARRAY_CONSTRUCT('x', 'y', 'z'), 'f64Arr', ARRAY_CONSTRUCT(1.1, 2.2, 3.3), 'someMap', {'x': true, 'y': false})::OBJECT(s VARCHAR, b TINYINT, i16 SMALLINT, i32 INTEGER, i64 BIGINT, f64 DOUBLE, bo BOOLEAN, bi BINARY, date DATE, time TIME, ltz TIMESTAMP_LTZ, tz TIMESTAMP_TZ, ntz TIMESTAMP_NTZ, so OBJECT(s VARCHAR, i INTEGER), sArr ARRAY(VARCHAR), f64Arr ARRAY(DOUBLE), someMap MAP(VARCHAR, BOOLEAN))") + uuid := newTestUUID() + rows := dbt.mustQueryContextT(ctx, t, fmt.Sprintf("select 1, object_construct_keep_null('s', 'abc', 'b', 1, 'i16', 2, 'i32', 3, 'i64', 9223372036854775807, 'f64', 2.2, 'bo', true, 'bi', TO_BINARY('616263', 'HEX'), 'date', '2024-03-21'::DATE, 'time', '13:03:02'::TIME, 'ltz', '2021-07-21 11:22:33'::TIMESTAMP_LTZ, 'tz', '2022-08-31 13:43:22 +0200'::TIMESTAMP_TZ, 'ntz', '2023-05-22 01:17:19'::TIMESTAMP_NTZ, 'so', {'s': 'child', 'i': 9}::OBJECT, 'sArr', ARRAY_CONSTRUCT('x', 'y', 'z'), 'f64Arr', ARRAY_CONSTRUCT(1.1, 2.2, 3.3), 'someMap', {'x': true, 'y': false}, 'uuid', '%s')::OBJECT(s VARCHAR, b TINYINT, i16 SMALLINT, i32 INTEGER, i64 BIGINT, f64 DOUBLE, bo BOOLEAN, bi BINARY, date DATE, time TIME, ltz TIMESTAMP_LTZ, tz TIMESTAMP_TZ, ntz TIMESTAMP_NTZ, so OBJECT(s VARCHAR, i INTEGER), sArr ARRAY(VARCHAR), f64Arr ARRAY(DOUBLE), someMap MAP(VARCHAR, BOOLEAN), uuid VARCHAR)", uuid)) defer rows.Close() rows.Next() var ignore sql.NullInt32 @@ -497,6 +536,7 @@ func TestObjectWithAllTypesNullable(t *testing.T) { assertDeepEqualE(t, res.sArr, []string{"x", "y", "z"}) assertDeepEqualE(t, res.f64Arr, []float64{1.1, 2.2, 3.3}) assertDeepEqualE(t, res.someMap, map[string]bool{"x": true, "y": false}) + assertEqualE(t, res.uuid.String(), uuid.String()) }) }) }) @@ -525,7 +565,11 @@ type objectWithAllTypesSimpleScan struct { } func (so *objectWithAllTypesSimpleScan) Scan(val any) error { - st := val.(StructuredObject) + st, ok := val.(StructuredObject) + if !ok { + return fmt.Errorf("expected StructuredObject, got %T", val) + } + return st.ScanTo(so) } @@ -534,13 +578,14 @@ func (so *objectWithAllTypesSimpleScan) Write(sowc StructuredObjectWriterContext } func TestObjectWithAllTypesSimpleScan(t *testing.T) { + uid := newTestUUID() warsawTz, err := time.LoadLocation("Europe/Warsaw") assertNilF(t, err) ctx := WithStructuredTypesEnabled(context.Background()) runDBTest(t, func(dbt *DBTest) { dbt.mustExec("ALTER SESSION SET TIMEZONE = 'Europe/Warsaw'") forAllStructureTypeFormats(dbt, func(t *testing.T, format string) { - rows := dbt.mustQueryContextT(ctx, t, "SELECT 1, {'s': 'some string', 'b': 1, 'i16': 2, 'i32': 3, 'i64': 9223372036854775807, 'f32': '1.1', 'f64': 2.2, 'nfraction': 3.3, 'bo': true, 'bi': TO_BINARY('616263', 'HEX'), 'date': '2024-03-21'::DATE, 'time': '13:03:02'::TIME, 'ltz': '2021-07-21 11:22:33'::TIMESTAMP_LTZ, 'tz': '2022-08-31 13:43:22 +0200'::TIMESTAMP_TZ, 'ntz': '2023-05-22 01:17:19'::TIMESTAMP_NTZ, 'so': {'s': 'child', 'i': 9}, 'sArr': ARRAY_CONSTRUCT('x', 'y', 'z'), 'f64Arr': ARRAY_CONSTRUCT(1.1, 2.2, 3.3), 'someMap': {'x': true, 'y': false}}::OBJECT(s VARCHAR, b TINYINT, i16 SMALLINT, i32 INTEGER, i64 BIGINT, f32 FLOAT, f64 DOUBLE, nfraction NUMBER(38, 19), bo BOOLEAN, bi BINARY, date DATE, time TIME, ltz TIMESTAMP_LTZ, tz TIMESTAMP_TZ, ntz TIMESTAMP_NTZ, so OBJECT(s VARCHAR, i INTEGER), sArr ARRAY(VARCHAR), f64Arr ARRAY(DOUBLE), someMap MAP(VARCHAR, BOOLEAN))") + rows := dbt.mustQueryContextT(ctx, t, fmt.Sprintf("SELECT 1, {'s': 'some string', 'b': 1, 'i16': 2, 'i32': 3, 'i64': 9223372036854775807, 'f32': '1.1', 'f64': 2.2, 'nfraction': 3.3, 'bo': true, 'bi': TO_BINARY('616263', 'HEX'), 'date': '2024-03-21'::DATE, 'time': '13:03:02'::TIME, 'ltz': '2021-07-21 11:22:33'::TIMESTAMP_LTZ, 'tz': '2022-08-31 13:43:22 +0200'::TIMESTAMP_TZ, 'ntz': '2023-05-22 01:17:19'::TIMESTAMP_NTZ, 'so': {'s': 'child', 'i': 9}, 'sArr': ARRAY_CONSTRUCT('x', 'y', 'z'), 'f64Arr': ARRAY_CONSTRUCT(1.1, 2.2, 3.3), 'someMap': {'x': true, 'y': false}, 'uuid': '%s'}::OBJECT(s VARCHAR, b TINYINT, i16 SMALLINT, i32 INTEGER, i64 BIGINT, f32 FLOAT, f64 DOUBLE, nfraction NUMBER(38, 19), bo BOOLEAN, bi BINARY, date DATE, time TIME, ltz TIMESTAMP_LTZ, tz TIMESTAMP_TZ, ntz TIMESTAMP_NTZ, so OBJECT(s VARCHAR, i INTEGER), sArr ARRAY(VARCHAR), f64Arr ARRAY(DOUBLE), someMap MAP(VARCHAR, BOOLEAN), uuid VARCHAR)", uid)) defer rows.Close() rows.Next() var ignore int @@ -577,7 +622,7 @@ func TestNullObjectSimpleScan(t *testing.T) { runDBTest(t, func(dbt *DBTest) { forAllStructureTypeFormats(dbt, func(t *testing.T, format string) { t.Run("null", func(t *testing.T) { - rows := dbt.mustQueryContextT(ctx, t, "SELECT null::OBJECT(s VARCHAR, b TINYINT, i16 SMALLINT, i32 INTEGER, i64 BIGINT, f32 FLOAT, f64 DOUBLE, nfraction NUMBER(38, 19), bo BOOLEAN, bi BINARY, date DATE, time TIME, ltz TIMESTAMP_LTZ, tz TIMESTAMP_TZ, ntz TIMESTAMP_NTZ, so OBJECT(s VARCHAR, i INTEGER), sArr ARRAY(VARCHAR), f64Arr ARRAY(DOUBLE), someMap MAP(VARCHAR, BOOLEAN))") + rows := dbt.mustQueryContextT(ctx, t, "SELECT null::OBJECT(s VARCHAR, b TINYINT, i16 SMALLINT, i32 INTEGER, i64 BIGINT, f32 FLOAT, f64 DOUBLE, nfraction NUMBER(38, 19), bo BOOLEAN, bi BINARY, date DATE, time TIME, ltz TIMESTAMP_LTZ, tz TIMESTAMP_TZ, ntz TIMESTAMP_NTZ, so OBJECT(s VARCHAR, i INTEGER), sArr ARRAY(VARCHAR), f64Arr ARRAY(DOUBLE), someMap MAP(VARCHAR, BOOLEAN), uuid VARCHAR)") defer rows.Close() assertTrueF(t, rows.Next()) var res *objectWithAllTypesSimpleScan @@ -586,7 +631,8 @@ func TestNullObjectSimpleScan(t *testing.T) { assertNilE(t, res) }) t.Run("not null", func(t *testing.T) { - rows := dbt.mustQueryContextT(ctx, t, "SELECT {'s': 'some string', 'b': 1, 'i16': 2, 'i32': 3, 'i64': 9223372036854775807, 'f32': '1.1', 'f64': 2.2, 'nfraction': 3.3, 'bo': true, 'bi': TO_BINARY('616263', 'HEX'), 'date': '2024-03-21'::DATE, 'time': '13:03:02'::TIME, 'ltz': '2021-07-21 11:22:33'::TIMESTAMP_LTZ, 'tz': '2022-08-31 13:43:22 +0200'::TIMESTAMP_TZ, 'ntz': '2023-05-22 01:17:19'::TIMESTAMP_NTZ, 'so': {'s': 'child', 'i': 9}, 'sArr': ARRAY_CONSTRUCT('x', 'y', 'z'), 'f64Arr': ARRAY_CONSTRUCT(1.1, 2.2, 3.3), 'someMap': {'x': true, 'y': false}}::OBJECT(s VARCHAR, b TINYINT, i16 SMALLINT, i32 INTEGER, i64 BIGINT, f32 FLOAT, f64 DOUBLE, nfraction NUMBER(38, 19), bo BOOLEAN, bi BINARY, date DATE, time TIME, ltz TIMESTAMP_LTZ, tz TIMESTAMP_TZ, ntz TIMESTAMP_NTZ, so OBJECT(s VARCHAR, i INTEGER), sArr ARRAY(VARCHAR), f64Arr ARRAY(DOUBLE), someMap MAP(VARCHAR, BOOLEAN))") + uid := newTestUUID() + rows := dbt.mustQueryContextT(ctx, t, fmt.Sprintf("SELECT {'s': 'some string', 'b': 1, 'i16': 2, 'i32': 3, 'i64': 9223372036854775807, 'f32': '1.1', 'f64': 2.2, 'nfraction': 3.3, 'bo': true, 'bi': TO_BINARY('616263', 'HEX'), 'date': '2024-03-21'::DATE, 'time': '13:03:02'::TIME, 'ltz': '2021-07-21 11:22:33'::TIMESTAMP_LTZ, 'tz': '2022-08-31 13:43:22 +0200'::TIMESTAMP_TZ, 'ntz': '2023-05-22 01:17:19'::TIMESTAMP_NTZ, 'so': {'s': 'child', 'i': 9}, 'sArr': ARRAY_CONSTRUCT('x', 'y', 'z'), 'f64Arr': ARRAY_CONSTRUCT(1.1, 2.2, 3.3), 'someMap': {'x': true, 'y': false}, 'uuid': '%s'}::OBJECT(s VARCHAR, b TINYINT, i16 SMALLINT, i32 INTEGER, i64 BIGINT, f32 FLOAT, f64 DOUBLE, nfraction NUMBER(38, 19), bo BOOLEAN, bi BINARY, date DATE, time TIME, ltz TIMESTAMP_LTZ, tz TIMESTAMP_TZ, ntz TIMESTAMP_NTZ, so OBJECT(s VARCHAR, i INTEGER), sArr ARRAY(VARCHAR), f64Arr ARRAY(DOUBLE), someMap MAP(VARCHAR, BOOLEAN), uuid VARCHAR)", uid)) defer rows.Close() assertTrueF(t, rows.Next()) var res *objectWithAllTypesSimpleScan @@ -619,7 +665,11 @@ type objectWithAllTypesNullableSimpleScan struct { } func (o *objectWithAllTypesNullableSimpleScan) Scan(val any) error { - st := val.(StructuredObject) + st, ok := val.(StructuredObject) + if !ok { + return fmt.Errorf("expected StructuredObject, got %T", val) + } + return st.ScanTo(o) } @@ -660,7 +710,8 @@ func TestObjectWithAllTypesSimpleScanNullable(t *testing.T) { assertDeepEqualE(t, res.So, so) }) t.Run("not null", func(t *testing.T) { - rows := dbt.mustQueryContextT(ctx, t, "select 1, object_construct_keep_null('s', 'abc', 'b', 1, 'i16', 2, 'i32', 3, 'i64', 9223372036854775807, 'f64', 2.2, 'bo', true, 'bi', TO_BINARY('616263', 'HEX'), 'date', '2024-03-21'::DATE, 'time', '13:03:02'::TIME, 'ltz', '2021-07-21 11:22:33'::TIMESTAMP_LTZ, 'tz', '2022-08-31 13:43:22 +0200'::TIMESTAMP_TZ, 'ntz', '2023-05-22 01:17:19'::TIMESTAMP_NTZ, 'so', {'s': 'child', 'i': 9}::OBJECT, 'sArr', ARRAY_CONSTRUCT('x', 'y', 'z'), 'f64Arr', ARRAY_CONSTRUCT(1.1, 2.2, 3.3), 'someMap', {'x': true, 'y': false})::OBJECT(s VARCHAR, b TINYINT, i16 SMALLINT, i32 INTEGER, i64 BIGINT, f64 DOUBLE, bo BOOLEAN, bi BINARY, date DATE, time TIME, ltz TIMESTAMP_LTZ, tz TIMESTAMP_TZ, ntz TIMESTAMP_NTZ, so OBJECT(s VARCHAR, i INTEGER), sArr ARRAY(VARCHAR), f64Arr ARRAY(DOUBLE), someMap MAP(VARCHAR, BOOLEAN))") + uuid := newTestUUID() + rows := dbt.mustQueryContextT(ctx, t, fmt.Sprintf("select 1, object_construct_keep_null('s', 'abc', 'b', 1, 'i16', 2, 'i32', 3, 'i64', 9223372036854775807, 'f64', 2.2, 'bo', true, 'bi', TO_BINARY('616263', 'HEX'), 'date', '2024-03-21'::DATE, 'time', '13:03:02'::TIME, 'ltz', '2021-07-21 11:22:33'::TIMESTAMP_LTZ, 'tz', '2022-08-31 13:43:22 +0200'::TIMESTAMP_TZ, 'ntz', '2023-05-22 01:17:19'::TIMESTAMP_NTZ, 'so', {'s': 'child', 'i': 9}::OBJECT, 'sArr', ARRAY_CONSTRUCT('x', 'y', 'z'), 'f64Arr', ARRAY_CONSTRUCT(1.1, 2.2, 3.3), 'someMap', {'x': true, 'y': false}, 'uuid', '%s')::OBJECT(s VARCHAR, b TINYINT, i16 SMALLINT, i32 INTEGER, i64 BIGINT, f64 DOUBLE, bo BOOLEAN, bi BINARY, date DATE, time TIME, ltz TIMESTAMP_LTZ, tz TIMESTAMP_TZ, ntz TIMESTAMP_NTZ, so OBJECT(s VARCHAR, i INTEGER), sArr ARRAY(VARCHAR), f64Arr ARRAY(DOUBLE), someMap MAP(VARCHAR, BOOLEAN), uuid VARCHAR)", uuid)) defer rows.Close() rows.Next() var ignore sql.NullInt32 @@ -702,7 +753,11 @@ type objectWithCustomNameAndIgnoredField struct { } func (o *objectWithCustomNameAndIgnoredField) Scan(val any) error { - st := val.(StructuredObject) + st, ok := val.(StructuredObject) + if !ok { + return fmt.Errorf("expected StructuredObject, got %T", val) + } + return st.ScanTo(o) } @@ -1804,7 +1859,11 @@ type HigherPrecisionStruct struct { } func (hps *HigherPrecisionStruct) Scan(val any) error { - st := val.(StructuredObject) + st, ok := val.(StructuredObject) + if !ok { + return fmt.Errorf("expected StructuredObject, got %T", val) + } + var err error if hps.i, err = st.GetBigInt("i"); err != nil { return err diff --git a/structured_type_write_test.go b/structured_type_write_test.go index 2be5e49c3..31720124e 100644 --- a/structured_type_write_test.go +++ b/structured_type_write_test.go @@ -132,7 +132,7 @@ func TestBindingObjectWithSchema(t *testing.T) { ctx := WithStructuredTypesEnabled(context.Background()) assertNilF(t, err) runDBTest(t, func(dbt *DBTest) { - dbt.mustExec("CREATE OR REPLACE TABLE test_object_binding (obj OBJECT(s VARCHAR, b TINYINT, i16 SMALLINT, i32 INTEGER, i64 BIGINT, f32 FLOAT, f64 DOUBLE, nfraction NUMBER(38, 9), bo boolean, bi BINARY, date DATE, time TIME, ltz TIMESTAMPLTZ, ntz TIMESTAMPNTZ, tz TIMESTAMPTZ, so OBJECT(s VARCHAR, i INTEGER), sArr ARRAY(VARCHAR), f64Arr ARRAY(DOUBLE), someMap MAP(VARCHAR, BOOLEAN)))") + dbt.mustExec("CREATE OR REPLACE TABLE test_object_binding (obj OBJECT(s VARCHAR, b TINYINT, i16 SMALLINT, i32 INTEGER, i64 BIGINT, f32 FLOAT, f64 DOUBLE, nfraction NUMBER(38, 9), bo boolean, bi BINARY, date DATE, time TIME, ltz TIMESTAMPLTZ, ntz TIMESTAMPNTZ, tz TIMESTAMPTZ, so OBJECT(s VARCHAR, i INTEGER), sArr ARRAY(VARCHAR), f64Arr ARRAY(DOUBLE), someMap MAP(VARCHAR, BOOLEAN), uuid VARCHAR))") defer func() { dbt.mustExec("DROP TABLE IF EXISTS test_object_binding") }() @@ -159,6 +159,7 @@ func TestBindingObjectWithSchema(t *testing.T) { sArr: []string{"a", "b"}, f64Arr: []float64{1.1, 2.2}, someMap: map[string]bool{"a": true, "b": false}, + uuid: newTestUUID(), } dbt.mustExecT(t, "INSERT INTO test_object_binding SELECT (?)", o) rows := dbt.mustQueryContextT(ctx, t, "SELECT * FROM test_object_binding WHERE obj = ?", o) @@ -189,6 +190,7 @@ func TestBindingObjectWithSchema(t *testing.T) { assertDeepEqualE(t, res.sArr, o.sArr) assertDeepEqualE(t, res.f64Arr, o.f64Arr) assertDeepEqualE(t, res.someMap, o.someMap) + assertEqualE(t, res.uuid.String(), o.uuid.String()) }) } @@ -197,7 +199,7 @@ func TestBindingObjectWithNullableFieldsWithSchema(t *testing.T) { assertNilF(t, err) ctx := WithStructuredTypesEnabled(context.Background()) runDBTest(t, func(dbt *DBTest) { - dbt.mustExec("CREATE OR REPLACE TABLE test_object_binding (obj OBJECT(s VARCHAR, b TINYINT, i16 SMALLINT, i32 INTEGER, i64 BIGINT, f64 DOUBLE, bo boolean, bi BINARY, date DATE, time TIME, ltz TIMESTAMPLTZ, ntz TIMESTAMPNTZ, tz TIMESTAMPTZ, so OBJECT(s VARCHAR, i INTEGER), sArr ARRAY(VARCHAR), f64Arr ARRAY(DOUBLE), someMap MAP(VARCHAR, BOOLEAN)))") + dbt.mustExec("CREATE OR REPLACE TABLE test_object_binding (obj OBJECT(s VARCHAR, b TINYINT, i16 SMALLINT, i32 INTEGER, i64 BIGINT, f64 DOUBLE, bo boolean, bi BINARY, date DATE, time TIME, ltz TIMESTAMPLTZ, ntz TIMESTAMPNTZ, tz TIMESTAMPTZ, so OBJECT(s VARCHAR, i INTEGER), sArr ARRAY(VARCHAR), f64Arr ARRAY(DOUBLE), someMap MAP(VARCHAR, BOOLEAN), uuid VARCHAR))") defer func() { dbt.mustExec("DROP TABLE IF EXISTS test_object_binding") }() @@ -223,6 +225,7 @@ func TestBindingObjectWithNullableFieldsWithSchema(t *testing.T) { sArr: []string{"a", "b"}, f64Arr: []float64{1.1, 2.2}, someMap: map[string]bool{"a": true, "b": false}, + uuid: newTestUUID(), } dbt.mustExecT(t, "INSERT INTO test_object_binding SELECT (?)", o) rows := dbt.mustQueryContextT(ctx, t, "SELECT * FROM test_object_binding WHERE obj = ?", o) @@ -251,6 +254,7 @@ func TestBindingObjectWithNullableFieldsWithSchema(t *testing.T) { assertDeepEqualE(t, res.sArr, o.sArr) assertDeepEqualE(t, res.f64Arr, o.f64Arr) assertDeepEqualE(t, res.someMap, o.someMap) + assertEqualE(t, res.uuid.String(), o.uuid.String()) }) t.Run("null", func(t *testing.T) { o := &objectWithAllTypesNullable{ @@ -503,7 +507,7 @@ func TestBindingObjectWithAllTypesNullable(t *testing.T) { ctx := WithStructuredTypesEnabled(context.Background()) runDBTest(t, func(dbt *DBTest) { dbt.forceJSON() - dbt.mustExec("CREATE OR REPLACE TABLE test_object_binding (o OBJECT(o OBJECT(s VARCHAR, b TINYINT, i16 SMALLINT, i32 INTEGER, i64 BIGINT, f32 FLOAT, f64 DOUBLE, nfraction NUMBER(38, 9), bo boolean, bi BINARY, date DATE, time TIME, ltz TIMESTAMPLTZ, tz TIMESTAMPTZ, ntz TIMESTAMPNTZ, so OBJECT(s VARCHAR, i INTEGER), sArr ARRAY(VARCHAR), f64Arr ARRAY(DOUBLE), someMap MAP(VARCHAR, BOOLEAN))))") + dbt.mustExec("CREATE OR REPLACE TABLE test_object_binding (o OBJECT(o OBJECT(s VARCHAR, b TINYINT, i16 SMALLINT, i32 INTEGER, i64 BIGINT, f32 FLOAT, f64 DOUBLE, nfraction NUMBER(38, 9), bo boolean, bi BINARY, date DATE, time TIME, ltz TIMESTAMPLTZ, tz TIMESTAMPTZ, ntz TIMESTAMPNTZ, so OBJECT(s VARCHAR, i INTEGER), sArr ARRAY(VARCHAR), f64Arr ARRAY(DOUBLE), someMap MAP(VARCHAR, BOOLEAN), uuid VARCHAR)))") defer func() { dbt.mustExec("DROP TABLE IF EXISTS test_object_binding") }()