diff --git a/BUILD.bazel b/BUILD.bazel index 2ede05b71ad380..e9c57169944846 100644 --- a/BUILD.bazel +++ b/BUILD.bazel @@ -147,7 +147,11 @@ genrule( "aten/src/ATen/Functions.cpp", "aten/src/ATen/RedispatchFunctions.h", "aten/src/ATen/Operators.h", - "aten/src/ATen/Operators.cpp", + "aten/src/ATen/Operators_0.cpp", + "aten/src/ATen/Operators_1.cpp", + "aten/src/ATen/Operators_2.cpp", + "aten/src/ATen/Operators_3.cpp", + "aten/src/ATen/Operators_4.cpp", "aten/src/ATen/NativeFunctions.h", "aten/src/ATen/MetaFunctions.h", "aten/src/ATen/MetaFunctions_inl.h", diff --git a/aten/src/ATen/templates/Operators.cpp b/aten/src/ATen/templates/Operators.cpp index 0d57d1fbb6d132..696d41ad96ab67 100644 --- a/aten/src/ATen/templates/Operators.cpp +++ b/aten/src/ATen/templates/Operators.cpp @@ -2,6 +2,8 @@ #include #include +// NOTE See [Sharded File] comment in VariableType + namespace at { namespace _ops { ${definitions} diff --git a/tools/code_analyzer/op_deps_pass.cpp b/tools/code_analyzer/op_deps_pass.cpp index f41e032969d785..d2f5876f1bf2da 100644 --- a/tools/code_analyzer/op_deps_pass.cpp +++ b/tools/code_analyzer/op_deps_pass.cpp @@ -770,6 +770,12 @@ class OpDependency : public ModulePass { static void extractStringValue( Value* V, const std::function& CB) { + if (isa(V)) { + // UndefValue inherits from ConstantValue, but don't contain any data + // See: https://llvm.org/docs/LangRef.html#undefined-values + return; + } + if (auto array = dyn_cast(V)) { // Normal case for c-style string literal and "std::basic_string". if (array->isCString()) { diff --git a/tools/codegen/gen.py b/tools/codegen/gen.py index 2deddb75fff0c0..65e30ab34e7a1d 100644 --- a/tools/codegen/gen.py +++ b/tools/codegen/gen.py @@ -215,7 +215,7 @@ class ComputeOperators: ] @method_with_native_function - def __call__(self, f: NativeFunction) -> Optional[str]: + def __call__(self, f: NativeFunction) -> str: sig = DispatcherSignature.from_schema(f.func) name = f.func.name.unambiguous_name() call_method_name = 'call' @@ -260,7 +260,15 @@ def __call__(self, f: NativeFunction) -> Optional[str]: defns = f""" STATIC_CONST_STR_OUT_OF_LINE_FOR_WIN_CUDA({name}, name, "aten::{str(f.func.name)}") STATIC_CONST_STR_OUT_OF_LINE_FOR_WIN_CUDA({name}, overload_name, "{f.func.name.overload_name}") -STATIC_CONST_STR_OUT_OF_LINE_FOR_WIN_CUDA({name}, schema_str, {cpp_string(str(f.func))})""" +STATIC_CONST_STR_OUT_OF_LINE_FOR_WIN_CUDA({name}, schema_str, {cpp_string(str(f.func))}) + +// aten::{f.func} +static C10_NOINLINE c10::TypedOperatorHandle<{name}::schema> create_{name}_typed_handle() {{ + return c10::Dispatcher::singleton() + .findSchemaOrThrow("aten::{f.func.name.name}", "{f.func.name.overload_name}") + .typed<{name}::schema>(); +}} +""" for is_redispatching_fn in [False, True]: if is_redispatching_fn: @@ -275,9 +283,7 @@ def __call__(self, f: NativeFunction) -> Optional[str]: defns += f""" // aten::{f.func} {sig.defn(name=method_name, is_redispatching_fn=is_redispatching_fn)} {{ - static auto op = c10::Dispatcher::singleton() - .findSchemaOrThrow("aten::{f.func.name.name}", "{f.func.name.overload_name}") - .typed<{sig.type()}>(); + static auto op = create_{name}_typed_handle(); return op.{dispatcher_call}({dispatcher_exprs_str}); }} """ @@ -1175,10 +1181,18 @@ def make_file_manager(install_dir: str) -> FileManager: 'schema_registrations': list(mapMaybe(RegisterSchema(schema_selector), native_functions)), }) - cpu_fm.write('Operators.cpp', lambda: { - 'definitions': list(mapMaybe(ComputeOperators( - Target.DEFINITION), native_functions)), - }) + def key_func(fn: NativeFunction) -> str: + return fn.func.name.unambiguous_name() + + cpu_fm.write_sharded( + 'Operators.cpp', + native_functions, + key_fn=key_func, + env_callable=lambda fn: { + 'definitions': [ComputeOperators(Target.DEFINITION)(fn)]}, + num_shards=5, + sharded_keys={'definitions'} + ) cpu_fm.write('Operators.h', lambda: { 'declarations': list(mapMaybe(ComputeOperators( Target.DECLARATION), native_functions)),