Skip to content

Commit

Permalink
Merge pull request #62 from zweihander/rpc-unit-tests
Browse files Browse the repository at this point in the history
client: RPC engine
  • Loading branch information
ernado authored Dec 22, 2020
2 parents f664ae7 + edfd21a commit 32ffc7e
Show file tree
Hide file tree
Showing 10 changed files with 588 additions and 175 deletions.
1 change: 1 addition & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -14,5 +14,6 @@ require (
go.uber.org/zap v1.16.0
golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9
golang.org/x/net v0.0.0-20200520004742-59133d7f0dd7
golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9
golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1
)
23 changes: 8 additions & 15 deletions telegram/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ import (
"github.com/gotd/td/internal/mt"
"github.com/gotd/td/internal/proto"
"github.com/gotd/td/internal/tmap"
"github.com/gotd/td/telegram/internal/rpc"
"github.com/gotd/td/tg"
"github.com/gotd/td/transport"
)
Expand Down Expand Up @@ -84,23 +85,14 @@ type Client struct {
updateHandler UpdateHandler // immutable
sessionStorage SessionStorage // immutable

// callbacks for RPC requests, protected by rpcMux.
// Key is request message id.
rpc map[int64]func(b *bin.Buffer, rpcErr error) error
rpcMux sync.Mutex
rpc *rpc.Engine

// callbacks for ack protected by ackMux
ack map[int64]func()
ackMux sync.Mutex
// ackSendChan is queue for outgoing message id's that require waiting for
// ack from server.
ackSendChan chan int64
ackBatchSize int
ackInterval time.Duration

maxRetries int
retryInterval time.Duration

// callbacks for ping results protected by pingMux.
// Key is ping id.
ping map[int64]func()
Expand Down Expand Up @@ -140,18 +132,13 @@ func NewClient(appID int, appHash string, opt Options) *Client {
cipher: crypto.NewClientCipher(opt.Random),
log: opt.Logger,
ping: map[int64]func(){},
rpc: map[int64]func(b *bin.Buffer, rpcErr error) error{},

sessionCreated: createCondOnce(),

ack: map[int64]func(){},
ackSendChan: make(chan int64),
ackInterval: opt.AckInterval,
ackBatchSize: opt.AckBatchSize,

maxRetries: opt.MaxRetries,
retryInterval: opt.RetryInterval,

ctx: clientCtx,
cancel: clientCancel,

Expand All @@ -169,6 +156,12 @@ func NewClient(appID int, appHash string, opt Options) *Client {
),
}

client.rpc = rpc.New(client.write, rpc.Config{
Logger: opt.Logger.Named("rpc"),
RetryInterval: opt.RetryInterval,
MaxRetries: opt.MaxRetries,
})

// Initializing internal RPC caller.
client.tg = tg.NewClient(client)

Expand Down
15 changes: 2 additions & 13 deletions telegram/handle_ack.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,21 +13,10 @@ func (c *Client) handleAck(b *bin.Buffer) error {
if err := ack.Decode(b); err != nil {
return xerrors.Errorf("decode: %w", err)
}
c.log.With(zap.Int64s("messages", ack.MsgIds)).Debug("Ack")

c.ackMux.Lock()
defer c.ackMux.Unlock()

for _, msgID := range ack.MsgIds {
fn, found := c.ack[msgID]
if !found {
c.log.Warn("ack callback is not set", zap.Int64("message_id", msgID))
continue
}
c.log.With(zap.Int64s("messages", ack.MsgIds)).Debug("Ack")

fn()
delete(c.ack, msgID)
}
c.rpc.NotifyAcks(ack.MsgIds)

return nil
}
16 changes: 2 additions & 14 deletions telegram/handle_bad_msg.go
Original file line number Diff line number Diff line change
Expand Up @@ -61,27 +61,15 @@ func (c *Client) handleBadMsg(b *bin.Buffer) error {
return err
}

c.rpcMux.Lock()
f, ok := c.rpc[bad.BadMsgID]
c.rpcMux.Unlock()
if ok {
return f(b, &badMessageError{Code: bad.ErrorCode})
}

c.rpc.NotifyError(bad.BadMsgID, &badMessageError{Code: bad.ErrorCode})
return nil
case mt.BadServerSaltTypeID:
var bad mt.BadServerSalt
if err := bad.Decode(b); err != nil {
return err
}

c.rpcMux.Lock()
f, ok := c.rpc[bad.BadMsgID]
c.rpcMux.Unlock()
if ok {
return f(b, &badMessageError{Code: bad.ErrorCode, NewSalt: bad.NewServerSalt})
}

c.rpc.NotifyError(bad.BadMsgID, &badMessageError{Code: bad.ErrorCode, NewSalt: bad.NewServerSalt})
return nil
default:
return xerrors.Errorf("unknown type id 0x%d", id)
Expand Down
2 changes: 2 additions & 0 deletions telegram/handle_message_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"testing"

"github.com/gotd/td/bin"
"github.com/gotd/td/telegram/internal/rpc"

"go.uber.org/zap"
)
Expand Down Expand Up @@ -57,6 +58,7 @@ func TestClientHandleMessageCorpus(t *testing.T) {
rand: Zero{},
log: zap.NewNop(),
sessionCreated: createCondOnce(),
rpc: rpc.New(rpc.NopSend, rpc.Config{}),
}
c.sessionCreated.Done()

Expand Down
26 changes: 6 additions & 20 deletions telegram/handle_result.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,32 +45,18 @@ func (c *Client) handleResult(b *bin.Buffer) error {
return xerrors.Errorf("error decode: %w", err)
}

c.rpcMux.Lock()
f, ok := c.rpc[res.RequestMessageID]
c.rpcMux.Unlock()
if ok {
e := &Error{
Code: rpcErr.ErrorCode,
Message: rpcErr.ErrorMessage,
}
e.extractArgument()
return f(nil, e)
e := &Error{
Code: rpcErr.ErrorCode,
Message: rpcErr.ErrorMessage,
}
e.extractArgument()

c.rpc.NotifyError(res.RequestMessageID, e)
return nil
}
if id == mt.PongTypeID {
return c.handlePong(b)
}

c.rpcMux.Lock()
f, ok := c.rpc[res.RequestMessageID]
c.rpcMux.Unlock()

if ok {
return f(b, nil)
}

c.log.Debug("Got unexpected result")
return nil
return c.rpc.NotifyResult(res.RequestMessageID, b)
}
45 changes: 45 additions & 0 deletions telegram/internal/rpc/ack.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
package rpc

import (
"context"

"go.uber.org/zap"
)

// NotifyAcks notifies engine about received acknowledgements.
func (e *Engine) NotifyAcks(ids []int64) {
for _, id := range ids {
e.mux.Lock()
cb, ok := e.ack[id]
e.mux.Unlock()

if !ok {
e.log.Warn("ack callback not set", zap.Int64("msg_id", id))
continue
}

cb()
}
}

// waitAck blocks until acknowledgement on message id is received.
func (e *Engine) waitAck(ctx context.Context, id int64) error {
got := make(chan struct{})

e.mux.Lock()
e.ack[id] = func() { close(got) }
e.mux.Unlock()

defer func() {
e.mux.Lock()
delete(e.ack, id)
e.mux.Unlock()
}()

select {
case <-ctx.Done():
return ctx.Err()
case <-got:
return nil
}
}
Loading

0 comments on commit 32ffc7e

Please sign in to comment.