diff --git a/CHANGELOG.md b/CHANGELOG.md index 248e14bbb3..71ac01f617 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -4,6 +4,10 @@ All notable changes to this project are documented below. The format is based on [keep a changelog](http://keepachangelog.com/) and this project uses [semantic versioning](http://semver.org/). ## [Unreleased] +### Added +- Dynamic leaderboards feature. +- Presence updates now report the user's handle. + ### Changed - The build system now strips up to current dir in recorded source file paths at compile. diff --git a/cmd/admin.go b/cmd/admin.go new file mode 100644 index 0000000000..d8abcdda19 --- /dev/null +++ b/cmd/admin.go @@ -0,0 +1,151 @@ +// Copyright 2017 The Nakama Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package cmd + +import ( + "database/sql" + "encoding/base64" + "encoding/json" + "flag" + "fmt" + "github.com/gorhill/cronexpr" + "github.com/satori/go.uuid" + "github.com/uber-go/zap" + "net/url" + "os" +) + +type adminService struct { + DSNS string + logger zap.Logger +} + +func AdminParse(args []string, logger zap.Logger) { + if len(args) == 0 { + logger.Fatal("Admin requires a subcommand. Available commands are: 'create-leaderboard'.") + } + + var exec func([]string, zap.Logger) + switch args[0] { + case "create-leaderboard": + exec = createLeaderboard + default: + logger.Fatal("Unrecognized admin subcommand. Available commands are: 'create-leaderboard'.") + } + + exec(args[1:], logger) + os.Exit(0) +} + +func createLeaderboard(args []string, logger zap.Logger) { + var dsns string + var id string + var authoritative bool + var sortOrder string + var resetSchedule string + var metadata string + + flags := flag.NewFlagSet("admin", flag.ExitOnError) + flags.StringVar(&dsns, "db", "root@localhost:26257", "CockroachDB JDBC connection details.") + flags.StringVar(&id, "id", "", "ID to assign to the leaderboard.") + flags.BoolVar(&authoritative, "authoritative", false, "True if clients may not submit scores directly, false otherwise.") + flags.StringVar(&sortOrder, "sort", "descending", "Leaderboard sort order, 'asc' or 'desc'.") + flags.StringVar(&resetSchedule, "reset", "", "Optional reset schedule in CRON format.") + flags.StringVar(&metadata, "metadata", "{}", "Optional additional metadata as a JSON string.") + + if err := flags.Parse(args); err != nil { + logger.Fatal("Could not parse admin flags.") + } + + if dsns == "" { + logger.Fatal("Database connection details are required.") + } + + query := `INSERT INTO leaderboard (id, authoritative, sort_order, reset_schedule, metadata) + VALUES ($1, $2, $3, $4, $5)` + params := []interface{}{} + + // ID. + if id == "" { + params = append(params, uuid.NewV4().Bytes()) + } else { + params = append(params, []byte(id)) + } + + // Authoritative. + params = append(params, authoritative) + + // Sort order. + if sortOrder == "asc" { + params = append(params, 0) + } else if sortOrder == "desc" { + params = append(params, 1) + } else { + logger.Fatal("Invalid sort value, must be 'asc' or 'desc'.") + } + + // Count is hardcoded in the INSERT above. + + // Reset schedule. + if resetSchedule != "" { + _, err := cronexpr.Parse(resetSchedule) + if err != nil { + logger.Fatal("Reset schedule must be a valid CRON expression.") + } + params = append(params, resetSchedule) + } else { + params = append(params, nil) + } + + // Metadata. + metadataBytes := []byte(metadata) + var maybeJSON map[string]interface{} + if json.Unmarshal(metadataBytes, &maybeJSON) != nil { + logger.Fatal("Metadata must be a valid JSON string.") + } + params = append(params, metadataBytes) + + rawurl := fmt.Sprintf("postgresql://%s?sslmode=disable", dsns) + url, err := url.Parse(rawurl) + if err != nil { + logger.Fatal("Bad connection URL", zap.Error(err)) + } + + logger.Info("Database connection", zap.String("dsns", dsns)) + + // Default to "nakama" as DB name. + dbname := "nakama" + if len(url.Path) > 1 { + dbname = url.Path[1:] + } + url.Path = fmt.Sprintf("/%s", dbname) + db, err := sql.Open(dialect, url.String()) + if err != nil { + logger.Fatal("Failed to open database", zap.Error(err)) + } + if err = db.Ping(); err != nil { + logger.Fatal("Error pinging database", zap.Error(err)) + } + + res, err := db.Exec(query, params...) + if err != nil { + logger.Fatal("Error creating leaderboard", zap.Error(err)) + } + if rowsAffected, _ := res.RowsAffected(); rowsAffected != 1 { + logger.Fatal("Error creating leaderboard, unexpected insert result") + } + + logger.Info("Leaderboard created", zap.String("base64(id)", base64.StdEncoding.EncodeToString(params[0].([]byte)))) +} diff --git a/glide.lock b/glide.lock index a1bc1f66a3..017748bf84 100644 --- a/glide.lock +++ b/glide.lock @@ -1,5 +1,5 @@ -hash: 8320f72a78e69c58350e25d60a59f6b36fc5cf4da055ce9c9a0c6a63083912d1 -updated: 2017-01-13T19:42:47.231844584Z +hash: d332790eaf0dd90a5d91b4fddfd82897055e261ff360d616cf363c0e689ab4f6 +updated: 2017-03-10T18:12:22.221261057Z imports: - name: github.com/armon/go-metrics version: 97c69685293dce4c0a2d0b19535179bbc976e4d2 @@ -16,11 +16,11 @@ imports: - name: github.com/gogo/protobuf version: 909568be09de550ed094403c2bf8a261b5bb730a subpackages: - - gogoproto - proto - - protoc-gen-gogo/descriptor - name: github.com/golang/protobuf version: 8ee79997227bf9b34611aee7946ae64735e6fd93 +- name: github.com/gorhill/cronexpr + version: a557574d6c024ed6e36acc8b610f5f211c91568a - name: github.com/gorilla/context version: 08b5f424b9271eedf6f9f0ce86cb9396ed337a42 - name: github.com/gorilla/handlers @@ -31,8 +31,12 @@ imports: version: 1f512fc3f05332ba7117626cdfb4e07474e58e60 - name: github.com/lib/pq version: 22cb3e4c487ce6242e2b03369219e5631eed1221 + subpackages: + - oid - name: github.com/rubenv/sql-migrate version: a3ed23a40ebd39f82bf2a36768ed7d595f2bdc1e + subpackages: + - sqlparse - name: github.com/satori/go.uuid version: b061729afc07e77a8aa4fad0a2fd840958f1942a - name: github.com/uber-go/atomic @@ -43,6 +47,7 @@ imports: version: f6b343c37ca80bfa8ea539da67a0b621f84fab1d subpackages: - bcrypt + - blowfish - name: golang.org/x/net version: 69d4b8aa71caaaa75c3dfc11211d1be495abec7c subpackages: diff --git a/glide.yaml b/glide.yaml index 68df033ebc..7b4a57078c 100644 --- a/glide.yaml +++ b/glide.yaml @@ -9,8 +9,12 @@ owners: - name: Mo Firouz email: mo@herioclabs.com import: -- package: golang.org/x/net/context -- package: golang.org/x/crypto/bcrypt +- package: golang.org/x/net + subpackages: + - context +- package: golang.org/x/crypto + subpackages: + - bcrypt - package: github.com/golang/protobuf - package: github.com/gogo/protobuf version: ~0.3.0 @@ -22,7 +26,7 @@ import: version: ~1.1 - package: github.com/lib/pq - package: github.com/rubenv/sql-migrate -- package: github.com/go-gorp/gorp/ +- package: github.com/go-gorp/gorp version: ~2.0.0 - package: github.com/go-yaml/yaml version: v2 @@ -32,4 +36,8 @@ import: - package: github.com/satori/go.uuid - package: github.com/dgrijalva/jwt-go version: ~3.0.0 -- package: github.com/elazarl/go-bindata-assetfs/... +- package: github.com/elazarl/go-bindata-assetfs + subpackages: + - '...' +- package: github.com/gorhill/cronexpr + version: ~1.0.0 diff --git a/main.go b/main.go index e406d66d45..072a6faf3c 100644 --- a/main.go +++ b/main.go @@ -67,6 +67,8 @@ func main() { cmd.DoctorParse(os.Args[2:]) case "migrate": cmd.MigrateParse(os.Args[2:], clogger) + case "admin": + cmd.AdminParse(os.Args[2:], clogger) } } diff --git a/migrations/20170228205100_leaderboards.sql b/migrations/20170228205100_leaderboards.sql new file mode 100644 index 0000000000..12ed4af8b2 --- /dev/null +++ b/migrations/20170228205100_leaderboards.sql @@ -0,0 +1,69 @@ +/* + * Copyright 2017 The Nakama Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +-- +migrate Up +CREATE TABLE IF NOT EXISTS leaderboard ( + PRIMARY KEY (id), + FOREIGN KEY (next_id) REFERENCES leaderboard(id), + FOREIGN KEY (prev_id) REFERENCES leaderboard(id), + id BYTEA NOT NULL, + authoritative BOOLEAN DEFAULT FALSE, + sort_order SMALLINT DEFAULT 1 NOT NULL, -- asc(0), desc(1) + count BIGINT DEFAULT 0 CHECK (count >= 0) NOT NULL, + reset_schedule VARCHAR(64), -- e.g. cron format: "* * * * * * *" + metadata BYTEA DEFAULT '{}' CHECK (length(metadata) < 16000) NOT NULL, + next_id BYTEA DEFAULT NULL::BYTEA CHECK (next_id <> id), + prev_id BYTEA DEFAULT NULL::BYTEA CHECK (prev_id <> id) +); + +CREATE TABLE IF NOT EXISTS leaderboard_record ( + PRIMARY KEY (leaderboard_id, expires_at, owner_id), + -- Creating a foreign key constraint and defining indexes that include it + -- in the same transaction breaks. See issue cockroachdb/cockroach#13505. + -- In this case we prefer the indexes over the constraint. + -- FOREIGN KEY (leaderboard_id) REFERENCES leaderboard(id), + id BYTEA UNIQUE NOT NULL, + leaderboard_id BYTEA NOT NULL, + owner_id BYTEA NOT NULL, + handle VARCHAR(20) NOT NULL, + lang VARCHAR(18) DEFAULT 'en' NOT NULL, + location VARCHAR(64), -- e.g. "San Francisco, CA" + timezone VARCHAR(64), -- e.g. "Pacific Time (US & Canada)" + rank_value BIGINT DEFAULT 0 CHECK (rank_value >= 0) NOT NULL, + score BIGINT DEFAULT 0 NOT NULL, + num_score INT DEFAULT 0 CHECK (num_score >= 0) NOT NULL, + -- FIXME replace with JSONB + metadata BYTEA DEFAULT '{}' CHECK (length(metadata) < 16000) NOT NULL, + ranked_at INT CHECK (ranked_at >= 0) DEFAULT 0 NOT NULL, + updated_at INT CHECK (updated_at > 0) NOT NULL, + -- Used to enable proper order in revscan when sorting by score descending. + updated_at_inverse INT CHECK (updated_at > 0) NOT NULL, + expires_at INT CHECK (expires_at >= 0) DEFAULT 0 NOT NULL, + banned_at INT CHECK (expires_at >= 0) DEFAULT 0 NOT NULL +); +CREATE INDEX IF NOT EXISTS owner_id_leaderboard_id_idx ON leaderboard_record (owner_id, leaderboard_id); +CREATE INDEX IF NOT EXISTS leaderboard_id_expires_at_score_updated_at_inverse_id_idx ON leaderboard_record (leaderboard_id, expires_at, score, updated_at_inverse, id); +CREATE INDEX IF NOT EXISTS leaderboard_id_expires_at_score_updated_at_id_idx ON leaderboard_record (leaderboard_id, expires_at, score, updated_at, id); +CREATE INDEX IF NOT EXISTS leaderboard_id_expires_at_lang_score_updated_at_inverse_id_idx ON leaderboard_record (leaderboard_id, expires_at, lang, score, updated_at_inverse, id); +CREATE INDEX IF NOT EXISTS leaderboard_id_expires_at_lang_score_updated_at_id_idx ON leaderboard_record (leaderboard_id, expires_at, lang, score, updated_at, id); +CREATE INDEX IF NOT EXISTS leaderboard_id_expires_at_location_score_updated_at_inverse_id_idx ON leaderboard_record (leaderboard_id, expires_at, location, score, updated_at_inverse, id); +CREATE INDEX IF NOT EXISTS leaderboard_id_expires_at_location_score_updated_at_id_idx ON leaderboard_record (leaderboard_id, expires_at, location, score, updated_at, id); +CREATE INDEX IF NOT EXISTS leaderboard_id_expires_at_timezone_score_updated_at_inverse_id_idx ON leaderboard_record (leaderboard_id, expires_at, timezone, score, updated_at_inverse, id); +CREATE INDEX IF NOT EXISTS leaderboard_id_expires_at_timezone_score_updated_at_id_idx ON leaderboard_record (leaderboard_id, expires_at, timezone, score, updated_at, id); + +-- +migrate Down +DROP TABLE IF EXISTS leaderboard_record; +DROP TABLE IF EXISTS leaderboard CASCADE; diff --git a/server/api.proto b/server/api.proto index 201403144d..82a8b2613e 100644 --- a/server/api.proto +++ b/server/api.proto @@ -131,6 +131,14 @@ message Envelope { TStorageRemove storage_remove = 50; TStorageData storage_data = 51; TStorageKey storage_key = 52; + + TLeaderboardsList leaderboards_list = 53; + TLeaderboardRecordWrite leaderboard_record_write = 54; + TLeaderboardRecordsFetch leaderboard_records_fetch = 55; + TLeaderboardRecordsList leaderboard_records_list = 56; + TLeaderboards leaderboards = 57; + TLeaderboardRecord leaderboard_record = 58; + TLeaderboardRecords leaderboard_records = 59; } } @@ -338,6 +346,7 @@ message TopicId { message UserPresence { bytes user_id = 1; bytes session_id = 2; + string handle = 3; } message TTopicJoin { @@ -492,3 +501,81 @@ message TStorageRemove { } repeated StorageKey keys = 1; } + +message Leaderboard { + bytes id = 1; + bool authoritative = 2; + int64 sort = 3; + int64 count = 4; + string reset_schedule = 5; + bytes metadata = 6; + bytes next_id = 7; + bytes prev_id = 8; +} + +message LeaderboardRecord { + bytes leaderboard_id = 1; + bytes owner_id = 2; + string handle = 3; + string lang = 4; + string location = 5; + string timezone = 6; + int64 rank = 7; + int64 score = 8; + int64 num_score = 9; + bytes metadata = 10; + int64 ranked_at = 11; + int64 updated_at = 12; + int64 expires_at = 13; +} + +message TLeaderboardsList { + int64 limit = 1; + bytes cursor = 2; +} +message TLeaderboards { + repeated Leaderboard leaderboards = 1; + bytes cursor = 2; +} + +message TLeaderboardRecordWrite { + bytes leaderboard_id = 1; + oneof op { + int64 incr = 2; + int64 decr = 3; + int64 set = 4; + int64 best = 5; + } + string location = 6; + string timezone = 7; + bytes metadata = 8; +} +message TLeaderboardRecord { + LeaderboardRecord record = 1; +} + +message TLeaderboardRecordsFetch { + repeated bytes leaderboard_ids = 1; + int64 limit = 2; + bytes cursor = 3; +} +message TLeaderboardRecordsList { + message Owners { + repeated bytes owner_ids = 1; + } + + bytes leaderboard_id = 1; + oneof filter { + bytes owner_id = 2; // "haystack" lookup + Owners owner_ids = 3; + string lang = 4; + string location = 5; + string timezone = 6; + } + int64 limit = 7; + bytes cursor = 8; +} +message TLeaderboardRecords { + repeated LeaderboardRecord records = 1; + bytes cursor = 2; +} diff --git a/server/pipeline.go b/server/pipeline.go index d533bedbc9..3908ef6b72 100644 --- a/server/pipeline.go +++ b/server/pipeline.go @@ -124,6 +124,15 @@ func (p *pipeline) processRequest(logger zap.Logger, session *session, envelope case *Envelope_StorageRemove: p.storageRemove(logger, session, envelope) + case *Envelope_LeaderboardsList: + p.leaderboardsList(logger, session, envelope) + case *Envelope_LeaderboardRecordWrite: + p.leaderboardRecordWrite(logger, session, envelope) + case *Envelope_LeaderboardRecordsFetch: + p.leaderboardRecordsFetch(logger, session, envelope) + case *Envelope_LeaderboardRecordsList: + p.leaderboardRecordsList(logger, session, envelope) + case nil: session.Send(&Envelope{CollationId: envelope.CollationId, Payload: &Envelope_Error{&Error{Reason: "No payload found"}}}) default: diff --git a/server/pipeline_group.go b/server/pipeline_group.go index 7a6c3ac868..08840837b9 100644 --- a/server/pipeline_group.go +++ b/server/pipeline_group.go @@ -385,9 +385,7 @@ func (p *pipeline) groupsList(logger zap.Logger, session *session, envelope *Env limit := incoming.PageLimit if limit == 0 { limit = 10 - } - - if limit < 10 || limit > 100 { + } else if limit < 10 || limit > 100 { session.Send(&Envelope{CollationId: envelope.CollationId, Payload: &Envelope_Error{&Error{Reason: "Page limit must be between 10 and 100"}}}) return } @@ -472,6 +470,7 @@ LIMIT $` + strconv.Itoa(len(params)) if gob.NewEncoder(cursorBuf).Encode(newCursor); err != nil { logger.Error("Error creating group list cursor", zap.Error(err)) session.Send(&Envelope{CollationId: envelope.CollationId, Payload: &Envelope_Error{&Error{Reason: "Database request failed"}}}) + return } cursor = cursorBuf.Bytes() break diff --git a/server/pipeline_leaderboard.go b/server/pipeline_leaderboard.go new file mode 100644 index 0000000000..4e05ae03bc --- /dev/null +++ b/server/pipeline_leaderboard.go @@ -0,0 +1,828 @@ +// Copyright 2017 The Nakama Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package server + +import ( + "bytes" + "database/sql" + "encoding/gob" + "encoding/json" + "github.com/gorhill/cronexpr" + "github.com/satori/go.uuid" + "github.com/uber-go/zap" + "strconv" + "strings" +) + +type leaderboardCursor struct { + Id []byte +} + +type leaderboardRecordFetchCursor struct { + OwnerId []byte + LeaderboardId []byte +} + +type leaderboardRecordListCursor struct { + Score int64 + UpdatedAt int64 + Id []byte +} + +func (p *pipeline) leaderboardsList(logger zap.Logger, session *session, envelope *Envelope) { + incoming := envelope.GetLeaderboardsList() + + limit := incoming.Limit + if limit == 0 { + limit = 10 + } else if limit < 10 || limit > 100 { + session.Send(&Envelope{CollationId: envelope.CollationId, Payload: &Envelope_Error{&Error{Reason: "Limit must be between 10 and 100"}}}) + return + } + + query := "SELECT id, authoritative, sort_order, count, reset_schedule, metadata, next_id, prev_id FROM leaderboard" + params := []interface{}{} + + if len(incoming.Cursor) != 0 { + var incomingCursor leaderboardCursor + if err := gob.NewDecoder(bytes.NewReader(incoming.Cursor)).Decode(&incomingCursor); err != nil { + session.Send(&Envelope{CollationId: envelope.CollationId, Payload: &Envelope_Error{&Error{Reason: "Invalid cursor data"}}}) + return + } + query += " WHERE id > $1" + params = append(params, incomingCursor.Id) + } + + params = append(params, limit+1) + query += " LIMIT $" + strconv.Itoa(len(params)) + + logger.Debug("Leaderboards list", zap.String("query", query)) + rows, err := p.db.Query(query, params...) + if err != nil { + logger.Error("Could not execute leaderboards list query", zap.Error(err)) + session.Send(&Envelope{CollationId: envelope.CollationId, Payload: &Envelope_Error{&Error{Reason: "Error loading leaderboards"}}}) + return + } + defer rows.Close() + + leaderboards := []*Leaderboard{} + var outgoingCursor []byte + + var id []byte + var authoritative bool + var sortOrder int64 + var count int64 + var resetSchedule sql.NullString + var metadata []byte + var nextId []byte + var prevId []byte + for rows.Next() { + if int64(len(leaderboards)) >= limit { + cursorBuf := new(bytes.Buffer) + newCursor := &leaderboardCursor{ + Id: id, + } + if gob.NewEncoder(cursorBuf).Encode(newCursor); err != nil { + logger.Error("Error creating leaderboards list cursor", zap.Error(err)) + session.Send(&Envelope{CollationId: envelope.CollationId, Payload: &Envelope_Error{&Error{Reason: "Error encoding cursor"}}}) + return + } + outgoingCursor = cursorBuf.Bytes() + break + } + + err = rows.Scan(&id, &authoritative, &sortOrder, &count, &resetSchedule, &metadata, &nextId, &prevId) + if err != nil { + logger.Error("Could not scan leaderboards list query results", zap.Error(err)) + session.Send(&Envelope{CollationId: envelope.CollationId, Payload: &Envelope_Error{&Error{Reason: "Error loading leaderboards"}}}) + return + } + + leaderboards = append(leaderboards, &Leaderboard{ + Id: id, + Authoritative: authoritative, + Sort: sortOrder, + Count: count, + ResetSchedule: resetSchedule.String, + Metadata: metadata, + NextId: nextId, + PrevId: prevId, + }) + } + if err = rows.Err(); err != nil { + logger.Error("Could not process leaderboards list query results", zap.Error(err)) + session.Send(&Envelope{CollationId: envelope.CollationId, Payload: &Envelope_Error{&Error{Reason: "Error loading leaderboards"}}}) + return + } + + session.Send(&Envelope{CollationId: envelope.CollationId, Payload: &Envelope_Leaderboards{Leaderboards: &TLeaderboards{ + Leaderboards: leaderboards, + Cursor: outgoingCursor, + }}}) +} + +func (p *pipeline) leaderboardRecordWrite(logger zap.Logger, session *session, envelope *Envelope) { + incoming := envelope.GetLeaderboardRecordWrite() + if len(incoming.LeaderboardId) == 0 { + session.Send(&Envelope{CollationId: envelope.CollationId, Payload: &Envelope_Error{&Error{Reason: "Leaderboard ID must be present"}}}) + return + } + + if len(incoming.Metadata) != 0 { + // Make this `var js interface{}` if we want to allow top-level JSON arrays. + var maybeJSON map[string]interface{} + if json.Unmarshal(incoming.Metadata, &maybeJSON) != nil { + session.Send(&Envelope{CollationId: envelope.CollationId, Payload: &Envelope_Error{&Error{Reason: "Metadata must be a valid JSON object"}}}) + return + } + } + + var authoritative bool + var sortOrder int64 + var resetSchedule sql.NullString + query := "SELECT authoritative, sort_order, reset_schedule FROM leaderboard WHERE id = $1" + logger.Debug("Leaderboard lookup", zap.String("query", query)) + err := p.db.QueryRow(query, incoming.LeaderboardId). + Scan(&authoritative, &sortOrder, &resetSchedule) + if err != nil { + logger.Error("Could not execute leaderboard record write metadata query", zap.Error(err)) + session.Send(&Envelope{CollationId: envelope.CollationId, Payload: &Envelope_Error{&Error{Reason: "Error writing leaderboard record"}}}) + return + } + + now := now() + updatedAt := timeToMs(now) + expiresAt := int64(0) + if resetSchedule.Valid { + expr, err := cronexpr.Parse(resetSchedule.String) + if err != nil { + logger.Error("Could not parse leaderboard reset schedule query", zap.Error(err)) + session.Send(&Envelope{CollationId: envelope.CollationId, Payload: &Envelope_Error{&Error{Reason: "Error writing leaderboard record"}}}) + return + } + expiresAt = timeToMs(expr.Next(now)) + } + + if authoritative == true { + session.Send(&Envelope{CollationId: envelope.CollationId, Payload: &Envelope_Error{&Error{Reason: "Cannot submit to authoritative leaderboard"}}}) + return + } + + var scoreOpSql string + var scoreDelta int64 + var scoreAbs int64 + switch incoming.Op.(type) { + case *TLeaderboardRecordWrite_Incr: + scoreOpSql = "score = leaderboard_record.score + $17::BIGINT" + scoreDelta = incoming.GetIncr() + scoreAbs = incoming.GetIncr() + case *TLeaderboardRecordWrite_Decr: + scoreOpSql = "score = leaderboard_record.score - $17::BIGINT" + scoreDelta = incoming.GetDecr() + scoreAbs = 0 - incoming.GetDecr() + case *TLeaderboardRecordWrite_Set: + scoreOpSql = "score = $17::BIGINT" + scoreDelta = incoming.GetSet() + scoreAbs = incoming.GetSet() + case *TLeaderboardRecordWrite_Best: + if sortOrder == 0 { + // Lower score is better. + scoreOpSql = "score = (leaderboard_record.score + $17::BIGINT - abs(leaderboard_record.score - $17::BIGINT)) / 2" + } else { + // Higher score is better. + scoreOpSql = "score = (leaderboard_record.score + $17::BIGINT + abs(leaderboard_record.score - $17::BIGINT)) / 2" + } + scoreDelta = incoming.GetBest() + scoreAbs = incoming.GetBest() + case nil: + session.Send(&Envelope{CollationId: envelope.CollationId, Payload: &Envelope_Error{&Error{Reason: "No leaderboard record write operator found"}}}) + return + default: + session.Send(&Envelope{CollationId: envelope.CollationId, Payload: &Envelope_Error{&Error{Reason: "Unknown leaderboard record write operator"}}}) + return + } + + handle := session.handle.Load() + params := []interface{}{uuid.NewV4().Bytes(), incoming.LeaderboardId, session.userID.Bytes(), handle, session.lang} + if incoming.Location != "" { + params = append(params, incoming.Location) + } else { + params = append(params, nil) + } + if incoming.Timezone != "" { + params = append(params, incoming.Timezone) + } else { + params = append(params, nil) + } + params = append(params, 0, scoreAbs, 1) + if len(incoming.Metadata) != 0 { + params = append(params, incoming.Metadata) + } else { + params = append(params, nil) + } + params = append(params, 0, updatedAt, invertMs(updatedAt), expiresAt, 0, scoreDelta) + + query = `INSERT INTO leaderboard_record (id, leaderboard_id, owner_id, handle, lang, location, timezone, + rank_value, score, num_score, metadata, ranked_at, updated_at, updated_at_inverse, expires_at, banned_at) + VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, COALESCE($11, '{}'), $12, $13, $14, $15, $16) + ON CONFLICT (leaderboard_id, expires_at, owner_id) + DO UPDATE SET handle = $4, lang = $5, location = COALESCE($6, leaderboard_record.location), + timezone = COALESCE($7, leaderboard_record.timezone), ` + scoreOpSql + `, num_score = leaderboard_record.num_score + 1, + metadata = COALESCE($11, leaderboard_record.metadata), updated_at = $13` + logger.Debug("Leaderboard record write", zap.String("query", query)) + res, err := p.db.Exec(query, params...) + if err != nil { + logger.Error("Could not execute leaderboard record write query", zap.Error(err)) + session.Send(&Envelope{CollationId: envelope.CollationId, Payload: &Envelope_Error{&Error{Reason: "Error writing leaderboard record"}}}) + return + } + if rowsAffected, _ := res.RowsAffected(); rowsAffected == 0 { + logger.Error("Unexpected row count from leaderboard record write query") + session.Send(&Envelope{CollationId: envelope.CollationId, Payload: &Envelope_Error{&Error{Reason: "Error writing leaderboard record"}}}) + return + } + + var location sql.NullString + var timezone sql.NullString + var rankValue int64 + var score int64 + var numScore int64 + var metadata []byte + var rankedAt int64 + var bannedAt int64 + query = `SELECT location, timezone, rank_value, score, num_score, metadata, ranked_at, banned_at + FROM leaderboard_record + WHERE leaderboard_id = $1 + AND expires_at = $2 + AND owner_id = $3` + logger.Debug("Leaderboard record read", zap.String("query", query)) + err = p.db.QueryRow(query, incoming.LeaderboardId, expiresAt, session.userID.Bytes()). + Scan(&location, &timezone, &rankValue, &score, &numScore, &metadata, &rankedAt, &bannedAt) + if err != nil { + logger.Error("Could not execute leaderboard record read query", zap.Error(err)) + session.Send(&Envelope{CollationId: envelope.CollationId, Payload: &Envelope_Error{&Error{Reason: "Error writing leaderboard record"}}}) + return + } + + session.Send(&Envelope{CollationId: envelope.CollationId, Payload: &Envelope_LeaderboardRecord{LeaderboardRecord: &TLeaderboardRecord{Record: &LeaderboardRecord{ + LeaderboardId: incoming.LeaderboardId, + OwnerId: session.userID.Bytes(), + Handle: handle, + Lang: session.lang, + Location: location.String, + Timezone: timezone.String, + Rank: rankValue, + Score: score, + NumScore: numScore, + Metadata: metadata, + RankedAt: rankedAt, + UpdatedAt: updatedAt, + ExpiresAt: expiresAt, + }}}}) +} + +func (p *pipeline) leaderboardRecordsFetch(logger zap.Logger, session *session, envelope *Envelope) { + incoming := envelope.GetLeaderboardRecordsFetch() + leaderboardIds := incoming.LeaderboardIds + if len(leaderboardIds) == 0 { + session.Send(&Envelope{CollationId: envelope.CollationId, Payload: &Envelope_Error{&Error{Reason: "Leaderboard IDs must be present"}}}) + return + } + + limit := incoming.Limit + if limit == 0 { + limit = 10 + } else if limit < 10 || limit > 100 { + session.Send(&Envelope{CollationId: envelope.CollationId, Payload: &Envelope_Error{&Error{Reason: "Limit must be between 10 and 100"}}}) + return + } + + var incomingCursor *leaderboardRecordFetchCursor + if len(incoming.Cursor) != 0 { + incomingCursor = &leaderboardRecordFetchCursor{} + if err := gob.NewDecoder(bytes.NewReader(incoming.Cursor)).Decode(incomingCursor); err != nil { + session.Send(&Envelope{CollationId: envelope.CollationId, Payload: &Envelope_Error{&Error{Reason: "Invalid cursor data"}}}) + return + } + } + + // TODO for now we return all records including expired ones, change this later? + // TODO special handling of banned records? + + statements := []string{} + params := []interface{}{session.userID.Bytes()} + for _, leaderboardId := range leaderboardIds { + params = append(params, leaderboardId) + statements = append(statements, "$"+strconv.Itoa(len(params))) + } + + query := `SELECT leaderboard_id, owner_id, handle, lang, location, timezone, + rank_value, score, num_score, metadata, ranked_at, updated_at, expires_at, banned_at + FROM leaderboard_record + WHERE owner_id = $1 + AND leaderboard_id IN (` + strings.Join(statements, ", ") + `)` + + if incomingCursor != nil { + query += " AND (owner_id, leaderboard_id) > ($" + strconv.Itoa(len(params)+1) + ", $" + strconv.Itoa(len(params)+2) + ")" + params = append(params, incomingCursor.OwnerId, incomingCursor.LeaderboardId) + } + + params = append(params, limit+1) + query += " LIMIT $" + strconv.Itoa(len(params)) + + logger.Debug("Leaderboard records fetch", zap.String("query", query)) + rows, err := p.db.Query(query, params...) + if err != nil { + logger.Error("Could not execute leaderboard records fetch query", zap.Error(err)) + session.Send(&Envelope{CollationId: envelope.CollationId, Payload: &Envelope_Error{&Error{Reason: "Error loading leaderboard records"}}}) + return + } + defer rows.Close() + + leaderboardRecords := []*LeaderboardRecord{} + var outgoingCursor []byte + + var leaderboardId []byte + var ownerId []byte + var handle string + var lang string + var location sql.NullString + var timezone sql.NullString + var rankValue int64 + var score int64 + var numScore int64 + var metadata []byte + var rankedAt int64 + var updatedAt int64 + var expiresAt int64 + var bannedAt int64 + for rows.Next() { + if int64(len(leaderboardRecords)) >= limit { + cursorBuf := new(bytes.Buffer) + newCursor := &leaderboardRecordFetchCursor{ + OwnerId: ownerId, + LeaderboardId: leaderboardId, + } + if gob.NewEncoder(cursorBuf).Encode(newCursor); err != nil { + logger.Error("Error creating leaderboard records fetch cursor", zap.Error(err)) + session.Send(&Envelope{CollationId: envelope.CollationId, Payload: &Envelope_Error{&Error{Reason: "Error encoding cursor"}}}) + return + } + outgoingCursor = cursorBuf.Bytes() + break + } + + err = rows.Scan(&leaderboardId, &ownerId, &handle, &lang, &location, &timezone, + &rankValue, &score, &numScore, &metadata, &rankedAt, &updatedAt, &expiresAt, &bannedAt) + if err != nil { + logger.Error("Could not scan leaderboard records fetch query results", zap.Error(err)) + session.Send(&Envelope{CollationId: envelope.CollationId, Payload: &Envelope_Error{&Error{Reason: "Error loading leaderboard records"}}}) + return + } + + leaderboardRecords = append(leaderboardRecords, &LeaderboardRecord{ + LeaderboardId: leaderboardId, + OwnerId: ownerId, + Handle: handle, + Lang: lang, + Location: location.String, + Timezone: timezone.String, + Rank: rankValue, + Score: score, + NumScore: numScore, + Metadata: metadata, + RankedAt: rankedAt, + UpdatedAt: updatedAt, + ExpiresAt: expiresAt, + }) + } + if err = rows.Err(); err != nil { + logger.Error("Could not process leaderboard records fetch query results", zap.Error(err)) + session.Send(&Envelope{CollationId: envelope.CollationId, Payload: &Envelope_Error{&Error{Reason: "Error loading leaderboard records"}}}) + return + } + + session.Send(&Envelope{CollationId: envelope.CollationId, Payload: &Envelope_LeaderboardRecords{LeaderboardRecords: &TLeaderboardRecords{ + Records: leaderboardRecords, + Cursor: outgoingCursor, + }}}) +} + +func (p *pipeline) leaderboardRecordsList(logger zap.Logger, session *session, envelope *Envelope) { + incoming := envelope.GetLeaderboardRecordsList() + + if len(incoming.LeaderboardId) == 0 { + session.Send(&Envelope{CollationId: envelope.CollationId, Payload: &Envelope_Error{&Error{Reason: "Leaderboard ID must be present"}}}) + return + } + + limit := incoming.Limit + if limit == 0 { + limit = 10 + } else if limit < 10 || limit > 100 { + session.Send(&Envelope{CollationId: envelope.CollationId, Payload: &Envelope_Error{&Error{Reason: "Limit must be between 10 and 100"}}}) + return + } + + var incomingCursor *leaderboardRecordListCursor + if len(incoming.Cursor) != 0 { + incomingCursor = &leaderboardRecordListCursor{} + if err := gob.NewDecoder(bytes.NewReader(incoming.Cursor)).Decode(incomingCursor); err != nil { + session.Send(&Envelope{CollationId: envelope.CollationId, Payload: &Envelope_Error{&Error{Reason: "Invalid cursor data"}}}) + return + } + } + + var sortOrder int64 + var resetSchedule sql.NullString + query := "SELECT sort_order, reset_schedule FROM leaderboard WHERE id = $1" + logger.Debug("Leaderboard lookup", zap.String("query", query)) + err := p.db.QueryRow(query, incoming.LeaderboardId). + Scan(&sortOrder, &resetSchedule) + if err != nil { + logger.Error("Could not execute leaderboard records list metadata query", zap.Error(err)) + session.Send(&Envelope{CollationId: envelope.CollationId, Payload: &Envelope_Error{&Error{Reason: "Error loading leaderboard records"}}}) + return + } + + currentExpiresAt := int64(0) + if resetSchedule.Valid { + expr, err := cronexpr.Parse(resetSchedule.String) + if err != nil { + logger.Error("Could not parse leaderboard reset schedule query", zap.Error(err)) + session.Send(&Envelope{CollationId: envelope.CollationId, Payload: &Envelope_Error{&Error{Reason: "Error loading leaderboard records"}}}) + return + } + currentExpiresAt = timeToMs(expr.Next(now())) + } + + query = `SELECT id, owner_id, handle, lang, location, timezone, + rank_value, score, num_score, metadata, ranked_at, updated_at, expires_at, banned_at + FROM leaderboard_record + WHERE leaderboard_id = $1 + AND expires_at = $2` + params := []interface{}{incoming.LeaderboardId, currentExpiresAt} + + returnCursor := true + switch incoming.Filter.(type) { + case *TLeaderboardRecordsList_OwnerId: + if incomingCursor != nil { + session.Send(&Envelope{CollationId: envelope.CollationId, Payload: &Envelope_Error{&Error{Reason: "Cursor not allowed with haystack query"}}}) + return + } + // Haystack queries are executed in a separate flow. + p.loadLeaderboardRecordsHaystack(logger, session, envelope, incoming.LeaderboardId, incoming.GetOwnerId(), currentExpiresAt, limit, sortOrder, query, params) + return + case *TLeaderboardRecordsList_OwnerIds: + if incomingCursor != nil { + session.Send(&Envelope{CollationId: envelope.CollationId, Payload: &Envelope_Error{&Error{Reason: "Cursor not allowed with batch filter query"}}}) + return + } + if len(incoming.GetOwnerIds().OwnerIds) < 1 || len(incoming.GetOwnerIds().OwnerIds) > 100 { + session.Send(&Envelope{CollationId: envelope.CollationId, Payload: &Envelope_Error{&Error{Reason: "Must be 1-100 owner IDs"}}}) + return + } + statements := []string{} + for _, ownerId := range incoming.GetOwnerIds().OwnerIds { + params = append(params, ownerId) + statements = append(statements, "$"+strconv.Itoa(len(params))) + } + query += " AND owner_id IN (" + strings.Join(statements, ", ") + ")" + // Never return a cursor with this filter type. + returnCursor = false + case *TLeaderboardRecordsList_Lang: + query += " AND lang = $3" + params = append(params, incoming.GetLang()) + case *TLeaderboardRecordsList_Location: + query += " AND location = $3" + params = append(params, incoming.GetLocation()) + case *TLeaderboardRecordsList_Timezone: + query += " AND timezone = $3" + params = append(params, incoming.GetTimezone()) + case nil: + // No filter. + break + default: + session.Send(&Envelope{CollationId: envelope.CollationId, Payload: &Envelope_Error{&Error{Reason: "Unknown leaderboard record list filter"}}}) + return + } + + if incomingCursor != nil { + count := len(params) + if sortOrder == 0 { + // Ascending leaderboard. + query += " AND (score, updated_at, id) > ($" + strconv.Itoa(count) + + ", $" + strconv.Itoa(count+1) + + ", $" + strconv.Itoa(count+2) + ")" + params = append(params, incomingCursor.Score, incomingCursor.UpdatedAt, incomingCursor.Id) + } else { + // Descending leaderboard. + query += " AND (score, updated_at_inverse, id) < ($" + strconv.Itoa(count) + + ", $" + strconv.Itoa(count+1) + + ", $" + strconv.Itoa(count+2) + ")" + params = append(params, incomingCursor.Score, invertMs(incomingCursor.UpdatedAt), incomingCursor.Id) + } + } + + if sortOrder == 0 { + // Ascending leaderboard, lower score is better. + query += " ORDER BY score ASC, updated_at ASC" + } else { + // Descending leaderboard, higher score is better. + query += " ORDER BY score DESC, updated_at_inverse DESC" + } + + params = append(params, limit+1) + query += " LIMIT $" + strconv.Itoa(len(params)) + + logger.Debug("Leaderboard records list", zap.String("query", query)) + rows, err := p.db.Query(query, params...) + if err != nil { + logger.Error("Could not execute leaderboard records list query", zap.Error(err)) + session.Send(&Envelope{CollationId: envelope.CollationId, Payload: &Envelope_Error{&Error{Reason: "Error loading leaderboard records"}}}) + return + } + defer rows.Close() + + leaderboardRecords := []*LeaderboardRecord{} + var outgoingCursor []byte + + var id []byte + var ownerId []byte + var handle string + var lang string + var location sql.NullString + var timezone sql.NullString + var rankValue int64 + var score int64 + var numScore int64 + var metadata []byte + var rankedAt int64 + var updatedAt int64 + var expiresAt int64 + var bannedAt int64 + for rows.Next() { + if returnCursor && int64(len(leaderboardRecords)) >= limit { + cursorBuf := new(bytes.Buffer) + newCursor := &leaderboardRecordListCursor{ + Score: score, + UpdatedAt: updatedAt, + Id: id, + } + if gob.NewEncoder(cursorBuf).Encode(newCursor); err != nil { + logger.Error("Error creating leaderboard records list cursor", zap.Error(err)) + session.Send(&Envelope{CollationId: envelope.CollationId, Payload: &Envelope_Error{&Error{Reason: "Error encoding cursor"}}}) + return + } + outgoingCursor = cursorBuf.Bytes() + break + } + + err = rows.Scan(&id, &ownerId, &handle, &lang, &location, &timezone, + &rankValue, &score, &numScore, &metadata, &rankedAt, &updatedAt, &expiresAt, &bannedAt) + if err != nil { + logger.Error("Could not scan leaderboard records list query results", zap.Error(err)) + session.Send(&Envelope{CollationId: envelope.CollationId, Payload: &Envelope_Error{&Error{Reason: "Error loading leaderboard records"}}}) + return + } + + leaderboardRecords = append(leaderboardRecords, &LeaderboardRecord{ + LeaderboardId: incoming.LeaderboardId, + OwnerId: ownerId, + Handle: handle, + Lang: lang, + Location: location.String, + Timezone: timezone.String, + Rank: rankValue, + Score: score, + NumScore: numScore, + Metadata: metadata, + RankedAt: rankedAt, + UpdatedAt: updatedAt, + ExpiresAt: expiresAt, + }) + } + if err = rows.Err(); err != nil { + logger.Error("Could not process leaderboard records list query results", zap.Error(err)) + session.Send(&Envelope{CollationId: envelope.CollationId, Payload: &Envelope_Error{&Error{Reason: "Error loading leaderboard records"}}}) + return + } + + p.normalizeAndSendLeaderboardRecords(logger, session, envelope, leaderboardRecords, outgoingCursor) +} + +func (p *pipeline) loadLeaderboardRecordsHaystack(logger zap.Logger, session *session, envelope *Envelope, leaderboardId, findOwnerId []byte, currentExpiresAt, limit, sortOrder int64, query string, params []interface{}) { + // Find the owner's record. + var id []byte + var score int64 + var updatedAt int64 + findQuery := `SELECT id, score, updated_at + FROM leaderboard_record + WHERE leaderboard_id = $1 + AND expires_at = $2 + AND owner_id = $3` + logger.Debug("Leaderboard record find", zap.String("query", findQuery)) + err := p.db.QueryRow(findQuery, leaderboardId, currentExpiresAt, findOwnerId).Scan(&id, &score, &updatedAt) + if err != nil { + // TODO handle errors other than record not found? + session.Send(&Envelope{CollationId: envelope.CollationId, Payload: &Envelope_LeaderboardRecords{LeaderboardRecords: &TLeaderboardRecords{ + Records: []*LeaderboardRecord{}, + // No cursor. + }}}) + return + } + + // First half. + count := len(params) + firstQuery := query + firstParams := params + if sortOrder == 0 { + // Lower score is better, but get in reverse order from current user to get those immediately above. + firstQuery += " AND (score, updated_at_inverse, id) <= ($" + strconv.Itoa(count) + + ", $" + strconv.Itoa(count+1) + + ", $" + strconv.Itoa(count+2) + ") ORDER BY score DESC, updated_at_inverse DESC" + firstParams = append(firstParams, score, invertMs(updatedAt), id) + } else { + // Higher score is better. + firstQuery += " AND (score, updated_at, id) >= ($" + strconv.Itoa(count) + + ", $" + strconv.Itoa(count+1) + + ", $" + strconv.Itoa(count+2) + ") ORDER BY score ASC, updated_at ASC" + firstParams = append(firstParams, score, updatedAt, id) + } + firstParams = append(firstParams, int64(limit/2)) + firstQuery += " LIMIT $" + strconv.Itoa(len(firstParams)) + + logger.Debug("Leaderboard records list", zap.String("query", firstQuery)) + firstRows, err := p.db.Query(firstQuery, firstParams...) + if err != nil { + logger.Error("Could not execute leaderboard records list query", zap.Error(err)) + session.Send(&Envelope{CollationId: envelope.CollationId, Payload: &Envelope_Error{&Error{Reason: "Error loading leaderboard records"}}}) + return + } + defer firstRows.Close() + + leaderboardRecords := []*LeaderboardRecord{} + + var ownerId []byte + var handle string + var lang string + var location sql.NullString + var timezone sql.NullString + var rankValue int64 + var numScore int64 + var metadata []byte + var rankedAt int64 + var expiresAt int64 + var bannedAt int64 + for firstRows.Next() { + err = firstRows.Scan(&id, &ownerId, &handle, &lang, &location, &timezone, + &rankValue, &score, &numScore, &metadata, &rankedAt, &updatedAt, &expiresAt, &bannedAt) + if err != nil { + logger.Error("Could not scan leaderboard records list query results", zap.Error(err)) + session.Send(&Envelope{CollationId: envelope.CollationId, Payload: &Envelope_Error{&Error{Reason: "Error loading leaderboard records"}}}) + return + } + + leaderboardRecords = append(leaderboardRecords, &LeaderboardRecord{ + LeaderboardId: leaderboardId, + OwnerId: ownerId, + Handle: handle, + Lang: lang, + Location: location.String, + Timezone: timezone.String, + Rank: rankValue, + Score: score, + NumScore: numScore, + Metadata: metadata, + RankedAt: rankedAt, + UpdatedAt: updatedAt, + ExpiresAt: expiresAt, + }) + } + if err = firstRows.Err(); err != nil { + logger.Error("Could not process leaderboard records list query results", zap.Error(err)) + session.Send(&Envelope{CollationId: envelope.CollationId, Payload: &Envelope_Error{&Error{Reason: "Error loading leaderboard records"}}}) + return + } + + // We went 'up' on the leaderboard, so reverse the first half of records. + for left, right := 0, len(leaderboardRecords)-1; left < right; left, right = left+1, right-1 { + leaderboardRecords[left], leaderboardRecords[right] = leaderboardRecords[right], leaderboardRecords[left] + } + + // Second half. + secondQuery := query + secondParams := params + if sortOrder == 0 { + // Lower score is better. + secondQuery += " AND (score, updated_at, id) > ($" + strconv.Itoa(count) + + ", $" + strconv.Itoa(count+1) + + ", $" + strconv.Itoa(count+2) + ") ORDER BY score ASC, updated_at ASC" + secondParams = append(secondParams, score, updatedAt, id) + } else { + // Higher score is better. + secondQuery += " AND (score, updated_at_inverse, id) < ($" + strconv.Itoa(count) + + ", $" + strconv.Itoa(count+1) + + ", $" + strconv.Itoa(count+2) + ") ORDER BY score DESC, updated_at DESC" + secondParams = append(secondParams, score, invertMs(updatedAt), id) + } + secondParams = append(secondParams, limit-int64(len(leaderboardRecords))+2) + secondQuery += " LIMIT $" + strconv.Itoa(len(secondParams)) + + logger.Debug("Leaderboard records list", zap.String("query", secondQuery)) + secondRows, err := p.db.Query(secondQuery, secondParams...) + if err != nil { + logger.Error("Could not execute leaderboard records list query", zap.Error(err)) + session.Send(&Envelope{CollationId: envelope.CollationId, Payload: &Envelope_Error{&Error{Reason: "Error loading leaderboard records"}}}) + return + } + defer secondRows.Close() + + var outgoingCursor []byte + + for secondRows.Next() { + if int64(len(leaderboardRecords)) >= limit { + cursorBuf := new(bytes.Buffer) + newCursor := &leaderboardRecordListCursor{ + Score: score, + UpdatedAt: updatedAt, + Id: id, + } + if gob.NewEncoder(cursorBuf).Encode(newCursor); err != nil { + logger.Error("Error creating leaderboard records list cursor", zap.Error(err)) + session.Send(&Envelope{CollationId: envelope.CollationId, Payload: &Envelope_Error{&Error{Reason: "Error encoding cursor"}}}) + return + } + outgoingCursor = cursorBuf.Bytes() + break + } + + err = secondRows.Scan(&id, &ownerId, &handle, &lang, &location, &timezone, + &rankValue, &score, &numScore, &metadata, &rankedAt, &updatedAt, &expiresAt, &bannedAt) + if err != nil { + logger.Error("Could not scan leaderboard records list query results", zap.Error(err)) + session.Send(&Envelope{CollationId: envelope.CollationId, Payload: &Envelope_Error{&Error{Reason: "Error loading leaderboard records"}}}) + return + } + + leaderboardRecords = append(leaderboardRecords, &LeaderboardRecord{ + LeaderboardId: leaderboardId, + OwnerId: ownerId, + Handle: handle, + Lang: lang, + Location: location.String, + Timezone: timezone.String, + Rank: rankValue, + Score: score, + NumScore: numScore, + Metadata: metadata, + RankedAt: rankedAt, + UpdatedAt: updatedAt, + ExpiresAt: expiresAt, + }) + } + if err = secondRows.Err(); err != nil { + logger.Error("Could not process leaderboard records list query results", zap.Error(err)) + session.Send(&Envelope{CollationId: envelope.CollationId, Payload: &Envelope_Error{&Error{Reason: "Error loading leaderboard records"}}}) + return + } + + p.normalizeAndSendLeaderboardRecords(logger, session, envelope, leaderboardRecords, outgoingCursor) +} + +func (p *pipeline) normalizeAndSendLeaderboardRecords(logger zap.Logger, session *session, envelope *Envelope, records []*LeaderboardRecord, cursor []byte) { + var bestRank int64 + for _, record := range records { + if record.Rank != 0 && record.Rank < bestRank { + bestRank = record.Rank + } + } + if bestRank != 0 { + for i := int64(0); i < int64(len(records)); i++ { + records[i].Rank = bestRank + i + } + } + + session.Send(&Envelope{CollationId: envelope.CollationId, Payload: &Envelope_LeaderboardRecords{LeaderboardRecords: &TLeaderboardRecords{ + Records: records, + Cursor: cursor, + }}}) +} + +func invertMs(ms int64) int64 { + // Subtract a millisecond timestamp from a fixed value. + // This value represents Wed, 16 Nov 5138 at about 09:46:39 UTC. + return 99999999999999 - ms +} diff --git a/server/pipeline_match.go b/server/pipeline_match.go index 9cc273181e..739068eb94 100644 --- a/server/pipeline_match.go +++ b/server/pipeline_match.go @@ -22,11 +22,16 @@ import ( func (p *pipeline) matchCreate(logger zap.Logger, session *session, envelope *Envelope) { matchID := uuid.NewV4() - p.tracker.Track(session.id, "match:"+matchID.String(), session.userID, PresenceMeta{}) + handle := session.handle.Load() + + p.tracker.Track(session.id, "match:"+matchID.String(), session.userID, PresenceMeta{ + Handle: handle, + }) self := &UserPresence{ UserId: session.userID.Bytes(), SessionId: session.id.Bytes(), + Handle: handle, } session.Send(&Envelope{CollationId: envelope.CollationId, Payload: &Envelope_Match{Match: &TMatch{ @@ -51,7 +56,11 @@ func (p *pipeline) matchJoin(logger zap.Logger, session *session, envelope *Enve return } - p.tracker.Track(session.id, topic, session.userID, PresenceMeta{}) + handle := session.handle.Load() + + p.tracker.Track(session.id, topic, session.userID, PresenceMeta{ + Handle: handle, + }) userPresences := make([]*UserPresence, len(ps)+1) for i := 0; i < len(ps); i++ { @@ -59,11 +68,13 @@ func (p *pipeline) matchJoin(logger zap.Logger, session *session, envelope *Enve userPresences[i] = &UserPresence{ UserId: p.UserID.Bytes(), SessionId: p.ID.SessionID.Bytes(), + Handle: p.Meta.Handle, } } self := &UserPresence{ UserId: session.userID.Bytes(), SessionId: session.id.Bytes(), + Handle: handle, } userPresences[len(ps)] = self @@ -155,6 +166,7 @@ func (p *pipeline) matchDataSend(logger zap.Logger, session *session, envelope * Presence: &UserPresence{ UserId: session.userID.Bytes(), SessionId: session.id.Bytes(), + Handle: session.handle.Load(), }, OpCode: incoming.OpCode, Data: incoming.Data, diff --git a/server/pipeline_self.go b/server/pipeline_self.go index 5e1ee481f2..53394f4b7c 100644 --- a/server/pipeline_self.go +++ b/server/pipeline_self.go @@ -175,5 +175,10 @@ func (p *pipeline) selfUpdate(logger zap.Logger, session *session, envelope *Env return } + // Update handle in session and any presences. + if update.Handle != "" { + session.handle.Store(update.Handle) + } + session.Send(&Envelope{CollationId: envelope.CollationId}) } diff --git a/server/pipeline_topic.go b/server/pipeline_topic.go index 3fb361d5ca..26a7dac394 100644 --- a/server/pipeline_topic.go +++ b/server/pipeline_topic.go @@ -123,19 +123,31 @@ func (p *pipeline) topicJoin(logger zap.Logger, session *session, envelope *Enve return } + handle := session.handle.Load() + // Track the presence, and gather current member list. - p.tracker.Track(session.id, trackerTopic, session.userID, PresenceMeta{}) + p.tracker.Track(session.id, trackerTopic, session.userID, PresenceMeta{ + Handle: handle, + }) presences := p.tracker.ListByTopic(trackerTopic) userPresences := make([]*UserPresence, len(presences)) for i := 0; i < len(presences); i++ { - userPresences[i] = &UserPresence{UserId: presences[i].UserID.Bytes(), SessionId: presences[i].ID.SessionID.Bytes()} + userPresences[i] = &UserPresence{ + UserId: presences[i].UserID.Bytes(), + SessionId: presences[i].ID.SessionID.Bytes(), + Handle: presences[i].Meta.Handle, + } } session.Send(&Envelope{CollationId: envelope.CollationId, Payload: &Envelope_Topic{Topic: &TTopic{ Topic: topic, Presences: userPresences, - Self: &UserPresence{UserId: session.userID.Bytes(), SessionId: session.id.Bytes()}, + Self: &UserPresence{ + UserId: session.userID.Bytes(), + SessionId: session.id.Bytes(), + Handle: handle, + }, }}}) } @@ -566,14 +578,12 @@ func (p *pipeline) storeMessage(logger zap.Logger, session *session, topic *Topi } createdAt := nowMs() messageID := uuid.NewV4().Bytes() - var expiresAt int64 - var handle string - err := p.db.QueryRow(` + expiresAt := int64(0) + handle := session.handle.Load() + _, err := p.db.Exec(` INSERT INTO message (topic, topic_type, message_id, user_id, created_at, expires_at, handle, type, data) -SELECT $1, $2, $3, $4, $5, $6, handle, $7, $8 -FROM users -WHERE id = $4 -RETURNING handle`, topicBytes, topicType, messageID, session.userID.Bytes(), createdAt, expiresAt, msgType, data).Scan(&handle) +VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9)`, + topicBytes, topicType, messageID, session.userID.Bytes(), createdAt, expiresAt, handle, msgType, data) if err != nil { logger.Error("Failed to insert new message", zap.Error(err)) return nil, "", 0, 0, err diff --git a/server/presence_notifier.go b/server/presence_notifier.go index 50b2008c7f..2453d798de 100644 --- a/server/presence_notifier.go +++ b/server/presence_notifier.go @@ -166,6 +166,7 @@ func (pn *presenceNotifier) handleDiffMatch(matchID []byte, to, joins, leaves [] muJoins[i] = &UserPresence{ UserId: joins[i].UserID.Bytes(), SessionId: joins[i].ID.SessionID.Bytes(), + Handle: joins[i].Meta.Handle, } } msg.Joins = muJoins @@ -176,6 +177,7 @@ func (pn *presenceNotifier) handleDiffMatch(matchID []byte, to, joins, leaves [] muLeaves[i] = &UserPresence{ UserId: leaves[i].UserID.Bytes(), SessionId: leaves[i].ID.SessionID.Bytes(), + Handle: leaves[i].Meta.Handle, } } msg.Leaves = muLeaves @@ -196,6 +198,7 @@ func (pn *presenceNotifier) handleDiffTopic(topic *TopicId, to, joins, leaves [] tuJoins[i] = &UserPresence{ UserId: joins[i].UserID.Bytes(), SessionId: joins[i].ID.SessionID.Bytes(), + Handle: joins[i].Meta.Handle, } } msg.Joins = tuJoins @@ -206,6 +209,7 @@ func (pn *presenceNotifier) handleDiffTopic(topic *TopicId, to, joins, leaves [] tuLeaves[i] = &UserPresence{ UserId: leaves[i].UserID.Bytes(), SessionId: leaves[i].ID.SessionID.Bytes(), + Handle: leaves[i].Meta.Handle, } } msg.Leaves = tuLeaves diff --git a/server/session.go b/server/session.go index 7b87f04ef1..af6ddddadb 100644 --- a/server/session.go +++ b/server/session.go @@ -23,6 +23,7 @@ import ( "github.com/gogo/protobuf/proto" "github.com/gorilla/websocket" "github.com/satori/go.uuid" + "github.com/uber-go/atomic" "github.com/uber-go/zap" ) @@ -32,6 +33,8 @@ type session struct { config Config id uuid.UUID userID uuid.UUID + handle *atomic.String + lang string stopped bool conn *websocket.Conn pingTicker *time.Ticker @@ -39,7 +42,7 @@ type session struct { } // NewSession creates a new session which encapsulates a socket connection -func NewSession(logger zap.Logger, config Config, userID uuid.UUID, websocketConn *websocket.Conn, unregister func(s *session)) *session { +func NewSession(logger zap.Logger, config Config, userID uuid.UUID, handle string, lang string, websocketConn *websocket.Conn, unregister func(s *session)) *session { sessionID := uuid.NewV4() sessionLogger := logger.With(zap.String("uid", userID.String()), zap.String("sid", sessionID.String())) @@ -50,6 +53,8 @@ func NewSession(logger zap.Logger, config Config, userID uuid.UUID, websocketCon config: config, id: sessionID, userID: userID, + handle: atomic.NewString(handle), + lang: lang, conn: websocketConn, stopped: false, pingTicker: time.NewTicker(time.Duration(config.GetTransport().PingPeriodMs) * time.Millisecond), diff --git a/server/session_auth.go b/server/session_auth.go index 5c78b445c5..53c0728a56 100644 --- a/server/session_auth.go +++ b/server/session_auth.go @@ -111,20 +111,26 @@ func (a *authenticationService) configure() { } token := r.URL.Query().Get("token") - uid, auth := a.authenticateToken(token) + uid, handle, auth := a.authenticateToken(token) if !auth { http.Error(w, "Missing or invalid token", 401) return } + // TODO validate BCP 47 lang format + lang := r.URL.Query().Get("lang") + if lang == "" { + lang = "en" + } + conn, err := a.upgrader.Upgrade(w, r, nil) if err != nil { - //http.Error is invoked automatically from within the Upgrade func + // http.Error is invoked automatically from within the Upgrade func a.logger.Warn("Could not upgrade to websockets", zap.Error(err)) return } - a.registry.add(uid, conn, a.pipeline.processRequest) + a.registry.add(uid, handle, lang, conn, a.pipeline.processRequest) }).Methods("GET", "OPTIONS") } @@ -143,7 +149,7 @@ func (a *authenticationService) StartServer(mlogger zap.Logger) { } func (a *authenticationService) handleAuth(w http.ResponseWriter, r *http.Request, - retrieveUserID func(authReq *AuthenticateRequest) ([]byte, string, int)) { + retrieveUserID func(authReq *AuthenticateRequest) ([]byte, string, string, int)) { w.Header().Set("Content-Type", "application/octet-stream") @@ -171,7 +177,7 @@ func (a *authenticationService) handleAuth(w http.ResponseWriter, r *http.Reques return } - userID, errString, errCode := retrieveUserID(authReq) + userID, handle, errString, errCode := retrieveUserID(authReq) if errString != "" { a.logger.Debug("Could not retrieve user ID", zap.String("error", errString), zap.Int("code", errCode)) a.sendAuthError(w, errString, errCode, authReq) @@ -183,6 +189,7 @@ func (a *authenticationService) handleAuth(w http.ResponseWriter, r *http.Reques token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{ "uid": uid.String(), "exp": time.Now().UTC().Add(time.Duration(a.config.GetSession().TokenExpiryMs) * time.Millisecond).Unix(), + "han": handle, }) signedToken, _ := token.SignedString(a.hmacSecretByte) @@ -211,9 +218,9 @@ func (a *authenticationService) sendAuthResponse(w http.ResponseWriter, response w.Write(payload) } -func (a *authenticationService) login(authReq *AuthenticateRequest) ([]byte, string, int) { +func (a *authenticationService) login(authReq *AuthenticateRequest) ([]byte, string, string, int) { // Route to correct login handler - var loginFunc func(authReq *AuthenticateRequest) ([]byte, int64, string, int) + var loginFunc func(authReq *AuthenticateRequest) ([]byte, string, int64, string, int) switch authReq.Payload.(type) { case *AuthenticateRequest_Device: loginFunc = a.loginDevice @@ -230,211 +237,218 @@ func (a *authenticationService) login(authReq *AuthenticateRequest) ([]byte, str case *AuthenticateRequest_Custom: loginFunc = a.loginCustom default: - return nil, errorInvalidPayload, 400 + return nil, "", errorInvalidPayload, 400 } - userID, disabledAt, message, status := loginFunc(authReq) + userID, handle, disabledAt, message, status := loginFunc(authReq) if disabledAt != 0 { - return nil, "ID disabled", 401 + return nil, "", "ID disabled", 401 } - return userID, message, status + return userID, handle, message, status } -func (a *authenticationService) loginDevice(authReq *AuthenticateRequest) ([]byte, int64, string, int) { +func (a *authenticationService) loginDevice(authReq *AuthenticateRequest) ([]byte, string, int64, string, int) { deviceID := authReq.GetDevice() if deviceID == "" { - return nil, 0, "Device ID is required", 400 + return nil, "", 0, "Device ID is required", 400 } else if invalidCharsRegex.MatchString(deviceID) { - return nil, 0, "Invalid device ID, no spaces or control characters allowed", 400 + return nil, "", 0, "Invalid device ID, no spaces or control characters allowed", 400 } else if len(deviceID) < 10 || len(deviceID) > 64 { - return nil, 0, "Invalid device ID, must be 10-64 bytes", 400 + return nil, "", 0, "Invalid device ID, must be 10-64 bytes", 400 } var userID []byte + var handle string var disabledAt int64 - err := a.db.QueryRow("SELECT u.id, u.disabled_at FROM users u, user_device ud WHERE ud.id = $1 AND u.id = ud.user_id", + err := a.db.QueryRow("SELECT u.id, u.handle, u.disabled_at FROM users u, user_device ud WHERE ud.id = $1 AND u.id = ud.user_id", deviceID). - Scan(&userID, &disabledAt) + Scan(&userID, &handle, &disabledAt) if err != nil { a.logger.Warn(errorCouldNotLogin, zap.Error(err)) - return nil, 0, errorIDNotFound, 401 + return nil, "", 0, errorIDNotFound, 401 } - return userID, disabledAt, "", 200 + return userID, handle, disabledAt, "", 200 } -func (a *authenticationService) loginFacebook(authReq *AuthenticateRequest) ([]byte, int64, string, int) { +func (a *authenticationService) loginFacebook(authReq *AuthenticateRequest) ([]byte, string, int64, string, int) { accessToken := authReq.GetFacebook() if accessToken == "" { - return nil, 0, errorAccessTokenIsRequired, 400 + return nil, "", 0, errorAccessTokenIsRequired, 400 } else if invalidCharsRegex.MatchString(accessToken) { - return nil, 0, "Invalid Facebook access token, no spaces or control characters allowed", 400 + return nil, "", 0, "Invalid Facebook access token, no spaces or control characters allowed", 400 } fbProfile, err := a.socialClient.GetFacebookProfile(accessToken) if err != nil { a.logger.Warn("Could not get Facebook profile", zap.Error(err)) - return nil, 0, errorCouldNotLogin, 401 + return nil, "", 0, errorCouldNotLogin, 401 } var userID []byte + var handle string var disabledAt int64 - err = a.db.QueryRow("SELECT id, disabled_at FROM users WHERE facebook_id = $1", + err = a.db.QueryRow("SELECT id, handle, disabled_at FROM users WHERE facebook_id = $1", fbProfile.ID). - Scan(&userID, &disabledAt) + Scan(&userID, &handle, &disabledAt) if err != nil { a.logger.Warn("Could not login with Facebook profile", zap.Error(err)) - return nil, 0, errorIDNotFound, 401 + return nil, "", 0, errorIDNotFound, 401 } - return userID, disabledAt, "", 200 + return userID, handle, disabledAt, "", 200 } -func (a *authenticationService) loginGoogle(authReq *AuthenticateRequest) ([]byte, int64, string, int) { +func (a *authenticationService) loginGoogle(authReq *AuthenticateRequest) ([]byte, string, int64, string, int) { accessToken := authReq.GetGoogle() if accessToken == "" { - return nil, 0, errorAccessTokenIsRequired, 400 + return nil, "", 0, errorAccessTokenIsRequired, 400 } else if invalidCharsRegex.MatchString(accessToken) { - return nil, 0, "Invalid Google access token, no spaces or control characters allowed", 400 + return nil, "", 0, "Invalid Google access token, no spaces or control characters allowed", 400 } googleProfile, err := a.socialClient.GetGoogleProfile(accessToken) if err != nil { a.logger.Warn("Could not get Google profile", zap.Error(err)) - return nil, 0, errorCouldNotLogin, 401 + return nil, "", 0, errorCouldNotLogin, 401 } var userID []byte + var handle string var disabledAt int64 - err = a.db.QueryRow("SELECT id, disabled_at FROM users WHERE google_id = $1", + err = a.db.QueryRow("SELECT id, handle, disabled_at FROM users WHERE google_id = $1", googleProfile.ID). - Scan(&userID, &disabledAt) + Scan(&userID, &handle, &disabledAt) if err != nil { a.logger.Warn("Could not login with Google profile", zap.Error(err)) - return nil, 0, errorIDNotFound, 401 + return nil, "", 0, errorIDNotFound, 401 } - return userID, disabledAt, "", 200 + return userID, handle, disabledAt, "", 200 } -func (a *authenticationService) loginGameCenter(authReq *AuthenticateRequest) ([]byte, int64, string, int) { +func (a *authenticationService) loginGameCenter(authReq *AuthenticateRequest) ([]byte, string, int64, string, int) { gc := authReq.GetGameCenter() if gc == nil || gc.PlayerId == "" || gc.BundleId == "" || gc.Timestamp == 0 || gc.Salt == "" || gc.Signature == "" || gc.PublicKeyUrl == "" { - return nil, 0, errorInvalidPayload, 400 + return nil, "", 0, errorInvalidPayload, 400 } _, err := a.socialClient.CheckGameCenterID(gc.PlayerId, gc.BundleId, gc.Timestamp, gc.Salt, gc.Signature, gc.PublicKeyUrl) if err != nil { a.logger.Warn("Could not check Game Center profile", zap.Error(err)) - return nil, 0, errorCouldNotLogin, 401 + return nil, "", 0, errorCouldNotLogin, 401 } var userID []byte + var handle string var disabledAt int64 - err = a.db.QueryRow("SELECT id, disabled_at FROM users WHERE gamecenter_id = $1", + err = a.db.QueryRow("SELECT id, handle, disabled_at FROM users WHERE gamecenter_id = $1", gc.PlayerId). - Scan(&userID, &disabledAt) + Scan(&userID, &handle, &disabledAt) if err != nil { a.logger.Warn("Could not login with Game Center profile", zap.Error(err)) - return nil, 0, errorIDNotFound, 401 + return nil, "", 0, errorIDNotFound, 401 } - return userID, disabledAt, "", 200 + return userID, handle, disabledAt, "", 200 } -func (a *authenticationService) loginSteam(authReq *AuthenticateRequest) ([]byte, int64, string, int) { +func (a *authenticationService) loginSteam(authReq *AuthenticateRequest) ([]byte, string, int64, string, int) { if a.config.GetSocial().Steam.PublisherKey == "" || a.config.GetSocial().Steam.AppID == 0 { - return nil, 0, "Steam login not available", 401 + return nil, "", 0, "Steam login not available", 401 } ticket := authReq.GetSteam() if ticket == "" { - return nil, 0, "Steam ticket is required", 400 + return nil, "", 0, "Steam ticket is required", 400 } else if invalidCharsRegex.MatchString(ticket) { - return nil, 0, "Invalid Steam ticket, no spaces or control characters allowed", 400 + return nil, "", 0, "Invalid Steam ticket, no spaces or control characters allowed", 400 } steamProfile, err := a.socialClient.GetSteamProfile(a.config.GetSocial().Steam.PublisherKey, a.config.GetSocial().Steam.AppID, ticket) if err != nil { a.logger.Warn("Could not check Steam profile", zap.Error(err)) - return nil, 0, errorCouldNotLogin, 401 + return nil, "", 0, errorCouldNotLogin, 401 } var userID []byte + var handle string var disabledAt int64 - err = a.db.QueryRow("SELECT id, disabled_at FROM users WHERE steam_id = $1", + err = a.db.QueryRow("SELECT id, handle, disabled_at FROM users WHERE steam_id = $1", strconv.FormatUint(steamProfile.SteamID, 10)). - Scan(&userID, &disabledAt) + Scan(&userID, &handle, &disabledAt) if err != nil { a.logger.Warn("Could not login with Steam profile", zap.Error(err)) - return nil, 0, errorIDNotFound, 401 + return nil, "", 0, errorIDNotFound, 401 } - return userID, disabledAt, "", 200 + return userID, handle, disabledAt, "", 200 } -func (a *authenticationService) loginEmail(authReq *AuthenticateRequest) ([]byte, int64, string, int) { +func (a *authenticationService) loginEmail(authReq *AuthenticateRequest) ([]byte, string, int64, string, int) { email := authReq.GetEmail() if email == nil { - return nil, 0, errorInvalidPayload, 400 + return nil, "", 0, errorInvalidPayload, 400 } else if email.Email == "" { - return nil, 0, "Email address is required", 400 + return nil, "", 0, "Email address is required", 400 } else if invalidCharsRegex.MatchString(email.Email) { - return nil, 0, "Invalid email address, no spaces or control characters allowed", 400 + return nil, "", 0, "Invalid email address, no spaces or control characters allowed", 400 } else if !emailRegex.MatchString(email.Email) { - return nil, 0, "Invalid email address format", 400 + return nil, "", 0, "Invalid email address format", 400 } else if len(email.Email) < 10 || len(email.Email) > 255 { - return nil, 0, "Invalid email address, must be 10-255 bytes", 400 + return nil, "", 0, "Invalid email address, must be 10-255 bytes", 400 } var userID []byte + var handle string var hashedPassword []byte var disabledAt int64 - err := a.db.QueryRow("SELECT id, password, disabled_at FROM users WHERE email = $1", + err := a.db.QueryRow("SELECT id, handle, password, disabled_at FROM users WHERE email = $1", email.Email). - Scan(&userID, &hashedPassword, &disabledAt) + Scan(&userID, &handle, &hashedPassword, &disabledAt) if err != nil { a.logger.Warn(errorCouldNotLogin, zap.Error(err)) - return nil, 0, "Invalid credentials", 401 + return nil, "", 0, "Invalid credentials", 401 } err = bcrypt.CompareHashAndPassword(hashedPassword, []byte(email.Password)) if err != nil { a.logger.Warn("Invalid credentials", zap.Error(err)) - return nil, 0, "Invalid credentials", 401 + return nil, "", 0, "Invalid credentials", 401 } - return userID, disabledAt, "", 200 + return userID, handle, disabledAt, "", 200 } -func (a *authenticationService) loginCustom(authReq *AuthenticateRequest) ([]byte, int64, string, int) { +func (a *authenticationService) loginCustom(authReq *AuthenticateRequest) ([]byte, string, int64, string, int) { customID := authReq.GetCustom() if customID == "" { - return nil, 0, "Custom ID is required", 400 + return nil, "", 0, "Custom ID is required", 400 } else if invalidCharsRegex.MatchString(customID) { - return nil, 0, "Invalid custom ID, no spaces or control characters allowed", 400 + return nil, "", 0, "Invalid custom ID, no spaces or control characters allowed", 400 } else if len(customID) < 10 || len(customID) > 64 { - return nil, 0, "Invalid custom ID, must be 10-64 bytes", 400 + return nil, "", 0, "Invalid custom ID, must be 10-64 bytes", 400 } var userID []byte + var handle string var disabledAt int64 - err := a.db.QueryRow("SELECT id, disabled_at FROM users WHERE custom_id = $1", + err := a.db.QueryRow("SELECT id, handle, disabled_at FROM users WHERE custom_id = $1", customID). - Scan(&userID, &disabledAt) + Scan(&userID, &handle, &disabledAt) if err != nil { a.logger.Warn(errorCouldNotLogin, zap.Error(err)) - return nil, 0, errorIDNotFound, 401 + return nil, "", 0, errorIDNotFound, 401 } - return userID, disabledAt, "", 200 + return userID, handle, disabledAt, "", 200 } -func (a *authenticationService) register(authReq *AuthenticateRequest) ([]byte, string, int) { +func (a *authenticationService) register(authReq *AuthenticateRequest) ([]byte, string, string, int) { // Route to correct register handler - var registerFunc func(tx *sql.Tx, authReq *AuthenticateRequest) ([]byte, string, int) + var registerFunc func(tx *sql.Tx, authReq *AuthenticateRequest) ([]byte, string, string, int) switch authReq.Payload.(type) { case *AuthenticateRequest_Device: @@ -452,16 +466,16 @@ func (a *authenticationService) register(authReq *AuthenticateRequest) ([]byte, case *AuthenticateRequest_Custom: registerFunc = a.registerCustom default: - return nil, errorInvalidPayload, 400 + return nil, "", errorInvalidPayload, 400 } tx, err := a.db.Begin() if err != nil { a.logger.Warn("Could not register, transaction begin error", zap.Error(err)) - return nil, errorCouldNotRegister, 500 + return nil, "", errorCouldNotRegister, 500 } - userID, errorMessage, errorCode := registerFunc(tx, authReq) + userID, handle, errorMessage, errorCode := registerFunc(tx, authReq) if errorCode != 200 { if tx != nil { @@ -470,17 +484,17 @@ func (a *authenticationService) register(authReq *AuthenticateRequest) ([]byte, a.logger.Error("Could not rollback transaction", zap.Error(err)) } } - return userID, errorMessage, errorCode + return userID, handle, errorMessage, errorCode } err = tx.Commit() if err != nil { a.logger.Error("Could not commit transaction", zap.Error(err)) - return nil, errorCouldNotRegister, 500 + return nil, "", errorCouldNotRegister, 500 } a.logger.Info("Registration complete", zap.String("uid", uuid.FromBytesOrNil(userID).String())) - return userID, errorMessage, errorCode + return userID, handle, errorMessage, errorCode } func (a *authenticationService) addUserEdgeMetadata(tx *sql.Tx, userID []byte, updatedAt int64) error { @@ -488,18 +502,19 @@ func (a *authenticationService) addUserEdgeMetadata(tx *sql.Tx, userID []byte, u return err } -func (a *authenticationService) registerDevice(tx *sql.Tx, authReq *AuthenticateRequest) ([]byte, string, int) { +func (a *authenticationService) registerDevice(tx *sql.Tx, authReq *AuthenticateRequest) ([]byte, string, string, int) { deviceID := authReq.GetDevice() if deviceID == "" { - return nil, "Device ID is required", 400 + return nil, "", "Device ID is required", 400 } else if invalidCharsRegex.MatchString(deviceID) { - return nil, "Invalid device ID, no spaces or control characters allowed", 400 + return nil, "", "Invalid device ID, no spaces or control characters allowed", 400 } else if len(deviceID) < 10 || len(deviceID) > 64 { - return nil, "Invalid device ID, must be 10-64 bytes", 400 + return nil, "", "Invalid device ID, must be 10-64 bytes", 400 } updatedAt := nowMs() userID := uuid.NewV4().Bytes() + handle := a.generateHandle() res, err := tx.Exec(` INSERT INTO users (id, handle, created_at, updated_at) SELECT $1 AS id, @@ -510,49 +525,50 @@ WHERE NOT EXISTS (SELECT id FROM user_device WHERE id = $3)`, - userID, a.generateHandle(), deviceID, updatedAt) + userID, handle, deviceID, updatedAt) if err != nil { a.logger.Warn("Could not register new device profile, query error", zap.Error(err)) - return nil, errorCouldNotRegister, 401 + return nil, "", errorCouldNotRegister, 401 } if rowsAffected, _ := res.RowsAffected(); rowsAffected == 0 { a.logger.Warn("Could not register new device profile, rows affected error") - return nil, errorIDAlreadyInUse, 401 + return nil, "", errorIDAlreadyInUse, 401 } res, err = tx.Exec("INSERT INTO user_device (id, user_id) VALUES ($1, $2)", deviceID, userID) if err != nil { a.logger.Warn("Could not register, query error", zap.Error(err)) - return nil, errorCouldNotRegister, 401 + return nil, "", errorCouldNotRegister, 401 } if count, _ := res.RowsAffected(); count == 0 { - return nil, errorCouldNotRegister, 401 + return nil, "", errorCouldNotRegister, 401 } err = a.addUserEdgeMetadata(tx, userID, updatedAt) if err != nil { - return nil, errorCouldNotRegister, 401 + return nil, "", errorCouldNotRegister, 401 } - return userID, "", 200 + return userID, handle, "", 200 } -func (a *authenticationService) registerFacebook(tx *sql.Tx, authReq *AuthenticateRequest) ([]byte, string, int) { +func (a *authenticationService) registerFacebook(tx *sql.Tx, authReq *AuthenticateRequest) ([]byte, string, string, int) { accessToken := authReq.GetFacebook() if accessToken == "" { - return nil, errorAccessTokenIsRequired, 400 + return nil, "", errorAccessTokenIsRequired, 400 } else if invalidCharsRegex.MatchString(accessToken) { - return nil, "Invalid Facebook access token, no spaces or control characters allowed", 400 + return nil, "", "Invalid Facebook access token, no spaces or control characters allowed", 400 } fbProfile, err := a.socialClient.GetFacebookProfile(accessToken) if err != nil { a.logger.Warn("Could not get Facebook profile", zap.Error(err)) - return nil, errorCouldNotRegister, 401 + return nil, "", errorCouldNotRegister, 401 } updatedAt := nowMs() userID := uuid.NewV4().Bytes() + handle := a.generateHandle() res, err := tx.Exec(` INSERT INTO users (id, handle, facebook_id, created_at, updated_at) SELECT $1 AS id, @@ -564,15 +580,15 @@ WHERE NOT EXISTS (SELECT id FROM users WHERE facebook_id = $3)`, - userID, a.generateHandle(), fbProfile.ID, updatedAt) + userID, handle, fbProfile.ID, updatedAt) if err != nil { a.logger.Warn("Could not register new Facebook profile, query error", zap.Error(err)) - return nil, errorCouldNotRegister, 401 + return nil, "", errorCouldNotRegister, 401 } if rowsAffected, _ := res.RowsAffected(); rowsAffected == 0 { a.logger.Warn("Could not register new Facebook profile, rows affected error") - return nil, errorIDAlreadyInUse, 401 + return nil, "", errorIDAlreadyInUse, 401 } l := a.logger.With(zap.String("user_id", uuid.FromBytesOrNil(userID).String())) @@ -580,28 +596,29 @@ WHERE NOT EXISTS err = a.addUserEdgeMetadata(tx, userID, updatedAt) if err != nil { - return nil, errorCouldNotRegister, 401 + return nil, "", errorCouldNotRegister, 401 } - return userID, "", 200 + return userID, handle, "", 200 } -func (a *authenticationService) registerGoogle(tx *sql.Tx, authReq *AuthenticateRequest) ([]byte, string, int) { +func (a *authenticationService) registerGoogle(tx *sql.Tx, authReq *AuthenticateRequest) ([]byte, string, string, int) { accessToken := authReq.GetGoogle() if accessToken == "" { - return nil, errorAccessTokenIsRequired, 400 + return nil, "", errorAccessTokenIsRequired, 400 } else if invalidCharsRegex.MatchString(accessToken) { - return nil, "Invalid Google access token, no spaces or control characters allowed", 400 + return nil, "", "Invalid Google access token, no spaces or control characters allowed", 400 } googleProfile, err := a.socialClient.GetGoogleProfile(accessToken) if err != nil { a.logger.Warn("Could not get Google profile", zap.Error(err)) - return nil, errorCouldNotRegister, 401 + return nil, "", errorCouldNotRegister, 401 } updatedAt := nowMs() userID := uuid.NewV4().Bytes() + handle := a.generateHandle() res, err := tx.Exec(` INSERT INTO users (id, handle, google_id, created_at, updated_at) SELECT $1 AS id, @@ -614,41 +631,42 @@ WHERE NOT EXISTS FROM users WHERE google_id = $3)`, userID, - a.generateHandle(), + handle, googleProfile.ID, updatedAt) if err != nil { a.logger.Warn("Could not register new Google profile, query error", zap.Error(err)) - return nil, errorCouldNotRegister, 401 + return nil, "", errorCouldNotRegister, 401 } if rowsAffected, _ := res.RowsAffected(); rowsAffected == 0 { a.logger.Warn("Could not register new Google profile, rows affected error") - return nil, errorIDAlreadyInUse, 401 + return nil, "", errorIDAlreadyInUse, 401 } err = a.addUserEdgeMetadata(tx, userID, updatedAt) if err != nil { - return nil, errorCouldNotRegister, 401 + return nil, "", errorCouldNotRegister, 401 } - return userID, "", 200 + return userID, handle, "", 200 } -func (a *authenticationService) registerGameCenter(tx *sql.Tx, authReq *AuthenticateRequest) ([]byte, string, int) { +func (a *authenticationService) registerGameCenter(tx *sql.Tx, authReq *AuthenticateRequest) ([]byte, string, string, int) { gc := authReq.GetGameCenter() if gc == nil || gc.PlayerId == "" || gc.BundleId == "" || gc.Timestamp == 0 || gc.Salt == "" || gc.Signature == "" || gc.PublicKeyUrl == "" { - return nil, errorInvalidPayload, 400 + return nil, "", errorInvalidPayload, 400 } _, err := a.socialClient.CheckGameCenterID(gc.PlayerId, gc.BundleId, gc.Timestamp, gc.Salt, gc.Signature, gc.PublicKeyUrl) if err != nil { a.logger.Warn("Could not get Game Center profile", zap.Error(err)) - return nil, errorCouldNotRegister, 401 + return nil, "", errorCouldNotRegister, 401 } updatedAt := nowMs() userID := uuid.NewV4().Bytes() + handle := a.generateHandle() res, err := tx.Exec(` INSERT INTO users (id, handle, gamecenter_id, created_at, updated_at) SELECT $1 AS id, @@ -661,47 +679,48 @@ WHERE NOT EXISTS FROM users WHERE gamecenter_id = $3)`, userID, - a.generateHandle(), + handle, gc.PlayerId, updatedAt) if err != nil { a.logger.Warn("Could not register new Game Center profile, query error", zap.Error(err)) - return nil, errorCouldNotRegister, 401 + return nil, "", errorCouldNotRegister, 401 } if rowsAffected, _ := res.RowsAffected(); rowsAffected == 0 { a.logger.Warn("Could not register new Game Center profile, rows affected error") - return nil, errorIDAlreadyInUse, 401 + return nil, "", errorIDAlreadyInUse, 401 } err = a.addUserEdgeMetadata(tx, userID, updatedAt) if err != nil { - return nil, errorCouldNotRegister, 401 + return nil, "", errorCouldNotRegister, 401 } - return userID, "", 200 + return userID, handle, "", 200 } -func (a *authenticationService) registerSteam(tx *sql.Tx, authReq *AuthenticateRequest) ([]byte, string, int) { +func (a *authenticationService) registerSteam(tx *sql.Tx, authReq *AuthenticateRequest) ([]byte, string, string, int) { if a.config.GetSocial().Steam.PublisherKey == "" || a.config.GetSocial().Steam.AppID == 0 { - return nil, "Steam registration not available", 401 + return nil, "", "Steam registration not available", 401 } ticket := authReq.GetSteam() if ticket == "" { - return nil, "Steam ticket is required", 400 + return nil, "", "Steam ticket is required", 400 } else if invalidCharsRegex.MatchString(ticket) { - return nil, "Invalid Steam ticket, no spaces or control characters allowed", 400 + return nil, "", "Invalid Steam ticket, no spaces or control characters allowed", 400 } steamProfile, err := a.socialClient.GetSteamProfile(a.config.GetSocial().Steam.PublisherKey, a.config.GetSocial().Steam.AppID, ticket) if err != nil { a.logger.Warn("Could not get Steam profile", zap.Error(err)) - return nil, errorCouldNotRegister, 401 + return nil, "", errorCouldNotRegister, 401 } updatedAt := nowMs() userID := uuid.NewV4().Bytes() + handle := a.generateHandle() res, err := tx.Exec(` INSERT INTO users (id, handle, steam_id, created_at, updated_at) SELECT $1 AS id, @@ -714,47 +733,48 @@ WHERE NOT EXISTS FROM users WHERE steam_id = $3)`, userID, - a.generateHandle(), + handle, strconv.FormatUint(steamProfile.SteamID, 10), updatedAt) if err != nil { a.logger.Warn("Could not register new Steam profile, query error", zap.Error(err)) - return nil, errorCouldNotRegister, 401 + return nil, "", errorCouldNotRegister, 401 } if rowsAffected, _ := res.RowsAffected(); rowsAffected == 0 { a.logger.Warn("Could not register new Steam profile, rows affected error") - return nil, errorIDAlreadyInUse, 401 + return nil, "", errorIDAlreadyInUse, 401 } err = a.addUserEdgeMetadata(tx, userID, updatedAt) if err != nil { - return nil, errorCouldNotRegister, 401 + return nil, "", errorCouldNotRegister, 401 } - return userID, "", 200 + return userID, handle, "", 200 } -func (a *authenticationService) registerEmail(tx *sql.Tx, authReq *AuthenticateRequest) ([]byte, string, int) { +func (a *authenticationService) registerEmail(tx *sql.Tx, authReq *AuthenticateRequest) ([]byte, string, string, int) { email := authReq.GetEmail() if email == nil { - return nil, errorInvalidPayload, 400 + return nil, "", errorInvalidPayload, 400 } else if email.Email == "" { - return nil, "Email address is required", 400 + return nil, "", "Email address is required", 400 } else if invalidCharsRegex.MatchString(email.Email) { - return nil, "Invalid email address, no spaces or control characters allowed", 400 + return nil, "", "Invalid email address, no spaces or control characters allowed", 400 } else if len(email.Password) < 8 { - return nil, "Password must be longer than 8 characters", 400 + return nil, "", "Password must be longer than 8 characters", 400 } else if !emailRegex.MatchString(email.Email) { - return nil, "Invalid email address format", 400 + return nil, "", "Invalid email address format", 400 } else if len(email.Email) < 10 || len(email.Email) > 255 { - return nil, "Invalid email address, must be 10-255 bytes", 400 + return nil, "", "Invalid email address, must be 10-255 bytes", 400 } hashedPassword, _ := bcrypt.GenerateFromPassword([]byte(email.Password), bcrypt.DefaultCost) updatedAt := nowMs() userID := uuid.NewV4().Bytes() + handle := a.generateHandle() res, err := tx.Exec(` INSERT INTO users (id, handle, email, password, created_at, updated_at) SELECT $1 AS id, @@ -768,40 +788,41 @@ WHERE NOT EXISTS FROM users WHERE email = $3)`, userID, - a.generateHandle(), + handle, email.Email, hashedPassword, updatedAt) if err != nil { a.logger.Warn("Could not register new email profile, query error", zap.Error(err)) - return nil, "Email already in use", 401 + return nil, "", "Email already in use", 401 } if rowsAffected, _ := res.RowsAffected(); rowsAffected == 0 { a.logger.Warn("Could not register new email profile, rows affected error") - return nil, errorIDAlreadyInUse, 401 + return nil, "", errorIDAlreadyInUse, 401 } err = a.addUserEdgeMetadata(tx, userID, updatedAt) if err != nil { - return nil, "Email already in use", 401 + return nil, "", "Email already in use", 401 } - return userID, "", 200 + return userID, handle, "", 200 } -func (a *authenticationService) registerCustom(tx *sql.Tx, authReq *AuthenticateRequest) ([]byte, string, int) { +func (a *authenticationService) registerCustom(tx *sql.Tx, authReq *AuthenticateRequest) ([]byte, string, string, int) { customID := authReq.GetCustom() if customID == "" { - return nil, "Custom ID is required", 400 + return nil, "", "Custom ID is required", 400 } else if invalidCharsRegex.MatchString(customID) { - return nil, "Invalid custom ID, no spaces or control characters allowed", 400 + return nil, "", "Invalid custom ID, no spaces or control characters allowed", 400 } else if len(customID) < 10 || len(customID) > 64 { - return nil, "Invalid custom ID, must be 10-64 bytes", 400 + return nil, "", "Invalid custom ID, must be 10-64 bytes", 400 } updatedAt := nowMs() userID := uuid.NewV4().Bytes() + handle := a.generateHandle() res, err := tx.Exec(` INSERT INTO users (id, handle, custom_id, created_at, updated_at) SELECT $1 AS id, @@ -814,25 +835,25 @@ WHERE NOT EXISTS FROM users WHERE custom_id = $3)`, userID, - a.generateHandle(), + handle, customID, updatedAt) if err != nil { a.logger.Warn("Could not register new custom profile, query error", zap.Error(err)) - return nil, errorCouldNotRegister, 401 + return nil, "", errorCouldNotRegister, 401 } if rowsAffected, _ := res.RowsAffected(); rowsAffected == 0 { a.logger.Warn("Could not register new custom profile, rows affected error") - return nil, errorIDAlreadyInUse, 401 + return nil, "", errorIDAlreadyInUse, 401 } err = a.addUserEdgeMetadata(tx, userID, updatedAt) if err != nil { - return nil, errorCouldNotRegister, 401 + return nil, "", errorCouldNotRegister, 401 } - return userID, "", 200 + return userID, handle, "", 200 } func (a *authenticationService) generateHandle() string { @@ -843,10 +864,10 @@ func (a *authenticationService) generateHandle() string { return string(b) } -func (a *authenticationService) authenticateToken(tokenString string) (uuid.UUID, bool) { +func (a *authenticationService) authenticateToken(tokenString string) (uuid.UUID, string, bool) { if tokenString == "" { a.logger.Warn("Token missing") - return uuid.Nil, false + return uuid.Nil, "", false } token, err := jwt.Parse(tokenString, func(token *jwt.Token) (interface{}, error) { @@ -861,14 +882,14 @@ func (a *authenticationService) authenticateToken(tokenString string) (uuid.UUID uid, uerr := uuid.FromString(claims["uid"].(string)) if uerr != nil { a.logger.Warn("Invalid user ID in token", zap.String("token", tokenString), zap.Error(uerr)) - return uuid.Nil, false + return uuid.Nil, "", false } - return uid, true + return uid, claims["han"].(string), true } } a.logger.Warn("Token invalid", zap.String("token", tokenString), zap.Error(err)) - return uuid.Nil, false + return uuid.Nil, "", false } func (a *authenticationService) Stop() { @@ -876,6 +897,14 @@ func (a *authenticationService) Stop() { a.registry.stop() } +func now() time.Time { + return time.Now().UTC() +} + func nowMs() int64 { - return int64(time.Nanosecond) * time.Now().UTC().UnixNano() / int64(time.Millisecond) + return timeToMs(now()) +} + +func timeToMs(t time.Time) int64 { + return int64(time.Nanosecond) * t.UnixNano() / int64(time.Millisecond) } diff --git a/server/session_registry.go b/server/session_registry.go index 489c28d22d..5f8542c1d2 100644 --- a/server/session_registry.go +++ b/server/session_registry.go @@ -62,8 +62,8 @@ func (a *SessionRegistry) Get(sessionID uuid.UUID) *session { return s } -func (a *SessionRegistry) add(userID uuid.UUID, conn *websocket.Conn, processRequest func(logger zap.Logger, session *session, envelope *Envelope)) { - s := NewSession(a.logger, a.config, userID, conn, a.remove) +func (a *SessionRegistry) add(userID uuid.UUID, handle string, lang string, conn *websocket.Conn, processRequest func(logger zap.Logger, session *session, envelope *Envelope)) { + s := NewSession(a.logger, a.config, userID, handle, lang, conn, a.remove) a.Lock() a.sessions[s.id] = s a.Unlock() diff --git a/server/tracker.go b/server/tracker.go index 158dfc5b83..2ea51e7be2 100644 --- a/server/tracker.go +++ b/server/tracker.go @@ -26,7 +26,9 @@ type PresenceID struct { SessionID uuid.UUID } -type PresenceMeta struct{} +type PresenceMeta struct { + Handle string +} type Presence struct { ID PresenceID @@ -43,6 +45,7 @@ type Tracker interface { Untrack(sessionID uuid.UUID, topic string, userID uuid.UUID) UntrackAll(sessionID uuid.UUID) Update(sessionID uuid.UUID, topic string, userID uuid.UUID, meta PresenceMeta) error + UpdateAll(sessionID uuid.UUID, meta PresenceMeta) // Get current total number of presences. Count() int @@ -159,6 +162,28 @@ func (t *TrackerService) Update(sessionID uuid.UUID, topic string, userID uuid.U return e } +func (t *TrackerService) UpdateAll(sessionID uuid.UUID, meta PresenceMeta) { + joins := make([]Presence, 0) + leaves := make([]Presence, 0) + t.Lock() + for pc, m := range t.values { + if pc.ID.SessionID == sessionID { + joins = append(joins, Presence{ID: pc.ID, Topic: pc.Topic, UserID: pc.UserID, Meta: meta}) + leaves = append(leaves, Presence{ID: pc.ID, Topic: pc.Topic, UserID: pc.UserID, Meta: m}) + } + } + if len(joins) != 0 { + for _, p := range joins { + t.values[presenceCompact{ID: p.ID, Topic: p.Topic, UserID: p.UserID}] = p.Meta + } + t.notifyDiffListeners( + joins, + leaves, + ) + } + t.Unlock() +} + func (t *TrackerService) Count() int { var count int t.RLock()