Skip to content

Commit

Permalink
refactor: use sync.Map instead of mutex, add more logs
Browse files Browse the repository at this point in the history
  • Loading branch information
davidramiro committed Feb 18, 2025
1 parent 69c868c commit 0c46420
Show file tree
Hide file tree
Showing 5 changed files with 171 additions and 54 deletions.
3 changes: 2 additions & 1 deletion internal/adapters/sender/telegram.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ func (s *TelegramSender) SendMessageReply(ctx context.Context, chatID int64, mes
})

if err != nil {
log.Error().Err(err).Msg("failed to send text response")
return err
}
}
Expand Down Expand Up @@ -75,7 +76,7 @@ func (s *TelegramSender) SendImageFileReply(ctx context.Context, chatID int64, m

_, err := s.bot.SendPhoto(ctx, params)
if err != nil {
log.Error().Err(err).Msg("failed to send photo response")
log.Error().Err(err).Msg("failed to send file response")
return err
}

Expand Down
93 changes: 60 additions & 33 deletions internal/core/domain/commands/chat.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package commands

import (
"context"
"errors"
"fmt"
"hsbot/internal/core/domain"
"hsbot/internal/core/port"
Expand All @@ -17,8 +18,7 @@ type ChatHandler struct {
transcriber port.Transcriber
cacheDuration time.Duration
command string
cache map[int64]*Conversation
mutex sync.Mutex
cache sync.Map
}

type Conversation struct {
Expand All @@ -36,7 +36,6 @@ func NewChatHandler(textGenerator port.TextGenerator, textSender port.TextSender
transcriber: transcriber,
cacheDuration: cacheDuration,
command: command,
cache: make(map[int64]*Conversation),
}

return h
Expand All @@ -60,45 +59,45 @@ func (h *ChatHandler) Respond(ctx context.Context, timeout time.Duration, messag

go h.textSender.SendChatAction(ctx, message.ChatID, domain.Typing)

promptText := domain.ParseCommandArgs(message.Text)
if promptText == "" {
err := h.textSender.SendMessageReply(ctx, message.ChatID, message.ID, "please input a prompt")
if err != nil {
l.Error().Err(err).Msg(domain.ErrSendingReplyFailed)
return err
}
return nil
promptText, err := h.extractPrompt(ctx, message)
if err != nil {
log.Error().Err(err).Msg("failed to extract prompt text")
}

if message.AudioURL != "" {
transcript, err := h.transcriber.GenerateFromAudio(ctx, message.AudioURL)
if err != nil {
l.Error().Err(err).Msg(domain.ErrSendingReplyFailed)
return err
}

promptText += ": " + transcript
if promptText == "" {
log.Debug().Msg(domain.ErrEmptyPrompt)
return nil
}

promptText = message.Username + ": " + promptText

h.mutex.Lock()
defer h.mutex.Unlock()

conversation, ok := h.cache[message.ChatID]
var conversation *Conversation
c, ok := h.cache.Load(message.ChatID)
if !ok {
l.Debug().Msg("new conversation")

h.cache[message.ChatID] = &Conversation{
h.cache.Store(message.ChatID, &Conversation{
chatID: message.ChatID,
exitSignal: make(chan struct{}),
})
c, _ = h.cache.Load(message.ChatID)
conversation, ok = c.(*Conversation)
if !ok {
err := errors.New("conversation type error")
l.Error().Err(err).Send()
return err
}
conversation = h.cache[message.ChatID]
} else {
conversation, ok = c.(*Conversation)
if !ok {
err := errors.New("conversation type error")
l.Error().Err(err).Send()
return err
}
l.Debug().Msg("existing conversation, stopping timer")
conversation.exitSignal <- struct{}{}
}

conversation.timestamp = time.Now()
go h.startConversationTimer(conversation)

if message.QuotedText != "" && message.ImageURL == "" {
// if there's a user message being replied to, add the previous message to the context
Expand All @@ -116,6 +115,7 @@ func (h *ChatHandler) Respond(ctx context.Context, timeout time.Duration, messag

response, err := h.textGenerator.GenerateFromPrompt(ctx, conversation.messages)
if err != nil {
l.Error().Err(err).Msg("failed to generate prompt")
conversation.messages = append(conversation.messages, domain.Prompt{Author: domain.System, Prompt: err.Error()})

err = h.textSender.SendMessageReply(ctx,
Expand All @@ -130,10 +130,9 @@ func (h *ChatHandler) Respond(ctx context.Context, timeout time.Duration, messag
return nil
}

l.Debug().Msg("reply generated")
conversation.messages = append(conversation.messages, domain.Prompt{Author: domain.System, Prompt: response})

go h.startConversationTimer(conversation)

err = h.textSender.SendMessageReply(ctx,
message.ChatID,
message.ID,
Expand All @@ -146,17 +145,45 @@ func (h *ChatHandler) Respond(ctx context.Context, timeout time.Duration, messag
return nil
}

func (h *ChatHandler) extractPrompt(ctx context.Context, message *domain.Message) (string, error) {
l := log.With().
Int("messageId", message.ID).
Int64("chatId", message.ChatID).
Str("command", h.GetCommand()).
Logger()

promptText := domain.ParseCommandArgs(message.Text)
if promptText == "" {
err := h.textSender.SendMessageReply(ctx, message.ChatID, message.ID, "please input a prompt")
if err != nil {
l.Error().Err(err).Msg(domain.ErrSendingReplyFailed)
return "", err
}
return "", nil
}

if message.AudioURL != "" {
transcript, err := h.transcriber.GenerateFromAudio(ctx, message.AudioURL)
if err != nil {
l.Error().Err(err).Msg(domain.ErrSendingReplyFailed)
return "", err
}

promptText += ": " + transcript
}

promptText = message.Username + ": " + promptText
return promptText, nil
}

func (h *ChatHandler) startConversationTimer(convo *Conversation) {
t := time.NewTimer(h.cacheDuration)

for {
select {
case <-t.C:
log.Debug().Int64("chatID", convo.chatID).Msg("clearing conversation")

h.mutex.Lock()
delete(h.cache, convo.chatID)
h.mutex.Unlock()
h.cache.Delete(convo.chatID)
return
case <-convo.exitSignal:
t.Stop()
Expand Down
118 changes: 99 additions & 19 deletions internal/core/domain/commands/chat_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,11 +47,18 @@ func TestChatHandlerClearingCache(t *testing.T) {

require.NoError(t, err)
assert.Equal(t, "mock response", ms.Message)
assert.Len(t, chatHandler.cache, 1)

c, ok := chatHandler.cache.Load(int64(1))
require.True(t, ok)

conversation, ok := c.(*Conversation)
require.True(t, ok)
assert.Len(t, conversation.messages, 2)

time.Sleep(time.Second * 4)

assert.Empty(t, chatHandler.cache)
_, ok = chatHandler.cache.Load(int64(1))
assert.False(t, ok)
}

func TestChatHandlerCache(t *testing.T) {
Expand All @@ -69,19 +76,76 @@ func TestChatHandlerCache(t *testing.T) {

require.NoError(t, err)
assert.Equal(t, "mock response", ms.Message)
assert.Len(t, chatHandler.cache, 1)

size := 0
chatHandler.cache.Range(func(_, _ interface{}) bool {
size++
return true
})
assert.Equal(t, 1, size)

err = chatHandler.Respond(context.Background(), time.Minute, &domain.Message{
ChatID: 1, ID: 2, Username: "@unit", Text: "/chat prompt2"})
require.NoError(t, err)
assert.Len(t, chatHandler.cache, 1)

assert.Len(t, chatHandler.cache[1].messages, 4)
c, ok := chatHandler.cache.Load(int64(1))
require.True(t, ok)

conversation, ok := c.(*Conversation)
require.True(t, ok)
assert.Len(t, conversation.messages, 4)

assert.Equal(t, "@unit: prompt", chatHandler.cache[1].messages[0].Prompt)
assert.Equal(t, "mock response", chatHandler.cache[1].messages[1].Prompt)
assert.Equal(t, "@unit: prompt2", chatHandler.cache[1].messages[2].Prompt)
assert.Equal(t, "mock response", chatHandler.cache[1].messages[3].Prompt)
assert.Equal(t, "@unit: prompt", conversation.messages[0].Prompt)
assert.Equal(t, "mock response", conversation.messages[1].Prompt)
assert.Equal(t, "@unit: prompt2", conversation.messages[2].Prompt)
assert.Equal(t, "mock response", conversation.messages[3].Prompt)
}

func TestChatHandlerCacheMultipleConversations(t *testing.T) {
mg := &MockTextGenerator{response: "mock response"}
ms := &MockTextSender{}
mt := &MockTranscriber{}

chatHandler := NewChatHandler(mg, ms, mt,
"/chat", time.Second*3)

assert.NotNil(t, chatHandler)

err := chatHandler.Respond(context.Background(), time.Minute, &domain.Message{
ChatID: 1, ID: 1, Username: "@unit", Text: "/chat prompt chat id 1"})

require.NoError(t, err)
assert.Equal(t, "mock response", ms.Message)

err = chatHandler.Respond(context.Background(), time.Minute, &domain.Message{
ChatID: 2, ID: 2, Username: "@unit", Text: "/chat prompt chat id 2"})
require.NoError(t, err)

size := 0
chatHandler.cache.Range(func(_, _ interface{}) bool {
size++
return true
})
assert.Equal(t, 2, size)

c1, ok := chatHandler.cache.Load(int64(1))
require.True(t, ok)

conversation1, ok := c1.(*Conversation)
require.True(t, ok)
assert.Len(t, conversation1.messages, 2)

c2, ok := chatHandler.cache.Load(int64(2))
require.True(t, ok)

conversation2, ok := c2.(*Conversation)
require.True(t, ok)
assert.Len(t, conversation2.messages, 2)

assert.Equal(t, "@unit: prompt chat id 1", conversation1.messages[0].Prompt)
assert.Equal(t, "mock response", conversation1.messages[1].Prompt)
assert.Equal(t, "@unit: prompt chat id 2", conversation2.messages[0].Prompt)
assert.Equal(t, "mock response", conversation2.messages[1].Prompt)
}

func TestChatHandlerCacheResetTimeout(t *testing.T) {
Expand All @@ -99,30 +163,46 @@ func TestChatHandlerCacheResetTimeout(t *testing.T) {

require.NoError(t, err)
assert.Equal(t, "mock response", ms.Message)
assert.Len(t, chatHandler.cache, 1)

size := 0
chatHandler.cache.Range(func(_, _ interface{}) bool {
size++
return true
})
assert.Equal(t, 1, size)

time.Sleep(time.Second * 2)

err = chatHandler.Respond(context.Background(), time.Minute, &domain.Message{
ChatID: 1, ID: 2, Username: "@unit", Text: "/chat prompt2"})
require.NoError(t, err)
assert.Len(t, chatHandler.cache, 1)

size = 0
chatHandler.cache.Range(func(_, _ interface{}) bool {
size++
return true
})
assert.Equal(t, 1, size)

time.Sleep(time.Second * 2)

err = chatHandler.Respond(context.Background(), time.Minute, &domain.Message{
ChatID: 1, ID: 2, Username: "@unit", Text: "/chat prompt3"})
require.NoError(t, err)
assert.Len(t, chatHandler.cache, 1)

assert.Len(t, chatHandler.cache[1].messages, 6)
c, ok := chatHandler.cache.Load(int64(1))
require.True(t, ok)

conversation, ok := c.(*Conversation)
require.True(t, ok)
assert.Len(t, conversation.messages, 6)

assert.Equal(t, "@unit: prompt", chatHandler.cache[1].messages[0].Prompt)
assert.Equal(t, "mock response", chatHandler.cache[1].messages[1].Prompt)
assert.Equal(t, "@unit: prompt2", chatHandler.cache[1].messages[2].Prompt)
assert.Equal(t, "mock response", chatHandler.cache[1].messages[3].Prompt)
assert.Equal(t, "@unit: prompt3", chatHandler.cache[1].messages[4].Prompt)
assert.Equal(t, "mock response", chatHandler.cache[1].messages[5].Prompt)
assert.Equal(t, "@unit: prompt", conversation.messages[0].Prompt)
assert.Equal(t, "mock response", conversation.messages[1].Prompt)
assert.Equal(t, "@unit: prompt2", conversation.messages[2].Prompt)
assert.Equal(t, "mock response", conversation.messages[3].Prompt)
assert.Equal(t, "@unit: prompt3", conversation.messages[4].Prompt)
assert.Equal(t, "mock response", conversation.messages[5].Prompt)
}

func TestGeneratorError(t *testing.T) {
Expand Down
1 change: 1 addition & 0 deletions internal/core/domain/constants.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,5 @@ package domain

const (
ErrSendingReplyFailed = "failed to send reply"
ErrEmptyPrompt = "empty prompt"
)
10 changes: 9 additions & 1 deletion main.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ import (
"os/signal"
"time"

"github.com/go-telegram/bot/models"

"github.com/rs/zerolog"

"github.com/go-telegram/bot"
Expand Down Expand Up @@ -48,7 +50,11 @@ func main() {
defer cancel()

token := viper.GetString("telegram.bot_token")
b, err := bot.New(token)
opts := []bot.Option{
bot.WithDefaultHandler(noOpHandler),
}

b, err := bot.New(token, opts...)
if err != nil {
log.Panic().Err(err).Msg("failed initializing telegram bot")
}
Expand Down Expand Up @@ -88,3 +94,5 @@ func main() {
log.Info().Msg("bot listening")
b.Start(ctx)
}

func noOpHandler(_ context.Context, _ *bot.Bot, _ *models.Update) {}

0 comments on commit 0c46420

Please sign in to comment.