Skip to content

Commit

Permalink
Merge pull request #385 from sylwiaszunejko/merge-upstream
Browse files Browse the repository at this point in the history
Merge upstream
  • Loading branch information
sylwiaszunejko authored Jan 23, 2025
2 parents fe68ec3 + b2d210a commit d95cf2f
Show file tree
Hide file tree
Showing 23 changed files with 734 additions and 311 deletions.
16 changes: 9 additions & 7 deletions batch_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,9 @@ func TestBatch_Errors(t *testing.T) {
t.Fatal(err)
}

b := session.NewBatch(LoggedBatch)
b.Query("SELECT * FROM batch_errors WHERE id=2 AND val=?", nil)
if err := session.ExecuteBatch(b); err == nil {
b := session.Batch(LoggedBatch)
b = b.Query("SELECT * FROM gocql_test.batch_errors WHERE id=2 AND val=?", nil)
if err := b.Exec(); err == nil {
t.Fatal("expected to get error for invalid query in batch")
}
}
Expand All @@ -44,15 +44,17 @@ func TestBatch_WithTimestamp(t *testing.T) {

micros := time.Now().UnixNano()/1e3 - 1000

b := session.NewBatch(LoggedBatch)
b := session.Batch(LoggedBatch)
b.WithTimestamp(micros)
b.Query("INSERT INTO batch_ts (id, val) VALUES (?, ?)", 1, "val")
if err := session.ExecuteBatch(b); err != nil {
b = b.Query("INSERT INTO gocql_test.batch_ts (id, val) VALUES (?, ?)", 1, "val")
b = b.Query("INSERT INTO gocql_test.batch_ts (id, val) VALUES (?, ?)", 2, "val")

if err := b.Exec(); err != nil {
t.Fatal(err)
}

var storedTs int64
if err := session.Query(`SELECT writetime(val) FROM batch_ts WHERE id = ?`, 1).Scan(&storedTs); err != nil {
if err := session.Query(`SELECT writetime(val) FROM gocql_test.batch_ts WHERE id = ?`, 1).Scan(&storedTs); err != nil {
t.Fatal(err)
}

Expand Down
34 changes: 17 additions & 17 deletions cassandra_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ import (
"time"
"unicode"

inf "gopkg.in/inf.v0"
"gopkg.in/inf.v0"
)

func TestEmptyHosts(t *testing.T) {
Expand Down Expand Up @@ -565,15 +565,15 @@ func TestCAS(t *testing.T) {
t.Fatal("truncate:", err)
}

successBatch := session.NewBatch(LoggedBatch)
successBatch := session.Batch(LoggedBatch)
successBatch.Query("INSERT INTO cas_table (title, revid, last_modified) VALUES (?, ?, ?) IF NOT EXISTS", title, revid, modified)
if applied, _, err := session.ExecuteBatchCAS(successBatch, &titleCAS, &revidCAS, &modifiedCAS); err != nil {
t.Fatal("insert:", err)
} else if !applied {
t.Fatalf("insert should have been applied: title=%v revID=%v modified=%v", titleCAS, revidCAS, modifiedCAS)
}

successBatch = session.NewBatch(LoggedBatch)
successBatch = session.Batch(LoggedBatch)
successBatch.Query("INSERT INTO cas_table (title, revid, last_modified) VALUES (?, ?, ?) IF NOT EXISTS", title+"_foo", revid, modified)
casMap := make(map[string]interface{})
if applied, _, err := session.MapExecuteBatchCAS(successBatch, casMap); err != nil {
Expand All @@ -582,22 +582,22 @@ func TestCAS(t *testing.T) {
t.Fatal("insert should have been applied")
}

failBatch := session.NewBatch(LoggedBatch)
failBatch := session.Batch(LoggedBatch)
failBatch.Query("INSERT INTO cas_table (title, revid, last_modified) VALUES (?, ?, ?) IF NOT EXISTS", title, revid, modified)
if applied, _, err := session.ExecuteBatchCAS(successBatch, &titleCAS, &revidCAS, &modifiedCAS); err != nil {
t.Fatal("insert:", err)
} else if applied {
t.Fatalf("insert should have been applied: title=%v revID=%v modified=%v", titleCAS, revidCAS, modifiedCAS)
}

insertBatch := session.NewBatch(LoggedBatch)
insertBatch := session.Batch(LoggedBatch)
insertBatch.Query("INSERT INTO cas_table (title, revid, last_modified) VALUES ('_foo', 2c3af400-73a4-11e5-9381-29463d90c3f0, DATEOF(NOW()))")
insertBatch.Query("INSERT INTO cas_table (title, revid, last_modified) VALUES ('_foo', 3e4ad2f1-73a4-11e5-9381-29463d90c3f0, DATEOF(NOW()))")
if err := session.ExecuteBatch(insertBatch); err != nil {
t.Fatal("insert:", err)
}

failBatch = session.NewBatch(LoggedBatch)
failBatch = session.Batch(LoggedBatch)
failBatch.Query("UPDATE cas_table SET last_modified = DATEOF(NOW()) WHERE title='_foo' AND revid=2c3af400-73a4-11e5-9381-29463d90c3f0 IF last_modified=DATEOF(NOW());")
failBatch.Query("UPDATE cas_table SET last_modified = DATEOF(NOW()) WHERE title='_foo' AND revid=3e4ad2f1-73a4-11e5-9381-29463d90c3f0 IF last_modified=DATEOF(NOW());")
if applied, iter, err := session.ExecuteBatchCAS(failBatch, &titleCAS, &revidCAS, &modifiedCAS); err != nil {
Expand Down Expand Up @@ -722,7 +722,7 @@ func TestBatch(t *testing.T) {
t.Fatal("create table:", err)
}

batch := session.NewBatch(LoggedBatch)
batch := session.Batch(LoggedBatch)
for i := 0; i < 100; i++ {
batch.Query(`INSERT INTO batch_table (id) VALUES (?)`, i)
}
Expand Down Expand Up @@ -754,9 +754,9 @@ func TestUnpreparedBatch(t *testing.T) {

var batch *Batch
if session.cfg.ProtoVersion == 2 {
batch = session.NewBatch(CounterBatch)
batch = session.Batch(CounterBatch)
} else {
batch = session.NewBatch(UnloggedBatch)
batch = session.Batch(UnloggedBatch)
}

for i := 0; i < 100; i++ {
Expand Down Expand Up @@ -795,7 +795,7 @@ func TestBatchLimit(t *testing.T) {
t.Fatal("create table:", err)
}

batch := session.NewBatch(LoggedBatch)
batch := session.Batch(LoggedBatch)
for i := 0; i < 65537; i++ {
batch.Query(`INSERT INTO batch_table2 (id) VALUES (?)`, i)
}
Expand Down Expand Up @@ -849,7 +849,7 @@ func TestTooManyQueryArgs(t *testing.T) {
t.Fatal("'`SELECT * FROM too_many_query_args WHERE id = ?`, 1, 2' should return an error")
}

batch := session.NewBatch(UnloggedBatch)
batch := session.Batch(UnloggedBatch)
batch.Query("INSERT INTO too_many_query_args (id, value) VALUES (?, ?)", 1, 2, 3)
err = session.ExecuteBatch(batch)

Expand Down Expand Up @@ -881,7 +881,7 @@ func TestNotEnoughQueryArgs(t *testing.T) {
t.Fatal("'`SELECT * FROM not_enough_query_args WHERE id = ? and cluster = ?`, 1' should return an error")
}

batch := session.NewBatch(UnloggedBatch)
batch := session.Batch(UnloggedBatch)
batch.Query("INSERT INTO not_enough_query_args (id, cluster, value) VALUES (?, ?, ?)", 1, 2)
err = session.ExecuteBatch(batch)

Expand Down Expand Up @@ -1454,7 +1454,7 @@ func TestBatchQueryInfo(t *testing.T) {
return values, nil
}

batch := session.NewBatch(LoggedBatch)
batch := session.Batch(LoggedBatch)
batch.Bind("INSERT INTO batch_query_info (id, cluster, value) VALUES (?, ?,?)", write)

if err := session.ExecuteBatch(batch); err != nil {
Expand Down Expand Up @@ -1582,7 +1582,7 @@ func TestPrepare_ReprepareBatch(t *testing.T) {
}

stmt, conn := injectInvalidPreparedStatement(t, session, "test_reprepare_statement_batch")
batch := session.NewBatch(UnloggedBatch)
batch := session.Batch(UnloggedBatch)
batch.Query(stmt, "bar")
if err := conn.executeBatch(ctx, batch).Close(); err != nil {
t.Fatalf("Failed to execute query for reprepare statement: %v", err)
Expand Down Expand Up @@ -1966,7 +1966,7 @@ func TestBatchStats(t *testing.T) {
t.Fatalf("failed to create table with error '%v'", err)
}

b := session.NewBatch(LoggedBatch)
b := session.Batch(LoggedBatch)
b.Query("INSERT INTO batchStats (id) VALUES (?)", 1)
b.Query("INSERT INTO batchStats (id) VALUES (?)", 2)

Expand Down Expand Up @@ -2009,7 +2009,7 @@ func TestBatchObserve(t *testing.T) {

var observedBatch *observation

batch := session.NewBatch(LoggedBatch)
batch := session.Batch(LoggedBatch)
batch.Observer(funcBatchObserver(func(ctx context.Context, o ObservedBatch) {
if observedBatch != nil {
t.Fatal("batch observe called more than once")
Expand Down Expand Up @@ -2632,7 +2632,7 @@ func TestUnsetColBatch(t *testing.T) {
t.Fatalf("failed to create table with error '%v'", err)
}

b := session.NewBatch(LoggedBatch)
b := session.Batch(LoggedBatch)
b.Query("INSERT INTO gocql_test.batchUnsetInsert(id, my_int, my_text) VALUES (?,?,?)", 1, 1, UnsetValue)
b.Query("INSERT INTO gocql_test.batchUnsetInsert(id, my_int, my_text) VALUES (?,?,?)", 1, UnsetValue, "")
b.Query("INSERT INTO gocql_test.batchUnsetInsert(id, my_int, my_text) VALUES (?,?,?)", 2, 2, UnsetValue)
Expand Down
8 changes: 7 additions & 1 deletion cluster.go
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,13 @@ type ClusterConfig struct {
// Initial keyspace. Optional.
Keyspace string

// Number of connections per host.
// The size of the connection pool for each host.
// The pool filling runs in separate gourutine during the session initialization phase.
// gocql will always try to get 1 connection on each host pool
// during session initialization AND it will attempt
// to fill each pool afterward asynchronously if NumConns > 1.
// Notice: There is no guarantee that pool filling will be finished in the initialization phase.
// Also, it describes a maximum number of connections at the same time.
// Default: 2
NumConns int

Expand Down
30 changes: 11 additions & 19 deletions conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,25 +23,11 @@ import (
"github.com/gocql/gocql/internal/streams"
)

var (
defaultApprovedAuthenticators = []string{
"org.apache.cassandra.auth.PasswordAuthenticator",
"com.instaclustr.cassandra.auth.SharedSecretAuthenticator",
"com.datastax.bdp.cassandra.auth.DseAuthenticator",
"io.aiven.cassandra.auth.AivenAuthenticator",
"com.ericsson.bss.cassandra.ecaudit.auth.AuditPasswordAuthenticator",
"com.amazon.helenus.auth.HelenusAuthenticator",
"com.ericsson.bss.cassandra.ecaudit.auth.AuditAuthenticator",
"com.scylladb.auth.SaslauthdAuthenticator",
"com.scylladb.auth.TransitionalAuthenticator",
"com.instaclustr.cassandra.auth.InstaclustrPasswordAuthenticator",
}
)

// approve the authenticator with the list of allowed authenticators or default list if approvedAuthenticators is empty.
// approve the authenticator with the list of allowed authenticators. If the provided list is empty,
// the given authenticator is allowed.
func approve(authenticator string, approvedAuthenticators []string) bool {
if len(approvedAuthenticators) == 0 {
approvedAuthenticators = defaultApprovedAuthenticators
return true
}
for _, s := range approvedAuthenticators {
if authenticator == s {
Expand Down Expand Up @@ -72,9 +58,15 @@ type WarningHandler interface {
HandleWarnings(qry ExecutableQuery, host *HostInfo, warnings []string)
}

// PasswordAuthenticator specifies credentials to be used when authenticating.
// It can be configured with an "allow list" of authenticator class names to avoid
// attempting to authenticate with Cassandra if it doesn't provide an expected authenticator.
type PasswordAuthenticator struct {
Username string
Password string
Username string
Password string
// Setting this to nil or empty will allow authenticating with any authenticator
// provided by the server. This is the default behavior of most other driver
// implementations.
AllowedAuthenticators []string
}

Expand Down
27 changes: 15 additions & 12 deletions conn_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,18 +37,21 @@ const (

func TestApprove(t *testing.T) {
tests := map[bool]bool{
approve("org.apache.cassandra.auth.PasswordAuthenticator", []string{}): true,
approve("com.instaclustr.cassandra.auth.SharedSecretAuthenticator", []string{}): true,
approve("com.datastax.bdp.cassandra.auth.DseAuthenticator", []string{}): true,
approve("io.aiven.cassandra.auth.AivenAuthenticator", []string{}): true,
approve("com.amazon.helenus.auth.HelenusAuthenticator", []string{}): true,
approve("com.ericsson.bss.cassandra.ecaudit.auth.AuditAuthenticator", []string{}): true,
approve("com.scylladb.auth.SaslauthdAuthenticator", []string{}): true,
approve("com.scylladb.auth.TransitionalAuthenticator", []string{}): true,
approve("com.instaclustr.cassandra.auth.InstaclustrPasswordAuthenticator", []string{}): true,
approve("com.apache.cassandra.auth.FakeAuthenticator", []string{}): false,
approve("com.apache.cassandra.auth.FakeAuthenticator", nil): false,
approve("com.apache.cassandra.auth.FakeAuthenticator", []string{"com.apache.cassandra.auth.FakeAuthenticator"}): true,
approve("org.apache.cassandra.auth.PasswordAuthenticator", []string{}): true,
approve("org.apache.cassandra.auth.MutualTlsWithPasswordFallbackAuthenticator", []string{}): true,
approve("org.apache.cassandra.auth.MutualTlsAuthenticator", []string{}): true,
approve("com.instaclustr.cassandra.auth.SharedSecretAuthenticator", []string{}): true,
approve("com.datastax.bdp.cassandra.auth.DseAuthenticator", []string{}): true,
approve("io.aiven.cassandra.auth.AivenAuthenticator", []string{}): true,
approve("com.amazon.helenus.auth.HelenusAuthenticator", []string{}): true,
approve("com.ericsson.bss.cassandra.ecaudit.auth.AuditAuthenticator", []string{}): true,
approve("com.scylladb.auth.SaslauthdAuthenticator", []string{}): true,
approve("com.scylladb.auth.TransitionalAuthenticator", []string{}): true,
approve("com.instaclustr.cassandra.auth.InstaclustrPasswordAuthenticator", []string{}): true,
approve("com.apache.cassandra.auth.FakeAuthenticator", []string{}): true,
approve("com.apache.cassandra.auth.FakeAuthenticator", nil): true,
approve("com.apache.cassandra.auth.FakeAuthenticator", []string{"com.apache.cassandra.auth.FakeAuthenticator"}): true,
approve("com.apache.cassandra.auth.FakeAuthenticator", []string{"com.apache.cassandra.auth.NotFakeAuthenticator"}): false,
}
for k, v := range tests {
if k != v {
Expand Down
12 changes: 11 additions & 1 deletion doc.go
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,16 @@
// }
// defer session.Close()
//
// By default, PasswordAuthenticator will attempt to authenticate regardless of what implementation the server returns
// in its AUTHENTICATE message as its authenticator, (e.g. org.apache.cassandra.auth.PasswordAuthenticator). If you
// wish to restrict this you may use PasswordAuthenticator.AllowedAuthenticators:
//
// cluster.Authenticator = gocql.PasswordAuthenticator {
// Username: "user",
// Password: "password"
// AllowedAuthenticators: []string{"org.apache.cassandra.auth.PasswordAuthenticator"},
// }
//
// # Transport layer security
//
// It is possible to secure traffic between the client and server with TLS.
Expand Down Expand Up @@ -280,7 +290,7 @@
// # Batches
//
// The CQL protocol supports sending batches of DML statements (INSERT/UPDATE/DELETE) and so does gocql.
// Use Session.NewBatch to create a new batch and then fill-in details of individual queries.
// Use Session.Batch to create a new batch and then fill-in details of individual queries.
// Then execute the batch with Session.ExecuteBatch.
//
// Logged batches ensure atomicity, either all or none of the operations in the batch will succeed, but they have
Expand Down
15 changes: 13 additions & 2 deletions example_batch_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,9 @@ package gocql_test
import (
"context"
"fmt"
"github.com/gocql/gocql"
"log"

"github.com/gocql/gocql"
)

// Example_batch demonstrates how to execute a batch of statements.
Expand All @@ -24,7 +25,7 @@ func Example_batch() {

ctx := context.Background()

b := session.NewBatch(gocql.UnloggedBatch).WithContext(ctx)
b := session.Batch(gocql.UnloggedBatch).WithContext(ctx)
b.Entries = append(b.Entries, gocql.BatchEntry{
Stmt: "INSERT INTO example.batches (pk, ck, description) VALUES (?, ?, ?)",
Args: []interface{}{1, 2, "1.2"},
Expand All @@ -35,11 +36,19 @@ func Example_batch() {
Args: []interface{}{1, 3, "1.3"},
Idempotent: true,
})

err = session.ExecuteBatch(b)
if err != nil {
log.Fatal(err)
}

err = b.Query("INSERT INTO example.batches (pk, ck, description) VALUES (?, ?, ?)", 1, 4, "1.4").
Query("INSERT INTO example.batches (pk, ck, description) VALUES (?, ?, ?)", 1, 5, "1.5").
Exec()
if err != nil {
log.Fatal(err)
}

scanner := session.Query("SELECT pk, ck, description FROM example.batches").Iter().Scanner()
for scanner.Next() {
var pk, ck int32
Expand All @@ -52,4 +61,6 @@ func Example_batch() {
}
// 1 2 1.2
// 1 3 1.3
// 1 4 1.4
// 1 5 1.5
}
5 changes: 3 additions & 2 deletions example_lwt_batch_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,9 @@ package gocql_test
import (
"context"
"fmt"
"github.com/gocql/gocql"
"log"

"github.com/gocql/gocql"
)

// ExampleSession_MapExecuteBatchCAS demonstrates how to execute a batch lightweight transaction.
Expand Down Expand Up @@ -37,7 +38,7 @@ func ExampleSession_MapExecuteBatchCAS() {
}

executeBatch := func(ck2Version int) {
b := session.NewBatch(gocql.LoggedBatch)
b := session.Batch(gocql.LoggedBatch)
b.Entries = append(b.Entries, gocql.BatchEntry{
Stmt: "UPDATE my_lwt_batch_table SET value=? WHERE pk=? AND ck=? IF version=?",
Args: []interface{}{"b", "pk1", "ck1", 1},
Expand Down
2 changes: 1 addition & 1 deletion integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,7 @@ func TestCustomPayloadMessages(t *testing.T) {
iter.Close()

// Batch Message
b := session.NewBatch(LoggedBatch)
b := session.Batch(LoggedBatch)
b.CustomPayload = customPayload
b.Query("INSERT INTO testCustomPayloadMessages(id,value) VALUES(1, 1)")
if err := session.ExecuteBatch(b); err != nil {
Expand Down
Loading

0 comments on commit d95cf2f

Please sign in to comment.