Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
110 changes: 100 additions & 10 deletions session.go
Original file line number Diff line number Diff line change
Expand Up @@ -800,6 +800,95 @@ func (s *Session) MapExecuteBatchCAS(batch *Batch, dest map[string]interface{})
return applied, iter, iter.err
}

// connectionType is a custom type that represents the different stages
// of a client connection in a Cassandra cluster. It is used to filter and categorize
// connections based on their current state.
type connectionType string

const (
Ready connectionType = "ready"
Connecting connectionType = "connecting"
Idle connectionType = "idle"
Closed connectionType = "closed"
Failed connectionType = "failed"
)

// ClientConnection represents a client connection to a Cassandra node. It holds detailed
// information about the connection, including the client address, connection stage, driver details,
// and various configuration options.
type ClientConnection struct {
Address string
Port int
ConnectionStage string
DriverName string
DriverVersion string
Hostname string
KeyspaceName *string
ProtocolVersion int
RequestCount int
SSLCipherSuite *string
SSLEnabled bool
SSLProtocol *string
Username string
}

// RetrieveClientConnections retrieves a list of client connections from the
// `system_views.clients` table based on the specified connection type. The function
// queries the Cassandra database for connections with a given `connection_stage` and
// scans the results into a slice of `ClientConnection` structs. It handles nullable
// fields and returns the list of connections or an error if the operation fails.
func (s *Session) RetrieveClientConnections(connectionType connectionType) ([]*ClientConnection, error) {
const stmt = `
SELECT address, port, connection_stage, driver_name, driver_version,
hostname, keyspace_name, protocol_version, request_count,
ssl_cipher_suite, ssl_enabled, ssl_protocol, username
FROM system_views.clients
WHERE connection_stage = ?`

iter := s.control.query(stmt, connectionType)
if iter.NumRows() == 0 {
return nil, ErrConnectionsDoNotExist
}
defer iter.Close()

var connections []*ClientConnection
for {
conn := &ClientConnection{}

// Variables to hold nullable fields
var keyspaceName, sslCipherSuite, sslProtocol *string

if !iter.Scan(
&conn.Address,
&conn.Port,
&conn.ConnectionStage,
&conn.DriverName,
&conn.DriverVersion,
&conn.Hostname,
&keyspaceName,
&conn.ProtocolVersion,
&conn.RequestCount,
&sslCipherSuite,
&conn.SSLEnabled,
&sslProtocol,
&conn.Username,
) {
if err := iter.Close(); err != nil {
return nil, err
}
break
}

conn.KeyspaceName = keyspaceName
conn.SSLCipherSuite = sslCipherSuite
conn.SSLProtocol = sslProtocol

connections = append(connections, conn)
}

return connections, nil
}

type hostMetrics struct {
// Attempts is count of how many times this query has been attempted for this host.
// An attempt is either a retry or fetching next page of results.
Expand Down Expand Up @@ -2279,16 +2368,17 @@ func (e Error) Error() string {
}

var (
ErrNotFound = errors.New("not found")
ErrUnavailable = errors.New("unavailable")
ErrUnsupported = errors.New("feature not supported")
ErrTooManyStmts = errors.New("too many statements")
ErrUseStmt = errors.New("use statements aren't supported. Please see https://github.com/apache/cassandra-gocql-driver for explanation.")
ErrSessionClosed = errors.New("session has been closed")
ErrNoConnections = errors.New("gocql: no hosts available in the pool")
ErrNoKeyspace = errors.New("no keyspace provided")
ErrKeyspaceDoesNotExist = errors.New("keyspace does not exist")
ErrNoMetadata = errors.New("no metadata available")
ErrNotFound = errors.New("not found")
ErrUnavailable = errors.New("unavailable")
ErrUnsupported = errors.New("feature not supported")
ErrTooManyStmts = errors.New("too many statements")
ErrUseStmt = errors.New("use statements aren't supported. Please see https://github.com/gocql/gocql for explanation.")
ErrSessionClosed = errors.New("session has been closed")
ErrNoConnections = errors.New("gocql: no hosts available in the pool")
ErrNoKeyspace = errors.New("no keyspace provided")
ErrKeyspaceDoesNotExist = errors.New("keyspace does not exist")
ErrConnectionsDoNotExist = errors.New("connections do not exist")
ErrNoMetadata = errors.New("no metadata available")
)

type ErrProtocol struct{ error }
Expand Down
74 changes: 74 additions & 0 deletions session_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -347,3 +347,77 @@ func TestIsUseStatement(t *testing.T) {
}
}
}

func TestRetrieveClientConnections(t *testing.T) {
testCases := []struct {
name string
connectionType connectionType
expectedResult []*ClientConnection
expectError bool
}{
{
name: "Valid ready connections",
connectionType: Ready,
expectedResult: []*ClientConnection{
{
Address: "127.0.0.1",
Port: 9042,
ConnectionStage: "ready",
DriverName: "gocql",
DriverVersion: "v1.0.0",
Hostname: "localhost",
KeyspaceName: nil,
ProtocolVersion: 4,
RequestCount: 10,
SSLCipherSuite: nil,
SSLEnabled: true,
SSLProtocol: nil,
Username: "user1",
},
},
expectError: false,
},
{
name: "No connections found",
connectionType: Closed,
expectedResult: nil,
expectError: true,
},
}

for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
session := &Session{
control: &controlConn{},
}

results, err := session.RetrieveClientConnections(tc.connectionType)

if tc.expectError {
if err == nil {
t.Fatalf("expected an error but got none")
}
} else {
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if !compareClientConnections(results, tc.expectedResult) {
t.Fatalf("expected result %+v, got %+v", tc.expectedResult, results)
}
}
})
}
}

// Helper function to compare two slices of ClientConnection pointers
func compareClientConnections(a, b []*ClientConnection) bool {
if len(a) != len(b) {
return false
}
for i := range a {
if *a[i] != *b[i] {
return false
}
}
return true
}