Skip to content

Commit 4cd1674

Browse files
committed
support nd arrays for PyArray
Signed-off-by: Jade Abraham <[email protected]>
1 parent 08cfc10 commit 4cd1674

File tree

1 file changed

+88
-43
lines changed

1 file changed

+88
-43
lines changed

Diff for: modules/packages/Python.chpl

+88-43
Original file line numberDiff line numberDiff line change
@@ -1111,12 +1111,6 @@ module Python {
11111111
if isSubtype(t, Array?) {
11121112
compilerError("Cannot create an Array from an existing PyObject");
11131113
}
1114-
if isSubtype(t, PyArray?) {
1115-
if isGeneric(t) {
1116-
compilerError("Cannot get a generic PyArray, try specifying the eltType like 'PyArray(int)'");
1117-
}
1118-
return new t(this, obj);
1119-
}
11201114

11211115
for converter in this.converters {
11221116
if converter.handlesType(t) {
@@ -1174,6 +1168,8 @@ module Python {
11741168
return new t(this, "<unknown>", obj);
11751169
} else if isSubtype(t, Value?) {
11761170
return new t(this, obj);
1171+
} else if isSubtype(t, PyObjectPtr) {
1172+
return obj;
11771173
} else if t == NoneType {
11781174
// returning NoneType can be used to ignore a return value
11791175
// but if its not actually None, we still need to decrement the reference count
@@ -1563,6 +1559,10 @@ module Python {
15631559
return new ImportError(str);
15641560
} else if PyErr_GivenExceptionMatches(exc, PyExc_KeyError) != 0 {
15651561
return new KeyError(str);
1562+
} else if PyErr_GivenExceptionMatches(exc, PyExc_BufferError) != 0 {
1563+
return new BufferError(str);
1564+
} else if PyErr_GivenExceptionMatches(exc, PyExc_NotImplementedError) != 0 {
1565+
return new NotImplementedError(str);
15661566
} else {
15671567
throw new PythonException(str);
15681568
}
@@ -1587,6 +1587,15 @@ module Python {
15871587
}
15881588
}
15891589

1590+
/*
1591+
Represents a NotImplementedError in the Python code
1592+
*/
1593+
class NotImplementedError: PythonException {
1594+
proc init(in message: string) {
1595+
super.init(message);
1596+
}
1597+
}
1598+
15901599
/*
15911600
Represents a KeyError in the Python code
15921601
*/
@@ -2776,22 +2785,13 @@ module Python {
27762785
`numpy.ndarray <https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html>`_.
27772786
*/
27782787
class PyArray: Value {
2779-
type eltType;
27802788
@chpldoc.nodoc
27812789
var view: Py_buffer;
2782-
@chpldoc.nodoc
2783-
var itemSize: Py_ssize_t;
2784-
2785-
@chpldoc.nodoc
2786-
proc init(type eltType) {
2787-
this.eltType = eltType;
2788-
}
27892790

27902791
@chpldoc.nodoc
2791-
proc init(type eltType, in interpreter: borrowed Interpreter,
2792+
proc init(in interpreter: borrowed Interpreter,
27922793
in obj: PyObjectPtr, isOwned: bool = true) {
27932794
super.init(interpreter, obj, isOwned=isOwned);
2794-
this.eltType = eltType;
27952795
}
27962796
@chpldoc.nodoc
27972797
proc postinit() throws {
@@ -2801,45 +2801,79 @@ module Python {
28012801
if PyObject_CheckBuffer(this.getPyObject()) == 0 {
28022802
throw new ChapelException("Object does not support buffer protocol");
28032803
}
2804-
var flags: c_int =
2805-
PyBUF_SIMPLE | PyBUF_WRITABLE | PyBUF_FORMAT | PyBUF_ND;
2806-
if PyObject_GetBuffer(this.getPyObject(), c_ptrTo(this.view), flags) == -1 {
2804+
const flags = PyBUF_SIMPLE | PyBUF_WRITABLE | PyBUF_FORMAT | PyBUF_ND;
2805+
if PyObject_GetBuffer(this.getPyObject(),
2806+
c_ptrTo(this.view), flags) == -1 {
28072807
this.interpreter.checkException();
28082808
// this.check should have raised an exception, if it didn't, raise one
28092809
throw new BufferError("Failed to get buffer");
28102810
}
2811-
2812-
if this.view.ndim > 1 {
2813-
throw new ChapelException("Only 1D arrays are currently supported");
2814-
}
2815-
2816-
this.itemSize = PyBuffer_SizeFromFormat(this.view.format);
2817-
if this.itemSize == -1 {
2818-
if this.view.shape != nil {
2819-
this.itemSize = this.view.itemsize;
2820-
} else {
2821-
// disregard itemsize, use 1
2822-
this.itemSize = 1;
2823-
}
2811+
if this.view.ndim == 0 {
2812+
throw new ChapelException("0-dimensional arrays are not supported");
28242813
}
2825-
2826-
if !checkFormatWithEltType(this.view.format, this.itemSize, this.eltType) {
2827-
throw new ChapelException("Format does not match element type");
2814+
if this.view.shape == nil {
2815+
throw new ChapelException("Shape is required for arrays");
28282816
}
2829-
28302817
}
28312818

2819+
@chpldoc.nodoc
28322820
proc deinit() {
28332821
var g = PyGILState_Ensure();
28342822
defer PyGILState_Release(g);
28352823

28362824
PyBuffer_Release(c_ptrTo(this.view));
28372825
}
28382826

2839-
proc array: [] {
2840-
var buf = view.buf: c_ptr(this.eltType);
2841-
var size = view.len / this.itemSize;
2842-
var res = makeArrayFromPtr(buf, size);
2827+
2828+
/*
2829+
The 'rank' of the Python array, also known as the number of dimensions of
2830+
the array.
2831+
2832+
:returns: The number of rank of the array.
2833+
*/
2834+
proc rank: int do return this.view.ndim;
2835+
/* Alias of :proc:`~PyArray.rank`. */
2836+
proc ndim: int do return this.view.ndim;
2837+
2838+
proc shape(param rank = 1): rank*int throws {
2839+
var s: rank*int;
2840+
assert(this.view.shape != nil); // checked in postinit
2841+
for param i in 0..<rank {
2842+
s(i) = this.view.shape(i);
2843+
}
2844+
return s;
2845+
}
2846+
2847+
@chpldoc.nodoc
2848+
proc computeItemSize() {
2849+
var itemSize = PyBuffer_SizeFromFormat(this.view.format);
2850+
if itemSize == -1 {
2851+
assert(this.view.shape != nil); // checked in postinit
2852+
itemSize = this.view.itemsize;
2853+
}
2854+
return itemSize;
2855+
}
2856+
2857+
proc array(type eltType, param rank = 1): [] throws {
2858+
if !checkFormatWithEltType(this.view.format, computeItemSize(), eltType) {
2859+
throw new ChapelException("Format does not match element type");
2860+
}
2861+
var buf = this.view.buf: c_ptr(eltType);
2862+
var ndim = this.view.ndim;
2863+
2864+
if ndim != rank {
2865+
throw new ChapelException(
2866+
"Python array of rank " + ndim:string +
2867+
" cannot be converted to a Chapel array of rank " + rank:string);
2868+
}
2869+
2870+
var ranges: rank * range(int, boundKind.both, strideKind.any);
2871+
for param i in 0..<rank {
2872+
ranges(i) = 0.. # this.view.shape(i);
2873+
}
2874+
2875+
var dom = chpl__buildDomainExpr((...ranges), false);
2876+
var res = makeArrayFromPtr(buf, dom);
28432877
return res;
28442878
}
28452879
}
@@ -3093,6 +3127,7 @@ module Python {
30933127
PyDict implements writeSerializable;
30943128
PySet implements writeSerializable;
30953129
Array implements writeSerializable;
3130+
PyArray implements writeSerializable;
30963131
NoneType implements writeSerializable;
30973132

30983133
@chpldoc.nodoc
@@ -3131,6 +3166,10 @@ module Python {
31313166
override proc Array.serialize(writer, ref serializer) throws do
31323167
writer.write(this:string);
31333168

3169+
@chpldoc.nodoc
3170+
override proc PyArray.serialize(writer, ref serializer) throws do
3171+
writer.write(this:string);
3172+
31343173
@chpldoc.nodoc
31353174
proc NoneType.serialize(writer, ref serializer) throws do
31363175
writer.write(this:string);
@@ -3306,6 +3345,8 @@ module Python {
33063345
exc: PyObjectPtr): c_int;
33073346
extern const PyExc_ImportError: PyObjectPtr;
33083347
extern const PyExc_KeyError: PyObjectPtr;
3348+
extern const PyExc_BufferError: PyObjectPtr;
3349+
extern const PyExc_NotImplementedError: PyObjectPtr;
33093350

33103351
/*
33113352
Values
@@ -3327,18 +3368,22 @@ module Python {
33273368
extern proc PyBytes_FromStringAndSize(s: c_ptrConst(c_char),
33283369
size: Py_ssize_t): PyObjectPtr;
33293370

3330-
proc Py_None: PyObjectPtr {
3371+
inline proc Py_None: PyObjectPtr {
33313372
extern proc chpl_Py_None(): PyObjectPtr;
33323373
return chpl_Py_None();
33333374
}
3334-
proc Py_True: PyObjectPtr {
3375+
inline proc Py_True: PyObjectPtr {
33353376
extern proc chpl_Py_True(): PyObjectPtr;
33363377
return chpl_Py_True();
33373378
}
3338-
proc Py_False: PyObjectPtr {
3379+
inline proc Py_False: PyObjectPtr {
33393380
extern proc chpl_Py_False(): PyObjectPtr;
33403381
return chpl_Py_False();
33413382
}
3383+
inline proc Py_NotImplemented: PyObjectPtr {
3384+
extern proc chpl_Py_NotImplemented(): PyObjectPtr;
3385+
return chpl_Py_NotImplemented();
3386+
}
33423387

33433388
/*
33443389
Sequences

0 commit comments

Comments
 (0)