-
Notifications
You must be signed in to change notification settings - Fork 1
/
main.go
220 lines (182 loc) · 6.27 KB
/
main.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
package main
import (
"encoding/json"
"errors"
"html/template"
"time"
"github.com/ChristianSch/Theta/adapters/inbound"
"github.com/ChristianSch/Theta/adapters/outbound"
"github.com/ChristianSch/Theta/adapters/outbound/repo"
"github.com/ChristianSch/Theta/domain/models"
outboundPorts "github.com/ChristianSch/Theta/domain/ports/outbound"
"github.com/ChristianSch/Theta/domain/usecases/chat"
"github.com/gofiber/fiber/v2"
"github.com/gofiber/websocket/v2"
)
type WebsocketMsg struct {
Message string `json:"message,omitempty"`
Headers struct {
} `json:"HEADERS,omitempty"`
}
func main() {
log := outbound.NewZapLogger(outbound.ZapLoggerConfig{Debug: true})
// init llms first to see if we have any models available
ollama, err := outbound.NewOllamaLlmService(log)
if err != nil {
panic(err)
}
ollamaModels, err := ollama.ListModels()
if err != nil {
panic(err)
}
log.Debug("available ollama models", outboundPorts.LogField{Key: "models", Value: ollamaModels})
// all models
llmModels := ollamaModels
// TODO: init openai
if len(ollamaModels) == 0 {
panic(errors.New("no models available"))
}
// convenience check: if we have more than one model, use the first one
ollama.SetModel(ollamaModels[0])
web := inbound.NewFiberWebServer(inbound.FiberWebServerConfig{
Port: 5467,
TemplatesPath: "./infrastructure/views",
TemplatesExtension: ".gohtml",
StaticResourcesPath: "./infrastructure/static",
}, inbound.FiberWebServerAdapters{Log: log})
// markdown 2 html post processor
mdToHtmlPostProcessor := outbound.NewMdToHtmlLlmPostProcessor()
// conversation repo
convRepo := repo.NewInMemoryConversationRepo()
msgSender := outbound.NewSendFiberWebsocketMessage(outbound.SendFiberWebsocketMessageConfig{Log: log})
msgFormatter := outbound.NewFiberMessageFormatter(outbound.FiberMessageFormatterConfig{
MessageTemplatePath: "./infrastructure/views/components/message.gohtml",
})
msgHandler := chat.NewIncomingMessageHandler(chat.IncomingMessageHandlerConfig{
Sender: msgSender,
Formatter: msgFormatter,
Llm: ollama,
PostProcessors: []outboundPorts.PostProcessor{
{
Processor: mdToHtmlPostProcessor,
Order: 0, // first one
Name: mdToHtmlPostProcessor.GetName(),
},
},
ConversationRepo: convRepo,
})
web.AddRoute("GET", "/", func(ctx interface{}) error {
fiberCtx := ctx.(*fiber.Ctx)
return fiberCtx.Render("new_chat", fiber.Map{
"Title": "",
"Models": llmModels,
}, "layouts/main")
})
web.AddRoute("GET", "/chat", func(ctx interface{}) error {
fiberCtx := ctx.(*fiber.Ctx)
return fiberCtx.Redirect("/")
})
// create new conversation
web.AddRoute("POST", "/chat", func(ctx interface{}) error {
fiberCtx := ctx.(*fiber.Ctx)
// get model from form
model := fiberCtx.FormValue("model")
if model == "" {
log.Error("no model specified", outboundPorts.LogField{Key: "error", Value: "no model specified"})
return fiberCtx.Redirect("/")
}
// get message from form
message := fiberCtx.FormValue("message")
if message == "" {
log.Error("no message specified", outboundPorts.LogField{Key: "error", Value: "no message specified"})
return fiberCtx.Redirect("/")
}
// get conversation
conv, err := convRepo.CreateConversation(model)
if err != nil {
log.Error("error while creating conversation", outboundPorts.LogField{Key: "error", Value: err})
return err
}
fiberCtx.Append("HX-Replace-Url", "/chat/"+conv.Id)
return fiberCtx.Render("chat", fiber.Map{
"Title": "",
"Models": llmModels,
"ConversationId": conv.Id,
"UserMessage": message,
}, "layouts/empty")
})
// open existing conversation
web.AddRoute("GET", "/chat/:id", func(ctx interface{}) error {
fiberCtx := ctx.(*fiber.Ctx)
convId := fiberCtx.Params("id")
// get conversation
conv, err := convRepo.GetConversation(convId)
if err != nil {
log.Error("error while getting conversation", outboundPorts.LogField{Key: "error", Value: err})
return fiberCtx.Redirect("/")
}
var renderedMessages []template.HTML
for _, msg := range conv.Messages {
renderedMsg, err := msgFormatter.Format(msg)
if err != nil {
log.Error("error while formatting message", outboundPorts.LogField{Key: "error", Value: err})
return err
}
// note that you shouldn't do this under no circumstances, this circumvents the XSS protection
renderedMessages = append(renderedMessages, template.HTML(renderedMsg))
}
return fiberCtx.Render("chat", fiber.Map{
"Title": "",
"Model": conv.Model,
"ConversationId": conv.Id,
"Messages": renderedMessages,
}, "layouts/main")
})
web.AddWebsocketRoute("/ws/chat/:id", func(conn interface{}) error {
fiberConn := conn.(*websocket.Conn)
convId := fiberConn.Params("id")
log.Debug("handling websocket request",
outboundPorts.LogField{Key: "path", Value: "/ws/chat/:id"},
outboundPorts.LogField{Key: "id", Value: convId})
// get conversation
conv, err := convRepo.GetConversation(convId)
if err != nil {
log.Error("error while getting conversation", outboundPorts.LogField{Key: "error", Value: err})
return err
}
log.Debug("conversation received message", outboundPorts.LogField{Key: "conversation", Value: conv})
for {
messageType, message, err := fiberConn.ReadMessage()
if err != nil {
log.Error("error while reading message", outboundPorts.LogField{Key: "error", Value: err})
break
}
// message is json, marshall it to WebsocketMsg
var wsMsg WebsocketMsg
if err := json.Unmarshal([]byte(message), &wsMsg); err != nil {
log.Error("error while unmarshalling message", outboundPorts.LogField{Key: "error", Value: err})
break
}
log.Debug("received message",
outboundPorts.LogField{Key: "message", Value: wsMsg.Message},
outboundPorts.LogField{Key: "messageType", Value: messageType},
)
if len(wsMsg.Message) > 0 {
msg := models.Message{
Text: wsMsg.Message,
Timestamp: time.Now(),
Type: models.UserMessage,
}
// add message to conversation!
if err := msgHandler.Handle(msg, conv, fiberConn); err != nil {
log.Error("error while writing message", outboundPorts.LogField{Key: "error", Value: err})
break
}
}
}
return nil
})
if err := web.Start(); err != nil {
panic(err)
}
}