diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py index 05e2cd9258..e74d68c386 100644 --- a/onnxscript/function_libs/torch_lib/ops/core.py +++ b/onnxscript/function_libs/torch_lib/ops/core.py @@ -3538,10 +3538,51 @@ def aten_feature_alpha_dropout(input: TensorType, p: float, train: bool) -> Tens raise NotImplementedError() -def aten_feature_dropout(input: TensorType, p: float, train: bool) -> TensorType: +@torch_op("aten::feature_dropout", trace_only=True) +def aten_feature_dropout(input: TFloat, p: FLOAT, train: BOOL) -> TFloat: """feature_dropout(Tensor input, float p, bool train) -> Tensor""" - raise NotImplementedError() + # Feature dropout applies dropout to entire feature maps/channels + # rather than individual elements + + if p == 0 or not train: + return input + + # Get input dimensions + ndim = op.Size(op.Shape(input)) + + # Create mask shape for feature dropout + # For 2D tensors [N, C]: mask shape is [N, C] + # For higher dim tensors [N, C, ...]: mask shape is [N, C, 1, 1, ...] + batch_size = op.Shape(input, start=0, end=1) + channel_size = op.Shape(input, start=1, end=2) + + # Create the appropriate mask shape based on tensor dimensions + is_2d = op.Equal(ndim, 2) + + # For 2D case, mask_shape = [N, C] + mask_shape_2d = op.Concat(batch_size, channel_size, axis=0) + + # For higher dimensions, mask_shape = [N, C, 1, 1, ...] + spatial_dims_count = op.Sub(ndim, 2) + ones_for_spatial = op.ConstantOfShape( + op.Reshape(spatial_dims_count, [1]), + value=1 + ) + mask_shape_nd = op.Concat(batch_size, channel_size, ones_for_spatial, axis=0) + + # Select appropriate mask shape + mask_shape = op.Where(is_2d, mask_shape_2d, mask_shape_nd) + + # Create a dummy tensor of ones with the mask shape and apply dropout to it + # This leverages op.Dropout to handle training mode, scaling, and random generation + dummy_tensor = op.ConstantOfShape(mask_shape, value=1.0) + mask, _ = op.Dropout(dummy_tensor, p, train) + + # Apply mask to input (broadcasting will handle different shapes) + result = op.Mul(input, mask) + + return result @torch_op(("aten::fill.Tensor", "aten::fill.Scalar"))