Skip to content

Commit

Permalink
Merge pull request #10 from tdakkota/feature/client-enhancement
Browse files Browse the repository at this point in the history
Add codegeneration for update handlers
  • Loading branch information
ernado authored Dec 13, 2020
2 parents 9bd72eb + a36e190 commit 6d01989
Show file tree
Hide file tree
Showing 12 changed files with 1,714 additions and 177 deletions.
106 changes: 43 additions & 63 deletions cmd/gotdecho/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,71 +5,20 @@ import (
"fmt"
"os"
"os/signal"
"path"
"path/filepath"
"strconv"
"syscall"
"time"

"go.uber.org/zap"
"go.uber.org/zap/zapcore"
"golang.org/x/net/proxy"
"golang.org/x/xerrors"

"github.com/gotd/td/telegram"
"github.com/gotd/td/tg"
)

type updateHandler struct {
log *zap.Logger
}

func (h updateHandler) handle(ctx context.Context, client telegram.UpdateClient, updates *tg.Updates) error {
// This wll be required to send message back.
users := map[int]*tg.User{}
for _, u := range updates.Users {
user, ok := u.(*tg.User)
if !ok {
continue
}
users[user.ID] = user
}

for _, update := range updates.Updates {
switch u := update.(type) {
case *tg.UpdateNewMessage:
switch m := u.Message.(type) {
case *tg.Message:
switch peer := m.PeerID.(type) {
case *tg.PeerUser:
user := users[peer.UserID]
h.log.With(
zap.String("text", m.Message),
zap.Int("user_id", user.ID),
zap.String("user_first_name", user.FirstName),
zap.String("username", user.Username),
).Info("Got message")

randomID, err := client.RandInt64()
if err != nil {
return err
}
p := &tg.InputPeerUser{
UserID: user.ID,
AccessHash: user.AccessHash,
}
return client.SendMessage(ctx, &tg.MessagesSendMessageRequest{
RandomID: randomID,
Message: m.Message,
Peer: p,
})
}
}
default:
h.log.With(zap.String("update_type", fmt.Sprintf("%T", u))).Info("Ignoring update")
}
}
return nil
}

func run(ctx context.Context) error {
logger, _ := zap.NewDevelopment(zap.IncreaseLevel(zapcore.DebugLevel))
defer func() { _ = logger.Sync() }()
Expand All @@ -90,31 +39,62 @@ func run(ctx context.Context) error {
if err != nil {
return err
}
sessionDir := path.Join(home, ".td")
sessionDir := filepath.Join(home, ".td")
if err := os.MkdirAll(sessionDir, 0600); err != nil {
return err
}

dispatcher := tg.NewUpdateDispatcher()
// Creating connection.
dialCtx, cancel := context.WithTimeout(ctx, time.Second*5)
defer cancel()
client, err := telegram.Dial(dialCtx, telegram.Options{
client := telegram.NewClient(appID, appHash, telegram.Options{
Logger: logger,
SessionStorage: &telegram.FileSessionStorage{
Path: path.Join(sessionDir, "session.json"),
Path: filepath.Join(sessionDir, "session.json"),
},

// Grab these from https://my.telegram.org/apps.
// Never share it or hardcode!
AppID: appID,
AppHash: appHash,
Dialer: telegram.DialFunc(proxy.Dial),
UpdateHandler: dispatcher.Handle,
})

UpdateHandler: updateHandler{log: logger}.handle,
dispatcher.OnNewMessage(func(ctx tg.UpdateContext, u *tg.UpdateNewMessage) error {
switch m := u.Message.(type) {
case *tg.Message:
switch peer := m.PeerID.(type) {
case *tg.PeerUser:
user := ctx.Users[peer.UserID]
logger.With(
zap.String("text", m.Message),
zap.Int("user_id", user.ID),
zap.String("user_first_name", user.FirstName),
zap.String("username", user.Username),
).Info("Got message")

randomID, err := client.RandInt64()
if err != nil {
return err
}
p := &tg.InputPeerUser{
UserID: user.ID,
AccessHash: user.AccessHash,
}
return client.SendMessage(ctx, &tg.MessagesSendMessageRequest{
RandomID: randomID,
Message: m.Message,
Peer: p,
})
}
}

return nil
})

err = client.Connect(ctx)
if err != nil {
return xerrors.Errorf("failed to dial: %w", err)
return xerrors.Errorf("failed to connect: %w", err)
}
logger.Info("Dialed")
logger.Info("Client started.")

auth, err := client.AuthStatus(dialCtx)
if err != nil {
Expand Down
1 change: 1 addition & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -11,5 +11,6 @@ require (
github.com/gotd/tl v0.2.0
github.com/stretchr/testify v1.6.1
go.uber.org/zap v1.16.0
golang.org/x/net v0.0.0-20200520004742-59133d7f0dd7
golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1
)
84 changes: 84 additions & 0 deletions internal/gen/_template/handlers.tmpl
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
{{ define "handlers" }}
{{ $pkg := $.Package }}
{{ template "header" $ }}

type handler = func(UpdateContext, UpdateClass) error

type UpdateDispatcher struct {
handlers map[int]handler
}

func NewUpdateDispatcher() UpdateDispatcher {
return UpdateDispatcher{
handlers: map[int]handler{},
}
}

type UpdateContext struct {
context.Context

Users map[int]*User
Chats map[int]*Chat
init bool
}

func (u *UpdateContext) lazyInitFromUpdates(updates *Updates) {
if u.init {
return
}

u.init = true
u.Users = make(map[int]*User, len(updates.Users))
for _, class := range updates.Users {
user, ok := class.(*User)
if !ok {
continue
}
u.Users[user.ID] = user
}

u.Chats = make(map[int]*Chat, len(updates.Chats))
for _, class := range updates.Chats {
chat, ok := class.(*Chat)
if !ok {
continue
}
u.Chats[chat.ID] = chat
}
}

func (u UpdateDispatcher) Handle(ctx context.Context, updates *Updates) error {
uctx := UpdateContext{
Context: ctx,
}

for _, update := range updates.Updates {
uctx.lazyInitFromUpdates(updates)
switch update.(type) {
{{- range $s:= $.Structs }}{{ if eq $s.Interface "UpdateClass" }}
case *{{ $s.Name }}:
if handler, ok := u.handlers[{{ $s.Name }}TypeID]; ok {
if err := handler(uctx, update); err != nil {
return err
}
}
{{- end }}{{ end }}
}
}
return nil
}

{{- range $s:= $.Structs }}{{ if eq $s.Interface "UpdateClass" }}
{{ $eventName := trimPrefix $s.Name "Update"}}
// {{ $eventName }}Handler is a {{ $eventName }} event handler.
type {{ $eventName }}Handler func(ctx UpdateContext, update *{{ $s.Name }}) error

// On{{ $eventName }} sets {{ $eventName }} handler.
func (u UpdateDispatcher) On{{ $eventName }}(handler {{ $eventName }}Handler) {
u.handlers[{{ $s.Name }}TypeID] = func(ctx UpdateContext, update UpdateClass) error {
return handler(ctx, update.(*{{ $s.Name }}))
}
}
{{- end }}{{ end }}

{{ end }}
31 changes: 27 additions & 4 deletions internal/gen/internal/bindata.go

Large diffs are not rendered by default.

7 changes: 5 additions & 2 deletions internal/gen/templates.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,16 @@ import (

func Template() *template.Template {
tmpl := template.New("templates").Funcs(template.FuncMap{
"trim": strings.TrimSpace,
"lower": strings.ToLower,
"trim": strings.TrimSpace,
"lower": strings.ToLower,
"trimPrefix": strings.TrimPrefix,
"hasPrefix": strings.HasPrefix,
})
tmpl = template.Must(tmpl.Parse(string(internal.MustAsset("_template/utils.tmpl"))))
tmpl = template.Must(tmpl.Parse(string(internal.MustAsset("_template/header.tmpl"))))
tmpl = template.Must(tmpl.Parse(string(internal.MustAsset("_template/registry.tmpl"))))
tmpl = template.Must(tmpl.Parse(string(internal.MustAsset("_template/client.tmpl"))))
tmpl = template.Must(tmpl.Parse(string(internal.MustAsset("_template/main.tmpl"))))
tmpl = template.Must(tmpl.Parse(string(internal.MustAsset("_template/handlers.tmpl"))))
return tmpl
}
8 changes: 8 additions & 0 deletions internal/gen/write_source.go
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,14 @@ func (g *Generator) WriteSource(fs FileSystem, pkgName string, t *template.Templ
}

cfg := config{
Package: pkgName,
Structs: g.structs,
}
if err := generate("handlers", "tl_handlers_gen.go", cfg); err != nil {
return err
}

cfg = config{
Package: pkgName,
Registry: g.registry,
}
Expand Down
Loading

0 comments on commit 6d01989

Please sign in to comment.