Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Visibility update for profiler_impl #23267

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
75 changes: 75 additions & 0 deletions third_party/stablehlo/temporary.patch
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,43 @@ diff --ruN a/stablehlo/stablehlo/tests/transforms/stablehlo_aggressive_simplific
}

// -----
@@ -1908,6 +1917,19 @@

// -----

+// CHECK-LABEL: @side_effecting_custom_call
+func.func @side_effecting_custom_call(%arg0: tensor<0xf32>) -> (tensor<0xf32>, tensor<0xf32>) {
+ // CHECK: %[[CST:.*]] = stablehlo.constant dense<> : tensor<0xf32>
+ // CHECK-NEXT: %[[CC:.*]] = stablehlo.custom_call @foo(%arg0) {api_version = 0 : i32, has_side_effect = true} : (tensor<0xf32>) -> tensor<0xf32>
+ %0 = stablehlo.custom_call @foo(%arg0) {api_version = 0 : i32, has_side_effect = true} : (tensor<0xf32>) -> tensor<0xf32>
+ // CHECK-NOT: stablehlo.custom_call{{.*}}has_side_effect = false
+ %1 = stablehlo.custom_call @foo(%arg0) {api_version = 0 : i32, has_side_effect = false} : (tensor<0xf32>) -> tensor<0xf32>
+ // CHECK: return %[[CC]], %[[CST]]
+ return %0, %1 : tensor<0xf32>, tensor<0xf32>
+}
+
+// -----
+
/////////
// Generic Shape Ops

diff --ruN a/stablehlo/stablehlo/transforms/optimization/Passes.h b/stablehlo/stablehlo/transforms/optimization/Passes.h
--- stablehlo/stablehlo/transforms/optimization/Passes.h
+++ stablehlo/stablehlo/transforms/optimization/Passes.h
@@ -50,6 +50,13 @@
MLIRContext *context,
bool foldFloat = false,
PatternBenefit benefit = 1);
+
+/// Some workloads in XLA import StableHLO from HLO. Since there are a few
+/// differences in HLO (no implicit captures, lots of tuples, etc.), this
+/// set of patterns brings the imported HLO back to a more canonical form
+/// without applying a full set of graph simplifications.
+void populateStablehloHloImportCanonicalizationPatterns(
+ MLIRContext *context, RewritePatternSet *patterns);
} // namespace stablehlo
} // namespace mlir

diff --ruN a/stablehlo/stablehlo/transforms/optimization/StablehloAggressiveSimplification.cpp b/stablehlo/stablehlo/transforms/optimization/StablehloAggressiveSimplification.cpp
--- stablehlo/stablehlo/transforms/optimization/StablehloAggressiveSimplification.cpp
+++ stablehlo/stablehlo/transforms/optimization/StablehloAggressiveSimplification.cpp
Expand All @@ -68,4 +105,42 @@ diff --ruN a/stablehlo/stablehlo/transforms/optimization/StablehloAggressiveSimp
return rewriter.notifyMatchFailure(op, "operand is not empty tensor");

if (resultTy.hasStaticShape()) {
@@ -1399,6 +1403,12 @@
return rewriter.notifyMatchFailure(op, "not stablehlo");
if (isa<ConstantOp>(op))
return rewriter.notifyMatchFailure(op, "op is empty constant");
+
+ // Skip ops that have memory effects, similar to XLA's zero extent
+ // simplification, replacing these doesn't save any computation.
+ auto effectInterface = dyn_cast<MemoryEffectOpInterface>(op);
+ if (effectInterface && !effectInterface.hasNoEffect())
+ return rewriter.notifyMatchFailure(op, "op has memory effect");

// If the result is a zero-extent tensor, replace the whole op with an empty
// constant.
@@ -1528,6 +1538,12 @@
DynamicReshapeOpIsStatic, DynamicIotaIsStatic>(context);
}

+void populateStablehloHloImportCanonicalizationPatterns(
+ MLIRContext *context, RewritePatternSet *patterns) {
+ patterns->add<TupleIsRepacking, TupleIsUnpacked, WhileOpImplicitCapture>(
+ context);
+}
+
std::unique_ptr<Pass> createStablehloAggressiveSimplificationPass(
GreedyRewriteConfig config) {
return std::make_unique<StablehloAggressiveSimplificationPass>(config);
diff --ruN a/stablehlo/stablehlo/transforms/optimization/StablehloAggressiveSimplificationPatterns.td b/stablehlo/stablehlo/transforms/optimization/StablehloAggressiveSimplificationPatterns.td
--- stablehlo/stablehlo/transforms/optimization/StablehloAggressiveSimplificationPatterns.td
+++ stablehlo/stablehlo/transforms/optimization/StablehloAggressiveSimplificationPatterns.td
@@ -411,7 +411,7 @@
// GetTupleElementOp

// Pattern: get_tuple_element(tuple(X_0, X_1, ...), i) -> X_i
-def : Pat<(StableHLO_GetTupleElementOp (StableHLO_TupleOp:$tuple $operands), $idx),
+def TupleIsUnpacked : Pat<(StableHLO_GetTupleElementOp (StableHLO_TupleOp:$tuple $operands), $idx),
(GetOperandN $tuple, $idx)>;

////////

3 changes: 1 addition & 2 deletions xla/mlir_hlo/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -1087,8 +1087,7 @@ cc_library(
"stablehlo_ext/transforms/sdy_refine_shapes.cpp",
"stablehlo_ext/transforms/stablehlo_add_quant_dequant_conv.cpp",
"stablehlo_ext/transforms/stablehlo_canonicalize_dynamism.cpp",
"stablehlo_ext/transforms/stablehlo_flatten_entry_function_tuples.cpp",
"stablehlo_ext/transforms/stablehlo_flatten_tuple.cpp",
"stablehlo_ext/transforms/stablehlo_canonicalize_from_hlo_import.cpp",
"stablehlo_ext/transforms/stablehlo_legalize_quant_composite.cpp",
"stablehlo_ext/transforms/stablehlo_prepare_for_hlo_export.cpp",
"stablehlo_ext/transforms/stablehlo_refine_shapes.cpp",
Expand Down
28 changes: 20 additions & 8 deletions xla/mlir_hlo/stablehlo_ext/transforms/passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -52,18 +52,30 @@ def StablehloPrepareForHloExportPass : Pass<"stablehlo-ext-prepare-for-hlo-expor
}];
}

def StablehloFlattenTuplePass : Pass<"stablehlo-ext-flatten-tuple", "func::FuncOp"> {
let summary = "Flatten tuples in operands and results of operators that "
"support both tuple and variadic type.";
}
def StablehloCanonicalizeFromHloImportPass : Pass<"stablehlo-ext-canonicalize-from-hlo-import", "mlir::func::FuncOp"> {
let summary = "Simplify StableHLO imported from HLO";

let dependentDialects = ["stablehlo::StablehloDialect"];

let description = [{
This pass simplifies StableHLO imported from HLO. This pass is a subset of
the graph simplification passes and is intended to bring the imported HLO
back to a more canonical form without applying a full set of graph
simplifications.

Namely, this pass:
* Simplifies tuples, undoing `tuple(get_tuple_element)` and
`get_tuple_element(tuple)`.
* Converts WhileOp explicit captured constants to implicit captures.
* Flattens tuples in operands and results of operators that support both
tuple and variadic type.
* Flattens tuples in entry function of the module.
}];

def StablehloFlattenEntryFunctionTuplesPass : Pass<"stablehlo-ext-expand-flatten-entry-function-tuples", "ModuleOp"> {
let summary = "Flatten HLO tuple for the entry function of the module.";
let options = [
Option<"entryFunctionNameOption", "entry-function", "std::string",
/*default=*/"", "the name of entry function of the module">,
/*default=*/[{"main"}], "the name of entry function of the module">,
];
let dependentDialects = ["mlir::stablehlo::StablehloDialect"];
}

def StablehloLegalizeQuantCompositePass : Pass<"stablehlo-ext-legalize-quant-composite", "ModuleOp"> {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,41 +13,128 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

// This file implements logic for flattening tuples in HLO ops.
// This file implements logic for some optimizations to reduce size on export.

#include <cassert>
#include <memory>
#include <iterator>
#include <utility>

#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/StringRef.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/IR/Block.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinTypeInterfaces.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Location.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/Operation.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/Types.h"
#include "mlir/IR/Region.h"
#include "mlir/IR/Value.h"
#include "mlir/IR/ValueRange.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Pass/PassRegistry.h"
#include "mlir/Support/LLVM.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "stablehlo/dialect/StablehloOps.h"
#include "stablehlo/transforms/optimization/Passes.h"
#include "stablehlo_ext/transforms/passes.h" // NOLINT: Used in passes.h.inc

#define DEBUG_TYPE "stablehlo-ext-canonicalize-from-hlo-import"

namespace mlir {
namespace stablehlo_ext {

#define GEN_PASS_DEF_STABLEHLOFLATTENTUPLEPASS
#define GEN_PASS_DEF_STABLEHLOCANONICALIZEFROMHLOIMPORTPASS
#include "stablehlo_ext/transforms/passes.h.inc"

namespace {

/////////////
// Flatten Tuples in entry computation

// Expands the mhlo.tuple used in return op. Also updates function
// signature accordingly.
void expandTupledTensorInReturnOp(func::FuncOp func) {
FunctionType oldFuncType = func.getFunctionType();
// Update input signatures.
// We will flatten the tuples for the function inputs as well.
// So if an input is tuple, will be flattened and packed as following:
// func_1(%arg0: tuple<input1, input2>) =>
//
// func_1(%arg0: <input1>, %arg1: <input2>) {
// %0 = mhlo.tuple(%arg0, %arg1)
// }
SmallVector<Type, 4> expandedInputTypes;
SmallVector<BlockArgument, 20> funcArguments(func.getArguments().begin(),
func.getArguments().end());
for (auto argument : funcArguments) {
auto type = argument.getType();
auto tupleType = mlir::dyn_cast_or_null<TupleType>(type);
if (!tupleType) {
expandedInputTypes.push_back(type);
} else {
// We need to
// 1) expand the tuple
// 2) insert a new tuple
// 3) rewire the new tuple
int originalArgumentIndex = argument.getArgNumber();
int argumentIndex = originalArgumentIndex;
SmallVector<Value, 4> flattenedOperands;
// insert the flattened tuples after the original tuple.
Location loc = func.getBody().getLoc();
for (auto flattenedType : tupleType.getTypes()) {
expandedInputTypes.push_back(flattenedType);
func.insertArgument(++argumentIndex, flattenedType, {}, loc);
flattenedOperands.push_back(func.getArgument(argumentIndex));
}

// Construct a new tuple and rewire it.
OpBuilder builder(func.getBody());
builder.setInsertionPointToStart(&func.getBody().front());
auto newTuple =
builder.create<stablehlo::TupleOp>(loc, tupleType, flattenedOperands);
func.getArgument(originalArgumentIndex).replaceAllUsesWith(newTuple);

// Now the original argument has been rewired, we should be able to
// safely erase it.
func.eraseArgument(originalArgumentIndex);
}
}

// Update output signatures.
auto returnOp = cast<mlir::func::ReturnOp>(func.getBody().back().back());
OpBuilder builder(returnOp);

// Expand all tuples in old return operands.
SmallVector<Value, 4> expandedReturnOperands;
SmallVector<Type, 4> expandedResultTypes;
for (auto value : returnOp.getOperands()) {
if (auto tupleTy = mlir::dyn_cast<TupleType>(value.getType())) {
llvm::copy(tupleTy.getTypes(), std::back_inserter(expandedResultTypes));
for (auto [index, ty] : llvm::enumerate(tupleTy.getTypes())) {
expandedReturnOperands.push_back(
builder.createOrFold<stablehlo::GetTupleElementOp>(
value.getLoc(), ty, value, index));
}
} else {
expandedReturnOperands.push_back(value);
expandedResultTypes.push_back(value.getType());
}
}

if (returnOp.getOperands() == expandedReturnOperands) return;

builder.create<mlir::func::ReturnOp>(returnOp.getLoc(),
expandedReturnOperands);
returnOp.erase();
auto newFuncType = FunctionType::get(oldFuncType.getContext(),
expandedInputTypes, expandedResultTypes);
func.setType(newFuncType);
}

/////////////
// Flatten Tuples in Custom Calls

// Calculates the flatten types of a value.
void flattenTupleType(Value value, llvm::SmallVectorImpl<Type> &types) {
if (!mlir::isa<TupleType>(value.getType())) {
Expand Down Expand Up @@ -132,27 +219,46 @@ struct FlattenCustomCallOp : public OpRewritePattern<stablehlo::CustomCallOp> {
}
};

class StablehloFlattenTuplePass
: public impl::StablehloFlattenTuplePassBase<StablehloFlattenTuplePass> {
public:
// Simplify a model after HLO import.
struct StablehloCanonicalizeFromHloImportPass
: public impl::StablehloCanonicalizeFromHloImportPassBase<
StablehloCanonicalizeFromHloImportPass> {
using StablehloCanonicalizeFromHloImportPassBase::
StablehloCanonicalizeFromHloImportPassBase;

void runOnOperation() override {
// If entry function, flatten the input tuples
func::FuncOp func = getOperation();
if (func.getName() == entryFunctionNameOption.getValue()) {
// Recursively expand tuples until all of them are gone.
while (
llvm::any_of(llvm::concat<const Type>(func.getArgumentTypes(),
func.getResultTypes()),
[](Type type) { return mlir::isa<TupleType>(type); })) {
expandTupledTensorInReturnOp(func);
}
}

// Flatten tuples in function body
MLIRContext *context = &getContext();
RewritePatternSet patterns(context);
patterns.add<FlattenCustomCallOp>(context);
stablehlo::populateStablehloHloImportCanonicalizationPatterns(context,
&patterns);

// Apply patterns without folding
GreedyRewriteConfig config;
config.useTopDownTraversal = true;
config.enableRegionSimplification =
mlir::GreedySimplifyRegionLevel::Disabled;
config.fold = false;
config.cseConstants = false;
if (failed(applyPatternsGreedily(getOperation(), std::move(patterns),
config))) {
if (failed(applyPatternsGreedily(func, std::move(patterns), config)))
signalPassFailure();
}
}
};

} // namespace
} // end namespace

} // namespace stablehlo_ext
} // namespace mlir
Loading
Loading