|  | 
|  | 1 | +#include <ATen/record_function.h> | 
|  | 2 | +#include <torch/nativert/executor/GraphExecutorBase.h> | 
|  | 3 | + | 
|  | 4 | +#include <c10/util/Logging.h> | 
|  | 5 | +#include <caffe2/core/timer.h> | 
|  | 6 | + | 
|  | 7 | +namespace torch::nativert { | 
|  | 8 | + | 
|  | 9 | +GraphExecutorBase::GraphExecutorBase( | 
|  | 10 | +    const Graph& graph, | 
|  | 11 | +    std::vector<std::unique_ptr<OpKernel>> nodeKernels, | 
|  | 12 | +    const ExecutorConfig& executorConfig) | 
|  | 13 | +    : graph_(graph), | 
|  | 14 | +      nodeKernels_(std::move(nodeKernels)), | 
|  | 15 | +      executorConfig_(executorConfig), | 
|  | 16 | +      execPlan_(ExecutionPlanner{graph_}.createPlan()) {}; | 
|  | 17 | + | 
|  | 18 | +void GraphExecutorBase::fillUserInputs( | 
|  | 19 | +    ExecutionFrame& frame, | 
|  | 20 | +    std::vector<c10::IValue> inputs) { | 
|  | 21 | +  RECORD_USER_SCOPE("Executor::fillUserInputs"); | 
|  | 22 | +  const auto& inputValues = graph_.userInputs(); | 
|  | 23 | +  TORCH_CHECK_EQ(inputValues.size(), inputs.size()); | 
|  | 24 | + | 
|  | 25 | +  // load user input tensor into execution frame | 
|  | 26 | +  for (size_t i = 0; i < inputValues.size(); i++) { | 
|  | 27 | +    if (inputValues[i]) { | 
|  | 28 | +      frame.setIValue(inputValues[i]->id(), std::move(inputs[i])); | 
|  | 29 | +    } | 
|  | 30 | +  } | 
|  | 31 | +} | 
|  | 32 | + | 
|  | 33 | +ProfileMetrics GraphExecutorBase::benchmarkIndividualNodes( | 
|  | 34 | +    ExecutionFrame& executionFrame, | 
|  | 35 | +    std::vector<std::vector<c10::IValue>> inputsList, | 
|  | 36 | +    const uint32_t warmupRuns, | 
|  | 37 | +    const uint32_t mainRuns) { | 
|  | 38 | +  // TODO: add support for memory profiling | 
|  | 39 | +  TORCH_CHECK(warmupRuns >= 1 && mainRuns >= 1); | 
|  | 40 | + | 
|  | 41 | +  ProfileMetrics results; | 
|  | 42 | +  const auto numNodes = static_cast<uint32_t>(nodeKernels_.size()); | 
|  | 43 | +  results.timePerNode.resize(numNodes, 0); | 
|  | 44 | +  if (inputsList.empty()) { | 
|  | 45 | +    auto i = 0; | 
|  | 46 | +    for (const auto& nodeKernel : nodeKernels_) { | 
|  | 47 | +      std::string target(nodeKernel->node()->target()); | 
|  | 48 | +      results.timePerNode[i] = 0; | 
|  | 49 | +      results.timePerNodeType[target] = 0; | 
|  | 50 | +      results.instancesPerNodeType[target]++; | 
|  | 51 | +      if (nodeKernel->hasPrimKernel()) { | 
|  | 52 | +        results.primNodesCount++; | 
|  | 53 | +        results.primNodes.insert(target); | 
|  | 54 | +      } else if (nodeKernel->hasStaticDispatch()) { | 
|  | 55 | +        results.staticDispatchNodesCount++; | 
|  | 56 | +        results.staticDispatchNodes.insert(target); | 
|  | 57 | +      } | 
|  | 58 | +      i++; | 
|  | 59 | +    } | 
|  | 60 | +    results.totalNodesCount = numNodes; | 
|  | 61 | +    for (const auto& p : results.timePerNodeType) { | 
|  | 62 | +      const std::string& kind = p.first; | 
|  | 63 | +      results.percentPerNodeType[kind] = 0; | 
|  | 64 | +    } | 
|  | 65 | +    return results; | 
|  | 66 | +  } | 
|  | 67 | + | 
|  | 68 | +  // Warmup | 
|  | 69 | +  for (uint32_t i = 0; i < warmupRuns; i++) { | 
|  | 70 | +    for (const auto& inputs : inputsList) { | 
|  | 71 | +      execute(executionFrame, inputs); | 
|  | 72 | +    } | 
|  | 73 | +  } | 
|  | 74 | + | 
|  | 75 | +  // Execute kernels | 
|  | 76 | +  caffe2::Timer timer; | 
|  | 77 | +  for (uint32_t i = 0; i < mainRuns; i++) { | 
|  | 78 | +    for (auto inputs : inputsList) { | 
|  | 79 | +      const auto& inputValues = graph_.userInputs(); | 
|  | 80 | + | 
|  | 81 | +      TORCH_CHECK_EQ(inputValues.size(), inputs.size()); | 
|  | 82 | +      for (size_t j = 0; j < inputValues.size(); j++) { | 
|  | 83 | +        executionFrame.setIValue(inputValues[j]->id(), std::move(inputs[j])); | 
|  | 84 | +      } | 
|  | 85 | +      for (NodeIndex nodeIdx = 0; nodeIdx < nodeKernels_.size(); ++nodeIdx) { | 
|  | 86 | +        timer.Start(); | 
|  | 87 | +        nodeKernels_[nodeIdx]->compute(executionFrame); | 
|  | 88 | +        float millis = timer.MilliSeconds(); | 
|  | 89 | +        results.timePerNode[nodeIdx] += millis; | 
|  | 90 | +      } | 
|  | 91 | +    } | 
|  | 92 | +  } | 
|  | 93 | + | 
|  | 94 | +  // Summarize results | 
|  | 95 | +  const float numTotalIters = | 
|  | 96 | +      (static_cast<float>(mainRuns) * static_cast<float>(inputsList.size())); | 
|  | 97 | +  for (const auto i : c10::irange(numNodes)) { | 
|  | 98 | +    const Node* node = nodeKernels_[i]->node(); | 
|  | 99 | +    std::string target(node->target()); | 
|  | 100 | +    results.timePerNode[i] /= numTotalIters; | 
|  | 101 | +    results.timePerNodeType[target] += results.timePerNode[i]; | 
|  | 102 | +    results.instancesPerNodeType[target]++; | 
|  | 103 | +    if (nodeKernels_[i]->hasPrimKernel()) { | 
|  | 104 | +      results.primNodes.insert(target); | 
|  | 105 | +      results.primNodesCount++; | 
|  | 106 | +    } else if (nodeKernels_[i]->hasStaticDispatch()) { | 
|  | 107 | +      results.staticDispatchNodes.insert(target); | 
|  | 108 | +      results.staticDispatchNodesCount++; | 
|  | 109 | +    } | 
|  | 110 | +    results.totalTime += results.timePerNode[i]; | 
|  | 111 | +  } | 
|  | 112 | +  results.totalNodesCount = numNodes; | 
|  | 113 | +  for (const auto& r : results.timePerNodeType) { | 
|  | 114 | +    const std::string& target = r.first; | 
|  | 115 | +    results.percentPerNodeType[target] = r.second * 100.0 / results.totalTime; | 
|  | 116 | +  } | 
|  | 117 | +  return results; | 
|  | 118 | +} | 
|  | 119 | + | 
|  | 120 | +} // namespace torch::nativert | 
0 commit comments