diff --git a/CMakeLists.txt b/CMakeLists.txt index ccda3d89fb5..abd17e8aebd 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -561,6 +561,42 @@ endif() target_link_libraries(codegen_internal PUBLIC LLVM_JIT) +# Precompiled Headers for Top nvFuser Headers +# Post-M8, template instantiation is reduced by 81%, making header parsing +# a significant fraction of build cost. This PCH targets the top 10 heaviest +# nvFuser-controllable headers by exclusive parse time (from M9 Task 4 analysis). +# Enabled by default for Release builds (provides ~50% build time improvement). +if(CMAKE_BUILD_TYPE STREQUAL "Release") + option(NVFUSER_USE_POLYMORPHIC_PCH "Use PCH for top nvFuser headers to reduce parse time" ON) +else() + option(NVFUSER_USE_POLYMORPHIC_PCH "Use PCH for top nvFuser headers to reduce parse time" OFF) +endif() + +if(NVFUSER_USE_POLYMORPHIC_PCH) + message(STATUS "Enabling PCH for top 10 nvFuser headers") + target_precompile_headers(codegen_internal PRIVATE + # Top 10 nvFuser headers by exclusive parse time (M9 Task 4 analysis) + "${NVFUSER_SRCS_DIR}/polymorphic_value.h" # 1675s (27.9m) + "${NVFUSER_ROOT}/lib/dynamic_type/src/dynamic_type/type_traits.h" # 473.6s (7.9m) + "${NVFUSER_SRCS_DIR}/ir/base_nodes.h" # 284.5s (4.7m) + "${NVFUSER_SRCS_DIR}/scheduler/tools/abstract_tensor.h" # 162.1s (2.7m) + "${NVFUSER_SRCS_DIR}/type.h" # 81.6s (1.4m) + "${NVFUSER_SRCS_DIR}/ir/container.h" # 51.6s (0.9m) + "${NVFUSER_SRCS_DIR}/serde/fusion_cache_generated.h" # 44.1s (0.7m) + "${NVFUSER_SRCS_DIR}/iter_visitor.h" # 38.2s (0.6m) + "${NVFUSER_SRCS_DIR}/ir/internal_nodes.h" # 33.3s (0.6m) + "${NVFUSER_SRCS_DIR}/ir/interface_nodes.h" # 29.6s (0.5m) + ) + # Skip PCH for polymorphic_value.cpp to allow visibility override + # (PCH caches type with hidden visibility) + set_source_files_properties( + "${NVFUSER_SRCS_DIR}/polymorphic_value.cpp" + PROPERTIES + SKIP_PRECOMPILE_HEADERS ON + COMPILE_OPTIONS "-fvisibility=default" + ) +endif() + add_library(nvfuser_codegen SHARED $) if (BUILD_CUTLASS AND CMAKE_CUDA_COMPILER_VERSION VERSION_GREATER_EQUAL 12.8) @@ -1109,6 +1145,35 @@ function(add_test_without_main TEST_NAME TEST_SRC ADDITIONAL_LINK) add_executable(${TEST_NAME} ${TEST_SRC}) set_property(TARGET ${TEST_NAME} PROPERTY CXX_STANDARD ${NVFUSER_CPP_STANDARD}) target_compile_definitions(${TEST_NAME} PRIVATE USE_GTEST) + + # PCH for test targets: All test executables share a single PCH to avoid + # redundant compilation. The first test target (test_nvfuser) creates the PCH, + # and all subsequent tests reuse it via REUSE_FROM. + # Note: Can't reuse from codegen_internal due to -fPIC flag difference. + if(NVFUSER_USE_POLYMORPHIC_PCH) + get_property(NVFUSER_TEST_PCH_TARGET GLOBAL PROPERTY NVFUSER_TEST_PCH_TARGET) + if(NOT NVFUSER_TEST_PCH_TARGET) + # First test target: create the PCH with top 10 nvFuser headers + message(STATUS "Creating shared test PCH on target: ${TEST_NAME}") + target_precompile_headers(${TEST_NAME} PRIVATE + "${NVFUSER_SRCS_DIR}/polymorphic_value.h" + "${NVFUSER_ROOT}/lib/dynamic_type/src/dynamic_type/type_traits.h" + "${NVFUSER_SRCS_DIR}/ir/base_nodes.h" + "${NVFUSER_SRCS_DIR}/scheduler/tools/abstract_tensor.h" + "${NVFUSER_SRCS_DIR}/type.h" + "${NVFUSER_SRCS_DIR}/ir/container.h" + "${NVFUSER_SRCS_DIR}/serde/fusion_cache_generated.h" + "${NVFUSER_SRCS_DIR}/iter_visitor.h" + "${NVFUSER_SRCS_DIR}/ir/internal_nodes.h" + "${NVFUSER_SRCS_DIR}/ir/interface_nodes.h" + ) + set_property(GLOBAL PROPERTY NVFUSER_TEST_PCH_TARGET ${TEST_NAME}) + else() + # Subsequent test targets: reuse existing PCH + target_precompile_headers(${TEST_NAME} REUSE_FROM ${NVFUSER_TEST_PCH_TARGET}) + endif() + endif() + target_include_directories(${TEST_NAME} PRIVATE "${NVFUSER_ROOT}") target_include_directories(${TEST_NAME} SYSTEM PRIVATE ${NVFUSER_ROOT}/third_party/googletest/googletest/include diff --git a/csrc/device_lower/utils.cpp b/csrc/device_lower/utils.cpp index efb5933aee7..7fb50a1f8f2 100644 --- a/csrc/device_lower/utils.cpp +++ b/csrc/device_lower/utils.cpp @@ -1375,7 +1375,7 @@ std::string print(const std::monostate&) { } std::string print(const Projection& proj) { - return Projection::dispatch( + return Projection::dispatch( [&](const auto& proj) { return print(proj); }, proj); } @@ -1400,7 +1400,7 @@ bool related(const std::monostate&, const ValGroup& to) { } bool related(const Projection& proj, const ValGroup& to) { - return Projection::dispatch( + return Projection::dispatch( [&](const auto& proj) { return related(proj, to); }, proj); } @@ -1430,7 +1430,7 @@ Val* extent(const std::monostate&) { } Val* extent(const Projection& proj) { - return Projection::dispatch( + return Projection::dispatch( [&](const auto& proj) { return extent(proj); }, proj); } @@ -1696,7 +1696,7 @@ Projection propagate( const ValGraph& id_graph, const ExprGroup& eg, Direction direction) { - return Projection::dispatch( + return Projection::dispatch( [&](const auto& proj) { return propagate(proj, id_graph, eg, direction); }, @@ -1757,7 +1757,7 @@ Val* proveLinearAndGetStrideAfterPropagation( Val* proveLinearAndGetStrideAfterPropagation( const Projection& proj, const ValGroups& domain) { - return Projection::dispatch( + return Projection::dispatch( [&](const auto& proj) { return proveLinearAndGetStrideAfterPropagation(proj, domain); }, @@ -2039,7 +2039,7 @@ Projection simplify(Projection projection) { auto simplified = projection; do { projection = simplified; - simplified = Projection::dispatch( + simplified = Projection::dispatch( [&](const auto& projection) { return simplify(projection); }, projection); } while (simplified.type() != projection.type() || simplified != projection); diff --git a/csrc/multidevice/symmetric_tensor.h b/csrc/multidevice/symmetric_tensor.h index 55860845f1b..13db3adf42c 100644 --- a/csrc/multidevice/symmetric_tensor.h +++ b/csrc/multidevice/symmetric_tensor.h @@ -85,11 +85,11 @@ class SymmetricTensor { size_t aligned_size_; bool are_remote_tensors_setup_ = false; bool is_multicast_setup_ = false; - CUmemGenericAllocationHandle mcast_handle_{}; - CUdevice cu_dev_{}; + [[maybe_unused]] CUmemGenericAllocationHandle mcast_handle_{}; + [[maybe_unused]] CUdevice cu_dev_{}; void* mc_ptr_{nullptr}; - int exporter_rank_{-1}; - int peer_fd_{-1}; + [[maybe_unused]] int exporter_rank_{-1}; + [[maybe_unused]] int peer_fd_{-1}; bool is_contiguous_view_setup_ = false; at::Tensor contiguous_view_; }; diff --git a/csrc/polymorphic_value.cpp b/csrc/polymorphic_value.cpp index 58c3c3344eb..7732518ae02 100644 --- a/csrc/polymorphic_value.cpp +++ b/csrc/polymorphic_value.cpp @@ -140,3 +140,18 @@ c10::IValue toIValue(const PolymorphicValue& x) { } // namespace PolymorphicValue_functions } // namespace nvfuser + +// Explicit instantiation of DynamicType for PolymorphicValue. +// This is the single point where the template is fully instantiated. +// Note: This file is compiled with -fvisibility=default (set in CMakeLists.txt) +// to ensure all DynamicType symbols are exported from the shared library. +template struct dynamic_type::DynamicType< + dynamic_type::Containers, + nvfuser::StructHandle, + nvfuser::Pointer, + nvfuser::Opaque, + at::Tensor, + std::complex, + double, + int64_t, + bool>; diff --git a/csrc/polymorphic_value.h b/csrc/polymorphic_value.h index 49b42555d79..27df7bc21e7 100644 --- a/csrc/polymorphic_value.h +++ b/csrc/polymorphic_value.h @@ -544,4 +544,17 @@ c10::IValue toIValue(const PolymorphicValue& x); } // namespace nvfuser +// Prevent implicit instantiation in other TUs - use explicit instantiation from +// polymorphic_value.cpp +extern template struct dynamic_type::DynamicType< + dynamic_type::Containers, + nvfuser::StructHandle, + nvfuser::Pointer, + nvfuser::Opaque, + at::Tensor, + std::complex, + double, + int64_t, + bool>; + #include diff --git a/csrc/scheduler/matmul_ampere-.cpp b/csrc/scheduler/matmul_ampere-.cpp index 04e9a220130..e594b506326 100644 --- a/csrc/scheduler/matmul_ampere-.cpp +++ b/csrc/scheduler/matmul_ampere-.cpp @@ -56,9 +56,9 @@ AbstractTensor swizzleSharedMemory(TensorView* shared_mem_tv) { // Extract the constant sizes of the swizzled tile const int64_t tile_size_x = - swizzle_domain[-2]->extent()->evaluate().as(); + swizzle_domain[-2].as()->extent()->evaluate().as(); const int64_t tile_size_y = - swizzle_domain[-1]->extent()->evaluate().as(); + swizzle_domain[-1].as()->extent()->evaluate().as(); // Only tested for (1) ldmatrix access with sizeof(T) == 16bit (i.e. // half/bfloat16) and (2) epilogue general access with sizeof(T) == 32bit diff --git a/csrc/scheduler/mma_utils.cpp b/csrc/scheduler/mma_utils.cpp index ab3c2b77049..481814f381f 100644 --- a/csrc/scheduler/mma_utils.cpp +++ b/csrc/scheduler/mma_utils.cpp @@ -1127,7 +1127,7 @@ AbstractTensor MmaSwizzler::scheduleMmaOutputAllocation(AbstractTensor t) { // Assume last 2 dims, for example [M64, N24] or [M64, N24, R] NVF_ERROR(t.size() >= 2); - bool has_reduction = t[-1]->isReduction(); + bool has_reduction = t[-1].as()->isReduction(); int64_t m_pos = has_reduction ? -3 : -2; int64_t n_pos = has_reduction ? -2 : -1; @@ -2473,9 +2473,9 @@ std::pair analyzeSwizzleSharedMemory( // Extract the constant sizes of the swizzled tile const int64_t tile_size_x = - swizzle_domain[-2]->extent()->evaluate().as(); + swizzle_domain[-2].as()->extent()->evaluate().as(); const int64_t tile_size_y = - swizzle_domain[-1]->extent()->evaluate().as(); + swizzle_domain[-1].as()->extent()->evaluate().as(); // Only tested for (1) ldmatrix access with sizeof(T) == 16bit (i.e. // half/bfloat16) and (2) epilogue general access with sizeof(T) == 32bit @@ -2717,7 +2717,7 @@ MmaInputSmemSwizzle tmaSwizzleSharedMemory(TensorView* shared_mem_tv) { AbstractTensor swizzle_domain(shared_mem_tv->getLoopDomain()); // Extract the constant sizes of the swizzled tile const int64_t inner_dim_size = - swizzle_domain[-1]->extent()->evaluate().as(); + swizzle_domain[-1].as()->extent()->evaluate().as(); auto dtype = shared_mem_tv->getDataType().value(); const int64_t B128_elements = 128 / dataTypeSizeByte(dtype); diff --git a/csrc/scheduler/tools/abstract_tensor.h b/csrc/scheduler/tools/abstract_tensor.h index 52801de1ff6..ec8b937a0a8 100644 --- a/csrc/scheduler/tools/abstract_tensor.h +++ b/csrc/scheduler/tools/abstract_tensor.h @@ -67,7 +67,8 @@ struct DispatchSplit { inner_result.reserve(in.size()); for (auto i : arange(in.size())) { auto [outer, inner] = - AbstractId::dispatch((*this), in[i], factor, inner_split); + AbstractId::dispatch>( + (*this), in[i], factor, inner_split); outer_result.emplace_back(outer); inner_result.emplace_back(inner); } @@ -119,7 +120,8 @@ struct DispatchMerge { std::vector result; result.reserve(lhs.size()); for (auto i : arange(lhs.size())) { - result.emplace_back(AbstractId::dispatch((*this), lhs[i], rhs[i])); + result.emplace_back( + AbstractId::dispatch((*this), lhs[i], rhs[i])); } return result; } else if constexpr (std::is_same_v>) { @@ -127,7 +129,8 @@ struct DispatchMerge { result.reserve(lhs.size()); for (auto i : arange(lhs.size())) { result.emplace_back( - AbstractId::dispatch((*this), lhs[i], std::forward(rhs))); + AbstractId::dispatch( + (*this), lhs[i], std::forward(rhs))); } return result; } else if constexpr (std::is_same_v>) { @@ -135,7 +138,8 @@ struct DispatchMerge { result.reserve(rhs.size()); for (auto i : arange(rhs.size())) { result.emplace_back( - AbstractId::dispatch((*this), std::forward(lhs), rhs[i])); + AbstractId::dispatch( + (*this), std::forward(lhs), rhs[i])); } return result; } else { @@ -198,7 +202,8 @@ struct DispatchSwizzle { result_y.reserve(lhs.size()); for (auto i : arange(lhs.size())) { auto [out_x, out_y] = - AbstractId::dispatch((*this), swizzle_type, lhs[i], rhs[i]); + AbstractId::dispatch>( + (*this), swizzle_type, lhs[i], rhs[i]); result_x.emplace_back(out_x); result_y.emplace_back(out_y); } @@ -209,8 +214,9 @@ struct DispatchSwizzle { result_x.reserve(lhs.size()); result_y.reserve(lhs.size()); for (auto i : arange(lhs.size())) { - auto [out_x, out_y] = AbstractId::dispatch( - (*this), swizzle_type, lhs[i], std::forward(rhs)); + auto [out_x, out_y] = + AbstractId::dispatch>( + (*this), swizzle_type, lhs[i], std::forward(rhs)); result_x.emplace_back(out_x); result_y.emplace_back(out_y); } @@ -221,8 +227,9 @@ struct DispatchSwizzle { result_x.reserve(rhs.size()); result_y.reserve(rhs.size()); for (auto i : arange(rhs.size())) { - auto [out_x, out_y] = AbstractId::dispatch( - (*this), swizzle_type, std::forward(lhs), rhs[i]); + auto [out_x, out_y] = + AbstractId::dispatch>( + (*this), swizzle_type, std::forward(lhs), rhs[i]); result_x.emplace_back(out_x); result_y.emplace_back(out_y); } @@ -283,7 +290,8 @@ struct DispatchLegacySwizzle { result_y.reserve(lhs.size()); for (auto i : arange(lhs.size())) { auto [out_x, out_y] = - AbstractId::dispatch((*this), swizzle_type, lhs[i], rhs[i]); + AbstractId::dispatch>( + (*this), swizzle_type, lhs[i], rhs[i]); result_x.emplace_back(out_x); result_y.emplace_back(out_y); } @@ -294,8 +302,9 @@ struct DispatchLegacySwizzle { result_x.reserve(lhs.size()); result_y.reserve(lhs.size()); for (auto i : arange(lhs.size())) { - auto [out_x, out_y] = AbstractId::dispatch( - (*this), swizzle_type, lhs[i], std::forward(rhs)); + auto [out_x, out_y] = + AbstractId::dispatch>( + (*this), swizzle_type, lhs[i], std::forward(rhs)); result_x.emplace_back(out_x); result_y.emplace_back(out_y); } @@ -306,8 +315,9 @@ struct DispatchLegacySwizzle { result_x.reserve(rhs.size()); result_y.reserve(rhs.size()); for (auto i : arange(rhs.size())) { - auto [out_x, out_y] = AbstractId::dispatch( - (*this), swizzle_type, std::forward(lhs), rhs[i]); + auto [out_x, out_y] = + AbstractId::dispatch>( + (*this), swizzle_type, std::forward(lhs), rhs[i]); result_x.emplace_back(out_x); result_y.emplace_back(out_y); } @@ -334,7 +344,7 @@ struct DispatchParallelize { } } else if constexpr (std::is_same_v>) { for (auto& aid : in) { - AbstractId::dispatch((*this), parallel_type, aid); + AbstractId::dispatch((*this), parallel_type, aid); } } else { NVF_CHECK(false, "Unsupported type in AbstractTensor::parallelize"); @@ -663,7 +673,8 @@ class AbstractTensorWithInfo { int64_t axis, ParallelType parallel_type) { axis = wrapDim(axis, (int64_t)domain_.size()); - AbstractId::dispatch(DispatchParallelize{}, parallel_type, domain_[axis]); + AbstractId::dispatch( + DispatchParallelize{}, parallel_type, domain_[axis]); return *this; } @@ -674,8 +685,9 @@ class AbstractTensorWithInfo { NVF_ERROR(domain_.size() == info_.size()); axis = wrapDim(axis, (int64_t)domain_.size()); - auto [outer, inner] = AbstractId::dispatch( - DispatchSplit{}, domain_[axis], factor, inner_split); + auto [outer, inner] = + AbstractId::dispatch>( + DispatchSplit{}, domain_[axis], factor, inner_split); std::swap(domain_[axis], inner); domain_.insert(domain_.begin() + axis, outer); @@ -700,8 +712,8 @@ class AbstractTensorWithInfo { axis_o = wrapDim(axis_o, (int64_t)domain_.size()); axis_i = wrapDim(axis_i, (int64_t)domain_.size()); - auto output = - AbstractId::dispatch(DispatchMerge{}, domain_[axis_o], domain_[axis_i]); + auto output = AbstractId::dispatch( + DispatchMerge{}, domain_[axis_o], domain_[axis_i]); // axis_o is the outer input of this merge but does not // automatically mean it's an outer domain in this AbstractTensorWithInfo. auto domain_outer_pos = axis_o < axis_i ? axis_o : axis_i; @@ -792,8 +804,9 @@ class AbstractTensorWithInfo { x = wrapDim(x, (int64_t)domain_.size()); y = wrapDim(y, (int64_t)domain_.size()); - auto [out_x, out_y] = AbstractId::dispatch( - DispatchSwizzle{}, swizzle_type, domain_[x], domain_[y]); + auto [out_x, out_y] = + AbstractId::dispatch>( + DispatchSwizzle{}, swizzle_type, domain_[x], domain_[y]); std::swap(domain_[x], out_x); std::swap(domain_[y], out_y); @@ -817,8 +830,9 @@ class AbstractTensorWithInfo { x = wrapDim(x, (int64_t)domain_.size()); y = wrapDim(y, (int64_t)domain_.size()); - auto [out_x, out_y] = AbstractId::dispatch( - DispatchLegacySwizzle{}, swizzle_type, domain_[x], domain_[y]); + auto [out_x, out_y] = + AbstractId::dispatch>( + DispatchLegacySwizzle{}, swizzle_type, domain_[x], domain_[y]); std::swap(domain_[x], out_x); std::swap(domain_[y], out_y); diff --git a/csrc/type.cpp b/csrc/type.cpp index 1a55427975d..285d47a8ac2 100644 --- a/csrc/type.cpp +++ b/csrc/type.cpp @@ -18,6 +18,67 @@ namespace nvfuser { +// Implementation moved from type.h to reduce template instantiation costs. +// Uses PolymorphicValue::for_all_types() which triggers ForAllTypes dispatch. +DataType getDataType(const PolymorphicValue& value) { + std::optional dtype = std::nullopt; + PolymorphicValue::for_all_types([&value, &dtype](auto _) { + using T = typename decltype(_)::type; + if constexpr (IsPrimitiveNativeType::value) { + if (value.is()) { + dtype = NativeTypeToDataType::type; + } + } else if constexpr (std::is_same_v>) { + if (value.is()) { + const auto& vec = value.as(); + size_t size = vec.size(); + NVF_CHECK(size > 0, "Empty array is not supported"); + dtype = + ArrayType{std::make_shared(getDataType(vec[0])), size}; + } + } else if constexpr (std::is_same_v) { + // For pointers in polymorphic value, we only store the data size of the + // pointee, so it is impossible to infer the pointer type. + NVF_CHECK(!value.is(), "Can not infer pointer type."); + } else if constexpr (std::is_same_v) { + if (value.is()) { + dtype = value.as().type(); + } + } else if constexpr (std::is_same_v) { + if (value.is()) { + const auto& opaque = value.as(); + dtype = DataType(OpaqueType{ + .type_info = opaque.any().type(), .size = opaque.size()}); + } + } + }); + NVF_CHECK(dtype.has_value(), "Unknown dtype for ", value.type().name()); + return dtype.value(); +} + +// Implementation moved from type.h to reduce template instantiation costs. +// Uses PolymorphicValue::for_all_types() which triggers ForAllTypes dispatch. +PolymorphicValue castToDtype(PolymorphicValue value, const DataType& dtype) { + if (!value.hasValue()) { + return value; + } + // Cast the given value to the given data type. This enables interface + // like: IrBuilder::create(0, DataType::Double) where value is + // an integer but the desired data type is double. + if (!hasCompatibleDataType(value, dtype)) { + PolymorphicValue::for_all_types([&](auto _) { + using T = typename decltype(_)::type; + if constexpr (IsPrimitiveNativeType::value) { + if (isCompatibleDataType(NativeTypeToDataType::type, dtype)) { + value = PolymorphicValue(static_cast(value)); + } + } + // TODO: support arrays and pointers + }); + } + return value; +} + StructType NotImplementedStruct::type() const { NVF_THROW("Not implemented"); } diff --git a/csrc/type.h b/csrc/type.h index 649297229b9..23723d4f2a5 100644 --- a/csrc/type.h +++ b/csrc/type.h @@ -414,41 +414,9 @@ DEFINE_DATATYPE_TO_NATIVE_TYPE(DataType::ComplexDouble, std::complex); #undef DEFINE_DATATYPE_TO_NATIVE_TYPE -inline DataType getDataType(const PolymorphicValue& value) { - std::optional dtype = std::nullopt; - PolymorphicValue::for_all_types([&value, &dtype](auto _) { - using T = typename decltype(_)::type; - if constexpr (IsPrimitiveNativeType::value) { - if (value.is()) { - dtype = NativeTypeToDataType::type; - } - } else if constexpr (std::is_same_v>) { - if (value.is()) { - const auto& vec = value.as(); - size_t size = vec.size(); - NVF_CHECK(size > 0, "Empty array is not supported"); - dtype = - ArrayType{std::make_shared(getDataType(vec[0])), size}; - } - } else if constexpr (std::is_same_v) { - // For pointers in polymorphic value, we only store the data size of the - // pointee, so it is impossible to infer the pointer type. - NVF_CHECK(!value.is(), "Can not infer pointer type."); - } else if constexpr (std::is_same_v) { - if (value.is()) { - dtype = value.as().type(); - } - } else if constexpr (std::is_same_v) { - if (value.is()) { - const auto& opaque = value.as(); - dtype = DataType(OpaqueType{ - .type_info = opaque.any().type(), .size = opaque.size()}); - } - } - }); - NVF_CHECK(dtype.has_value(), "Unknown dtype for ", value.type().name()); - return dtype.value(); -} +// Get the DataType corresponding to the runtime type held in a PolymorphicValue. +// Implementation moved to type.cpp to reduce template instantiation costs. +NVF_API DataType getDataType(const PolymorphicValue& value); inline bool isCompatibleDataType(DataType dtype, DataType dtype2) { if (dtype == dtype2) { @@ -1128,28 +1096,11 @@ Pointer::Pointer(void* ptr, DataType dtype) : ptr_(reinterpret_cast(ptr)), size_bit_(dataTypeSizeBit(dtype)) {} -inline PolymorphicValue castToDtype( +// Cast a PolymorphicValue to match the specified DataType. +// Implementation moved to type.cpp to reduce template instantiation costs. +NVF_API PolymorphicValue castToDtype( PolymorphicValue value, - const DataType& dtype) { - if (!value.hasValue()) { - return value; - } - // Cast the given value to the given data type. This enables interface - // like: IrBuilder::create(0, DataType::Double) where value is - // an integer but the desired data type is double. - if (!hasCompatibleDataType(value, dtype)) { - PolymorphicValue::for_all_types([&](auto _) { - using T = typename decltype(_)::type; - if constexpr (IsPrimitiveNativeType::value) { - if (isCompatibleDataType(NativeTypeToDataType::type, dtype)) { - value = PolymorphicValue(static_cast(value)); - } - } - // TODO: support arrays and pointers - }); - } - return value; -} + const DataType& dtype); // Converts an enum to its underlying type. // It corresponds with std::to_underlying introduced in c++23 diff --git a/lib/dynamic_type/benchmark/knn.cpp b/lib/dynamic_type/benchmark/knn.cpp index 2f39fe29ee8..53c3dd96a80 100644 --- a/lib/dynamic_type/benchmark/knn.cpp +++ b/lib/dynamic_type/benchmark/knn.cpp @@ -143,7 +143,7 @@ static StructVecDouble kNN_Dictionary( sum += distances_and_values.top().second; distances_and_values.pop(); } - return sum / k; + return sum / static_cast(k); } static void kNN_Dictionary(benchmark::State& state) { diff --git a/lib/dynamic_type/src/dynamic_type/decl.h b/lib/dynamic_type/src/dynamic_type/decl.h new file mode 100644 index 00000000000..1d0d56067e8 --- /dev/null +++ b/lib/dynamic_type/src/dynamic_type/decl.h @@ -0,0 +1,1035 @@ +// clang-format off +/* + * SPDX-FileCopyrightText: Copyright (c) 2023-present NVIDIA CORPORATION & AFFILIATES. + * All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + */ +// clang-format on +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "error.h" +#include "type_traits.h" + +// Visibility attribute for exported symbols. +// Static member functions need default visibility to be exported from +// shared libraries built with -fvisibility=hidden. +#if defined _WIN32 || defined __CYGWIN__ +#define DT_API __declspec(dllexport) +#else +#define DT_API __attribute__((visibility("default"))) +#endif + +namespace dynamic_type { + +// We must disable a lot of compiler warnings to make this work. The reason for +// the need to disable these warnings is not because the code quality in this +// file is bad, but because these apparently "bad" practices are necessary. For +// example, if you have a dynamic type that can be either a bool or a class +// SomeType{}, then we should support the ~ operator on it, because in the C++ +// standard bool supports it. Usually, when people write code like ~bool, they +// are making a mistake, and the compiler will want you to use !bool instead. +// However, in our case here we will allow everything that the C++ standard +// allows. The compiler should yell at the user who uses DynamicType with ~ +// but not at us for implementing it. + +#if defined(__clang__) +#pragma clang diagnostic push +#pragma clang diagnostic ignored "-Wunused-comparison" +#pragma clang diagnostic ignored "-Wbitwise-instead-of-logical" +#pragma clang diagnostic ignored "-Wliteral-conversion" +#pragma clang diagnostic ignored "-Wunused-lambda-capture" +#pragma clang diagnostic ignored "-Wunknown-warning-option" +#pragma clang diagnostic ignored "-Wbool-operation" +#endif + +#if defined(__GNUC__) && !defined(__clang__) +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wbool-operation" +// gcc, even the latest version (13.1.1), is complaining about the following +// code: +// std::optional ret = std::nullopt; +// ... +// DYNAMIC_TYPE_CHECK(ret.has_value(), ...); +// return ret.value(); +// saying that ret.value() is used uninitialized. This complaint is totoally +// nonsense. +#pragma GCC diagnostic ignored "-Wmaybe-uninitialized" +#endif + +template