Skip to content

Commit

Permalink
Add pass to process input and output alias hint info (#1273)
Browse files Browse the repository at this point in the history
add pass to process input and output alias info
  • Loading branch information
eedalong committed Jan 10, 2024
1 parent 67c3242 commit ddcc491
Show file tree
Hide file tree
Showing 9 changed files with 212 additions and 4 deletions.
2 changes: 1 addition & 1 deletion docker/scripts/install-python.sh
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ function install_python() {
fi
ln -s /usr/bin/python3 /usr/bin/python
python3 -m pip install --upgrade pip
python3 -m pip install cpython virtualenv numpy
python3 -m pip install virtualenv numpy
}

function install_venv() {
Expand Down
1 change: 0 additions & 1 deletion pytorch_blade/scripts/pip/requirements-dev-1.13.1+cpu.txt
100644 → 100755
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
cpython
pyyaml
typing_extensions
numpy
Expand Down
1 change: 0 additions & 1 deletion pytorch_blade/scripts/pip/requirements-dev-1.13.1+cu116.txt
100644 → 100755
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
cpython
pyyaml
typing_extensions
numpy
Expand Down
28 changes: 28 additions & 0 deletions tao_compiler/mlir/disc/BUILD
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -967,6 +967,33 @@ cc_library(
alwayslink = 1,
)

cc_library(
name = "disc_input_output_alias",
srcs = ["transforms/disc_input_output_alias.cc"],
hdrs = [
"transforms/passes.h",
"transforms/rewriters.h",
],
deps = [
":pass_details",
":placement_utils",
":lmhlo_disc",
"@org_tensorflow//tensorflow/compiler/xla/mlir_hlo:lhlo",
"@llvm-project//llvm:Support",
"@llvm-project//llvm:TransformUtils",
"@llvm-project//mlir:FuncDialect",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:LLVMDialect",
"@llvm-project//mlir:Pass",
"@llvm-project//mlir:ShapeDialect",
"@llvm-project//mlir:Support",
"@llvm-project//mlir:TensorDialect",
"@llvm-project//mlir:TensorTransforms",
"@llvm-project//mlir:Transforms",
],
alwayslink = 1,
)

cc_library(
name = "disc_assign_memory_space",
srcs = ["transforms/disc_assign_memory_space.cc"],
Expand Down Expand Up @@ -2304,6 +2331,7 @@ cc_library(
":disc_lower_quantize_and_dequantize",
":disc_transform_weight_data_layout_for_weight_only_quant",
":disc_lower_to_library_call",
":disc_input_output_alias",
":disc_math_approximation",
":disc_memref_canonicalizer",
":disc_outline_cpu_kernel",
Expand Down
2 changes: 1 addition & 1 deletion tao_compiler/mlir/disc/disc_compiler.cc
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -242,8 +242,8 @@ LogicalResult LowerHLOToLLVM(ModuleOp m, const DISCLoweringOptions& options) {
/*printAfterOnlyOnChange=*/true,
/*printAfterOnlyOnFailure*/ false, llvm::dbgs(), printingFlags);

pm.addPass(disc_ral::createDiscInputOutputAliasPass());
pm.addPass(mlir::createInlinerPass());

// TODO(disc): Lower HLO shape constraints instead of eliding them here.
pm.addNestedPass<FuncOp>(disc_ral::createDiscMhloDecompositionRewriterPass());
pm.addNestedPass<FuncOp>(disc_ral::createDiscRemoveShapeConstraintsPass());
Expand Down
164 changes: 164 additions & 0 deletions tao_compiler/mlir/disc/transforms/disc_input_output_alias.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,164 @@
// Copyright 2021 The BladeDISC Authors. All rights reserved.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include <iostream>
#include <stdexcept>

#include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/StringMap.h"
#include "llvm/ADT/StringRef.h"
#include "llvm/Support/Debug.h"
#include "mhlo/IR/hlo_ops.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h" // TF:llvm-project
#include "mlir/IR/Attributes.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Location.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/Operation.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/DialectConversion.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "mlir/disc/IR/hlo_disc_ops.h"
#include "mlir/disc/transforms/PassDetail.h"
#include "mlir/disc/transforms/rewriters.h"

namespace mlir {
namespace disc_ral {

bool ParseInputOutputAliasInfo(func::FuncOp main, std::vector<int>& params,
std::vector<int>& outputs) {
auto dict_attr = main->getAttrOfType<DictionaryAttr>("tf.entry_function");
if (!dict_attr) {
return false;
}

const std::string inputOutputAliasParamsKey = "input_output_alias_params";
const std::string inputOutputAliasOutputsKey = "input_output_alias_outputs";

if (!dict_attr.get(inputOutputAliasParamsKey) ||
!dict_attr.get(inputOutputAliasParamsKey)) {
return false;
}

auto param_str =
dict_attr.get(inputOutputAliasParamsKey).dyn_cast<mlir::StringAttr>();
auto outputs_str =
dict_attr.get(inputOutputAliasOutputsKey).dyn_cast<mlir::StringAttr>();

SmallVector<StringRef, 4> parsed_params, parsed_outputs;
param_str.getValue().split(parsed_params, ',', /*MaxSplit=*/-1,
/*KeepEmpty=*/false);
outputs_str.getValue().split(parsed_outputs, ',', /*MaxSplit=*/-1,
/*KeepEmpty=*/false);

for (StringRef str : parsed_params) {
try {
params.push_back(std::stoi(str.str()));
} catch (const std::invalid_argument& e) {
throw std::invalid_argument("An invalid value " + str.str() +
" is received when converting index in "
"input_output_alias_params to int value");
}
}

for (StringRef str : parsed_outputs) {
try {
outputs.push_back(std::stoi(str.str()));
} catch (const std::invalid_argument& e) {
throw std::invalid_argument("An invalid value " + str.str() +
" is received when converting index in "
"input_output_alias_outputs to int value");
}
}

return true;
}

struct DiscInputOutputAliasPass
: public DiscInputOutputAliasPassBase<DiscInputOutputAliasPass> {
using DiscInputOutputAliasPassBase<
DiscInputOutputAliasPass>::DiscInputOutputAliasPassBase;

void getDependentDialects(DialectRegistry& registry) const override {
registry.insert<mhlo_disc::MhloDiscDialect>();
}

public:
DiscInputOutputAliasPass() {}

void runOnOperation() override {
ModuleOp module = getOperation();
auto main_func = module.lookupSymbol<mlir::func::FuncOp>("main");
if (!main_func) {
signalPassFailure();
return;
}

// Parse attribute info
std::vector<int> params_index, outputs_index;
try {
if (!ParseInputOutputAliasInfo(main_func, params_index, outputs_index)) {
return;
}
} catch (const std::invalid_argument& e) {
main_func.emitOpError() << e.what();
}

if (params_index.size() != outputs_index.size()) {
main_func.emitOpError()
<< "input_output_alias_params and input_output_alias_outputs should "
"have same number of index";
signalPassFailure();
}

OpBuilder builder(main_func.getBody());
auto returnOp =
cast<mhlo::ReturnOp>(main_func.getBody().back().getTerminator());

// Get input and output tensor for main function
auto params = main_func.getArguments();
auto outputs = returnOp.getOperands();

// Insert mhlo_disc::ArgsMutationOp
for (int i = 0; i < params_index.size(); i++) {
if (outputs[outputs_index[i]] == params[params_index[i]]) {
continue;
}
// DISC now only support one-hop buffer sharing.
auto defineOp = outputs[outputs_index[i]].getDefiningOp();
for (const auto& value : defineOp->getOperands()) {
if (params[params_index[i]] == value) {
builder.setInsertionPointAfterValue(outputs[outputs_index[i]]);
builder.create<mhlo_disc::ArgsMutationOp>(main_func.getLoc(),
outputs[outputs_index[i]],
params[params_index[i]]);
break;
}
}
}
}
};

std::unique_ptr<OperationPass<ModuleOp>> createDiscInputOutputAliasPass() {
return std::make_unique<DiscInputOutputAliasPass>();
}

} // namespace disc_ral
} // namespace mlir
5 changes: 5 additions & 0 deletions tao_compiler/mlir/disc/transforms/disc_passes.td
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -658,3 +658,8 @@ def DiscEraseBufferDeallocationPass : Pass<"disc-erase-buffer-deallocation", "ml
let summary = "Erase dealloc op for GPU func ops.";
let constructor = "createDiscEraseBufferDeallocationPass()";
}

def DiscInputOutputAliasPass : Pass<"disc-input-output-alias", "ModuleOp"> {
let summary = "Input and output alias information for buffer reuse";
let constructor = "createDiscInputOutputAliasPass()";
}
2 changes: 2 additions & 0 deletions tao_compiler/mlir/disc/transforms/passes.h
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -326,6 +326,8 @@ createDiscDuplicateComputationAfterFusionPass();
std::unique_ptr<OperationPass<gpu::GPUFuncOp>>
createDiscEraseBufferDeallocationPass();

// Insert ArgsMutationOp for buffer reuse
std::unique_ptr<OperationPass<ModuleOp>> createDiscInputOutputAliasPass();
} // namespace disc_ral
} // namespace mlir

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
// RUN: disc-opt -disc-input-output-alias \
// RUN: %s -o - | FileCheck %s

// CHECK-LABEL: main
func.func @main(%arg0: tensor<200x200xf32>, %arg1: tensor<200x200xf32>) -> (tensor<200x200xf32>, tensor<200x200xf32>) attributes {tf.entry_function = {input_output_alias_outputs = "0,1", input_output_alias_params = "0,1", input_placements = "gpu,gpu", output_placements = "gpu,gpu"}} {
// CHECK: %0 = mhlo.add %arg1, %arg0 : tensor<200x200xf32>
%0 = mhlo.add %arg1, %arg0 : tensor<200x200xf32>
// CHECK: "mhlo_disc.args_mutation"(%0, %arg1) : (tensor<200x200xf32>, tensor<200x200xf32>) -> ()
// CHECK: return %arg0, %0 : tensor<200x200xf32>, tensor<200x200xf32>
return %arg0, %0 : tensor<200x200xf32>, tensor<200x200xf32>
}

0 comments on commit ddcc491

Please sign in to comment.