Skip to content
This repository has been archived by the owner on Aug 16, 2024. It is now read-only.

Commit

Permalink
Shard Operators.cpp (pytorch#62185)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: pytorch#62185

This file can take 5 minutes on its own to compile, and is the single limiting
factor for compile time of `libtorch_cpu` on a 32-core threadripper. Instead,
sharding into 5 files that take around 1 minute each cuts a full minute off the
overall build time.

This also factors out the `.findSchemaOrThrow(...).typed` step so the code can
be shared between `call` and `redispatch`.

Test Plan: Imported from OSS

Reviewed By: bdhirsh

Differential Revision: D29962049

Pulled By: albanD

fbshipit-source-id: be5df05fbea09ada0d825855f1618c25a11abbd8
  • Loading branch information
peterbell10 authored and facebook-github-bot committed Aug 9, 2021
1 parent 4b9ca72 commit 93e0f3a
Show file tree
Hide file tree
Showing 4 changed files with 36 additions and 10 deletions.
6 changes: 5 additions & 1 deletion BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,11 @@ genrule(
"aten/src/ATen/Functions.cpp",
"aten/src/ATen/RedispatchFunctions.h",
"aten/src/ATen/Operators.h",
"aten/src/ATen/Operators.cpp",
"aten/src/ATen/Operators_0.cpp",
"aten/src/ATen/Operators_1.cpp",
"aten/src/ATen/Operators_2.cpp",
"aten/src/ATen/Operators_3.cpp",
"aten/src/ATen/Operators_4.cpp",
"aten/src/ATen/NativeFunctions.h",
"aten/src/ATen/MetaFunctions.h",
"aten/src/ATen/MetaFunctions_inl.h",
Expand Down
2 changes: 2 additions & 0 deletions aten/src/ATen/templates/Operators.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
#include <ATen/Tensor.h>
#include <ATen/core/dispatch/Dispatcher.h>

// NOTE See [Sharded File] comment in VariableType

namespace at { namespace _ops {

${definitions}
Expand Down
6 changes: 6 additions & 0 deletions tools/code_analyzer/op_deps_pass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -770,6 +770,12 @@ class OpDependency : public ModulePass {

static void extractStringValue(
Value* V, const std::function<void(const std::string&)>& CB) {
if (isa<UndefValue>(V)) {
// UndefValue inherits from ConstantValue, but don't contain any data
// See: https://llvm.org/docs/LangRef.html#undefined-values
return;
}

if (auto array = dyn_cast<ConstantDataArray>(V)) {
// Normal case for c-style string literal and "std::basic_string".
if (array->isCString()) {
Expand Down
32 changes: 23 additions & 9 deletions tools/codegen/gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,7 @@ class ComputeOperators:
]

@method_with_native_function
def __call__(self, f: NativeFunction) -> Optional[str]:
def __call__(self, f: NativeFunction) -> str:
sig = DispatcherSignature.from_schema(f.func)
name = f.func.name.unambiguous_name()
call_method_name = 'call'
Expand Down Expand Up @@ -260,7 +260,15 @@ def __call__(self, f: NativeFunction) -> Optional[str]:
defns = f"""
STATIC_CONST_STR_OUT_OF_LINE_FOR_WIN_CUDA({name}, name, "aten::{str(f.func.name)}")
STATIC_CONST_STR_OUT_OF_LINE_FOR_WIN_CUDA({name}, overload_name, "{f.func.name.overload_name}")
STATIC_CONST_STR_OUT_OF_LINE_FOR_WIN_CUDA({name}, schema_str, {cpp_string(str(f.func))})"""
STATIC_CONST_STR_OUT_OF_LINE_FOR_WIN_CUDA({name}, schema_str, {cpp_string(str(f.func))})
// aten::{f.func}
static C10_NOINLINE c10::TypedOperatorHandle<{name}::schema> create_{name}_typed_handle() {{
return c10::Dispatcher::singleton()
.findSchemaOrThrow("aten::{f.func.name.name}", "{f.func.name.overload_name}")
.typed<{name}::schema>();
}}
"""

for is_redispatching_fn in [False, True]:
if is_redispatching_fn:
Expand All @@ -275,9 +283,7 @@ def __call__(self, f: NativeFunction) -> Optional[str]:
defns += f"""
// aten::{f.func}
{sig.defn(name=method_name, is_redispatching_fn=is_redispatching_fn)} {{
static auto op = c10::Dispatcher::singleton()
.findSchemaOrThrow("aten::{f.func.name.name}", "{f.func.name.overload_name}")
.typed<{sig.type()}>();
static auto op = create_{name}_typed_handle();
return op.{dispatcher_call}({dispatcher_exprs_str});
}}
"""
Expand Down Expand Up @@ -1175,10 +1181,18 @@ def make_file_manager(install_dir: str) -> FileManager:
'schema_registrations': list(mapMaybe(RegisterSchema(schema_selector), native_functions)),
})

cpu_fm.write('Operators.cpp', lambda: {
'definitions': list(mapMaybe(ComputeOperators(
Target.DEFINITION), native_functions)),
})
def key_func(fn: NativeFunction) -> str:
return fn.func.name.unambiguous_name()

cpu_fm.write_sharded(
'Operators.cpp',
native_functions,
key_fn=key_func,
env_callable=lambda fn: {
'definitions': [ComputeOperators(Target.DEFINITION)(fn)]},
num_shards=5,
sharded_keys={'definitions'}
)
cpu_fm.write('Operators.h', lambda: {
'declarations': list(mapMaybe(ComputeOperators(
Target.DECLARATION), native_functions)),
Expand Down

0 comments on commit 93e0f3a

Please sign in to comment.