Skip to content

Commit 993c042

Browse files
committed
Add support for removing chunks by id
1 parent 76abd3e commit 993c042

File tree

2 files changed

+121
-3
lines changed

2 files changed

+121
-3
lines changed

R/ragnar-chat.R

Lines changed: 91 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,10 @@
1212
#' Eg, the identity is simply `function(self, ...) list(...)`.
1313
#' The default callback prunes the previous tool calls from the chat history and
1414
#' inserts a tool call request, so that the LLM always sees retrieval results.
15-
#'
15+
#' @param .on_retrieval A function that is called when the tool retrieves results.
16+
#' It's called with `self` (the instance of `RagnarChat`), `results` the results
17+
#' retrieved from the store. It's useful for applying deoverlapping, and for
18+
#' pruning repeated chunks in the context.
1619
#' @export
1720
chat_ragnar <- function(
1821
chat_fun,
@@ -23,10 +26,13 @@ chat_ragnar <- function(
2326
self$turns_prune_tool_calls(keep_last_n = 0)
2427
# inserts a new tool call request with the user's input
2528
self$turns_insert_tool_call_request(..., query = paste(..., collapse = " "))
29+
},
30+
.on_retrieval = function(self, results) {
31+
results
2632
}
2733
) {
2834
chat <- chat_fun(...)
29-
RagnarChat$new(chat, .store, .on_user_turn)
35+
RagnarChat$new(chat, .store, .on_user_turn, .on_retrieval)
3036
}
3137

3238
#' Adds extra capabilities to a `ellmer::Chat` object.
@@ -46,7 +52,13 @@ RagnarChat <- R6::R6Class(
4652
#' Eg, the identity is simply `function(self, ...) list(...)`.
4753
on_user_turn = NULL,
4854

49-
initialize = function(chat, store, on_user_turn) {
55+
#' @field on_retrieval A function that is called when the tool retrieves results.
56+
#' It's called with `self` (the instance of `RagnarChat`), `results` the results
57+
#' retrieved from the store. It's useful for applying deoverlapping, and for
58+
#' pruning repeated chunks in the context.
59+
on_retrieval = NULL,
60+
61+
initialize = function(chat, store, on_user_turn, on_retrieval) {
5062
self$ragnar_store <- store
5163
super$initialize(
5264
chat$get_provider(),
@@ -63,6 +75,7 @@ RagnarChat <- R6::R6Class(
6375
)
6476
self$register_tool(self$ragnar_tool_def)
6577
self$on_user_turn <- on_user_turn
78+
self$on_retrieval <- on_retrieval
6679
},
6780

6881
chat = function(..., echo = NULL) {
@@ -106,6 +119,7 @@ RagnarChat <- R6::R6Class(
106119
}
107120

108121
results |>
122+
private$callback_retrieval() |>
109123
dplyr::select(-hash) |>
110124
jsonlite::toJSON()
111125
},
@@ -271,6 +285,76 @@ RagnarChat <- R6::R6Class(
271285
self$set_turns(turns)
272286
},
273287

288+
#' @description
289+
#' Removes chunks from the history by id.
290+
#' Rewrites the LLm context remving the chunks with the given ids. It will also
291+
#' enitrely remove the tool call request and results if all chunks are removed.
292+
#'
293+
#' @param chunk_ids A vector of chunk ids to remove from the chat history.
294+
turns_remove_chunks = function(chunk_ids) {
295+
turns <- self$get_turns()
296+
drop_turn_idx <- integer(0)
297+
298+
for (ti in seq_along(turns)) {
299+
turn <- turns[[ti]]
300+
if (turn@role != "user") {
301+
next
302+
}
303+
304+
contents <- turn@contents
305+
drop_content_idx <- integer(0)
306+
307+
for (ci in seq_along(contents)) {
308+
content <- contents[[ci]]
309+
310+
if (!S7::S7_inherits(content, ellmer::ContentToolResult)) {
311+
next
312+
}
313+
if (content@request@name != self$ragnar_tool_def@name) {
314+
next
315+
}
316+
if (!is.character(content@value)) {
317+
next
318+
}
319+
320+
chunks <- jsonlite::fromJSON(content@value, simplifyVector = FALSE)
321+
322+
# Remove the chunks with the given ids.
323+
chunks <- chunks[!sapply(chunks, function(x) x$id %in% chunk_ids)]
324+
325+
# If we have no chunks left, we remove the entire content from the list.
326+
if (length(chunks) == 0) {
327+
drop_content_idx[[length(drop_content_idx) + 1]] <- ci
328+
next
329+
}
330+
331+
# Restore the content if some chunks remained.
332+
contents[[ci]]@value <- jsonlite::toJSON(chunks, pretty = TRUE)
333+
}
334+
335+
turn@contents <- contents
336+
if (length(drop_content_idx) > 0) {
337+
turn@contents <- contents[-drop_content_idx]
338+
}
339+
340+
# If we removed all contents from the turn, we remove the entire turn.
341+
# and the assistant turn that came before it.
342+
if (length(turn@contents) == 0) {
343+
drop_turn_idx <- c(drop_turn_idx, ti, ti - 1L)
344+
next
345+
}
346+
347+
turns[[ti]] <- turn
348+
}
349+
350+
# Remove the turns that we marked for removal.
351+
if (length(drop_turn_idx) > 0) {
352+
turns <- turns[-drop_turn_idx]
353+
}
354+
355+
self$set_turns(turns)
356+
},
357+
274358
#' @description
275359
#' Some LLM's are lazy at tool calling, and for applications to be
276360
#' robust, it's great to append context for the LLM, even if
@@ -312,6 +396,10 @@ RagnarChat <- R6::R6Class(
312396
result <- list(result)
313397
}
314398
result
399+
},
400+
callback_retrieval = function(results) {
401+
result <- self$on_retrieval(self, results)
402+
result
315403
}
316404
)
317405
)

tests/testthat/test-ragnar-chat.R

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,3 +50,33 @@ test_that("Implementing query rewriting", {
5050
tool_call_request <- chat$get_turns()[[2]]@contents[[1]]
5151
expect_equal(tool_call_request@arguments$query, "hello world")
5252
})
53+
54+
test_that("remove chunks by id works", {
55+
56+
store <- test_store()
57+
chat <- chat_ragnar(
58+
ellmer::chat_openai,
59+
.store = store,
60+
.on_user_turn = function(self, ...) {
61+
self$turns_insert_tool_call_request(
62+
...,
63+
query = paste(..., collapse = " ")
64+
)
65+
}
66+
)
67+
68+
chat$chat("advanced R")
69+
chunks <- chat$turns_list_chunks()
70+
id <- chunks[[1]]$id
71+
72+
chat$turns_remove_chunks(id)
73+
74+
chunks <- chat$turns_list_chunks()
75+
chunk_ids <- sapply(chunks, function(x) x$id)
76+
expect_false(id %in% chunk_ids)
77+
78+
chat$turns_remove_chunks(chunk_ids)
79+
# we removed all turns, thus the tool call request and result turns also got removed
80+
expect_equal(length(chat$get_turns(), 2))
81+
expect_equal(length(chat$turns_list_chunks()), 0)
82+
})

0 commit comments

Comments
 (0)