Skip to content

Commit

Permalink
Merge pull request numpy#27648 from fengluoqiuwu/enh_string_buffer_3
Browse files Browse the repository at this point in the history
MAINT: Fix the code style to our C-Style-Guide
  • Loading branch information
ngoldbaum authored Dec 11, 2024
2 parents fc539ff + f2be23a commit f7823da
Showing 1 changed file with 55 additions and 46 deletions.
101 changes: 55 additions & 46 deletions numpy/_core/src/umath/string_buffer.h
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,7 @@ codepoint_isupper<ENCODING::UTF8>(npy_ucs4 code)

template<ENCODING enc>
inline bool
codepoint_istitle(npy_ucs4);
codepoint_istitle(npy_ucs4 code);

template<>
inline bool
Expand Down Expand Up @@ -387,19 +387,19 @@ struct Buffer {
}

inline void
buffer_memcpy(Buffer<enc> out, size_t n_chars)
buffer_memcpy(Buffer<enc> other, size_t len)
{
if (n_chars == 0) {
if (len == 0) {
return;
}
switch (enc) {
case ENCODING::ASCII:
case ENCODING::UTF8:
// for UTF8 we treat n_chars as number of bytes
memcpy(out.buf, buf, n_chars);
memcpy(other.buf, buf, len);
break;
case ENCODING::UTF32:
memcpy(out.buf, buf, n_chars * sizeof(npy_ucs4));
memcpy(other.buf, buf, len * sizeof(npy_ucs4));
break;
}
}
Expand Down Expand Up @@ -460,7 +460,7 @@ struct Buffer {
}

inline size_t
num_bytes_next_character(void) {
num_bytes_next_character() {
switch (enc) {
case ENCODING::ASCII:
return 1;
Expand Down Expand Up @@ -503,6 +503,18 @@ struct Buffer {
return unary_loop<IMPLEMENTED_UNARY_FUNCTIONS::ISALPHA>();
}

inline bool
isdecimal()
{
return unary_loop<IMPLEMENTED_UNARY_FUNCTIONS::ISDECIMAL>();
}

inline bool
isdigit()
{
return unary_loop<IMPLEMENTED_UNARY_FUNCTIONS::ISDIGIT>();
}

inline bool
first_character_isspace()
{
Expand All @@ -521,12 +533,6 @@ struct Buffer {
return unary_loop<IMPLEMENTED_UNARY_FUNCTIONS::ISSPACE>();
}

inline bool
isdigit()
{
return unary_loop<IMPLEMENTED_UNARY_FUNCTIONS::ISDIGIT>();
}

inline bool
isalnum()
{
Expand All @@ -542,7 +548,7 @@ struct Buffer {
}

Buffer<enc> tmp = *this;
bool cased = 0;
bool cased = false;
for (size_t i = 0; i < len; i++) {
if (codepoint_isupper<enc>(*tmp) || codepoint_istitle<enc>(*tmp)) {
return false;
Expand All @@ -564,7 +570,7 @@ struct Buffer {
}

Buffer<enc> tmp = *this;
bool cased = 0;
bool cased = false;
for (size_t i = 0; i < len; i++) {
if (codepoint_islower<enc>(*tmp) || codepoint_istitle<enc>(*tmp)) {
return false;
Expand Down Expand Up @@ -616,12 +622,6 @@ struct Buffer {
return unary_loop<IMPLEMENTED_UNARY_FUNCTIONS::ISNUMERIC>();
}

inline bool
isdecimal()
{
return unary_loop<IMPLEMENTED_UNARY_FUNCTIONS::ISDECIMAL>();
}

inline Buffer<enc>
rstrip()
{
Expand Down Expand Up @@ -895,10 +895,12 @@ string_find(Buffer<enc> buf1, Buffer<enc> buf2, npy_int64 start, npy_int64 end)
npy_intp pos;
switch(enc) {
case ENCODING::UTF8:
pos = fastsearch(start_loc, end_loc - start_loc, buf2.buf, buf2.after - buf2.buf, -1, FAST_SEARCH);
pos = fastsearch(start_loc, end_loc - start_loc, buf2.buf,
buf2.after - buf2.buf, -1, FAST_SEARCH);
// pos is the byte index, but we need the character index
if (pos > 0) {
pos = utf8_character_index(start_loc, start_loc - buf1.buf, start, pos, buf1.after - start_loc);
pos = utf8_character_index(start_loc, start_loc - buf1.buf,
start, pos, buf1.after - start_loc);
}
break;
case ENCODING::ASCII:
Expand Down Expand Up @@ -999,10 +1001,12 @@ string_rfind(Buffer<enc> buf1, Buffer<enc> buf2, npy_int64 start, npy_int64 end)
npy_intp pos;
switch (enc) {
case ENCODING::UTF8:
pos = fastsearch(start_loc, end_loc - start_loc, buf2.buf, buf2.after - buf2.buf, -1, FAST_RSEARCH);
pos = fastsearch(start_loc, end_loc - start_loc,
buf2.buf, buf2.after - buf2.buf, -1, FAST_RSEARCH);
// pos is the byte index, but we need the character index
if (pos > 0) {
pos = utf8_character_index(start_loc, start_loc - buf1.buf, start, pos, buf1.after - start_loc);
pos = utf8_character_index(start_loc, start_loc - buf1.buf,
start, pos, buf1.after - start_loc);
}
break;
case ENCODING::ASCII:
Expand Down Expand Up @@ -1064,7 +1068,7 @@ string_count(Buffer<enc> buf1, Buffer<enc> buf2, npy_int64 start, npy_int64 end)
start_loc = (buf1 + start).buf;
end_loc = (buf1 + end).buf;
}
npy_intp count;
npy_intp count = 0;
switch (enc) {
case ENCODING::UTF8:
count = fastsearch(start_loc, end_loc - start_loc, buf2.buf,
Expand Down Expand Up @@ -1139,7 +1143,7 @@ enum class STRIPTYPE {

template <ENCODING enc>
static inline size_t
string_lrstrip_whitespace(Buffer<enc> buf, Buffer<enc> out, STRIPTYPE striptype)
string_lrstrip_whitespace(Buffer<enc> buf, Buffer<enc> out, STRIPTYPE strip_type)
{
size_t len = buf.num_codepoints();
if (len == 0) {
Expand All @@ -1154,7 +1158,7 @@ string_lrstrip_whitespace(Buffer<enc> buf, Buffer<enc> out, STRIPTYPE striptype)
size_t num_bytes = (buf.after - buf.buf);
Buffer traverse_buf = Buffer<enc>(buf.buf, num_bytes);

if (striptype != STRIPTYPE::RIGHTSTRIP) {
if (strip_type != STRIPTYPE::RIGHTSTRIP) {
while (new_start < len) {
if (!traverse_buf.first_character_isspace()) {
break;
Expand All @@ -1173,7 +1177,7 @@ string_lrstrip_whitespace(Buffer<enc> buf, Buffer<enc> out, STRIPTYPE striptype)
traverse_buf = buf + (new_stop - 1);
}

if (striptype != STRIPTYPE::LEFTSTRIP) {
if (strip_type != STRIPTYPE::LEFTSTRIP) {
while (new_stop > new_start) {
if (*traverse_buf != 0 && !traverse_buf.first_character_isspace()) {
break;
Expand Down Expand Up @@ -1202,7 +1206,7 @@ string_lrstrip_whitespace(Buffer<enc> buf, Buffer<enc> out, STRIPTYPE striptype)

template <ENCODING enc>
static inline size_t
string_lrstrip_chars(Buffer<enc> buf1, Buffer<enc> buf2, Buffer<enc> out, STRIPTYPE striptype)
string_lrstrip_chars(Buffer<enc> buf1, Buffer<enc> buf2, Buffer<enc> out, STRIPTYPE strip_type)
{
size_t len1 = buf1.num_codepoints();
if (len1 == 0) {
Expand All @@ -1228,9 +1232,9 @@ string_lrstrip_chars(Buffer<enc> buf1, Buffer<enc> buf2, Buffer<enc> out, STRIPT
size_t num_bytes = (buf1.after - buf1.buf);
Buffer traverse_buf = Buffer<enc>(buf1.buf, num_bytes);

if (striptype != STRIPTYPE::RIGHTSTRIP) {
if (strip_type != STRIPTYPE::RIGHTSTRIP) {
for (; new_start < len1; traverse_buf++) {
Py_ssize_t res;
Py_ssize_t res = 0;
size_t current_point_bytes = traverse_buf.num_bytes_next_character();
switch (enc) {
case ENCODING::ASCII:
Expand All @@ -1245,7 +1249,9 @@ string_lrstrip_chars(Buffer<enc> buf1, Buffer<enc> buf2, Buffer<enc> out, STRIPT
CheckedIndexer<char> ind(buf2.buf, len2);
res = find_char<char>(ind, len2, *traverse_buf);
} else {
res = fastsearch(buf2.buf, buf2.after - buf2.buf,traverse_buf.buf, current_point_bytes, -1, FAST_SEARCH);
res = fastsearch(buf2.buf, buf2.after - buf2.buf,
traverse_buf.buf, current_point_bytes,
-1, FAST_SEARCH);
}
break;
}
Expand All @@ -1272,10 +1278,10 @@ string_lrstrip_chars(Buffer<enc> buf1, Buffer<enc> buf2, Buffer<enc> out, STRIPT
traverse_buf = buf1 + (new_stop - 1);
}

if (striptype != STRIPTYPE::LEFTSTRIP) {
if (strip_type != STRIPTYPE::LEFTSTRIP) {
while (new_stop > new_start) {
size_t current_point_bytes = traverse_buf.num_bytes_next_character();
Py_ssize_t res;
Py_ssize_t res = 0;
switch (enc) {
case ENCODING::ASCII:
{
Expand All @@ -1289,7 +1295,9 @@ string_lrstrip_chars(Buffer<enc> buf1, Buffer<enc> buf2, Buffer<enc> out, STRIPT
CheckedIndexer<char> ind(buf2.buf, len2);
res = find_char<char>(ind, len2, *traverse_buf);
} else {
res = fastsearch(buf2.buf, buf2.after - buf2.buf, traverse_buf.buf, current_point_bytes, -1, FAST_RSEARCH);
res = fastsearch(buf2.buf, buf2.after - buf2.buf,
traverse_buf.buf, current_point_bytes,
-1, FAST_RSEARCH);
}
break;
}
Expand Down Expand Up @@ -1333,7 +1341,8 @@ findslice_for_replace(CheckedIndexer<char_type> buf1, npy_intp len1,
if (len2 == 1) {
return (npy_intp) find_char(buf1, len1, *buf2);
}
return (npy_intp) fastsearch(buf1.buffer, len1, buf2.buffer, len2, -1, FAST_SEARCH);
return (npy_intp) fastsearch(buf1.buffer, len1, buf2.buffer, len2,
-1, FAST_SEARCH);
}


Expand Down Expand Up @@ -1538,8 +1547,8 @@ template <ENCODING enc>
static inline npy_intp
string_pad(Buffer<enc> buf, npy_int64 width, npy_ucs4 fill, JUSTPOSITION pos, Buffer<enc> out)
{
size_t finalwidth = width > 0 ? width : 0;
if (finalwidth > PY_SSIZE_T_MAX) {
size_t final_width = width > 0 ? width : 0;
if (final_width > PY_SSIZE_T_MAX) {
npy_gil_error(PyExc_OverflowError, "padded string is too long");
return -1;
}
Expand All @@ -1555,23 +1564,23 @@ string_pad(Buffer<enc> buf, npy_int64 width, npy_ucs4 fill, JUSTPOSITION pos, Bu
len = len_codepoints;
}

if (len_codepoints >= finalwidth) {
if (len_codepoints >= final_width) {
buf.buffer_memcpy(out, len);
return (npy_intp) len;
}

size_t left, right;
if (pos == JUSTPOSITION::CENTER) {
size_t pad = finalwidth - len_codepoints;
left = pad / 2 + (pad & finalwidth & 1);
size_t pad = final_width - len_codepoints;
left = pad / 2 + (pad & final_width & 1);
right = pad - left;
}
else if (pos == JUSTPOSITION::LEFT) {
left = 0;
right = finalwidth - len_codepoints;
right = final_width - len_codepoints;
}
else {
left = finalwidth - len_codepoints;
left = final_width - len_codepoints;
right = 0;
}

Expand All @@ -1589,23 +1598,23 @@ string_pad(Buffer<enc> buf, npy_int64 width, npy_ucs4 fill, JUSTPOSITION pos, Bu
out.advance_chars_or_bytes(out.buffer_memset(fill, right));
}

return finalwidth;
return final_width;
}


template <ENCODING enc>
static inline npy_intp
string_zfill(Buffer<enc> buf, npy_int64 width, Buffer<enc> out)
{
size_t finalwidth = width > 0 ? width : 0;
size_t final_width = width > 0 ? width : 0;

npy_ucs4 fill = '0';
npy_intp new_len = string_pad(buf, width, fill, JUSTPOSITION::RIGHT, out);
if (new_len == -1) {
return -1;
}

size_t offset = finalwidth - buf.num_codepoints();
size_t offset = final_width - buf.num_codepoints();
Buffer<enc> tmp = out + offset;

npy_ucs4 c = *tmp;
Expand Down

0 comments on commit f7823da

Please sign in to comment.