diff --git a/python/paddle/distributed/fleet/meta_parallel/parallel_layers/pp_layers.py b/python/paddle/distributed/fleet/meta_parallel/parallel_layers/pp_layers.py index 17a0e1fa4d130d..7ee0c535d26ada 100755 --- a/python/paddle/distributed/fleet/meta_parallel/parallel_layers/pp_layers.py +++ b/python/paddle/distributed/fleet/meta_parallel/parallel_layers/pp_layers.py @@ -41,11 +41,14 @@ import glob import math import os +import random import re import warnings from collections import OrderedDict from functools import partial +import numpy as np + import paddle import paddle.distributed as dist from paddle import framework, nn @@ -53,20 +56,40 @@ from paddle.distributed.fleet.meta_optimizers.dygraph_optimizer.hybrid_parallel_optimizer import ( SHARED_WEIGHT_SYNC_PREFIX, ) +from paddle.distributed.fleet.meta_parallel.parallel_layers.random import ( + get_rng_state_tracker, +) from paddle.distributed.fleet.utils.log_util import layer_to_str, logger from paddle.framework import core from paddle.incubate.distributed.fleet import recompute_hybrid +from ...recompute.recompute import ( + CustomStatesManager, + detach_variable, + switch_rng_state_tracker, +) from ..pp_utils.forward_backward_overlap_utils import ( ScheduleChunk, ) __all__ = [] +def reset_outputs(outputs): + # 使用recompute时,pylayer的输出会把output都变成vstop_gradient为True的 + if isinstance(outputs, paddle.Tensor): + outputs.stop_gradient = False + elif isinstance(outputs, tuple) or isinstance(outputs, list): + for output_tensor in outputs: + if isinstance(output_tensor, paddle.Tensor): + output_tensor.stop_gradient = False + else: + raise ValueError(f"Unsupported type of outputs: {type(outputs)}.") + class LayerDesc: def __init__(self, layer_func, *inputs, **kwargs): self.layer_func = layer_func + self.recompute_config = kwargs.pop('recompute_config', {}) self.inputs = inputs self.kwargs = kwargs @@ -132,6 +155,9 @@ def __init__( self.shared_weight_attr = shared_weight_attr +custom_state_manager = CustomStatesManager() + + class SegmentLayers: def __init__( self, @@ -389,6 +415,7 @@ def __init__( num_virtual_pipeline_stages=None, use_cudagraph=False, use_dualpipev=False, + recompute_overlap=False, ): super().__init__() if num_stages is None and topology is None: @@ -483,6 +510,22 @@ def __init__( self.local_shared_layers = paddle.nn.LayerDict() self.local_shared_weight_attrs = {} + # initialize recompute_overlap need + self.recompute_overlap = recompute_overlap + if self.recompute_overlap: + self.recompute_idxs = [] + self.preserve_rng_states = {} + self.offload_indices = {} + self.layer_results = [] # key:idx value:[detach_inputs, detach_outputs] + self.states = [] + + if self._num_virtual_pipeline_stages > 1: + self.total_recompute_idxs = [] + self.total_preserve_rng_states = [] + self.total_offload_indices = [] + self.total_layer_results = [] # key:idx value:[detach_inputs, detach_outputs] + self.total_states = [] + if self._use_dualpipev: self._start_poss = [] self._end_poss = [] @@ -921,6 +964,13 @@ def _build_chunked_layer(self): # Add the chunk to all chunks and add this chunk to the sublayer self._model_chunks.append(chunk) self.add_sublayer(str(start), chunk) + if self.recompute_overlap: + self.total_layer_results.append([]) + self.total_states.append([]) + + self.recompute_idxs = [] + self.preserve_rng_states = {} + self.offload_indices = {} paddle.set_rng_state(orig_rng_state) get_rng_state_tracker().set_states_tracker(orig_rng_tracker) @@ -973,6 +1023,17 @@ def flush_into_run_function(): for index, layer in enumerate(self._layers_desc[start:end]): layer_index = start + index + if layer.recompute_config: + self.recompute_idxs.append(index) + self.preserve_rng_states[index] = layer.recompute_config[ + "preserve_rng_states" + ] + if "offload_indices" in layer.recompute_config: + self.offload_indices[index] = layer.recompute_config[ + "offload_indices" + ] + else: + self.offload_indices[index] = [] # NOTE(shenliang03): need set different seeds for pipeline parameters initialization. # Since the parameters of model_parallel are controlled by its own RNG_STATE_TRACKER, @@ -1070,7 +1131,10 @@ def flush_into_run_function(): else: flush_into_run_function() run_function.append(layer) - + if self.recompute_overlap and self._num_virtual_pipeline_stages > 1: + self.total_recompute_idxs.append(self.recompute_idxs) + self.total_preserve_rng_states.append(self.preserve_rng_states) + self.total_offload_indices.append(self.offload_indices) flush_into_run_function() return run_function @@ -1097,15 +1161,281 @@ def check_overlap_schedule_mode(): schedule_chunk = ScheduleChunk(nodes=nodes) return schedule_chunk + def generate_state( + self, + run_function_idx, + inputs, + ): + state = {} + preserve_rng_state = self.preserve_rng_states[run_function_idx] + + if isinstance(inputs, paddle.Tensor): + inputs = (inputs,) + state["inputs_type"] = "Tensor" + elif isinstance(inputs, dict): + # if inputs is dict, split its values as inputs + inputs = tuple(inputs.values()) + state["inputs_keys"] = tuple(inputs.keys()) + state["inputs_type"] = "dict" + elif isinstance(inputs, tuple): + state["inputs_type"] = "tuple" + elif isinstance(inputs, list): + state["inputs_type"] = "list" + else: + raise ValueError( + "Inputs of types other than Tensor, tuple, list, or dict are currently not supported in save_state now." + ) + + if custom_state_manager.custom_get_state_func is None: + assert custom_state_manager.custom_set_state_func is None + custom_get_state_func = lambda x=None: None + custom_set_state_func = lambda x=None: None + else: + custom_get_state_func = custom_state_manager.custom_get_state_func + custom_set_state_func = custom_state_manager.custom_set_state_func + + state["preserve_rng_state"] = preserve_rng_state + + if preserve_rng_state: + state["fw_rng_state"] = paddle.get_rng_state() + state["fwd_rng_state_tracker"] = ( + get_rng_state_tracker().get_states_tracker() + ) + state["fwd_numpy_state"] = np.random.get_state() + state["fwd_random_state"] = random.getstate() + state["fwd_custom_state"] = custom_get_state_func() + state["custom_get_state_func"] = custom_get_state_func + state["custom_set_state_func"] = custom_set_state_func + tracer = framework._dygraph_tracer() + state["is_fw_autocast"] = ( + False if tracer._amp_level == framework.core.AmpLevel.O0 else True + ) + if tracer._amp_level == framework.core.AmpLevel.O2: + state["amp_level"] = 'O2' + elif tracer._amp_level in ( + framework.core.AmpLevel.O1, + framework.core.AmpLevel.O0, + ): + state["amp_level"] = 'O1' + else: + raise ValueError(f"unsupported amp level: {tracer._amp_level}") + + if tracer._amp_dtype == 'float16': + state["amp_dtype"] = 'float16' + elif tracer._amp_dtype in ('bfloat16', 'float32'): + state["amp_dtype"] = 'bfloat16' + else: + raise ValueError(f"unsupported amp dtype: {tracer._amp_dtype}") + state["amp_white_list"], state["amp_black_list"] = ( + tracer._get_amp_op_list() + ) + + # tensor's idx in inputs + state["tensor_indices"] = [] + # tensor in inputs if its idx in offload_indices + state["tensor_inputs"] = [] + # input which is not tensor in inputs, if it is tensor append None + state["inputs"] = [] + for i, input_tensor in enumerate(inputs): + if paddle.is_tensor(input_tensor): + if i in self.offload_indices[run_function_idx]: + cpu_tensor = ( + input_tensor.pin_memory() + if framework.core.is_compiled_with_cuda() + else input_tensor.cpu() + ) + cpu_tensor._share_buffer_to(input_tensor) + ''' + The tensor prev_pp_rank was computed inside a with paddle.no_grad() is stop_gradient. + To ensure that backward works properly, its stop_gradient should be set to False. + ''' + # input_tensor.stop_gradient = False + state["tensor_inputs"].append(input_tensor) + state["tensor_indices"].append(i) + state["inputs"].append(None) + elif type(input_tensor) is tuple: + assert i not in self.offload_indices, ( + f"offload_indices should not contain tensor tuple in position{i}" + ) + is_tensors = [paddle.is_tensor(a) for a in input_tensor] + if all(is_tensors): + # the tuple is a tuple of tensors + tensors_stop_gradient = [ + a.stop_gradient for a in input_tensor + ] + if not all(tensors_stop_gradient) and any( + tensors_stop_gradient + ): + # tensors in the tuple have different stop_gradient value, which pylayer doesn't support + raise ValueError( + "Recompute receive a tuple containing tensor holds different stop gradient." + ) + state["tensor_inputs"].append(input_tensor) + state["tensor_indices"].append(i) + state["inputs"].append(None) + elif any(is_tensors): + # the tuple contains tensors and non-tensor values + raise ValueError( + "Recompute receive a tuple containing tensor and non-tensor at same time." + ) + else: + state["inputs"].append(input_tensor) + else: + state["inputs"].append(input_tensor) + return state + + def load_state_and_forward(self, state, run_function): + with paddle.base.dygraph.guard(): + inputs = list(state["inputs"]) + tensor_indices = state["tensor_indices"] + tensors = state["tensor_inputs"] + for i, idx in enumerate(tensor_indices): + inputs[idx] = ( + tensors[i].to( + paddle.base.framework._current_expected_place() + ) + if i in self.offload_indices + else tensors[i] + ) + if i in self.offload_indices: + # NOTE(zhiqiu): tensor.to(device) will set stop_gradient=True, which may break the gragh + inputs[idx].stop_gradient = tensors[i].stop_gradient + tracer = framework._dygraph_tracer() + tracer._has_grad = True + + preserve_rng_state = state["preserve_rng_state"] + + # NOTE support AMP + # need restore auto_cast state as well as w/b list + if preserve_rng_state: + with ( + switch_rng_state_tracker( + state["fw_rng_state"], + state["fwd_rng_state_tracker"], + state["fwd_numpy_state"], + state["fwd_random_state"], + state["fwd_custom_state"], + state["custom_get_state_func"], + state["custom_set_state_func"], + ), + paddle.amp.auto_cast( + enable=state["is_fw_autocast"], + custom_white_list=state["amp_white_list"], + custom_black_list=state["amp_black_list"], + level=state["amp_level"], + dtype=state["amp_dtype"], + ), + ): + detached_inputs = detach_variable(tuple(inputs)) + if state["inputs_type"] == "dict": + # form detached_inputs to dict, keys:state["inputs_keys"] values:detached_inputs + final_input = dict( + zip(state["inputs_keys"], detached_inputs) + ) + elif state["inputs_type"] == "tuple": + final_input = detached_inputs + elif state["inputs_type"] == "Tensor": + final_input = detached_inputs[0] + else: + raise ValueError( + "Inputs of types other than Tensor, tuple, or dict are currently not supported in load_state_and_forward now." + ) + detached_outputs = run_function(final_input) + else: + with paddle.amp.auto_cast( + enable=state["is_fw_autocast"], + custom_white_list=state["amp_white_list"], + custom_black_list=state["amp_black_list"], + level=state["amp_level"], + dtype=state["amp_dtype"], + ): + detached_inputs = detach_variable(tuple(inputs)) + if state["inputs_type"] == "dict": + # form detached_inputs to dict, keys:state["inputs_keys"] values:detached_inputs + final_input = dict( + zip(state["inputs_keys"], detached_inputs) + ) + elif state["inputs_type"] == "tuple": + final_input = detached_inputs + elif state["inputs_type"] == "Tensor": + final_input = detached_inputs[0] + elif state["inputs_type"] == "list": + final_input = list(detached_inputs) + else: + raise ValueError( + "Inputs of types other than Tensor, tuple, or dict are currently not supported in load_state_and_forward now." + ) + detached_outputs = run_function(final_input) + return final_input, detached_outputs + def forward_function(self, start, end): run_function = self.run_function + if self.recompute_overlap: + state_dict = {} + layer_result = {} + recompute_idx = self.recompute_idxs + + def execute_func(*inputs, is_pipeline_last_stage): + if len(inputs) == 1: + inputs = inputs[0] + + if not self.recompute_overlap: + for idx, layer in enumerate(run_function[start:end]): + inputs = layer(inputs) + return inputs + else: + if is_pipeline_last_stage: + # 最后一个state就算开了recompute 也不需要re-forward + for idx, layer in enumerate(run_function[start:end]): + inputs = layer(inputs) + self.layer_results.append(layer_result) + return inputs + else: + no_recompute_flag = False + layer_result_idx = 0 + for idx, layer in enumerate(run_function[start:end]): + if idx in recompute_idx: + layer_result[idx] = [] + no_recompute_flag = False + layer_result_idx = idx + 1 + state = self.generate_state(idx, inputs) + state_dict[idx] = state + with paddle.no_grad(): + outputs = layer(inputs) + reset_outputs(outputs) + else: + outputs = layer(inputs) + if not no_recompute_flag: + layer_result_idx = idx + layer_result[layer_result_idx] = [ + inputs, + outputs, + ] + no_recompute_flag = True + else: + layer_result[layer_result_idx][1] = outputs + inputs = outputs + self.layer_results.append(layer_result) + self.states.append(state_dict) + return inputs - def execute_func(*x): - if len(x) == 1: - x = x[0] - for idx, layer in enumerate(run_function[start:end]): - x = layer(x) - return x + return execute_func + + def recompute_function(self, start, end): + run_function = self.run_function[start:end] + + def execute_func(*inputs): + layer_result = self.layer_results.pop(0) + state_dict = self.states.pop(0) if len(self.states) != 0 else {} + if state_dict: + if len(inputs) == 1: + inputs = inputs[0] + for idx, state in state_dict.items(): + final_input, detach_outputs = self.load_state_and_forward( + state, run_function[idx] + ) + layer_result[idx] = [final_input, detach_outputs] + return layer_result return execute_func @@ -1127,17 +1457,27 @@ def update_run_function(self, chunk_id): # But for interleave, self.run_function will keep updating to the target functions at every run. self.run_function = model_chunk.get_run_function() + if self.recompute_overlap: + self.recompute_idxs = self.total_recompute_idxs[chunk_id] + self.layer_results = self.total_layer_results[chunk_id] + self.states = self.total_states[chunk_id] + self.preserve_rng_states = self.total_preserve_rng_states[ + chunk_id + ] + self.offload_indices = self.total_offload_indices[chunk_id] + def get_schedule_chunk(self, chunk_id): self.update_run_function(chunk_id) assert self._recompute_interval == 0 return self.build_schedule_nodes(0, len(self.run_function)) - def forward(self, input, chunk_id=None): + def forward(self, input, chunk_id=None, is_pipeline_last_stage=False): self.update_run_function(chunk_id) - if self._recompute_interval == 0: - input = self.forward_function(0, len(self.run_function))(input) + input = self.forward_function(0, len(self.run_function))( + input, is_pipeline_last_stage=is_pipeline_last_stage + ) else: num_layers = len(self.run_function) for start_idx in range(0, num_layers, self._recompute_interval): @@ -1152,9 +1492,41 @@ def forward(self, input, chunk_id=None): self.recompute_ctx, self.forward_function(start_idx, end_idx), *input, + is_pipeline_last_stage=is_pipeline_last_stage, + ) + else: + input = self.forward_function(start_idx, end_idx)( + *input, is_pipeline_last_stage=is_pipeline_last_stage + ) + + return input + + def recompute(self, input, chunk_id=None): + self.update_run_function(chunk_id) + + if self._recompute_interval == 0: + input = self.recompute_function(0, len(self.run_function))(input) + else: + num_layers = len(self.run_function) + for start_idx in range(0, num_layers, self._recompute_interval): + end_idx = min(start_idx + self._recompute_interval, num_layers) + funcs = self.run_function[start_idx:end_idx] + + if not isinstance(input, tuple): + input = (input,) + + if self._need_recompute(funcs, input): + input = recompute_hybrid( + self.recompute_ctx, + self.recompute_function(start_idx, end_idx), + states, + start_idx, + *input, ) else: - input = self.forward_function(start_idx, end_idx)(*input) + input = self.recompute_function(start_idx, end_idx)( + states, start_idx, *input + ) return input diff --git a/python/paddle/distributed/fleet/meta_parallel/pipeline_parallel.py b/python/paddle/distributed/fleet/meta_parallel/pipeline_parallel.py index 8a81fc354d7a80..e2c6cc3e783a0b 100644 --- a/python/paddle/distributed/fleet/meta_parallel/pipeline_parallel.py +++ b/python/paddle/distributed/fleet/meta_parallel/pipeline_parallel.py @@ -495,6 +495,8 @@ def __init__(self, layers, hcg, strategy): # only support user hooks during training self.user_hooks_enabled = True + self.recompute_overlap = True + def register_hook( self, location: PipelineParallelMicroStepLocations, hook: Callable ): @@ -866,6 +868,18 @@ def forward_backward_pipeline( ) output_tensor_tuple = dict_to_tuple_helper(output_tensor) + + input_buffers.append(input_tensor) + output_buffers.append(output_tensor_tuple) + + input_tensor, output_tensor = ( + input_buffers.pop(0), + output_buffers.pop(0), + ) + + if self.recompute_overlap: + recompute_result = self._layers.recompute(input_tensor) + # NOTE: `send_forward_recv_backward` is intentionally unused to # prevent hanging bugs in dynamic shape mode. self._p2p_helper.send_forward( @@ -879,21 +893,26 @@ def forward_backward_pipeline( batch_p2p_comm=self._use_batch_p2p_comm, ) - input_buffers.append(input_tensor) - output_buffers.append(output_tensor_tuple) - if not self.is_pipeline_last_stage(): self._release_output(output_tensor_tuple) - input_tensor, output_tensor = ( - input_buffers.pop(0), - output_buffers.pop(0), - ) - self._record_stamp("B", i, '"B"', self._backward_color) - input_tensor_grad = self._backward_step( - input_tensor, output_tensor, output_tensor_grad, step_id=i - ) + if self.recompute_overlap: + input_tensor_grad = self._backward_step( + input_tensor, + output_tensor, + output_tensor_grad, + recompute_result, + step_id=i, + ) + else: + input_tensor_grad = self._backward_step( + input_tensor, + output_tensor, + output_tensor_grad, + step_id=i, + ) + self._record_stamp("B", i, '"E"', self._backward_color) if last_iter: @@ -930,15 +949,28 @@ def forward_backward_pipeline( batch_p2p_comm=self._use_batch_p2p_comm, ) + if self.recompute_overlap: + recompute_result = self._layers.recompute(input_tensor) + self._record_stamp( "B", steady_steps + i, '"B"', self._backward_color ) - input_tensor_grad = self._backward_step( - input_tensor, - output_tensor, - output_tensor_grad, - step_id=steady_steps + i, - ) + + if self.recompute_overlap: + input_tensor_grad = self._backward_step( + input_tensor, + output_tensor, + output_tensor_grad, + recompute_result, + step_id=steady_steps + i, + ) + else: + input_tensor_grad = self._backward_step( + input_tensor, + output_tensor, + output_tensor_grad, + step_id=steady_steps + i, + ) self._record_stamp( "B", steady_steps + i, '"E"', self._backward_color ) @@ -1254,10 +1286,15 @@ def _forward_step( schedule_chunk = None if overlap_schedule_mode: schedule_chunk = self._layers.get_schedule_chunk(chunk_id=chunk_id) - output_tensor = schedule_chunk.forward(input_tensor) + output_tensor = schedule_chunk.forward( + input_tensor, + is_pipeline_last_stage=self.is_pipeline_last_stage(), + ) else: output_tensor = self._layers.forward( - input_tensor, chunk_id=chunk_id + input_tensor, + chunk_id=chunk_id, + is_pipeline_last_stage=self.is_pipeline_last_stage(), ) self.callbacks.on_location( @@ -1290,6 +1327,7 @@ def _backward_step( input_tensor, output_tensor, output_tensor_grad, + recompute_result=None, chunk_id=None, step_id=None, overlap_schedule_mode=False, @@ -1312,22 +1350,15 @@ def _backward_step( output_tensor_grad=output_tensor_grad, step_id=step_id, ) - if self.is_pipeline_last_stage(): - assert output_tensor_grad is None - if overlap_schedule_mode: - assert ( - loss_fn_node is not None and schedule_chunk is not None - ), ( - "loss_fn_node and schedule_chunk should not be None in overlap_schedule_mode" - ) - input_tensor_grad = loss_fn_node.backward( - scaler=self.scaler - ) - input_tensor_grad = schedule_chunk.backward( - input_tensor_grad - ) - else: - # In align mode, we scale the grad directly after forward + if self.recompute_overlap: + # TODO 暂时不支持overlap_schedule_mode + if self.is_pipeline_last_stage(): + if isinstance(input_tensor, paddle.Tensor): + input_tensor.stop_gradient = False + else: + for input_item in input_tensor: + input_item.stop_gradient = False + if paddle.distributed.in_auto_parallel_align_mode(): output_tensor = output_tensor / _get_align_mode_scale() if self.scaler: @@ -1336,31 +1367,8 @@ def _backward_step( ) else: paddle.autograd.backward(output_tensor) - else: - if isinstance(output_tensor, tuple): - outputs = [t for t in output_tensor if not t.stop_gradient] - assert len(outputs) == len(output_tensor_grad) - grad_tensors = list(output_tensor_grad) - else: - outputs = [output_tensor] - grad_tensors = [output_tensor_grad] - - if overlap_schedule_mode: - assert schedule_chunk is not None, ( - "schedule_chunk should not be None in overlap_schedule_mode" - ) - input_tensor_grad = schedule_chunk.backward(grad_tensors) - else: - paddle.autograd.backward( - tensors=outputs, - grad_tensors=grad_tensors, - ) - if not overlap_schedule_mode: - # Extract input_tensor_grad from the input tensor. In overlap_schedule_mode, - # the input_tensor_grad is extracted inside the schedule_chunk. - input_tensor_grad = None - if input_tensor is not None: + input_tensor_grad = None if isinstance(input_tensor, tuple): input_tensor_grad = tuple( [ @@ -1369,8 +1377,140 @@ def _backward_step( if not t.stop_gradient ] ) + elif isinstance(input_tensor, list): + input_tensor_grad = list( + [ + t.grad + for t in input_tensor + if not t.stop_gradient + ] + ) else: input_tensor_grad = input_tensor.grad + else: + for idx in sorted(recompute_result.keys(), reverse=True): + layer_result = recompute_result[idx] + output_tensor = layer_result[1] + if isinstance(output_tensor, tuple): + outputs = [ + t for t in output_tensor if not t.stop_gradient + ] + assert len(outputs) == len(output_tensor_grad), ( + f"{len(outputs)=}, {len(output_tensor_grad)=}" + ) + grad_tensors = list(output_tensor_grad) + elif isinstance(output_tensor, list): + outputs = [ + t for t in output_tensor if not t.stop_gradient + ] + assert len(outputs) == len(output_tensor_grad), ( + f"{len(outputs)=}, {len(output_tensor_grad)=}" + ) + grad_tensors = list(output_tensor_grad) + elif isinstance(output_tensor, paddle.Tensor): + outputs = [output_tensor] + grad_tensors = [output_tensor_grad] + else: + raise ValueError( + f"Unsupported type of output_tensor: {type(output_tensor)}" + ) + + paddle.autograd.backward( + tensors=outputs, + grad_tensors=grad_tensors, + ) + input_tensor_grad = None + input_tensor = layer_result[0] + if input_tensor is not None: + if isinstance(input_tensor, tuple): + input_tensor_grad = tuple( + [ + t.grad + for t in input_tensor + if not t.stop_gradient + ] + ) + elif isinstance(input_tensor, list): + input_tensor_grad = list( + [ + t.grad + for t in input_tensor + if not t.stop_gradient + ] + ) + else: + input_tensor_grad = input_tensor.grad + output_tensor_grad = input_tensor_grad + input_tensor_grad = output_tensor_grad + else: + if self.is_pipeline_last_stage(): + assert output_tensor_grad is None + if overlap_schedule_mode: + assert ( + loss_fn_node is not None + and schedule_chunk is not None + ), ( + "loss_fn_node and schedule_chunk should not be None in overlap_schedule_mode" + ) + input_tensor_grad = loss_fn_node.backward( + scaler=self.scaler + ) + input_tensor_grad = schedule_chunk.backward( + input_tensor_grad + ) + else: + # In align mode, we scale the grad directly after forward + if paddle.distributed.in_auto_parallel_align_mode(): + output_tensor = ( + output_tensor / _get_align_mode_scale() + ) + if self.scaler: + paddle.autograd.backward( + self.scaler.scale(output_tensor) + ) + else: + paddle.autograd.backward(output_tensor) + else: + if isinstance(output_tensor, tuple): + outputs = [ + t for t in output_tensor if not t.stop_gradient + ] + assert len(outputs) == len(output_tensor_grad) + grad_tensors = list(output_tensor_grad) + else: + outputs = [output_tensor] + grad_tensors = [output_tensor_grad] + + if overlap_schedule_mode: + assert schedule_chunk is not None, ( + "schedule_chunk should not be None in overlap_schedule_mode" + ) + input_tensor_grad = schedule_chunk.backward( + grad_tensors + ) + + else: + tracer = framework._dygraph_tracer() + paddle.autograd.backward( + tensors=outputs, + grad_tensors=grad_tensors, + ) + + if not overlap_schedule_mode: + # Extract input_tensor_grad from the input tensor. In overlap_schedule_mode, + # the input_tensor_grad is extracted inside the schedule_chunk. + input_tensor_grad = None + if input_tensor is not None: + if isinstance(input_tensor, tuple): + input_tensor_grad = tuple( + [ + t.grad + for t in input_tensor + if not t.stop_gradient + ] + ) + else: + input_tensor_grad = input_tensor.grad if self._enable_timer: self.timers("backward_step").stop() self.callbacks.on_location(