diff --git a/tests/tool.go b/tests/tool.go index 488108fc624..5c5c2fde5c6 100644 --- a/tests/tool.go +++ b/tests/tool.go @@ -34,6 +34,7 @@ import ( "github.com/google/uuid" "github.com/googleapis/genai-toolbox/internal/server/mcp/jsonrpc" "github.com/googleapis/genai-toolbox/internal/sources" + "github.com/jackc/pgx/v5" "github.com/jackc/pgx/v5/pgxpool" ) @@ -3661,26 +3662,43 @@ func RunMSSQLListTablesTest(t *testing.T, tableNameParam, tableNameAuth string) } } +func CreateAndLockPostgresTable(t *testing.T, ctx context.Context, pool *pgxpool.Pool, tableName string) func() { + _, err := pool.Exec(ctx, fmt.Sprintf("CREATE TABLE IF NOT EXISTS %s (id INT PRIMARY KEY)", pgx.Identifier{tableName}.Sanitize())) + if err != nil { + t.Fatalf("unable to create table: %s", err) + } + + tx, err := pool.BeginTx(ctx, pgx.TxOptions{}) + if err != nil { + t.Fatalf("unable to create transaction: %s", err) + } + if _, err := tx.Exec(ctx, fmt.Sprintf("LOCK TABLE %s IN ACCESS EXCLUSIVE MODE", pgx.Identifier{tableName}.Sanitize())); err != nil { + t.Fatalf("unable to acquire lock: %s", err) + } + + return func() { + if err := tx.Rollback(ctx); err != nil { + t.Fatalf("unable to rollback transaction: %s", err) + } + if _, err := pool.Exec(ctx, fmt.Sprintf("DROP TABLE IF EXISTS %s", pgx.Identifier{tableName}.Sanitize())); err != nil { + t.Fatalf("unable to drop table: %s", err) + } + } +} + // RunPostgresListLocksTest runs tests for the postgres list-locks tool func RunPostgresListLocksTest(t *testing.T, ctx context.Context, pool *pgxpool.Pool) { + + // Create and lock a test table + cleanup := CreateAndLockPostgresTable(t, ctx, pool, "test_postgres_list_locks_table") + defer cleanup() + type lockDetails struct { - Pid any `json:"pid"` - Usename string `json:"usename"` - Database string `json:"database"` - RelName string `json:"relname"` - LockType string `json:"locktype"` - Mode string `json:"mode"` - Granted bool `json:"granted"` - FastPath bool `json:"fastpath"` - VirtualXid any `json:"virtualxid"` - TransactionId any `json:"transactionid"` - ClassId any `json:"classid"` - ObjId any `json:"objid"` - ObjSubId any `json:"objsubid"` - PageNumber any `json:"page"` - TupleNumber any `json:"tuple"` - VirtualBlock any `json:"virtualblock"` - BlockNumber any `json:"blockno"` + Pid any `json:"pid"` + Usename string `json:"usename"` + Query string `json:"query"` + TrxID string `json:"trxid"` + Locks string `json:"locks"` } invokeTcs := []struct { @@ -3693,7 +3711,7 @@ func RunPostgresListLocksTest(t *testing.T, ctx context.Context, pool *pgxpool.P name: "invoke list_locks with no arguments", requestBody: bytes.NewBuffer([]byte(`{}`)), wantStatusCode: http.StatusOK, - expectResults: false, // locks may or may not exist + expectResults: true, }, } for _, tc := range invokeTcs { @@ -3725,12 +3743,9 @@ func RunPostgresListLocksTest(t *testing.T, ctx context.Context, pool *pgxpool.P t.Fatalf("failed to unmarshal result: %v, result string: %s", err, resultString) } } - - // Verify that if results exist, they have the expected structure - for _, lock := range got { - if lock.LockType == "" { - t.Errorf("lock type should not be empty") - } + // Verify that we got results if expected + if tc.expectResults && len(got) == 0 { + t.Errorf("expected results but got none") } }) }