diff --git a/compact_str/src/lib.rs b/compact_str/src/lib.rs index d1f6a544..8a99ff29 100644 --- a/compact_str/src/lib.rs +++ b/compact_str/src/lib.rs @@ -1232,24 +1232,38 @@ impl CompactString { pub fn retain(&mut self, mut predicate: impl FnMut(char) -> bool) { // We iterate over the string, and copy character by character. - let s = self.as_mut_str(); - let mut dest_idx = 0; - let mut src_idx = 0; - while let Some(ch) = s[src_idx..].chars().next() { + struct SetLenOnDrop<'a> { + self_: &'a mut CompactString, + src_idx: usize, + dst_idx: usize, + } + + let mut g = SetLenOnDrop { + self_: self, + src_idx: 0, + dst_idx: 0, + }; + let s = g.self_.as_mut_str(); + while let Some(ch) = s[g.src_idx..].chars().next() { let ch_len = ch.len_utf8(); if predicate(ch) { // SAFETY: We know that both indices are valid, and that we don't split a char. unsafe { let p = s.as_mut_ptr(); - core::ptr::copy(p.add(src_idx), p.add(dest_idx), ch_len); + core::ptr::copy(p.add(g.src_idx), p.add(g.dst_idx), ch_len); } - dest_idx += ch_len; + g.dst_idx += ch_len; } - src_idx += ch_len; + g.src_idx += ch_len; } - // SAFETY: We know that the index is a valid position to break the string. - unsafe { self.set_len(dest_idx) }; + impl Drop for SetLenOnDrop<'_> { + fn drop(&mut self) { + // SAFETY: We know that the index is a valid position to break the string. + unsafe { self.self_.set_len(self.dst_idx) }; + } + } + drop(g); } /// Decode a bytes slice as UTF-8 string, replacing any illegal codepoints diff --git a/compact_str/src/tests.rs b/compact_str/src/tests.rs index 1a15c71c..f311ef5a 100644 --- a/compact_str/src/tests.rs +++ b/compact_str/src/tests.rs @@ -1367,6 +1367,41 @@ fn test_insert(to_compact: fn(&'static str) -> CompactString) { ); } +#[test] +#[cfg_attr(not(panic = "unwind"), ignore = "test requires unwinding support")] +fn test_retain() { + let mut s = CompactString::from("α_β_γ"); + + s.retain(|_| true); + assert_eq!(s, "α_β_γ"); + + s.retain(|c| c != '_'); + assert_eq!(s, "αβγ"); + + s.retain(|c| c != 'β'); + assert_eq!(s, "αγ"); + + s.retain(|c| c == 'α'); + assert_eq!(s, "α"); + + s.retain(|_| false); + assert_eq!(s, ""); + + let mut s = CompactString::from("0è0"); + let _ = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| { + let mut count = 0; + s.retain(|_| { + count += 1; + match count { + 1 => false, + 2 => true, + _ => panic!(), + } + }); + })); + assert!(std::str::from_utf8(s.as_bytes()).is_ok()); +} + #[test] fn test_remove() { let mut control = String::from("🦄🦀hello🎶world🇺🇸");