@@ -2,12 +2,12 @@ package agent
2
2
3
3
import (
4
4
"context"
5
+ "database/sql"
5
6
"encoding/json"
6
7
"fmt"
7
8
"log"
8
9
9
10
"github.com/invopop/jsonschema"
10
- "github.com/jackc/pgx/v5/pgtype"
11
11
openai "github.com/sashabaranov/go-openai"
12
12
"github.com/semanser/ai-coder/assets"
13
13
"github.com/semanser/ai-coder/config"
@@ -137,7 +137,7 @@ func NextTask(args AgentPrompt) *database.Task {
137
137
if task .Type .String == "input" {
138
138
messages = append (messages , openai.ChatCompletionMessage {
139
139
Role : openai .ChatMessageRoleUser ,
140
- Content : string ( task .Args ) ,
140
+ Content : task .Args . String ,
141
141
})
142
142
}
143
143
@@ -149,7 +149,7 @@ func NextTask(args AgentPrompt) *database.Task {
149
149
ID : task .ToolCallID .String ,
150
150
Function : openai.FunctionCall {
151
151
Name : task .Type .String ,
152
- Arguments : string ( task .Args ) ,
152
+ Arguments : task .Args . String ,
153
153
},
154
154
Type : openai .ToolTypeFunction ,
155
155
},
@@ -212,7 +212,7 @@ func NextTask(args AgentPrompt) *database.Task {
212
212
}
213
213
214
214
task := database.Task {
215
- Type : database .StringToPgText (tool .Function .Name ),
215
+ Type : database .StringToNullString (tool .Function .Name ),
216
216
}
217
217
218
218
switch tool .Function .Name {
@@ -227,16 +227,16 @@ func NextTask(args AgentPrompt) *database.Task {
227
227
log .Printf ("Failed to marshal terminal args, asking user: %v" , err )
228
228
return defaultAskTask ("There was an error running the terminal command" )
229
229
}
230
- task .Args = args
230
+ task .Args = database . StringToNullString ( string ( args ))
231
231
232
232
// Sometimes the model returns an empty string for the message
233
233
msg := string (params .Message )
234
234
if msg == "" {
235
235
msg = params .Input
236
236
}
237
237
238
- task .Message = database .StringToPgText (msg )
239
- task .Status = database .StringToPgText ("in_progress" )
238
+ task .Message = database .StringToNullString (msg )
239
+ task .Status = database .StringToNullString ("in_progress" )
240
240
241
241
case "browser" :
242
242
params , err := extractArgs (tool .Function .Arguments , & BrowserArgs {})
@@ -249,11 +249,8 @@ func NextTask(args AgentPrompt) *database.Task {
249
249
log .Printf ("Failed to marshal browser args, asking user: %v" , err )
250
250
return defaultAskTask ("There was an error opening the browser" )
251
251
}
252
- task .Args = args
253
- task .Message = pgtype.Text {
254
- String : string (params .Message ),
255
- Valid : true ,
256
- }
252
+ task .Args = database .StringToNullString (string (args ))
253
+ task .Message = database .StringToNullString (string (params .Message ))
257
254
case "code" :
258
255
params , err := extractArgs (tool .Function .Arguments , & CodeArgs {})
259
256
if err != nil {
@@ -265,11 +262,8 @@ func NextTask(args AgentPrompt) *database.Task {
265
262
log .Printf ("Failed to marshal code args, asking user: %v" , err )
266
263
return defaultAskTask ("There was an error reading or updating the file" )
267
264
}
268
- task .Args = args
269
- task .Message = pgtype.Text {
270
- String : string (params .Message ),
271
- Valid : true ,
272
- }
265
+ task .Args = database .StringToNullString (string (args ))
266
+ task .Message = database .StringToNullString (string (params .Message ))
273
267
case "ask" :
274
268
params , err := extractArgs (tool .Function .Arguments , & AskArgs {})
275
269
if err != nil {
@@ -281,11 +275,8 @@ func NextTask(args AgentPrompt) *database.Task {
281
275
log .Printf ("Failed to marshal ask args, asking user: %v" , err )
282
276
return defaultAskTask ("There was an error asking the user for additional information" )
283
277
}
284
- task .Args = args
285
- task .Message = pgtype.Text {
286
- String : string (params .Message ),
287
- Valid : true ,
288
- }
278
+ task .Args = database .StringToNullString (string (args ))
279
+ task .Message = database .StringToNullString (string (params .Message ))
289
280
case "done" :
290
281
params , err := extractArgs (tool .Function .Arguments , & DoneArgs {})
291
282
if err != nil {
@@ -296,28 +287,22 @@ func NextTask(args AgentPrompt) *database.Task {
296
287
if err != nil {
297
288
return defaultAskTask ("There was an error marking the task as done" )
298
289
}
299
- task .Args = args
300
- task .Message = pgtype.Text {
301
- String : string (params .Message ),
302
- Valid : true ,
303
- }
290
+ task .Args = database .StringToNullString (string (args ))
291
+ task .Message = database .StringToNullString (string (params .Message ))
304
292
}
305
293
306
- task .ToolCallID = pgtype.Text {
307
- String : tool .ID ,
308
- Valid : true ,
309
- }
294
+ task .ToolCallID = database .StringToNullString (tool .ID )
310
295
311
296
return & task
312
297
}
313
298
314
299
func defaultAskTask (message string ) * database.Task {
315
300
task := database.Task {
316
- Type : database .StringToPgText ("ask" ),
301
+ Type : database .StringToNullString ("ask" ),
317
302
}
318
303
319
- task .Args = [] byte ("{}" )
320
- task .Message = pgtype. Text {
304
+ task .Args = database . StringToNullString ("{}" )
305
+ task .Message = sql. NullString {
321
306
String : fmt .Sprintf ("%s. What should I do next?" , message ),
322
307
Valid : true ,
323
308
}
0 commit comments