Skip to content

Commit

Permalink
Support for CUDA compilation of the AD forward mode.
Browse files Browse the repository at this point in the history
Merge pull request #44 from feature/cudaSupport.
Reviewed-by: Johannes Blühdorn <[email protected]>
  • Loading branch information
MaxSagebaum committed Oct 12, 2023
2 parents eee1b5e + 9f6d669 commit abc6c7d
Show file tree
Hide file tree
Showing 29 changed files with 1,236 additions and 131 deletions.
1 change: 1 addition & 0 deletions documentation/Changelog.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ Changelog {#Changelog}
- Features:
* New helper for adding Enzyme-generated derivative functions to the tape. See \ref Example_24_Enzyme_external_function_helper.
* Recover primal values from primal values tapes in ExternalFunctionHelper.
* Forward AD type for CUDA kernels.

- Bugfix:
* Uninitialized values in external function helper.
Expand Down
4 changes: 3 additions & 1 deletion include/codi.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@

#include "codi/config.h"
#include "codi/expressions/activeType.hpp"
#include "codi/expressions/activeTypeStatelessTape.hpp"
#include "codi/expressions/activeTypeWrapper.hpp"
#include "codi/expressions/immutableActiveType.hpp"
#include "codi/expressions/real/allOperators.hpp"
Expand Down Expand Up @@ -242,11 +243,12 @@ namespace codi {
* This is the scalar version which does not use a vector mode.
*/
using JacobianComputationScalarType = RealReverseIndex;

}

#include "codi/tools/helpers/evaluationHelper.hpp"

#include "codi/tools/cuda/codiCUDA.hpp"

#if CODI_EnableOpenMP
#include "codi/tools/parallel/openmp/codiOpenMP.hpp"
#endif
Expand Down
17 changes: 13 additions & 4 deletions include/codi/config.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
#include <stddef.h>
#include <stdint.h>

#include "tools/cuda/cudaFunctionAttributes.hpp"
#include "misc/exceptions.hpp"

/** @file */
Expand Down Expand Up @@ -339,6 +340,9 @@ namespace codi {
/// @name Macro definitions
/// @{

/// Attributes for all CoDiPack functions.
#define CODI_FunctionAttributes CODI_CUDAFunctionAttributes

#ifndef CODI_AnnotateBranchLikelihood
/// See codi::Config::AnnotateBranchLikelihood.
#define CODI_AnnotateBranchLikelihood CODI_HasCpp20
Expand Down Expand Up @@ -398,16 +402,21 @@ namespace codi {
#endif
#if CODI_ForcedInlines
#if defined(__INTEL_COMPILER) | defined(_MSC_VER)
#define CODI_INLINE __forceinline
#define CODI_INLINE CODI_FunctionAttributes __forceinline
#define CODI_INLINE_NO_FA __forceinline
#elif defined(__GNUC__)
#define CODI_INLINE inline __attribute__((always_inline))
#define CODI_INLINE CODI_FunctionAttributes inline __attribute__((always_inline))
#define CODI_INLINE_NO_FA inline __attribute__((always_inline))
#else
#warning Could not determine compiler for forced inline definitions. Using inline.
#define CODI_INLINE inline
#define CODI_INLINE CODI_FunctionAttributes inline
#define CODI_INLINE_NO_FA inline
#endif
#else
/// See codi::Config::ForcedInlines.
#define CODI_INLINE inline
#define CODI_INLINE CODI_FunctionAttributes inline
/// See codi::Config::ForcedInlines.
#define CODI_INLINE_NO_FA inline
#endif
/// Force inlining instead of using the heuristics from the compiler.
bool constexpr ForcedInlines = CODI_ForcedInlines;
Expand Down
12 changes: 6 additions & 6 deletions include/codi/expressions/activeTypeBase.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,12 @@ namespace codi {
Base::init(Real(), EventHints::Statement::Passive);
}

/// Constructor
template<typename U = Real, typename = RealTraits::EnableIfNotPassiveReal<U>>
CODI_INLINE ActiveTypeBase(PassiveReal const& value) : primalValue(value), identifier() {
Base::init(value, EventHints::Statement::Passive);
}

/// Constructor
CODI_INLINE ActiveTypeBase(ActiveTypeBase const& v) : primalValue(), identifier() {
Base::init(v.getValue(), EventHints::Statement::Copy);
Expand All @@ -108,12 +114,6 @@ namespace codi {
Base::init(value, EventHints::Statement::Passive);
}

/// Constructor
template<typename U = Real, typename = RealTraits::EnableIfNotPassiveReal<U>>
CODI_INLINE ActiveTypeBase(PassiveReal const& value) : primalValue(value), identifier() {
Base::init(value, EventHints::Statement::Passive);
}

/// Constructor
template<typename Rhs>
CODI_INLINE ActiveTypeBase(ExpressionInterface<Real, Rhs> const& rhs) : primalValue(), identifier() {
Expand Down
153 changes: 153 additions & 0 deletions include/codi/expressions/activeTypeStatelessTape.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,153 @@
/*
* CoDiPack, a Code Differentiation Package
*
* Copyright (C) 2015-2023 Chair for Scientific Computing (SciComp), University of Kaiserslautern-Landau
* Homepage: http://www.scicomp.uni-kl.de
* Contact: Prof. Nicolas R. Gauger ([email protected])
*
* Lead developers: Max Sagebaum, Johannes Blühdorn (SciComp, University of Kaiserslautern-Landau)
*
* This file is part of CoDiPack (http://www.scicomp.uni-kl.de/software/codi).
*
* CoDiPack is free software: you can redistribute it and/or
* modify it under the terms of the GNU General Public License
* as published by the Free Software Foundation, either version 3 of the
* License, or (at your option) any later version.
*
* CoDiPack is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty
* of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
*
* See the GNU General Public License for more details.
* You should have received a copy of the GNU
* General Public License along with CoDiPack.
* If not, see <http://www.gnu.org/licenses/>.
*
* For other licensing options please contact us.
*
* Authors:
* - SciComp, University of Kaiserslautern-Landau:
* - Max Sagebaum
* - Johannes Blühdorn
* - Former members:
* - Tim Albring
*/
#pragma once

#include "../config.h"
#include "../misc/macros.hpp"
#include "../tapes/interfaces/fullTapeInterface.hpp"
#include "../traits/realTraits.hpp"
#include "assignmentOperators.hpp"
#include "incrementOperators.hpp"
#include "lhsExpressionInterface.hpp"

/** \copydoc codi::Namespace */
namespace codi {

/**
* @brief Represents a concrete lvalue in the CoDiPack expression tree.
*
* See also LhsExpressionInterface.
*
* This active type does not work with a fixed tape. Instead, getTape() constructs a new temporary tape on every call.
* In particular, tapes for this active type can not have a persistent state.
*
* @tparam T_Tape The tape that manages all expressions created with this type.
*/
template<typename T_Tape>
struct ActiveTypeStatelessTape : public LhsExpressionInterface<typename T_Tape::Real, typename T_Tape::Gradient, T_Tape,
ActiveTypeStatelessTape<T_Tape>>,
public AssignmentOperators<T_Tape, ActiveTypeStatelessTape<T_Tape>>,
public IncrementOperators<T_Tape, ActiveTypeStatelessTape<T_Tape>> {
public:

using Tape = CODI_DD(T_Tape, CODI_DEFAULT_TAPE); ///< See ActiveTypeStatelessTape.

using Real = typename Tape::Real; ///< See LhsExpressionInterface.
using PassiveReal = RealTraits::PassiveReal<Real>; ///< Basic computation type.
using Identifier = typename Tape::Identifier; ///< See LhsExpressionInterface.
using Gradient = typename Tape::Gradient; ///< See LhsExpressionInterface.

using Base =
LhsExpressionInterface<Real, Gradient, T_Tape, ActiveTypeStatelessTape<T_Tape>>; ///< Base class abbreviation.

private:

Real primalValue;
Identifier identifier;

public:

/// @brief Constructor
/// @details CUDA compiler has problems when this function is annotated with \c __device__.
constexpr CODI_INLINE_NO_FA ActiveTypeStatelessTape() = default;

/// Constructor
constexpr CODI_INLINE ActiveTypeStatelessTape(PassiveReal const& value) : primalValue(value), identifier() {}

/// Constructor
CODI_INLINE ActiveTypeStatelessTape(ActiveTypeStatelessTape const& v) : primalValue(), identifier() {
Base::init(v.getValue(), EventHints::Statement::Copy);
getTape().store(*this, v);
}

/// Constructor
template<typename Rhs>
CODI_INLINE ActiveTypeStatelessTape(ExpressionInterface<Real, Rhs> const& rhs) : primalValue(), identifier() {
Base::init(rhs.cast().getValue(), EventHints::Statement::Expression);
getTape().store(*this, rhs.cast());
}

/*******************************************************************************/
/// @name Assignment operators (all forwarding to the base class)
/// @{

/// See ActiveTypeStatelessTape::operator=(ActiveTypeStatelessTape const&).
CODI_INLINE ActiveTypeStatelessTape& operator=(ActiveTypeStatelessTape const& v) {
static_cast<Base&>(*this) = static_cast<Base const&>(v);
return *this;
}

using Base::operator=;

/*******************************************************************************/
/// @name Implementation of ExpressionInterface
/// @{

using StoreAs = ActiveTypeStatelessTape const&; ///< \copydoc codi::ExpressionInterface::StoreAs
using ActiveResult = ActiveTypeStatelessTape; ///< \copydoc codi::ExpressionInterface::ActiveResult

/// @}
/*******************************************************************************/
/// @name Implementation of LhsExpressionInterface
/// @{

/// \copydoc codi::LhsExpressionInterface::getIdentifier()
CODI_INLINE Identifier& getIdentifier() {
return identifier;
}

/// \copydoc codi::LhsExpressionInterface::getIdentifier() const
CODI_INLINE Identifier const& getIdentifier() const {
return identifier;
}

/// \copydoc codi::LhsExpressionInterface::value()
CODI_INLINE Real& value() {
return primalValue;
}

/// \copydoc codi::LhsExpressionInterface::value() const
CODI_INLINE Real const& value() const {
return primalValue;
}

/// \copydoc codi::LhsExpressionInterface::getTape()
static CODI_INLINE Tape getTape() {
return Tape();
}

/// @}
};
}
4 changes: 2 additions & 2 deletions include/codi/expressions/binaryExpression.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -101,8 +101,8 @@ namespace codi {

/// Constructor
template<typename RealA, typename RealB>
explicit BinaryExpression(ExpressionInterface<RealA, ArgA> const& argA,
ExpressionInterface<RealB, ArgB> const& argB)
CODI_INLINE explicit BinaryExpression(ExpressionInterface<RealA, ArgA> const& argA,
ExpressionInterface<RealB, ArgB> const& argB)
: argA(argA.cast()),
argB(argB.cast()),
result(Operation::primal(this->argA.getValue(), this->argB.getValue())) {}
Expand Down
5 changes: 3 additions & 2 deletions include/codi/expressions/lhsExpressionInterface.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
#include "../config.h"
#include "../misc/eventSystem.hpp"
#include "../misc/macros.hpp"
#include "../misc/toConst.hpp"
#include "../tapes/interfaces/fullTapeInterface.hpp"
#include "../traits/expressionTraits.hpp"
#include "../traits/realTraits.hpp"
Expand Down Expand Up @@ -67,7 +68,7 @@ namespace codi {
using Tape = CODI_DD(T_Tape, CODI_DEFAULT_TAPE); ///< See LhsExpressionInterface.
using Impl = CODI_DD(T_Impl, LhsExpressionInterface); ///< See LhsExpressionInterface.

using Base = ExpressionInterface<T_Real, T_Impl>;
using Base = ExpressionInterface<T_Real, T_Impl>; ///< Base class abbreviation.

using Identifier = typename Tape::Identifier; ///< See GradientAccessTapeInterface.
using PassiveReal = RealTraits::PassiveReal<Real>; ///< Basic computation type.
Expand Down Expand Up @@ -107,7 +108,7 @@ namespace codi {

/// Get the gradient of this lvalue from the tape.
CODI_INLINE Gradient const& gradient() const {
return const_cast<Tape const&>(Impl::getTape()).gradient(cast().getIdentifier());
return toConst(Impl::getTape()).gradient(cast().getIdentifier());
}

/// Get the gradient of this lvalue from the tape.
Expand Down
53 changes: 53 additions & 0 deletions include/codi/expressions/real/binaryOperators.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -203,13 +203,15 @@ namespace codi {
using std::copysign;
using std::fmax;
using std::fmin;
using std::fmod;
using std::frexp;
using std::hypot;
using std::ldexp;
using std::max;
using std::min;
using std::pow;
using std::remainder;
using std::trunc;

/// BinaryOperation implementation for atan2
template<typename T_Real>
Expand Down Expand Up @@ -311,6 +313,50 @@ namespace codi {

#define OPERATION_LOGIC OperationCopysign
#define FUNCTION copysignf
#include "binaryOverloads.tpp"

/// BinaryOperation implementation for fmod
template<typename T_Real>
struct OperationFmod : public BinaryOperation<T_Real> {
public:

using Real = CODI_DD(T_Real, double); ///< See BinaryOperation.

/// \copydoc codi::BinaryOperation::primal()
template<typename ArgA, typename ArgB>
static CODI_INLINE Real primal(ArgA const& argA, ArgB const& argB) {
return fmod(argA, argB);
}

/// \copydoc codi::BinaryOperation::gradientA()
template<typename ArgA, typename ArgB>
static CODI_INLINE RealTraits::PassiveReal<Real> gradientA(ArgA const& argA, ArgB const& argB,
Real const& result) {
CODI_UNUSED(argA, argB, result);

return RealTraits::PassiveReal<Real>(1.0);
}

/// \copydoc codi::BinaryOperation::gradientB()
template<typename ArgA, typename ArgB>
static CODI_INLINE RealTraits::PassiveReal<Real> gradientB(ArgA const& argA, ArgB const& argB,
Real const& result) {
CODI_UNUSED(result);

if (RealTraits::getPassiveValue(argB) == 0.0) {
return RealTraits::PassiveReal<Real>(0.0);
} else {
return -trunc(RealTraits::getPassiveValue(argA / argB));
}
}
};

#define OPERATION_LOGIC OperationFmod
#define FUNCTION fmod
#include "binaryOverloads.tpp"

#define OPERATION_LOGIC OperationFmod
#define FUNCTION fmodf
#include "binaryOverloads.tpp"

/// BinaryOperation implementation for frexp
Expand Down Expand Up @@ -614,6 +660,10 @@ namespace codi {
};
#define OPERATION_LOGIC OperationPow
#define FUNCTION pow
#include "binaryOverloads.tpp"

#define OPERATION_LOGIC OperationPow
#define FUNCTION powf
#include "binaryOverloads.tpp"

/// BinaryOperation implementation for remainder
Expand Down Expand Up @@ -688,6 +738,8 @@ namespace std {
using codi::copysignf;
using codi::fmax;
using codi::fmin;
using codi::fmod;
using codi::fmodf;
using codi::frexp;
using codi::hypot;
using codi::hypotf;
Expand All @@ -696,6 +748,7 @@ namespace std {
using codi::max;
using codi::min;
using codi::pow;
using codi::powf;
using codi::remainder;
using codi::swap;
}
2 changes: 1 addition & 1 deletion include/codi/expressions/referenceActiveType.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ namespace codi {
}

/// \copydoc codi::LhsExpressionInterface::getTape()
static CODI_INLINE Tape& getTape() {
static CODI_INLINE decltype(Type::getTape()) getTape() {
return Type::getTape();
}
};
Expand Down
Loading

0 comments on commit abc6c7d

Please sign in to comment.