@@ -91,6 +91,16 @@ def __init__(
9191 self .pool = pool
9292 self ._graph : torch .cuda .CUDAGraph = None
9393
94+ def make_output_buffers (self , output ):
95+ """Make output buffers."""
96+ output_buffers = dict (logits = output )
97+ return output_buffers
98+
99+ def slice_output (self , output_buffers : Dict [str , Any ], inputs : Dict [str , Any ]):
100+ """Slice output."""
101+ num_tokens = inputs ['input_ids' ].size (- 1 )
102+ return output_buffers ['logits' ][:, :num_tokens ]
103+
94104 @record_function ('capture_cudagraph' )
95105 def capture (self , ** kwargs ):
96106 """Capture graph."""
@@ -102,29 +112,31 @@ def capture(self, **kwargs):
102112 current_stream = torch .cuda .current_stream ()
103113
104114 # warmup
105- self .model (** padded_kwargs )
115+ warmup_output = self .model (** padded_kwargs )
116+ warmup_buffers = self .make_output_buffers (warmup_output )
106117
107118 self ._graph = torch .cuda .CUDAGraph ()
108119 # unsafe kernel call in other thread might invalid the capture
109120 # so we set thread_safe capture mode here.
110121 with torch .cuda .graph (self ._graph , pool = self .pool , stream = current_stream , capture_error_mode = 'thread_local' ):
111122 output = self .model (** padded_kwargs )
112123
113- output_buffers = dict ( logits = output )
124+ output_buffers = self . make_output_buffers ( output )
114125 self .meta .output_buffers = output_buffers
126+ output = self .slice_output (warmup_buffers , kwargs )
115127 return output
116128
117129 @record_function ('forward_cudagraph' )
118130 def forward (self , ** kwargs ):
119131 """forward."""
120- num_tokens = kwargs ['input_ids' ].size (- 1 )
121132 assert self ._graph is not None
122133 self .model .fill_buffers_cudagraph (self .meta , ** kwargs )
123134 context = self .ctx_mgr .current_context ()
124135 self .model .update_context_cudagraph (self .meta , context )
125136 self ._graph .replay ()
126137
127- output = self .meta .output_buffers ['logits' ][:, :num_tokens ]
138+ output_buffers = self .meta .output_buffers
139+ output = self .slice_output (output_buffers , kwargs )
128140 return output
129141
130142 def __del__ (self ):
@@ -223,12 +235,14 @@ def __call__(self, **kwargs):
223235 pool = self .graph_pool_handle ,
224236 model_config = self .model_config ,
225237 device = self .device )
226- runner .capture (** kwargs )
238+ output = runner .capture (** kwargs )
227239 self ._runner_map [graph_key ] = runner
240+ # SSM would update the state in capture(warmup), replay the graph will leads unexpected state update.
241+ return output
228242 else :
229243 runner = self ._runner_map [graph_key ]
230- output = runner .forward (** kwargs )
231- return output
244+ output = runner .forward (** kwargs )
245+ return output
232246
233247 @record_function ('prepare_inputs_for_generation' )
234248 def prepare_inputs_for_generation (
0 commit comments