Skip to content

Commit

Permalink
Reorganize code and remove code duplication (#10)
Browse files Browse the repository at this point in the history
  • Loading branch information
pabloferz authored Jun 21, 2022
2 parents f64215d + faf6d03 commit 935059b
Show file tree
Hide file tree
Showing 13 changed files with 503 additions and 522 deletions.
190 changes: 31 additions & 159 deletions dlext/include/DLExt.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,36 +4,47 @@
#ifndef HOOMD_DLPACK_EXTENSION_H_
#define HOOMD_DLPACK_EXTENSION_H_

#include <type_traits>
#include <vector>

#include "SystemView.h"
#include "cxx11utils.h"
#include "dlpack/dlpack.h"

#include "hoomd/GlobalArray.h"

namespace dlext
{

using namespace hoomd;

// { // Aliases

using DLManagedTensorPtr = DLManagedTensor*;
using DLManagedTensorDeleter = void (*)(DLManagedTensorPtr);

using AccessLocation = access_location::Enum;
const auto kOnHost = access_location::host;
#ifdef ENABLE_CUDA
const auto kOnDevice = access_location::device;
#endif
template <typename T>
using ArrayHandleUPtr = std::unique_ptr<ArrayHandle<T>>;

using AccessMode = access_mode::Enum;
const auto kRead = access_mode::read;
const auto kReadWrite = access_mode::readwrite;
const auto kOverwrite = access_mode::overwrite;
// } // Aliases

// { // Constants

constexpr uint8_t kBits = std::is_same<Scalar, float>::value ? 32 : 64;

template <template <typename> class Array, typename T, typename Object>
using PropertyGetter = const Array<T>& (Object::*)() const;
constexpr DLManagedTensor kInvalidDLManagedTensor {
DLTensor {
nullptr, // data
DLDevice { kDLExtDev, -1 }, // device
-1, // ndim
DLDataType { 0, 0, 0 }, // dtype
nullptr, // shape
nullptr, // stride
0 // byte_offset
},
nullptr,
nullptr
};

template <typename T>
using ArrayHandleUPtr = std::unique_ptr<ArrayHandle<T>>;
// } // Constants

template <typename T>
struct DLDataBridge {
Expand All @@ -51,19 +62,19 @@ template <typename T>
using DLDataBridgeUPtr = std::unique_ptr<DLDataBridge<T>>;

template <typename T>
void DLDataBridgeDeleter(DLManagedTensorPtr tensor)
void delete_bridge(DLManagedTensorPtr tensor)
{
if (tensor)
delete static_cast<DLDataBridge<T>*>(tensor->manager_ctx);
}

void do_not_delete(DLManagedTensorPtr tensor) { }

template <typename T>
inline void* opaque(T* data) { return static_cast<void*>(data); }

inline DLDevice dldevice(const SystemView& sysview, bool gpu_flag)
{
return DLDevice { gpu_flag ? kDLCUDA : kDLCPU, sysview.get_device_id(gpu_flag) };
}
template <typename T>
inline void* opaque(const T* data) { return (void*)(data); }

template <typename>
constexpr DLDataType dtype();
Expand All @@ -78,19 +89,6 @@ constexpr DLDataType dtype<int3>() { return DLDataType {kDLInt, 32, 1}; }
template <>
constexpr DLDataType dtype<unsigned int>() { return DLDataType {kDLUInt, 32, 1}; }

template <template <typename> class>
unsigned int particle_number(const SystemView& sysview);
template <>
inline unsigned int particle_number<GlobalArray>(const SystemView& sysview)
{
return sysview.local_particle_number();
}
template <>
inline unsigned int particle_number<GlobalVector>(const SystemView& sysview)
{
return sysview.global_particle_number();
}

template <typename>
constexpr int64_t stride1();
template <>
Expand All @@ -104,132 +102,6 @@ constexpr int64_t stride1<int3>() { return 3; }
template <>
constexpr int64_t stride1<unsigned int>() { return 1; }

template <template <typename> class A, typename T, typename O>
DLManagedTensorPtr wrap(
const SystemView& sysview, PropertyGetter<A, T, O> getter,
AccessLocation requested_location, AccessMode mode,
int64_t size2 = 1, uint64_t offset = 0, uint64_t stride1_offset = 0
) {
assert((size2 >= 1)); // assert is a macro so the extra parentheses are requiered here

auto location = sysview.is_gpu_enabled() ? requested_location : kOnHost;
auto handle = ArrayHandleUPtr<T>(
new ArrayHandle<T>(INVOKE(*(sysview.particle_data()), getter)(), location, mode)
);
auto bridge = DLDataBridgeUPtr<T>(new DLDataBridge<T>(handle));

#ifdef ENABLE_CUDA
auto gpu_flag = (location == kOnDevice);
#else
auto gpu_flag = false;
#endif

bridge->tensor.manager_ctx = bridge.get();
bridge->tensor.deleter = DLDataBridgeDeleter<T>;

auto& dltensor = bridge->tensor.dl_tensor;
dltensor.data = opaque(bridge->handle->data);
dltensor.device = dldevice(sysview, gpu_flag);
dltensor.dtype = dtype<T>();

auto& shape = bridge->shape;
shape.push_back(particle_number<A>(sysview));
if (size2 > 1)
shape.push_back(size2);

auto& strides = bridge->strides;
strides.push_back(stride1<T>() + stride1_offset);
if (size2 > 1)
strides.push_back(1);

dltensor.ndim = shape.size();
dltensor.shape = reinterpret_cast<std::int64_t*>(shape.data());
dltensor.strides = reinterpret_cast<std::int64_t*>(strides.data());
dltensor.byte_offset = offset;

return &(bridge.release()->tensor);
}

inline DLManagedTensorPtr positions_types(
const SystemView& sysview, AccessLocation location, AccessMode mode = kReadWrite
) {
return wrap(sysview, &ParticleData::getPositions, location, mode, 4);
}

inline DLManagedTensorPtr velocities_masses(
const SystemView& sysview, AccessLocation location, AccessMode mode = kReadWrite
) {
return wrap(sysview, &ParticleData::getVelocities, location, mode, 4);
}

inline DLManagedTensorPtr orientations(
const SystemView& sysview, AccessLocation location, AccessMode mode = kReadWrite
) {
return wrap(sysview, &ParticleData::getOrientationArray, location, mode, 4);
}

inline DLManagedTensorPtr angular_momenta(
const SystemView& sysview, AccessLocation location, AccessMode mode = kReadWrite
) {
return wrap(sysview, &ParticleData::getAngularMomentumArray, location, mode, 4);
}

inline DLManagedTensorPtr moments_of_intertia(
const SystemView& sysview, AccessLocation location, AccessMode mode = kReadWrite
) {
return wrap(sysview, &ParticleData::getMomentsOfInertiaArray, location, mode, 3);
}

inline DLManagedTensorPtr charges(
const SystemView& sysview, AccessLocation location, AccessMode mode = kReadWrite
) {
return wrap(sysview, &ParticleData::getCharges, location, mode);
}

inline DLManagedTensorPtr diameters(
const SystemView& sysview, AccessLocation location, AccessMode mode = kReadWrite
) {
return wrap(sysview, &ParticleData::getDiameters, location, mode);
}

inline DLManagedTensorPtr images(
const SystemView& sysview, AccessLocation location, AccessMode mode = kReadWrite
) {
return wrap(sysview, &ParticleData::getImages, location, mode, 3);
}

inline DLManagedTensorPtr tags(
const SystemView& sysview, AccessLocation location, AccessMode mode = kReadWrite
) {
return wrap(sysview, &ParticleData::getTags, location, mode);
}

inline DLManagedTensorPtr rtags(
const SystemView& sysview, AccessLocation location, AccessMode mode = kReadWrite
) {
return wrap(sysview, &ParticleData::getRTags, location, mode);
}

inline DLManagedTensorPtr net_forces(
const SystemView& sysview, AccessLocation location, AccessMode mode = kReadWrite
) {
return wrap(sysview, &ParticleData::getNetForce, location, mode, 4);
}

inline DLManagedTensorPtr net_torques(
const SystemView& sysview, AccessLocation location, AccessMode mode = kReadWrite
) {
return wrap(sysview, &ParticleData::getNetTorqueArray, location, mode, 4);
}

inline DLManagedTensorPtr net_virial(
const SystemView& sysview, AccessLocation location, AccessMode mode = kReadWrite
) {
return wrap(sysview, &ParticleData::getNetVirial, location, mode, 6);
}


} // namespace dlext


#endif // HOOMD_DLPACK_EXTENSION_H_
82 changes: 82 additions & 0 deletions dlext/include/Sampler.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
// SPDX-License-Identifier: MIT
// This file is part of `hoomd-dlext`, see LICENSE.md

#ifndef DLEXT_SAMPLER_H_
#define DLEXT_SAMPLER_H_

#include "SystemView.h"
#include "hoomd/HalfStepHook.h"

namespace dlext
{

using TimeStep = unsigned int;

template <typename ExternalUpdater, template <typename> class Wrapper>
class DEFAULT_VISIBILITY Sampler : public HalfStepHook {
public:
//! Constructor
Sampler(
SystemView sysview,
ExternalUpdater update_callback,
AccessLocation location,
AccessMode mode
);
void setSystemDefinition(SystemDefinitionSPtr sysdef) override
{
_sysview = SystemView(sysdef);
}
void update(TimeStep timestep) override
{
forward_data(_update_callback, _location, _mode, timestep);
}

const SystemView& system_view() const;

//! Wraps the system positions, velocities, reverse tags, images and forces as
//! DLPack tensors and passes them to the external function `callback`.
//!
//! The (non-typed) signature of `callback` is expected to be
//! callback(positions, velocities, rtags, images, forces, n)
//! where `n` ìs an additional `TimeStep` parameter.
//!
//! The data for the particles information is requested at the given `location`
//! and access `mode`. NOTE: Forces are always passed in readwrite mode.
template <typename Callback>
void forward_data(Callback callback, AccessLocation location, AccessMode mode, TimeStep n)
{
auto pos_capsule = Wrapper<PositionsTypes>::wrap(_sysview, location, mode);
auto vel_capsule = Wrapper<VelocitiesMasses>::wrap(_sysview, location, mode);
auto rtags_capsule = Wrapper<RTags>::wrap(_sysview, location, mode);
auto img_capsule = Wrapper<Images>::wrap(_sysview, location, mode);
auto force_capsule = Wrapper<NetForces>::wrap(_sysview, location, kReadWrite);

callback(pos_capsule, vel_capsule, rtags_capsule, img_capsule, force_capsule, n);
}

private:
SystemView _sysview;
ExternalUpdater _update_callback;
AccessLocation _location;
AccessMode _mode;
};

template <typename ExternalUpdater, template <typename> class Wrapper>
Sampler<ExternalUpdater, Wrapper>::Sampler(
SystemView sysview, ExternalUpdater update, AccessLocation location, AccessMode mode
)
: _sysview { sysview }
, _update_callback { update }
, _location { location }
, _mode { mode }
{ }

template <typename ExternalUpdater, template <typename> class Wrapper>
const SystemView& Sampler<ExternalUpdater, Wrapper>::system_view() const
{
return _sysview;
}

} // namespace dlext

#endif // DLEXT_SAMPLER_H_
Loading

0 comments on commit 935059b

Please sign in to comment.