Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 39 additions & 24 deletions tests/tool.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ import (
"github.com/google/uuid"
"github.com/googleapis/genai-toolbox/internal/server/mcp/jsonrpc"
"github.com/googleapis/genai-toolbox/internal/sources"
"github.com/jackc/pgx/v5"
"github.com/jackc/pgx/v5/pgxpool"
)

Expand Down Expand Up @@ -3661,26 +3662,43 @@ func RunMSSQLListTablesTest(t *testing.T, tableNameParam, tableNameAuth string)
}
}

func CreateAndLockPostgresTable(t *testing.T, ctx context.Context, pool *pgxpool.Pool, tableName string) func() {
_, err := pool.Exec(ctx, fmt.Sprintf("CREATE TABLE IF NOT EXISTS %s (id INT PRIMARY KEY)", pgx.Identifier{tableName}.Sanitize()))
if err != nil {
t.Fatalf("unable to create table: %s", err)
}

tx, err := pool.BeginTx(ctx, pgx.TxOptions{})
if err != nil {
t.Fatalf("unable to create transaction: %s", err)
}
if _, err := tx.Exec(ctx, fmt.Sprintf("LOCK TABLE %s IN ACCESS EXCLUSIVE MODE", pgx.Identifier{tableName}.Sanitize())); err != nil {
t.Fatalf("unable to acquire lock: %s", err)
}

return func() {
if err := tx.Rollback(ctx); err != nil {
t.Fatalf("unable to rollback transaction: %s", err)
}
if _, err := pool.Exec(ctx, fmt.Sprintf("DROP TABLE IF EXISTS %s", pgx.Identifier{tableName}.Sanitize())); err != nil {
t.Fatalf("unable to drop table: %s", err)
}
}
}

// RunPostgresListLocksTest runs tests for the postgres list-locks tool
func RunPostgresListLocksTest(t *testing.T, ctx context.Context, pool *pgxpool.Pool) {

// Create and lock a test table
cleanup := CreateAndLockPostgresTable(t, ctx, pool, "test_postgres_list_locks_table")
defer cleanup()

type lockDetails struct {
Pid any `json:"pid"`
Usename string `json:"usename"`
Database string `json:"database"`
RelName string `json:"relname"`
LockType string `json:"locktype"`
Mode string `json:"mode"`
Granted bool `json:"granted"`
FastPath bool `json:"fastpath"`
VirtualXid any `json:"virtualxid"`
TransactionId any `json:"transactionid"`
ClassId any `json:"classid"`
ObjId any `json:"objid"`
ObjSubId any `json:"objsubid"`
PageNumber any `json:"page"`
TupleNumber any `json:"tuple"`
VirtualBlock any `json:"virtualblock"`
BlockNumber any `json:"blockno"`
Pid any `json:"pid"`
Usename string `json:"usename"`
Query string `json:"query"`
TrxID string `json:"trxid"`
Locks string `json:"locks"`
}

invokeTcs := []struct {
Expand All @@ -3693,7 +3711,7 @@ func RunPostgresListLocksTest(t *testing.T, ctx context.Context, pool *pgxpool.P
name: "invoke list_locks with no arguments",
requestBody: bytes.NewBuffer([]byte(`{}`)),
wantStatusCode: http.StatusOK,
expectResults: false, // locks may or may not exist
expectResults: true,
},
}
for _, tc := range invokeTcs {
Expand Down Expand Up @@ -3725,12 +3743,9 @@ func RunPostgresListLocksTest(t *testing.T, ctx context.Context, pool *pgxpool.P
t.Fatalf("failed to unmarshal result: %v, result string: %s", err, resultString)
}
}

// Verify that if results exist, they have the expected structure
for _, lock := range got {
if lock.LockType == "" {
t.Errorf("lock type should not be empty")
}
// Verify that we got results if expected
if tc.expectResults && len(got) == 0 {
t.Errorf("expected results but got none")
}
})
}
Expand Down
Loading