Skip to content

Commit b73fb81

Browse files
amitksingh1490tusharmathautofix-ci[bot]
authored
fix(title-generation): abort background task on conversation end (#2906)
Co-authored-by: Tushar <tusharmath@gmail.com> Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
1 parent 4bde378 commit b73fb81

File tree

1 file changed

+64
-123
lines changed

1 file changed

+64
-123
lines changed

crates/forge_app/src/hooks/title_generation.rs

Lines changed: 64 additions & 123 deletions
Original file line numberDiff line numberDiff line change
@@ -2,33 +2,26 @@ use std::sync::Arc;
22

33
use async_trait::async_trait;
44
use dashmap::DashMap;
5-
use dashmap::mapref::entry::Entry;
65
use forge_domain::{
76
Conversation, ConversationId, EndPayload, EventData, EventHandle, StartPayload,
87
};
98
use tokio::sync::oneshot;
10-
use tracing::debug;
9+
use tokio::task::JoinHandle;
1110

1211
use crate::agent::AgentService;
1312
use crate::title_generator::TitleGenerator;
1413

1514
/// Per-conversation title generation state.
16-
enum TitleTask {
17-
/// A background task is running; the receiver will deliver its result.
18-
InProgress(oneshot::Receiver<Option<String>>),
19-
/// `EndPayload` has extracted the receiver and is currently awaiting it.
20-
/// Kept in the map as a sentinel so a concurrent `StartPayload` sees an
21-
/// occupied entry and does not spawn a duplicate task.
22-
Awaiting,
23-
/// Title generation has finished successfully; stores the generated title.
24-
Done(#[allow(dead_code)] String),
15+
struct TitleGenerationState {
16+
rx: oneshot::Receiver<Option<String>>,
17+
handle: JoinHandle<()>,
2518
}
2619

2720
/// Hook handler that generates a conversation title asynchronously.
2821
#[derive(Clone)]
2922
pub struct TitleGenerationHandler<S> {
3023
services: Arc<S>,
31-
title_tasks: Arc<DashMap<ConversationId, TitleTask>>,
24+
title_tasks: Arc<DashMap<ConversationId, TitleGenerationState>>,
3225
}
3326

3427
impl<S> TitleGenerationHandler<S> {
@@ -77,13 +70,11 @@ impl<S: AgentService> EventHandle<EventData<StartPayload>> for TitleGenerationHa
7770
// one task is ever spawned per conversation id.
7871
self.title_tasks.entry(conversation.id).or_insert_with(|| {
7972
let (tx, rx) = oneshot::channel();
80-
tokio::spawn(async move {
81-
let result = generator.generate().await.ok().flatten();
82-
// If the receiver was dropped (e.g. task cancelled), this is a
83-
// no-op — the send simply fails silently.
84-
let _ = tx.send(result);
73+
let handle = tokio::spawn(async move {
74+
let title = generator.generate().await.ok().flatten();
75+
let _ = tx.send(title);
8576
});
86-
TitleTask::InProgress(rx)
77+
TitleGenerationState { rx, handle }
8778
});
8879

8980
Ok(())
@@ -97,52 +88,14 @@ impl<S: AgentService> EventHandle<EventData<EndPayload>> for TitleGenerationHand
9788
_event: &EventData<EndPayload>,
9889
conversation: &mut Conversation,
9990
) -> anyhow::Result<()> {
100-
// Atomically transition InProgress → Awaiting, extracting the receiver
101-
// while keeping the entry occupied. A concurrent StartPayload sees
102-
// Occupied and skips, so no duplicate task can be spawned during the
103-
// await below.
104-
let rx = match self.title_tasks.entry(conversation.id) {
105-
Entry::Occupied(mut e) => {
106-
match std::mem::replace(e.get_mut(), TitleTask::Awaiting) {
107-
TitleTask::InProgress(rx) => rx,
108-
// Awaiting or Done: another EndPayload is already handling this.
109-
TitleTask::Done(title) => {
110-
conversation.title = Some(title);
111-
return Ok(());
112-
}
113-
other => {
114-
*e.get_mut() = other; // restore
115-
return Ok(());
116-
}
117-
}
118-
}
119-
Entry::Vacant(_) => return Ok(()),
120-
};
121-
122-
// Await the oneshot receiver. Unlike a raw JoinHandle, a oneshot
123-
// receiver never panics on poll-after-completion — it simply returns
124-
// `Err(RecvError)` if the sender was dropped.
125-
match rx.await {
126-
Ok(Some(title)) => {
127-
debug!(
128-
conversation_id = %conversation.id,
129-
title = %title,
130-
"Title generated successfully"
131-
);
132-
conversation.title = Some(title.clone());
133-
// Transition Awaiting → Done only on success.
134-
self.title_tasks
135-
.insert(conversation.id, TitleTask::Done(title));
136-
}
137-
Ok(None) => {
138-
debug!("Title generation returned None");
139-
// Remove so a future StartPayload can retry.
140-
self.title_tasks.remove(&conversation.id);
141-
}
142-
Err(_) => {
143-
debug!("Title generation task was cancelled");
144-
// Remove so a future StartPayload can retry.
145-
self.title_tasks.remove(&conversation.id);
91+
if let Some((_, entry)) = self.title_tasks.remove(&conversation.id) {
92+
let handle = &entry.handle;
93+
let rx = entry.rx;
94+
95+
if rx.is_empty() {
96+
handle.abort();
97+
} else if let Some(title) = rx.await? {
98+
conversation.title = Some(title);
14699
}
147100
}
148101

@@ -152,16 +105,18 @@ impl<S: AgentService> EventHandle<EventData<EndPayload>> for TitleGenerationHand
152105

153106
impl<S> Drop for TitleGenerationHandler<S> {
154107
fn drop(&mut self) {
155-
// Clearing the map drops all `oneshot::Receiver`s, which signals the
156-
// corresponding spawned tasks that the result is no longer needed.
157-
// The tasks will observe a closed channel on `tx.send()` and exit
158-
// gracefully — no `abort()` required.
108+
// Clearing the map drops all `JoinHandle`s (aborting the spawned
109+
// tasks) and `oneshot::Receiver`s. The tasks will observe a closed
110+
// channel on `tx.send()` and exit gracefully.
159111
self.title_tasks.clear();
160112
}
161113
}
162114

163115
#[cfg(test)]
164116
mod tests {
117+
use std::sync::Arc;
118+
use std::time::Duration;
119+
165120
use forge_domain::{
166121
Agent, ChatCompletionMessage, Context, ContextMessage, Conversation, EventValue, ModelId,
167122
ProviderId, Role, TextMessage, ToolCallContext, ToolCallFull, ToolResult,
@@ -233,100 +188,90 @@ mod tests {
233188
let (handler, mut conversation) = setup("test message");
234189
let (tx, rx) = oneshot::channel();
235190
tx.send(Some("original".to_string())).unwrap();
191+
let handle = tokio::spawn(async {});
192+
handle.abort();
236193
handler
237194
.title_tasks
238-
.insert(conversation.id, TitleTask::InProgress(rx));
239-
240-
handler
241-
.handle(&event(StartPayload), &mut conversation)
242-
.await
243-
.unwrap();
244-
245-
let (_, task) = handler.title_tasks.remove(&conversation.id).unwrap();
246-
let actual = match task {
247-
TitleTask::InProgress(rx) => rx.await.unwrap(),
248-
_ => panic!("Expected InProgress"),
249-
};
250-
assert_eq!(actual, Some("original".into()));
251-
}
252-
253-
/// A StartPayload that races with an EndPayload mid-await must not spawn a
254-
/// new task — the Awaiting sentinel keeps the entry occupied.
255-
#[tokio::test]
256-
async fn test_start_skips_if_awaiting() {
257-
let (handler, mut conversation) = setup("test message");
258-
handler
259-
.title_tasks
260-
.insert(conversation.id, TitleTask::Awaiting);
195+
.insert(conversation.id, TitleGenerationState { rx, handle });
261196

262197
handler
263198
.handle(&event(StartPayload), &mut conversation)
264199
.await
265200
.unwrap();
266201

267-
assert!(matches!(
268-
handler.title_tasks.get(&conversation.id).as_deref(),
269-
Some(TitleTask::Awaiting)
270-
));
202+
// Entry should still exist (wasn't replaced)
203+
assert!(handler.title_tasks.contains_key(&conversation.id));
271204
}
272205

273-
/// A StartPayload after generation has finished must not re-spawn.
274206
#[tokio::test]
275-
async fn test_start_skips_if_done() {
207+
async fn test_end_sets_title_from_completed_task() {
276208
let (handler, mut conversation) = setup("test message");
209+
let (tx, rx) = oneshot::channel();
210+
tx.send(Some("generated".to_string())).unwrap();
211+
let handle = tokio::spawn(async {});
212+
handle.abort();
277213
handler
278214
.title_tasks
279-
.insert(conversation.id, TitleTask::Done("existing".into()));
215+
.insert(conversation.id, TitleGenerationState { rx, handle });
280216

281217
handler
282-
.handle(&event(StartPayload), &mut conversation)
218+
.handle(&event(EndPayload), &mut conversation)
283219
.await
284220
.unwrap();
285221

286-
assert!(matches!(
287-
handler.title_tasks.get(&conversation.id).as_deref(),
288-
Some(TitleTask::Done(_))
289-
));
222+
assert_eq!(conversation.title, Some("generated".into()));
223+
// Entry should be removed after successful title generation
224+
assert!(!handler.title_tasks.contains_key(&conversation.id));
290225
}
291226

292227
#[tokio::test]
293-
async fn test_end_sets_title_from_completed_task() {
228+
async fn test_end_handles_task_cancellation() {
294229
let (handler, mut conversation) = setup("test message");
295-
let (tx, rx) = oneshot::channel();
296-
tx.send(Some("generated".to_string())).unwrap();
230+
let (tx, rx) = oneshot::channel::<Option<String>>();
231+
// Drop the sender to simulate a cancelled task.
232+
drop(tx);
233+
let handle = tokio::spawn(async {});
234+
handle.abort();
297235
handler
298236
.title_tasks
299-
.insert(conversation.id, TitleTask::InProgress(rx));
237+
.insert(conversation.id, TitleGenerationState { rx, handle });
300238

301239
handler
302240
.handle(&event(EndPayload), &mut conversation)
303241
.await
304242
.unwrap();
305243

306-
assert_eq!(conversation.title, Some("generated".into()));
307-
assert!(matches!(
308-
handler.title_tasks.get(&conversation.id).as_deref(),
309-
Some(TitleTask::Done(_))
310-
));
244+
assert!(conversation.title.is_none());
245+
assert!(!handler.title_tasks.contains_key(&conversation.id));
311246
}
312247

248+
/// When EndPayload is received, the spawned task should be aborted so it
249+
/// doesn't continue running unnecessarily.
313250
#[tokio::test]
314-
async fn test_end_handles_task_cancellation() {
251+
async fn test_end_aborts_in_progress_task() {
315252
let (handler, mut conversation) = setup("test message");
316253
let (tx, rx) = oneshot::channel::<Option<String>>();
317-
// Drop the sender to simulate a cancelled task.
318-
drop(tx);
254+
let handle = tokio::spawn(async move {
255+
// Simulate a long-running task that would block indefinitely.
256+
tokio::time::sleep(Duration::from_secs(60)).await;
257+
let _ = tx.send(None);
258+
});
259+
319260
handler
320261
.title_tasks
321-
.insert(conversation.id, TitleTask::InProgress(rx));
262+
.insert(conversation.id, TitleGenerationState { rx, handle });
322263

323264
handler
324265
.handle(&event(EndPayload), &mut conversation)
325266
.await
326267
.unwrap();
327268

328-
assert!(conversation.title.is_none());
269+
// Entry should have been removed from map
329270
assert!(!handler.title_tasks.contains_key(&conversation.id));
271+
272+
// Verify the task is no longer running by checking that the
273+
// EndPayload handler didn't hang (it completed immediately).
274+
assert!(conversation.title.is_none());
330275
}
331276

332277
/// Many concurrent StartPayload calls for the same conversation id must
@@ -354,11 +299,7 @@ mod tests {
354299
j.await.unwrap();
355300
}
356301

357-
let actual = handler
358-
.title_tasks
359-
.iter()
360-
.filter(|e| matches!(e.value(), TitleTask::InProgress(_)))
361-
.count();
362-
assert_eq!(actual, 1);
302+
// Only one task should exist in the map
303+
assert_eq!(handler.title_tasks.len(), 1);
363304
}
364305
}

0 commit comments

Comments
 (0)