diff --git a/pkg/channels/qq/qq.go b/pkg/channels/qq/qq.go index 4cb4db3c6..6fa14d5c3 100644 --- a/pkg/channels/qq/qq.go +++ b/pkg/channels/qq/qq.go @@ -199,20 +199,7 @@ func (c *QQChannel) Send(ctx context.Context, msg bus.OutboundMessage) error { msgToCreate.Content = "" } - // Attach passive reply msg_id and msg_seq if available. - if v, ok := c.lastMsgID.Load(msg.ChatID); ok { - if msgID, ok := v.(string); ok && msgID != "" { - msgToCreate.MsgID = msgID - - // Increment msg_seq atomically for multi-part replies. - if counterVal, ok := c.msgSeqCounters.Load(msg.ChatID); ok { - if counter, ok := counterVal.(*atomic.Uint64); ok { - seq := counter.Add(1) - msgToCreate.MsgSeq = uint32(seq) - } - } - } - } + c.applyReplyContext(msg.ChatID, msgToCreate) // Sanitize URLs in group messages to avoid QQ's URL blacklist rejection. if chatKind == "group" { @@ -244,6 +231,22 @@ func (c *QQChannel) Send(ctx context.Context, msg bus.OutboundMessage) error { return nil } +func (c *QQChannel) applyReplyContext(chatID string, msgToCreate *dto.MessageToCreate) { + if v, ok := c.lastMsgID.Load(chatID); ok { + if msgID, ok := v.(string); ok && msgID != "" { + msgToCreate.MsgID = msgID + + // Increment msg_seq atomically for multi-part replies. + if counterVal, ok := c.msgSeqCounters.Load(chatID); ok { + if counter, ok := counterVal.(*atomic.Uint64); ok { + seq := counter.Add(1) + msgToCreate.MsgSeq = uint32(seq) + } + } + } + } +} + // StartTyping implements channels.TypingCapable. // It sends an InputNotify (msg_type=6) immediately and re-sends every 8 seconds. // The returned stop function is idempotent and cancels the goroutine. @@ -316,71 +319,60 @@ func (c *QQChannel) SendMedia(ctx context.Context, msg bus.OutboundMediaMessage) chatKind := c.getChatKind(msg.ChatID) for _, part := range msg.Parts { - // If the ref is already an HTTP(S) URL, use it directly. - mediaURL := part.Ref - if !isHTTPURL(mediaURL) { - // Try resolving through media store. - store := c.GetMediaStore() - if store == nil { - logger.WarnCF("qq", "QQ media requires HTTP/HTTPS URL, no media store available", map[string]any{ - "ref": part.Ref, - }) - continue + if isHTTPURL(part.Ref) { + richMedia := &dto.RichMediaMessage{ + FileType: qqFileType(part.Type), + URL: part.Ref, + SrvSendMsg: true, } - resolved, err := store.Resolve(part.Ref) - if err != nil { - logger.ErrorCF("qq", "Failed to resolve media ref", map[string]any{ - "ref": part.Ref, - "error": err.Error(), - }) - continue + var sendErr error + if chatKind == "group" { + _, sendErr = c.api.PostGroupMessage(ctx, msg.ChatID, richMedia) + } else { + _, sendErr = c.api.PostC2CMessage(ctx, msg.ChatID, richMedia) } - if !isHTTPURL(resolved) { - logger.WarnCF("qq", "QQ media requires HTTP/HTTPS URL, local files not supported", map[string]any{ - "ref": part.Ref, - "resolved": resolved, + if sendErr != nil { + logger.ErrorCF("qq", "Failed to send remote media", map[string]any{ + "type": part.Type, + "chat_id": msg.ChatID, + "error": sendErr.Error(), }) - continue + return fmt.Errorf("qq send media: %w", channels.ErrTemporary) } - - mediaURL = resolved + continue } - // Map part type to QQ file type: 1=image, 2=video, 3=audio, 4=file. - var fileType uint64 - switch part.Type { - case "image": - fileType = 1 - case "video": - fileType = 2 - case "audio": - fileType = 3 - default: - fileType = 4 // file + store := c.GetMediaStore() + if store == nil { + return fmt.Errorf("qq send media: media store not configured for local media ref %q", part.Ref) } - richMedia := &dto.RichMediaMessage{ - FileType: fileType, - URL: mediaURL, - SrvSendMsg: true, + resolved, err := store.Resolve(part.Ref) + if err != nil { + return fmt.Errorf("qq send media: resolve local media ref %q: %w", part.Ref, err) } - var sendErr error - if chatKind == "group" { - _, sendErr = c.api.PostGroupMessage(ctx, msg.ChatID, richMedia) - } else { - _, sendErr = c.api.PostC2CMessage(ctx, msg.ChatID, richMedia) + fileInfo, err := c.uploadLocalMedia(ctx, chatKind, msg.ChatID, part.Type, part.Filename, resolved) + if err != nil { + logger.ErrorCF("qq", "Failed to upload local media", map[string]any{ + "type": part.Type, + "chat_id": msg.ChatID, + "ref": part.Ref, + "resolved": resolved, + "error": err.Error(), + }) + return fmt.Errorf("qq send media: %w", err) } - if sendErr != nil { - logger.ErrorCF("qq", "Failed to send media", map[string]any{ + if err := c.sendUploadedMedia(ctx, chatKind, msg.ChatID, part, fileInfo); err != nil { + logger.ErrorCF("qq", "Failed to send uploaded media", map[string]any{ "type": part.Type, "chat_id": msg.ChatID, - "error": sendErr.Error(), + "error": err.Error(), }) - return fmt.Errorf("qq send media: %w", channels.ErrTemporary) + return err } } diff --git a/pkg/channels/qq/qq_test.go b/pkg/channels/qq/qq_test.go index 3ceee0d09..a7068bfff 100644 --- a/pkg/channels/qq/qq_test.go +++ b/pkg/channels/qq/qq_test.go @@ -2,13 +2,20 @@ package qq import ( "context" + "encoding/json" + "errors" + "os" + "path/filepath" "testing" "time" "github.com/tencent-connect/botgo/dto" + "github.com/tencent-connect/botgo/openapi" + "github.com/tencent-connect/botgo/openapi/options" "github.com/sipeed/picoclaw/pkg/bus" "github.com/sipeed/picoclaw/pkg/channels" + "github.com/sipeed/picoclaw/pkg/media" ) func TestHandleC2CMessage_IncludesAccountIDMetadata(t *testing.T) { @@ -42,3 +49,222 @@ func TestHandleC2CMessage_IncludesAccountIDMetadata(t *testing.T) { t.Fatalf("account_id metadata = %q, want %q", inbound.Metadata["account_id"], "7750283E123456") } } + +type fakeQQAPI struct { + openapi.OpenAPI + + transportCalls []transportCall + transportResp []byte + transportErr error + + groupMessages []dto.APIMessage + groupErr error +} + +type transportCall struct { + method string + url string + body any +} + +func (f *fakeQQAPI) Transport(ctx context.Context, method, url string, body any) ([]byte, error) { + f.transportCalls = append(f.transportCalls, transportCall{ + method: method, + url: url, + body: body, + }) + return f.transportResp, f.transportErr +} + +func (f *fakeQQAPI) PostGroupMessage( + ctx context.Context, + groupID string, + msg dto.APIMessage, + opt ...options.Option, +) (*dto.Message, error) { + f.groupMessages = append(f.groupMessages, msg) + if f.groupErr != nil { + return nil, f.groupErr + } + return &dto.Message{}, nil +} + +func TestSendMedia_LocalFileUploadsThenSendsRichMediaMessage(t *testing.T) { + tmpDir := t.TempDir() + pdfPath := filepath.Join(tmpDir, "report.pdf") + if err := os.WriteFile(pdfPath, []byte("%PDF-1.4 test"), 0o644); err != nil { + t.Fatalf("write file: %v", err) + } + + store := media.NewFileMediaStore() + ref, err := store.Store(pdfPath, media.MediaMeta{ + Filename: "report.pdf", + ContentType: "application/pdf", + Source: "test", + }, "scope") + if err != nil { + t.Fatalf("store media: %v", err) + } + + uploadedFileInfo := []byte("uploaded-file-info") + respBody, err := json.Marshal(struct { + FileInfo []byte `json:"file_info"` + }{ + FileInfo: uploadedFileInfo, + }) + if err != nil { + t.Fatalf("marshal response: %v", err) + } + + api := &fakeQQAPI{transportResp: respBody} + ch := &QQChannel{ + BaseChannel: channels.NewBaseChannel("qq", nil, bus.NewMessageBus(), nil), + api: api, + } + ch.SetRunning(true) + ch.SetMediaStore(store) + ch.chatType.Store("group-1", "group") + + err = ch.SendMedia(context.Background(), bus.OutboundMediaMessage{ + ChatID: "group-1", + Parts: []bus.MediaPart{ + {Ref: ref, Type: "file"}, + }, + }) + if err != nil { + t.Fatalf("SendMedia() error = %v", err) + } + + if len(api.transportCalls) != 1 { + t.Fatalf("transport call count = %d, want 1", len(api.transportCalls)) + } + + payload, ok := api.transportCalls[0].body.(map[string]any) + if !ok { + t.Fatalf("transport body type = %T, want map[string]any", api.transportCalls[0].body) + } + switch v := payload["file_type"].(type) { + case int: + if v != 4 { + t.Fatalf("file_type = %v, want 4", payload["file_type"]) + } + case float64: + if v != 4 { + t.Fatalf("file_type = %v, want 4", payload["file_type"]) + } + case uint64: + if v != 4 { + t.Fatalf("file_type = %v, want 4", payload["file_type"]) + } + default: + t.Fatalf("file_type type = %T, want numeric 4", payload["file_type"]) + } + if payload["srv_send_msg"] != false { + t.Fatalf("srv_send_msg = %v, want false", payload["srv_send_msg"]) + } + fileData, _ := payload["file_data"].(string) + if fileData == "" { + t.Fatal("file_data is empty, want base64-encoded local file contents") + } + if payload["file_name"] != "report.pdf" { + t.Fatalf("file_name = %v, want %q", payload["file_name"], "report.pdf") + } + + if len(api.groupMessages) != 1 { + t.Fatalf("group message count = %d, want 1", len(api.groupMessages)) + } + + msg, ok := api.groupMessages[0].(*dto.MessageToCreate) + if !ok { + t.Fatalf("group message type = %T, want *dto.MessageToCreate", api.groupMessages[0]) + } + if msg.MsgType != dto.RichMediaMsg { + t.Fatalf("msg_type = %v, want %v", msg.MsgType, dto.RichMediaMsg) + } + if msg.Media == nil { + t.Fatal("msg.Media is nil, want uploaded file_info") + } + if string(msg.Media.FileInfo) != string(uploadedFileInfo) { + t.Fatalf("file_info = %q, want %q", string(msg.Media.FileInfo), string(uploadedFileInfo)) + } +} + +func TestSendMedia_LocalFileUploadFailureReturnsError(t *testing.T) { + tmpDir := t.TempDir() + filePath := filepath.Join(tmpDir, "notes.txt") + if err := os.WriteFile(filePath, []byte("hello"), 0o644); err != nil { + t.Fatalf("write file: %v", err) + } + + store := media.NewFileMediaStore() + ref, err := store.Store(filePath, media.MediaMeta{ + Filename: "notes.txt", + ContentType: "text/plain", + Source: "test", + }, "scope") + if err != nil { + t.Fatalf("store media: %v", err) + } + + api := &fakeQQAPI{transportErr: errors.New("upload failed")} + ch := &QQChannel{ + BaseChannel: channels.NewBaseChannel("qq", nil, bus.NewMessageBus(), nil), + api: api, + } + ch.SetRunning(true) + ch.SetMediaStore(store) + ch.chatType.Store("group-1", "group") + + err = ch.SendMedia(context.Background(), bus.OutboundMediaMessage{ + ChatID: "group-1", + Parts: []bus.MediaPart{ + {Ref: ref, Type: "file"}, + }, + }) + if err == nil { + t.Fatal("SendMedia() error = nil, want upload failure") + } + if len(api.groupMessages) != 0 { + t.Fatalf("group message count = %d, want 0 after upload failure", len(api.groupMessages)) + } +} + +func TestSendMedia_RemoteURLStillUsesRichMediaDirectSend(t *testing.T) { + api := &fakeQQAPI{} + ch := &QQChannel{ + BaseChannel: channels.NewBaseChannel("qq", nil, bus.NewMessageBus(), nil), + api: api, + } + ch.SetRunning(true) + ch.chatType.Store("group-1", "group") + + err := ch.SendMedia(context.Background(), bus.OutboundMediaMessage{ + ChatID: "group-1", + Parts: []bus.MediaPart{ + { + Ref: "https://example.com/report.pdf", + Type: "file", + }, + }, + }) + if err != nil { + t.Fatalf("SendMedia() error = %v", err) + } + if len(api.transportCalls) != 0 { + t.Fatalf("transport call count = %d, want 0 for remote URL", len(api.transportCalls)) + } + if len(api.groupMessages) != 1 { + t.Fatalf("group message count = %d, want 1", len(api.groupMessages)) + } + + msg, ok := api.groupMessages[0].(*dto.RichMediaMessage) + if !ok { + t.Fatalf("group message type = %T, want *dto.RichMediaMessage", api.groupMessages[0]) + } + if msg.URL != "https://example.com/report.pdf" { + t.Fatalf("URL = %q, want %q", msg.URL, "https://example.com/report.pdf") + } + if !msg.SrvSendMsg { + t.Fatal("SrvSendMsg = false, want true") + } +} diff --git a/pkg/channels/qq/upload.go b/pkg/channels/qq/upload.go new file mode 100644 index 000000000..6001fa3b4 --- /dev/null +++ b/pkg/channels/qq/upload.go @@ -0,0 +1,111 @@ +package qq + +import ( + "context" + "encoding/base64" + "encoding/json" + "fmt" + "net/http" + "os" + "path/filepath" + "strings" + + "github.com/tencent-connect/botgo/constant" + "github.com/tencent-connect/botgo/dto" + + "github.com/sipeed/picoclaw/pkg/bus" + "github.com/sipeed/picoclaw/pkg/channels" +) + +func qqFileType(mediaType string) uint64 { + switch mediaType { + case "image": + return 1 + case "video": + return 2 + case "audio": + return 3 + default: + return 4 + } +} + +func (c *QQChannel) uploadLocalMedia( + ctx context.Context, + chatKind string, + chatID string, + partType string, + filename string, + localPath string, +) ([]byte, error) { + content, err := os.ReadFile(localPath) + if err != nil { + return nil, fmt.Errorf("read local media: %w", err) + } + + filename = strings.TrimSpace(filename) + if filename == "" { + filename = filepath.Base(localPath) + } else { + filename = filepath.Base(filename) + } + + payload := map[string]any{ + "file_type": qqFileType(partType), + "file_name": filename, + "srv_send_msg": false, + "file_data": base64.StdEncoding.EncodeToString(content), + } + + respBody, err := c.api.Transport(ctx, http.MethodPost, c.uploadURL(chatKind, chatID), payload) + if err != nil { + return nil, fmt.Errorf("upload local media: %w", err) + } + + var uploaded dto.Message + if err := json.Unmarshal(respBody, &uploaded); err != nil { + return nil, fmt.Errorf("decode upload response: %w", err) + } + if len(uploaded.FileInfo) == 0 { + return nil, fmt.Errorf("upload local media: empty file_info") + } + + return uploaded.FileInfo, nil +} + +func (c *QQChannel) uploadURL(chatKind, chatID string) string { + if chatKind == "group" { + return fmt.Sprintf("%s/v2/groups/%s/files", constant.APIDomain, chatID) + } + return fmt.Sprintf("%s/v2/users/%s/files", constant.APIDomain, chatID) +} + +func (c *QQChannel) sendUploadedMedia( + ctx context.Context, + chatKind string, + chatID string, + part bus.MediaPart, + fileInfo []byte, +) error { + msgToCreate := &dto.MessageToCreate{ + Content: part.Caption, + MsgType: dto.RichMediaMsg, + Media: &dto.MediaInfo{ + FileInfo: fileInfo, + }, + } + + c.applyReplyContext(chatID, msgToCreate) + + var err error + if chatKind == "group" { + _, err = c.api.PostGroupMessage(ctx, chatID, msgToCreate) + } else { + _, err = c.api.PostC2CMessage(ctx, chatID, msgToCreate) + } + if err != nil { + return fmt.Errorf("qq send uploaded media: %w", channels.ErrTemporary) + } + + return nil +}