Skip to content

Commit 2891172

Browse files
REGR: fix string contains/match methods with compiled regex with flags (#62251)
1 parent 3c14b71 commit 2891172

File tree

5 files changed

+158
-24
lines changed

5 files changed

+158
-24
lines changed

doc/source/whatsnew/v2.3.3.rst

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,8 @@ become the default string dtype in pandas 3.0. See
2222

2323
Bug fixes
2424
^^^^^^^^^
25-
-
25+
- Fix regression in ``~Series.str.contains``, ``~Series.str.match`` and ``~Series.str.fullmatch``
26+
with a compiled regex and custom flags (:issue:`62240`)
2627

2728
.. ---------------------------------------------------------------------------
2829
.. _whatsnew_233.contributors:

pandas/core/arrays/_arrow_string_mixins.py

Lines changed: 4 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -311,29 +311,23 @@ def _str_contains(
311311

312312
def _str_match(
313313
self,
314-
pat: str | re.Pattern,
314+
pat: str,
315315
case: bool = True,
316316
flags: int = 0,
317317
na: Scalar | lib.NoDefault = lib.no_default,
318318
):
319-
if isinstance(pat, re.Pattern):
320-
# GH#61952
321-
pat = pat.pattern
322-
if isinstance(pat, str) and not pat.startswith("^"):
319+
if not pat.startswith("^"):
323320
pat = f"^{pat}"
324321
return self._str_contains(pat, case, flags, na, regex=True)
325322

326323
def _str_fullmatch(
327324
self,
328-
pat: str | re.Pattern,
325+
pat: str,
329326
case: bool = True,
330327
flags: int = 0,
331328
na: Scalar | lib.NoDefault = lib.no_default,
332329
):
333-
if isinstance(pat, re.Pattern):
334-
# GH#61952
335-
pat = pat.pattern
336-
if isinstance(pat, str) and (not pat.endswith("$") or pat.endswith("\\$")):
330+
if not pat.endswith("$") or pat.endswith("\\$"):
337331
pat = f"{pat}$"
338332
return self._str_match(pat, case, flags, na)
339333

pandas/core/arrays/string_arrow.py

Lines changed: 54 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@
5555
ArrayLike,
5656
Dtype,
5757
NpDtype,
58+
Scalar,
5859
npt,
5960
)
6061

@@ -333,8 +334,6 @@ def astype(self, dtype, copy: bool = True):
333334
_str_startswith = ArrowStringArrayMixin._str_startswith
334335
_str_endswith = ArrowStringArrayMixin._str_endswith
335336
_str_pad = ArrowStringArrayMixin._str_pad
336-
_str_match = ArrowStringArrayMixin._str_match
337-
_str_fullmatch = ArrowStringArrayMixin._str_fullmatch
338337
_str_lower = ArrowStringArrayMixin._str_lower
339338
_str_upper = ArrowStringArrayMixin._str_upper
340339
_str_strip = ArrowStringArrayMixin._str_strip
@@ -349,6 +348,28 @@ def astype(self, dtype, copy: bool = True):
349348
_str_len = ArrowStringArrayMixin._str_len
350349
_str_slice = ArrowStringArrayMixin._str_slice
351350

351+
@staticmethod
352+
def _is_re_pattern_with_flags(pat: str | re.Pattern) -> bool:
353+
# check if `pat` is a compiled regex pattern with flags that are not
354+
# supported by pyarrow
355+
return (
356+
isinstance(pat, re.Pattern)
357+
and (pat.flags & ~(re.IGNORECASE | re.UNICODE)) != 0
358+
)
359+
360+
@staticmethod
361+
def _preprocess_re_pattern(pat: re.Pattern, case: bool) -> tuple[str, bool, int]:
362+
pattern = pat.pattern
363+
flags = pat.flags
364+
# flags is not supported by pyarrow, but `case` is -> extract and remove
365+
if flags & re.IGNORECASE:
366+
case = False
367+
flags = flags & ~re.IGNORECASE
368+
# when creating a pattern with re.compile and a string, it automatically
369+
# gets a UNICODE flag, while pyarrow assumes unicode for strings anyway
370+
flags = flags & ~re.UNICODE
371+
return pattern, case, flags
372+
352373
def _str_contains(
353374
self,
354375
pat,
@@ -357,13 +378,42 @@ def _str_contains(
357378
na=lib.no_default,
358379
regex: bool = True,
359380
):
360-
if flags:
381+
if flags or self._is_re_pattern_with_flags(pat):
361382
return super()._str_contains(pat, case, flags, na, regex)
362383
if isinstance(pat, re.Pattern):
363-
pat = pat.pattern
384+
# TODO flags passed separately by user are ignored
385+
pat, case, flags = self._preprocess_re_pattern(pat, case)
364386

365387
return ArrowStringArrayMixin._str_contains(self, pat, case, flags, na, regex)
366388

389+
def _str_match(
390+
self,
391+
pat: str | re.Pattern,
392+
case: bool = True,
393+
flags: int = 0,
394+
na: Scalar | lib.NoDefault = lib.no_default,
395+
):
396+
if flags or self._is_re_pattern_with_flags(pat):
397+
return super()._str_match(pat, case, flags, na)
398+
if isinstance(pat, re.Pattern):
399+
pat, case, flags = self._preprocess_re_pattern(pat, case)
400+
401+
return ArrowStringArrayMixin._str_match(self, pat, case, flags, na)
402+
403+
def _str_fullmatch(
404+
self,
405+
pat: str | re.Pattern,
406+
case: bool = True,
407+
flags: int = 0,
408+
na: Scalar | lib.NoDefault = lib.no_default,
409+
):
410+
if flags or self._is_re_pattern_with_flags(pat):
411+
return super()._str_fullmatch(pat, case, flags, na)
412+
if isinstance(pat, re.Pattern):
413+
pat, case, flags = self._preprocess_re_pattern(pat, case)
414+
415+
return ArrowStringArrayMixin._str_fullmatch(self, pat, case, flags, na)
416+
367417
def _str_replace(
368418
self,
369419
pat: str | re.Pattern,

pandas/core/strings/object_array.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -262,8 +262,7 @@ def _str_match(
262262
):
263263
if not case:
264264
flags |= re.IGNORECASE
265-
if isinstance(pat, re.Pattern):
266-
pat = pat.pattern
265+
267266
regex = re.compile(pat, flags=flags)
268267

269268
f = lambda x: regex.match(x) is not None
@@ -278,8 +277,7 @@ def _str_fullmatch(
278277
):
279278
if not case:
280279
flags |= re.IGNORECASE
281-
if isinstance(pat, re.Pattern):
282-
pat = pat.pattern
280+
283281
regex = re.compile(pat, flags=flags)
284282

285283
f = lambda x: regex.fullmatch(x) is not None

pandas/tests/strings/test_find_replace.py

Lines changed: 96 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -283,13 +283,60 @@ def test_contains_nan(any_string_dtype):
283283

284284
def test_contains_compiled_regex(any_string_dtype):
285285
# GH#61942
286-
ser = Series(["foo", "bar", "baz"], dtype=any_string_dtype)
286+
expected_dtype = (
287+
np.bool_ if is_object_or_nan_string_dtype(any_string_dtype) else "boolean"
288+
)
289+
290+
ser = Series(["foo", "bar", "Baz"], dtype=any_string_dtype)
291+
287292
pat = re.compile("ba.")
288293
result = ser.str.contains(pat)
294+
expected = Series([False, True, False], dtype=expected_dtype)
295+
tm.assert_series_equal(result, expected)
296+
297+
# TODO this currently works for pyarrow-backed dtypes but raises for python
298+
if any_string_dtype == "string" and any_string_dtype.storage == "pyarrow":
299+
result = ser.str.contains(pat, case=False)
300+
expected = Series([False, True, True], dtype=expected_dtype)
301+
tm.assert_series_equal(result, expected)
302+
else:
303+
with pytest.raises(
304+
ValueError, match="cannot process flags argument with a compiled pattern"
305+
):
306+
ser.str.contains(pat, case=False)
307+
308+
pat = re.compile("ba.", flags=re.IGNORECASE)
309+
result = ser.str.contains(pat)
310+
expected = Series([False, True, True], dtype=expected_dtype)
311+
tm.assert_series_equal(result, expected)
312+
313+
# TODO should this be supported?
314+
with pytest.raises(
315+
ValueError, match="cannot process flags argument with a compiled pattern"
316+
):
317+
ser.str.contains(pat, flags=re.IGNORECASE)
318+
289319

320+
def test_contains_compiled_regex_flags(any_string_dtype):
321+
# ensure other (than ignorecase) flags are respected
290322
expected_dtype = (
291323
np.bool_ if is_object_or_nan_string_dtype(any_string_dtype) else "boolean"
292324
)
325+
326+
ser = Series(["foobar", "foo\nbar", "Baz"], dtype=any_string_dtype)
327+
328+
pat = re.compile("^ba")
329+
result = ser.str.contains(pat)
330+
expected = Series([False, False, False], dtype=expected_dtype)
331+
tm.assert_series_equal(result, expected)
332+
333+
pat = re.compile("^ba", flags=re.MULTILINE)
334+
result = ser.str.contains(pat)
335+
expected = Series([False, True, False], dtype=expected_dtype)
336+
tm.assert_series_equal(result, expected)
337+
338+
pat = re.compile("^ba", flags=re.MULTILINE | re.IGNORECASE)
339+
result = ser.str.contains(pat)
293340
expected = Series([False, True, True], dtype=expected_dtype)
294341
tm.assert_series_equal(result, expected)
295342

@@ -833,14 +880,36 @@ def test_match_case_kwarg(any_string_dtype):
833880

834881
def test_match_compiled_regex(any_string_dtype):
835882
# GH#61952
836-
values = Series(["ab", "AB", "abc", "ABC"], dtype=any_string_dtype)
837-
result = values.str.match(re.compile(r"ab"), case=False)
838883
expected_dtype = (
839884
np.bool_ if is_object_or_nan_string_dtype(any_string_dtype) else "boolean"
840885
)
886+
887+
values = Series(["ab", "AB", "abc", "ABC"], dtype=any_string_dtype)
888+
889+
result = values.str.match(re.compile("ab"))
890+
expected = Series([True, False, True, False], dtype=expected_dtype)
891+
tm.assert_series_equal(result, expected)
892+
893+
# TODO this currently works for pyarrow-backed dtypes but raises for python
894+
if any_string_dtype == "string" and any_string_dtype.storage == "pyarrow":
895+
result = values.str.match(re.compile("ab"), case=False)
896+
expected = Series([True, True, True, True], dtype=expected_dtype)
897+
tm.assert_series_equal(result, expected)
898+
else:
899+
with pytest.raises(
900+
ValueError, match="cannot process flags argument with a compiled pattern"
901+
):
902+
values.str.match(re.compile("ab"), case=False)
903+
904+
result = values.str.match(re.compile("ab", flags=re.IGNORECASE))
841905
expected = Series([True, True, True, True], dtype=expected_dtype)
842906
tm.assert_series_equal(result, expected)
843907

908+
with pytest.raises(
909+
ValueError, match="cannot process flags argument with a compiled pattern"
910+
):
911+
values.str.match(re.compile("ab"), flags=re.IGNORECASE)
912+
844913

845914
# --------------------------------------------------------------------------------------
846915
# str.fullmatch
@@ -913,14 +982,36 @@ def test_fullmatch_case_kwarg(any_string_dtype):
913982

914983
def test_fullmatch_compiled_regex(any_string_dtype):
915984
# GH#61952
916-
values = Series(["ab", "AB", "abc", "ABC"], dtype=any_string_dtype)
917-
result = values.str.fullmatch(re.compile(r"ab"), case=False)
918985
expected_dtype = (
919986
np.bool_ if is_object_or_nan_string_dtype(any_string_dtype) else "boolean"
920987
)
988+
989+
values = Series(["ab", "AB", "abc", "ABC"], dtype=any_string_dtype)
990+
991+
result = values.str.fullmatch(re.compile("ab"))
992+
expected = Series([True, False, False, False], dtype=expected_dtype)
993+
tm.assert_series_equal(result, expected)
994+
995+
# TODO this currently works for pyarrow-backed dtypes but raises for python
996+
if any_string_dtype == "string" and any_string_dtype.storage == "pyarrow":
997+
result = values.str.fullmatch(re.compile("ab"), case=False)
998+
expected = Series([True, True, False, False], dtype=expected_dtype)
999+
tm.assert_series_equal(result, expected)
1000+
else:
1001+
with pytest.raises(
1002+
ValueError, match="cannot process flags argument with a compiled pattern"
1003+
):
1004+
values.str.fullmatch(re.compile("ab"), case=False)
1005+
1006+
result = values.str.fullmatch(re.compile("ab", flags=re.IGNORECASE))
9211007
expected = Series([True, True, False, False], dtype=expected_dtype)
9221008
tm.assert_series_equal(result, expected)
9231009

1010+
with pytest.raises(
1011+
ValueError, match="cannot process flags argument with a compiled pattern"
1012+
):
1013+
values.str.fullmatch(re.compile("ab"), flags=re.IGNORECASE)
1014+
9241015

9251016
# --------------------------------------------------------------------------------------
9261017
# str.findall

0 commit comments

Comments
 (0)