diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp index a131d4c6ce5c..1c2634e294e1 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp @@ -2715,23 +2715,75 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP( reducedShape[i] = xShape[i]; auto reducedType = xType.getWithSizesAndDtype(reducedShape, *stashDtype); + + // native_layer_norm preserves input dtype, so when stash_type + // caused a cast of x, use stashDtype for y and cast back after. + auto actualYType = yType; + if (*stashDtype != yType.getOptionalDtype()) { + actualYType = cast(yType.getWithSizesAndDtype( + yType.getOptionalSizes(), *stashDtype)); + + // Also cast scale and bias to stash_type so all tensor args + // to native_layer_norm share the same dtype. + Value stashDtypeConst = Torch::getDtypeIntValueForType( + rewriter, binder.getLoc(), *stashDtype); + if (auto scaleTy = + dyn_cast(scale.getType())) { + auto newScaleTy = scaleTy.getWithSizesAndDtype( + scaleTy.getOptionalSizes(), *stashDtype); + scale = Torch::AtenToDtypeOp::create( + rewriter, binder.getLoc(), newScaleTy, scale, + /*dtype=*/stashDtypeConst, + /*non_blocking=*/cstFalse, /*copy=*/cstFalse, + /*memory_format=*/none); + } + if (auto bTy = dyn_cast(b.getType())) { + auto newBTy = + bTy.getWithSizesAndDtype(bTy.getOptionalSizes(), *stashDtype); + b = Torch::AtenToDtypeOp::create( + rewriter, binder.getLoc(), newBTy, b, + /*dtype=*/stashDtypeConst, + /*non_blocking=*/cstFalse, /*copy=*/cstFalse, + /*memory_format=*/none); + } + } auto y = Torch::AtenNativeLayerNormOp::create( - rewriter, binder.getLoc(), yType, /*meanType=*/reducedType, + rewriter, binder.getLoc(), actualYType, /*meanType=*/reducedType, /*invStdDevType=*/reducedType, x, normalized_shape, scale, b, constEpsilon); int64_t numResults = binder.op->getNumResults(); if (numResults == 1) { - rewriter.replaceOp(binder.op, y.getResult0()); + Value yResult = y.getResult0(); + if (*stashDtype != yType.getOptionalDtype()) { + Value yDtypeConst = Torch::getDtypeIntValueForType( + rewriter, binder.getLoc(), yType.getDtype()); + yResult = Torch::AtenToDtypeOp::create( + rewriter, binder.getLoc(), yType, yResult, + /*dtype=*/yDtypeConst, + /*non_blocking=*/cstFalse, /*copy=*/cstFalse, + /*memory_format=*/none); + } + rewriter.replaceOp(binder.op, yResult); return success(); } + Value yResult = y.getResult0(); Value meanOutput = y.getResult1(); Value varOutput = y.getResult2(); - // Convert meanType and varType back if stash_dtype is different + // Convert outputs back if stash_dtype is different if (binder.tensorResultTypeAtIndex(meanType, 1) || binder.tensorResultTypeAtIndex(invStdDevType, 2)) return failure(); + if (*stashDtype != yType.getOptionalDtype()) { + Value yDtypeConst = Torch::getDtypeIntValueForType( + rewriter, binder.getLoc(), yType.getDtype()); + yResult = Torch::AtenToDtypeOp::create( + rewriter, binder.getLoc(), yType, yResult, + /*dtype=*/yDtypeConst, + /*non_blocking=*/cstFalse, /*copy=*/cstFalse, + /*memory_format=*/none); + } if (*stashDtype != meanType.getOptionalDtype()) { Value constDtype = Torch::getDtypeIntValueForType( rewriter, binder.getLoc(), meanType.getDtype()); @@ -2746,7 +2798,7 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP( /*non_blocking=*/cstFalse, /*copy=*/cstFalse, /*memory_format=*/none); } - rewriter.replaceOp(binder.op, {y.getResult0(), meanOutput, varOutput}); + rewriter.replaceOp(binder.op, {yResult, meanOutput, varOutput}); return success(); }); diff --git a/test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir b/test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir index 0e9e7b1ebd61..feaf1b963b59 100644 --- a/test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir +++ b/test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir @@ -396,6 +396,48 @@ func.func @test_layer_norm_single_result(%arg0: !torch.vtensor<[1,4,768],f32>, % // ----- +// Test LayerNormalization with stash_type upcasting (f16 input, stash_type=f32). +// When stash_type differs from input dtype, x, scale, and bias must all be +// cast to the stash dtype before calling native_layer_norm, and the result +// must be cast back to the original dtype. +func.func @test_layer_norm_stash_type_f16(%arg0: !torch.vtensor<[2,8,256],f16>, %arg1: !torch.vtensor<[256],f16>, %arg2: !torch.vtensor<[256],f16>) -> !torch.vtensor<[2,8,256],f16> + attributes {torch.onnx_meta.ir_version = 6 : si64, torch.onnx_meta.opset_version = 17 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + %0 = torch.operator "onnx.LayerNormalization"(%arg0, %arg1, %arg2) {torch.onnx.axis = -1 : si64, torch.onnx.epsilon = 9.99999974E-6 : f32, torch.onnx.stash_type = 1 : si64} : (!torch.vtensor<[2,8,256],f16>, !torch.vtensor<[256],f16>, !torch.vtensor<[256],f16>) -> !torch.vtensor<[2,8,256],f16> + return %0 : !torch.vtensor<[2,8,256],f16> +} +// CHECK-LABEL: func.func @test_layer_norm_stash_type_f16 +// CHECK-SAME: %[[X:[a-zA-Z0-9]+]]: !torch.vtensor<[2,8,256],f16> +// CHECK-SAME: %[[SCALE:[a-zA-Z0-9]+]]: !torch.vtensor<[256],f16> +// CHECK-SAME: %[[BIAS:[a-zA-Z0-9]+]]: !torch.vtensor<[256],f16> +// CHECK: %[[X_CAST:.*]] = torch.aten.to.dtype %[[X]], %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} : !torch.vtensor<[2,8,256],f16>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[2,8,256],f32> +// CHECK: %[[SCALE_CAST:.*]] = torch.aten.to.dtype %[[SCALE]], %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} : !torch.vtensor<[256],f16>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[256],f32> +// CHECK: %[[BIAS_CAST:.*]] = torch.aten.to.dtype %[[BIAS]], %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} : !torch.vtensor<[256],f16>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[256],f32> +// CHECK: %[[NORM:.*]], %{{.*}}, %{{.*}} = torch.aten.native_layer_norm %[[X_CAST]], %{{.*}}, %[[SCALE_CAST]], %[[BIAS_CAST]], %{{.*}} : !torch.vtensor<[2,8,256],f32>, !torch.list, !torch.vtensor<[256],f32>, !torch.vtensor<[256],f32>, !torch.float -> !torch.vtensor<[2,8,256],f32>, !torch.vtensor<[2,8,1],f32>, !torch.vtensor<[2,8,1],f32> +// CHECK: torch.aten.to.dtype %[[NORM]], %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} : !torch.vtensor<[2,8,256],f32>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[2,8,256],f16> + +// ----- + +// Test LayerNormalization with stash_type upcasting returning all 3 results. +func.func @test_layer_norm_stash_type_f16_3results(%arg0: !torch.vtensor<[2,8,256],f16>, %arg1: !torch.vtensor<[256],f16>, %arg2: !torch.vtensor<[256],f16>) -> (!torch.vtensor<[2,8,256],f16>, !torch.vtensor<[2,8,1],f16>, !torch.vtensor<[2,8,1],f16>) + attributes {torch.onnx_meta.ir_version = 6 : si64, torch.onnx_meta.opset_version = 17 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + %0:3 = torch.operator "onnx.LayerNormalization"(%arg0, %arg1, %arg2) {torch.onnx.axis = -1 : si64, torch.onnx.epsilon = 9.99999974E-6 : f32, torch.onnx.stash_type = 1 : si64} : (!torch.vtensor<[2,8,256],f16>, !torch.vtensor<[256],f16>, !torch.vtensor<[256],f16>) -> (!torch.vtensor<[2,8,256],f16>, !torch.vtensor<[2,8,1],f16>, !torch.vtensor<[2,8,1],f16>) + return %0#0, %0#1, %0#2 : !torch.vtensor<[2,8,256],f16>, !torch.vtensor<[2,8,1],f16>, !torch.vtensor<[2,8,1],f16> +} +// CHECK-LABEL: func.func @test_layer_norm_stash_type_f16_3results +// CHECK-SAME: %[[X:[a-zA-Z0-9]+]]: !torch.vtensor<[2,8,256],f16> +// CHECK-SAME: %[[SCALE:[a-zA-Z0-9]+]]: !torch.vtensor<[256],f16> +// CHECK-SAME: %[[BIAS:[a-zA-Z0-9]+]]: !torch.vtensor<[256],f16> +// CHECK: %[[X_CAST:.*]] = torch.aten.to.dtype %[[X]] +// CHECK: %[[SCALE_CAST:.*]] = torch.aten.to.dtype %[[SCALE]] +// CHECK: %[[BIAS_CAST:.*]] = torch.aten.to.dtype %[[BIAS]] +// CHECK: %[[NORM:.*]], %[[MEAN:.*]], %[[VAR:.*]] = torch.aten.native_layer_norm %[[X_CAST]], %{{.*}}, %[[SCALE_CAST]], %[[BIAS_CAST]], %{{.*}} : !torch.vtensor<[2,8,256],f32>, !torch.list, !torch.vtensor<[256],f32>, !torch.vtensor<[256],f32>, !torch.float -> !torch.vtensor<[2,8,256],f32>, !torch.vtensor<[2,8,1],f32>, !torch.vtensor<[2,8,1],f32> +// CHECK: %[[Y_BACK:.*]] = torch.aten.to.dtype %[[NORM]], %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} : !torch.vtensor<[2,8,256],f32>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[2,8,256],f16> +// CHECK: %[[MEAN_BACK:.*]] = torch.aten.to.dtype %[[MEAN]], %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} : !torch.vtensor<[2,8,1],f32>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[2,8,1],f16> +// CHECK: %[[VAR_BACK:.*]] = torch.aten.to.dtype %[[VAR]], %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} : !torch.vtensor<[2,8,1],f32>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[2,8,1],f16> +// CHECK: return %[[Y_BACK]], %[[MEAN_BACK]], %[[VAR_BACK]] + +// ----- + // CHECK-LABEL: func.func @test_leaky_relu func.func @test_leaky_relu(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32> attributes {torch.onnx_meta.opset_version = 16 : si64} { // CHECK-DAG: %[[F2:.+]] = torch.constant.float 2