@@ -2,33 +2,26 @@ use std::sync::Arc;
22
33use async_trait:: async_trait;
44use dashmap:: DashMap ;
5- use dashmap:: mapref:: entry:: Entry ;
65use forge_domain:: {
76 Conversation , ConversationId , EndPayload , EventData , EventHandle , StartPayload ,
87} ;
98use tokio:: sync:: oneshot;
10- use tracing :: debug ;
9+ use tokio :: task :: JoinHandle ;
1110
1211use crate :: agent:: AgentService ;
1312use crate :: title_generator:: TitleGenerator ;
1413
1514/// Per-conversation title generation state.
16- enum TitleTask {
17- /// A background task is running; the receiver will deliver its result.
18- InProgress ( oneshot:: Receiver < Option < String > > ) ,
19- /// `EndPayload` has extracted the receiver and is currently awaiting it.
20- /// Kept in the map as a sentinel so a concurrent `StartPayload` sees an
21- /// occupied entry and does not spawn a duplicate task.
22- Awaiting ,
23- /// Title generation has finished successfully; stores the generated title.
24- Done ( #[ allow( dead_code) ] String ) ,
15+ struct TitleGenerationState {
16+ rx : oneshot:: Receiver < Option < String > > ,
17+ handle : JoinHandle < ( ) > ,
2518}
2619
2720/// Hook handler that generates a conversation title asynchronously.
2821#[ derive( Clone ) ]
2922pub struct TitleGenerationHandler < S > {
3023 services : Arc < S > ,
31- title_tasks : Arc < DashMap < ConversationId , TitleTask > > ,
24+ title_tasks : Arc < DashMap < ConversationId , TitleGenerationState > > ,
3225}
3326
3427impl < S > TitleGenerationHandler < S > {
@@ -77,13 +70,11 @@ impl<S: AgentService> EventHandle<EventData<StartPayload>> for TitleGenerationHa
7770 // one task is ever spawned per conversation id.
7871 self . title_tasks . entry ( conversation. id ) . or_insert_with ( || {
7972 let ( tx, rx) = oneshot:: channel ( ) ;
80- tokio:: spawn ( async move {
81- let result = generator. generate ( ) . await . ok ( ) . flatten ( ) ;
82- // If the receiver was dropped (e.g. task cancelled), this is a
83- // no-op — the send simply fails silently.
84- let _ = tx. send ( result) ;
73+ let handle = tokio:: spawn ( async move {
74+ let title = generator. generate ( ) . await . ok ( ) . flatten ( ) ;
75+ let _ = tx. send ( title) ;
8576 } ) ;
86- TitleTask :: InProgress ( rx )
77+ TitleGenerationState { rx , handle }
8778 } ) ;
8879
8980 Ok ( ( ) )
@@ -97,52 +88,14 @@ impl<S: AgentService> EventHandle<EventData<EndPayload>> for TitleGenerationHand
9788 _event : & EventData < EndPayload > ,
9889 conversation : & mut Conversation ,
9990 ) -> anyhow:: Result < ( ) > {
100- // Atomically transition InProgress → Awaiting, extracting the receiver
101- // while keeping the entry occupied. A concurrent StartPayload sees
102- // Occupied and skips, so no duplicate task can be spawned during the
103- // await below.
104- let rx = match self . title_tasks . entry ( conversation. id ) {
105- Entry :: Occupied ( mut e) => {
106- match std:: mem:: replace ( e. get_mut ( ) , TitleTask :: Awaiting ) {
107- TitleTask :: InProgress ( rx) => rx,
108- // Awaiting or Done: another EndPayload is already handling this.
109- TitleTask :: Done ( title) => {
110- conversation. title = Some ( title) ;
111- return Ok ( ( ) ) ;
112- }
113- other => {
114- * e. get_mut ( ) = other; // restore
115- return Ok ( ( ) ) ;
116- }
117- }
118- }
119- Entry :: Vacant ( _) => return Ok ( ( ) ) ,
120- } ;
121-
122- // Await the oneshot receiver. Unlike a raw JoinHandle, a oneshot
123- // receiver never panics on poll-after-completion — it simply returns
124- // `Err(RecvError)` if the sender was dropped.
125- match rx. await {
126- Ok ( Some ( title) ) => {
127- debug ! (
128- conversation_id = %conversation. id,
129- title = %title,
130- "Title generated successfully"
131- ) ;
132- conversation. title = Some ( title. clone ( ) ) ;
133- // Transition Awaiting → Done only on success.
134- self . title_tasks
135- . insert ( conversation. id , TitleTask :: Done ( title) ) ;
136- }
137- Ok ( None ) => {
138- debug ! ( "Title generation returned None" ) ;
139- // Remove so a future StartPayload can retry.
140- self . title_tasks . remove ( & conversation. id ) ;
141- }
142- Err ( _) => {
143- debug ! ( "Title generation task was cancelled" ) ;
144- // Remove so a future StartPayload can retry.
145- self . title_tasks . remove ( & conversation. id ) ;
91+ if let Some ( ( _, entry) ) = self . title_tasks . remove ( & conversation. id ) {
92+ let handle = & entry. handle ;
93+ let rx = entry. rx ;
94+
95+ if rx. is_empty ( ) {
96+ handle. abort ( ) ;
97+ } else if let Some ( title) = rx. await ? {
98+ conversation. title = Some ( title) ;
14699 }
147100 }
148101
@@ -152,16 +105,18 @@ impl<S: AgentService> EventHandle<EventData<EndPayload>> for TitleGenerationHand
152105
153106impl < S > Drop for TitleGenerationHandler < S > {
154107 fn drop ( & mut self ) {
155- // Clearing the map drops all `oneshot::Receiver`s, which signals the
156- // corresponding spawned tasks that the result is no longer needed.
157- // The tasks will observe a closed channel on `tx.send()` and exit
158- // gracefully — no `abort()` required.
108+ // Clearing the map drops all `JoinHandle`s (aborting the spawned
109+ // tasks) and `oneshot::Receiver`s. The tasks will observe a closed
110+ // channel on `tx.send()` and exit gracefully.
159111 self . title_tasks . clear ( ) ;
160112 }
161113}
162114
163115#[ cfg( test) ]
164116mod tests {
117+ use std:: sync:: Arc ;
118+ use std:: time:: Duration ;
119+
165120 use forge_domain:: {
166121 Agent , ChatCompletionMessage , Context , ContextMessage , Conversation , EventValue , ModelId ,
167122 ProviderId , Role , TextMessage , ToolCallContext , ToolCallFull , ToolResult ,
@@ -233,100 +188,90 @@ mod tests {
233188 let ( handler, mut conversation) = setup ( "test message" ) ;
234189 let ( tx, rx) = oneshot:: channel ( ) ;
235190 tx. send ( Some ( "original" . to_string ( ) ) ) . unwrap ( ) ;
191+ let handle = tokio:: spawn ( async { } ) ;
192+ handle. abort ( ) ;
236193 handler
237194 . title_tasks
238- . insert ( conversation. id , TitleTask :: InProgress ( rx) ) ;
239-
240- handler
241- . handle ( & event ( StartPayload ) , & mut conversation)
242- . await
243- . unwrap ( ) ;
244-
245- let ( _, task) = handler. title_tasks . remove ( & conversation. id ) . unwrap ( ) ;
246- let actual = match task {
247- TitleTask :: InProgress ( rx) => rx. await . unwrap ( ) ,
248- _ => panic ! ( "Expected InProgress" ) ,
249- } ;
250- assert_eq ! ( actual, Some ( "original" . into( ) ) ) ;
251- }
252-
253- /// A StartPayload that races with an EndPayload mid-await must not spawn a
254- /// new task — the Awaiting sentinel keeps the entry occupied.
255- #[ tokio:: test]
256- async fn test_start_skips_if_awaiting ( ) {
257- let ( handler, mut conversation) = setup ( "test message" ) ;
258- handler
259- . title_tasks
260- . insert ( conversation. id , TitleTask :: Awaiting ) ;
195+ . insert ( conversation. id , TitleGenerationState { rx, handle } ) ;
261196
262197 handler
263198 . handle ( & event ( StartPayload ) , & mut conversation)
264199 . await
265200 . unwrap ( ) ;
266201
267- assert ! ( matches!(
268- handler. title_tasks. get( & conversation. id) . as_deref( ) ,
269- Some ( TitleTask :: Awaiting )
270- ) ) ;
202+ // Entry should still exist (wasn't replaced)
203+ assert ! ( handler. title_tasks. contains_key( & conversation. id) ) ;
271204 }
272205
273- /// A StartPayload after generation has finished must not re-spawn.
274206 #[ tokio:: test]
275- async fn test_start_skips_if_done ( ) {
207+ async fn test_end_sets_title_from_completed_task ( ) {
276208 let ( handler, mut conversation) = setup ( "test message" ) ;
209+ let ( tx, rx) = oneshot:: channel ( ) ;
210+ tx. send ( Some ( "generated" . to_string ( ) ) ) . unwrap ( ) ;
211+ let handle = tokio:: spawn ( async { } ) ;
212+ handle. abort ( ) ;
277213 handler
278214 . title_tasks
279- . insert ( conversation. id , TitleTask :: Done ( "existing" . into ( ) ) ) ;
215+ . insert ( conversation. id , TitleGenerationState { rx , handle } ) ;
280216
281217 handler
282- . handle ( & event ( StartPayload ) , & mut conversation)
218+ . handle ( & event ( EndPayload ) , & mut conversation)
283219 . await
284220 . unwrap ( ) ;
285221
286- assert ! ( matches!(
287- handler. title_tasks. get( & conversation. id) . as_deref( ) ,
288- Some ( TitleTask :: Done ( _) )
289- ) ) ;
222+ assert_eq ! ( conversation. title, Some ( "generated" . into( ) ) ) ;
223+ // Entry should be removed after successful title generation
224+ assert ! ( !handler. title_tasks. contains_key( & conversation. id) ) ;
290225 }
291226
292227 #[ tokio:: test]
293- async fn test_end_sets_title_from_completed_task ( ) {
228+ async fn test_end_handles_task_cancellation ( ) {
294229 let ( handler, mut conversation) = setup ( "test message" ) ;
295- let ( tx, rx) = oneshot:: channel ( ) ;
296- tx. send ( Some ( "generated" . to_string ( ) ) ) . unwrap ( ) ;
230+ let ( tx, rx) = oneshot:: channel :: < Option < String > > ( ) ;
231+ // Drop the sender to simulate a cancelled task.
232+ drop ( tx) ;
233+ let handle = tokio:: spawn ( async { } ) ;
234+ handle. abort ( ) ;
297235 handler
298236 . title_tasks
299- . insert ( conversation. id , TitleTask :: InProgress ( rx ) ) ;
237+ . insert ( conversation. id , TitleGenerationState { rx , handle } ) ;
300238
301239 handler
302240 . handle ( & event ( EndPayload ) , & mut conversation)
303241 . await
304242 . unwrap ( ) ;
305243
306- assert_eq ! ( conversation. title, Some ( "generated" . into( ) ) ) ;
307- assert ! ( matches!(
308- handler. title_tasks. get( & conversation. id) . as_deref( ) ,
309- Some ( TitleTask :: Done ( _) )
310- ) ) ;
244+ assert ! ( conversation. title. is_none( ) ) ;
245+ assert ! ( !handler. title_tasks. contains_key( & conversation. id) ) ;
311246 }
312247
248+ /// When EndPayload is received, the spawned task should be aborted so it
249+ /// doesn't continue running unnecessarily.
313250 #[ tokio:: test]
314- async fn test_end_handles_task_cancellation ( ) {
251+ async fn test_end_aborts_in_progress_task ( ) {
315252 let ( handler, mut conversation) = setup ( "test message" ) ;
316253 let ( tx, rx) = oneshot:: channel :: < Option < String > > ( ) ;
317- // Drop the sender to simulate a cancelled task.
318- drop ( tx) ;
254+ let handle = tokio:: spawn ( async move {
255+ // Simulate a long-running task that would block indefinitely.
256+ tokio:: time:: sleep ( Duration :: from_secs ( 60 ) ) . await ;
257+ let _ = tx. send ( None ) ;
258+ } ) ;
259+
319260 handler
320261 . title_tasks
321- . insert ( conversation. id , TitleTask :: InProgress ( rx ) ) ;
262+ . insert ( conversation. id , TitleGenerationState { rx , handle } ) ;
322263
323264 handler
324265 . handle ( & event ( EndPayload ) , & mut conversation)
325266 . await
326267 . unwrap ( ) ;
327268
328- assert ! ( conversation . title . is_none ( ) ) ;
269+ // Entry should have been removed from map
329270 assert ! ( !handler. title_tasks. contains_key( & conversation. id) ) ;
271+
272+ // Verify the task is no longer running by checking that the
273+ // EndPayload handler didn't hang (it completed immediately).
274+ assert ! ( conversation. title. is_none( ) ) ;
330275 }
331276
332277 /// Many concurrent StartPayload calls for the same conversation id must
@@ -354,11 +299,7 @@ mod tests {
354299 j. await . unwrap ( ) ;
355300 }
356301
357- let actual = handler
358- . title_tasks
359- . iter ( )
360- . filter ( |e| matches ! ( e. value( ) , TitleTask :: InProgress ( _) ) )
361- . count ( ) ;
362- assert_eq ! ( actual, 1 ) ;
302+ // Only one task should exist in the map
303+ assert_eq ! ( handler. title_tasks. len( ) , 1 ) ;
363304 }
364305}
0 commit comments