Skip to content

Commit

Permalink
Use intrusive reference counting
Browse files Browse the repository at this point in the history
  • Loading branch information
inducer committed May 20, 2024
1 parent f5e1b7b commit f4bd141
Show file tree
Hide file tree
Showing 6 changed files with 47 additions and 25 deletions.
11 changes: 11 additions & 0 deletions src/wrap_cl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
#define PY_ARRAY_UNIQUE_SYMBOL pyopencl_ARRAY_API

#include "wrap_cl.hpp"
#include <nanobind/intrusive/counter.inl>



Expand All @@ -49,6 +50,16 @@ static bool import_numpy_helper()

NB_MODULE(_cl, m)
{
py::intrusive_init(
[](PyObject *o) noexcept {
py::gil_scoped_acquire guard;
Py_INCREF(o);
},
[](PyObject *o) noexcept {
py::gil_scoped_acquire guard;
Py_DECREF(o);
});

if (!import_numpy_helper())
throw py::python_error();

Expand Down
23 changes: 11 additions & 12 deletions src/wrap_cl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -1143,7 +1143,7 @@ namespace pyopencl

// {{{ context

class context : public noncopyable
class context : public noncopyable, public py::intrusive_base
{
private:
cl_context m_context;
Expand Down Expand Up @@ -1415,7 +1415,7 @@ namespace pyopencl

// {{{ command_queue

class command_queue
class command_queue: public py::intrusive_base
{
private:
cl_command_queue m_queue;
Expand Down Expand Up @@ -1625,13 +1625,12 @@ namespace pyopencl
}
}

std::unique_ptr<context> get_context() const
py::ref<context> get_context() const
{
cl_context param_value;
PYOPENCL_CALL_GUARDED(clGetCommandQueueInfo,
(data(), CL_QUEUE_CONTEXT, sizeof(param_value), &param_value, 0));
return std::unique_ptr<context>(
new context(param_value, /*retain*/ true));
return py::ref<context>(new context(param_value, /*retain*/ true));
}

#if PYOPENCL_CL_VERSION < 0x1010
Expand Down Expand Up @@ -3437,12 +3436,12 @@ namespace pyopencl
{
private:
bool m_valid;
std::shared_ptr<command_queue> m_queue;
py::ref<command_queue> m_queue;
memory_object m_mem;
void *m_ptr;

public:
memory_map(std::shared_ptr<command_queue> cq, memory_object const &mem, void *ptr)
memory_map(py::ref<command_queue> cq, memory_object const &mem, void *ptr)
: m_valid(true), m_queue(cq), m_mem(mem), m_ptr(ptr)
{
}
Expand Down Expand Up @@ -3479,7 +3478,7 @@ namespace pyopencl
#ifndef PYPY_VERSION
inline
py::object enqueue_map_buffer(
std::shared_ptr<command_queue> cq,
py::ref<command_queue> cq,
memory_object_holder &buf,
cl_map_flags flags,
size_t offset,
Expand Down Expand Up @@ -3563,7 +3562,7 @@ namespace pyopencl
#ifndef PYPY_VERSION
inline
py::object enqueue_map_image(
std::shared_ptr<command_queue> cq,
py::ref<command_queue> cq,
memory_object_holder &img,
cl_map_flags flags,
py::object py_origin,
Expand Down Expand Up @@ -3697,15 +3696,15 @@ namespace pyopencl
class svm_allocation : public svm_pointer
{
private:
std::shared_ptr<context> m_context;
py::ref<context> m_context;
void *m_allocation;
size_t m_size;
command_queue_ref m_queue;
// FIXME Should maybe also allow keeping a list of events so that we can
// wait for users to finish in the case of out-of-order queues.

public:
svm_allocation(std::shared_ptr<context> const &ctx, size_t size, cl_uint alignment,
svm_allocation(py::ref<context> const &ctx, size_t size, cl_uint alignment,
cl_svm_mem_flags flags, const command_queue *queue = nullptr)
: m_context(ctx), m_size(size)
{
Expand Down Expand Up @@ -3738,7 +3737,7 @@ namespace pyopencl
}
}

svm_allocation(std::shared_ptr<context> const &ctx, void *allocation, size_t size,
svm_allocation(py::ref<context> const &ctx, void *allocation, size_t size,
const cl_command_queue queue)
: m_context(ctx), m_allocation(allocation), m_size(size)
{
Expand Down
14 changes: 12 additions & 2 deletions src/wrap_cl_part_1.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,13 @@ void pyopencl_expose_part_1(py::module_ &m)

{
typedef context cls;
py::class_<cls>(m, "Context", py::dynamic_attr(), py::is_weak_referenceable())
py::class_<cls>(
m, "Context",
py::dynamic_attr(),
py::is_weak_referenceable(),
py::intrusive_ptr<cls>(
[](cls *o, PyObject *po) noexcept { o->set_self_py(po); })
)
.def(
"__init__",
[](cls *self, py::object py_devices, py::object py_properties,
Expand Down Expand Up @@ -112,7 +118,11 @@ void pyopencl_expose_part_1(py::module_ &m)
// {{{ command queue
{
typedef command_queue cls;
py::class_<cls>(m, "CommandQueue", py::dynamic_attr())
py::class_<cls>(
m, "CommandQueue",
py::dynamic_attr(),
py::intrusive_ptr<cls>(
[](cls *o, PyObject *po) noexcept { o->set_self_py(po); }) )
.def(
py::init<const context &, const device *, py::object>(),
py::arg("context"),
Expand Down
2 changes: 1 addition & 1 deletion src/wrap_cl_part_2.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -359,7 +359,7 @@ void pyopencl_expose_part_2(py::module_ &m)
{
typedef svm_allocation cls;
py::class_<cls, svm_pointer>(m, "SVMAllocation", py::dynamic_attr())
.def(py::init<std::shared_ptr<context>, size_t, cl_uint, cl_svm_mem_flags, const command_queue *>(),
.def(py::init<py::ref<context>, size_t, cl_uint, cl_svm_mem_flags, const command_queue *>(),
py::arg("context"),
py::arg("size"),
py::arg("alignment"),
Expand Down
2 changes: 2 additions & 0 deletions src/wrap_helpers.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@
#include <nanobind/nanobind.h>
#include <nanobind/stl/string.h>
#include <nanobind/stl/shared_ptr.h>
#include <nanobind/intrusive/counter.h>
#include <nanobind/intrusive/ref.h>
#include <nanobind/ndarray.h>


Expand Down
20 changes: 10 additions & 10 deletions src/wrap_mempool.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -82,11 +82,11 @@ namespace pyopencl {
class buffer_allocator_base
{
protected:
std::shared_ptr<pyopencl::context> m_context;
py::ref<pyopencl::context> m_context;
cl_mem_flags m_flags;

public:
buffer_allocator_base(std::shared_ptr<pyopencl::context> const &ctx,
buffer_allocator_base(py::ref<pyopencl::context> const &ctx,
cl_mem_flags flags=CL_MEM_READ_WRITE)
: m_context(ctx), m_flags(flags)
{
Expand Down Expand Up @@ -131,7 +131,7 @@ namespace pyopencl {
typedef buffer_allocator_base super;

public:
deferred_buffer_allocator(std::shared_ptr<pyopencl::context> const &ctx,
deferred_buffer_allocator(py::ref<pyopencl::context> const &ctx,
cl_mem_flags flags=CL_MEM_READ_WRITE)
: super(ctx, flags)
{ }
Expand All @@ -158,7 +158,7 @@ namespace pyopencl {
public:
immediate_buffer_allocator(pyopencl::command_queue &queue,
cl_mem_flags flags=CL_MEM_READ_WRITE)
: super(std::shared_ptr<pyopencl::context>(queue.get_context()), flags),
: super(queue.get_context(), flags),
m_queue(queue.data(), /*retain*/ true)
{ }

Expand Down Expand Up @@ -333,13 +333,13 @@ namespace pyopencl {
typedef size_t size_type;

protected:
std::shared_ptr<pyopencl::context> m_context;
py::ref<pyopencl::context> m_context;
cl_uint m_alignment;
cl_svm_mem_flags m_flags;
pyopencl::command_queue_ref m_queue;

public:
svm_allocator(std::shared_ptr<pyopencl::context> const &ctx,
svm_allocator(py::ref<pyopencl::context> const &ctx,
cl_uint alignment=0, cl_svm_mem_flags flags=CL_MEM_READ_WRITE,
pyopencl::command_queue *queue=nullptr)
: m_context(ctx), m_alignment(alignment), m_flags(flags)
Expand Down Expand Up @@ -367,7 +367,7 @@ namespace pyopencl {
return false;
}

std::shared_ptr<pyopencl::context> context() const
py::ref<pyopencl::context> context() const
{
return m_context;
}
Expand Down Expand Up @@ -631,9 +631,9 @@ void pyopencl_expose_mempool(py::module_ &m)
py::class_<cls, pyopencl::buffer_allocator_base> wrapper(
m, "DeferredAllocator");
wrapper
.def(py::init<std::shared_ptr<pyopencl::context> const &>())
.def(py::init<py::ref<pyopencl::context> const &>())
.def(py::init<
std::shared_ptr<pyopencl::context> const &,
py::ref<pyopencl::context> const &,
cl_mem_flags>(),
py::arg("queue"), py::arg("mem_flags"))
;
Expand Down Expand Up @@ -681,7 +681,7 @@ void pyopencl_expose_mempool(py::module_ &m)
typedef pyopencl::svm_allocator cls;
py::class_<cls> wrapper(m, "SVMAllocator");
wrapper
.def(py::init<std::shared_ptr<pyopencl::context> const &, cl_uint, cl_uint, pyopencl::command_queue *>(),
.def(py::init<py::ref<pyopencl::context> const &, cl_uint, cl_uint, pyopencl::command_queue *>(),
py::arg("context"),
/* py::kw_only(), */
py::arg("alignment")=0,
Expand Down

0 comments on commit f4bd141

Please sign in to comment.