diff --git a/lib/dupekit/Cargo.lock b/lib/dupekit/Cargo.lock index 6d7c2fd894..07db1f67de 100644 --- a/lib/dupekit/Cargo.lock +++ b/lib/dupekit/Cargo.lock @@ -336,15 +336,16 @@ dependencies = [ [[package]] name = "blake3" -version = "1.8.2" +version = "1.8.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3888aaa89e4b2a40fca9848e400f6a658a5a3978de7be858e209cafa8be9a4a0" +checksum = "2468ef7d57b3fb7e16b576e8377cdbde2320c60e1491e961d11da40fc4f02a2d" dependencies = [ "arrayref", "arrayvec", "cc", "cfg-if", "constant_time_eq", + "cpufeatures", ] [[package]] @@ -379,9 +380,9 @@ dependencies = [ [[package]] name = "bumpalo" -version = "3.19.0" +version = "3.19.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "46c5e41b57b8bba42a04676d81cb89e9ee8e859a1a66f80a5a72e1cb76b34d43" +checksum = "5dd9dc738b7a8311c7ade152424974d8115f2cdad61e8dab8dac9f2362298510" [[package]] name = "byteorder" @@ -397,9 +398,9 @@ checksum = "b35204fbdc0b3f4446b89fc1ac2cf84a8a68971995d0bf2e925ec7cd960f9cb3" [[package]] name = "cc" -version = "1.2.48" +version = "1.2.51" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c481bdbf0ed3b892f6f806287d72acd515b352a4ec27a208489b8c1bc839633a" +checksum = "7a0aeaff4ff1a90589618835a598e545176939b97874f7abc7851caa0618f203" dependencies = [ "find-msvc-tools", "jobserver", @@ -446,9 +447,9 @@ dependencies = [ [[package]] name = "constant_time_eq" -version = "0.3.1" +version = "0.4.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7c74b8349d32d297c9134b8c88677813a227df8f779daa29bfc29c183fe3dca6" +checksum = "3d52eff69cd5e647efe296129160853a42795992097e8af39800e1060caeea9b" [[package]] name = "core-foundation-sys" @@ -456,6 +457,15 @@ version = "0.8.7" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "773648b94d0e5d620f64f280777445740e61fe701025087ec8b57f45c791888b" +[[package]] +name = "cpufeatures" +version = "0.2.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "59ed5838eebb26a2bb2e58f6d5b5316989ae9d08bab10e0e6d103e656d1b0280" +dependencies = [ + "libc", +] + [[package]] name = "crc32fast" version = "1.5.0" @@ -523,6 +533,9 @@ dependencies = [ "hex", "parquet", "pyo3", + "rand", + "rand_pcg", + "regex", "xxhash-rust", ] @@ -534,15 +547,15 @@ checksum = "877a4ace8713b0bcf2a4e7eec82529c029f1d0619886d18145fea96c3ffe5c0f" [[package]] name = "find-msvc-tools" -version = "0.1.5" +version = "0.1.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3a3076410a55c90011c298b04d0cfa770b00fa04e1e3c97d3f6c9de105a03844" +checksum = "645cbb3a84e60b7531617d5ae4e57f7e27308f6445f5abf653209ea76dec8dff" [[package]] name = "flatbuffers" -version = "25.9.23" +version = "25.12.19" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "09b6620799e7340ebd9968d2e0708eb82cf1971e9a16821e2091b6d6e475eed5" +checksum = "35f6839d7b3b98adde531effaf34f0c2badc6f4735d26fe74709d8e513a96ef3" dependencies = [ "bitflags", "rustc_version", @@ -648,9 +661,9 @@ dependencies = [ [[package]] name = "indexmap" -version = "2.12.1" +version = "2.13.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0ad4bb2b565bca0645f4d68c5c9af97fba094e9791da685bf83cb5f3ce74acf2" +checksum = "7714e70437a7dc3ac8eb7e6f8df75fd8eb422675fc7678aff7364301092b1017" dependencies = [ "equivalent", "hashbrown", @@ -658,9 +671,12 @@ dependencies = [ [[package]] name = "indoc" -version = "2.0.6" +version = "2.0.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f4c7245a08504955605670dbf141fceab975f15ca21570696aebe9d2e71576bd" +checksum = "79cf5c93f93228cf8efb3ba362535fb11199ac548a09ce117c9b1adc3030d706" +dependencies = [ + "rustversion", +] [[package]] name = "integer-encoding" @@ -670,9 +686,9 @@ checksum = "8bb03732005da905c88227371639bf1ad885cc712789c011c31c5fb3ab3ccf02" [[package]] name = "itoa" -version = "1.0.15" +version = "1.0.17" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4a5f13b858c8d314ee3e8f639011f7ccefe71f97f96e50151fb991f267928e2c" +checksum = "92ecc6618181def0457392ccd0ee51198e065e016d1d527a7ac1b6dc7c1f09d2" [[package]] name = "jobserver" @@ -753,9 +769,9 @@ dependencies = [ [[package]] name = "libc" -version = "0.2.175" +version = "0.2.180" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6a82ae493e598baaea5209805c49bbf2ea7de956d50d7da0da1164f9c6d28543" +checksum = "bcc35a38544a891a5f7c865aca548a982ccb3b8650a5b06d0fd33a10283c56fc" [[package]] name = "libm" @@ -765,18 +781,18 @@ checksum = "f9fbbcab51052fe104eb5e5d351cf728d30a5be1fe14d9be8a3b097481fb97de" [[package]] name = "libz-rs-sys" -version = "0.5.2" +version = "0.5.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "840db8cf39d9ec4dd794376f38acc40d0fc65eec2a8f484f7fd375b84602becd" +checksum = "c10501e7805cee23da17c7790e59df2870c0d4043ec6d03f67d31e2b53e77415" dependencies = [ "zlib-rs", ] [[package]] name = "log" -version = "0.4.28" +version = "0.4.29" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "34080505efa8e45a4b816c349525ebe327ceaa8559756f0356cba97ef3bf7432" +checksum = "5e5032e24019045c762d3c0f28f5b6b8bbf38563a65908389bf7978758920897" [[package]] name = "lz4_flex" @@ -913,15 +929,24 @@ checksum = "7edddbd0b52d732b21ad9a5fab5c704c14cd949e5e9a1ec5929a24fded1b904c" [[package]] name = "portable-atomic" -version = "1.11.1" +version = "1.13.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f84267b20a16ea918e43c6a88433c2d54fa145c92a811b5b047ccbe153674483" +checksum = "f89776e4d69bb58bc6993e99ffa1d11f228b839984854c7daeb5d37f87cbe950" + +[[package]] +name = "ppv-lite86" +version = "0.2.21" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "85eae3c4ed2f50dcfe72643da4befc30deadb458a9b590d720cde2f2b1e97da9" +dependencies = [ + "zerocopy", +] [[package]] name = "proc-macro2" -version = "1.0.101" +version = "1.0.105" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "89ae43fd86e4158d6db51ad8e2b80f313af9cc74f5c0e03ccb87de09998732de" +checksum = "535d180e0ecab6268a3e718bb9fd44db66bbbc256257165fc699dadf70d16fe7" dependencies = [ "unicode-ident", ] @@ -989,9 +1014,9 @@ dependencies = [ [[package]] name = "quote" -version = "1.0.40" +version = "1.0.43" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1885c039570dc00dcb4ff087a89e185fd56bae234ddc7f056a945bf36467248d" +checksum = "dc74d9a594b72ae6656596548f56f667211f8a97b3d4c3d467150794690dc40a" dependencies = [ "proc-macro2", ] @@ -1002,6 +1027,45 @@ version = "5.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "69cdb34c158ceb288df11e18b4bd39de994f6657d83847bdffdbd7f346754b0f" +[[package]] +name = "rand" +version = "0.8.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "34af8d1a0e25924bc5b7c43c079c942339d8f0a8b57c39049bef581b46327404" +dependencies = [ + "libc", + "rand_chacha", + "rand_core", +] + +[[package]] +name = "rand_chacha" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e6c10a63a0fa32252be49d21e7709d4d4baf8d231c2dbce1eaa8141b9b127d88" +dependencies = [ + "ppv-lite86", + "rand_core", +] + +[[package]] +name = "rand_core" +version = "0.6.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ec0be4795e2f6a28069bec0b5ff3e2ac9bafc99e6a9a7dc3547996c5c816922c" +dependencies = [ + "getrandom 0.2.16", +] + +[[package]] +name = "rand_pcg" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "59cad018caf63deb318e5a4586d99a24424a364f40f1e5778c29aca23f4fc73e" +dependencies = [ + "rand_core", +] + [[package]] name = "regex" version = "1.12.2" @@ -1048,9 +1112,9 @@ checksum = "b39cdef0fa800fc44525c84ccb54a029961a8215f9619753635a9c0d2538d46d" [[package]] name = "ryu" -version = "1.0.20" +version = "1.0.22" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "28d3b2b1366ec20994f1fd18c3c594f05c5dd4bc44d8bb0c1c632c8d6829481f" +checksum = "a50f4cf475b65d88e057964e0e9bb1f0aa9bbb2036dc65c64596b42932536984" [[package]] name = "semver" @@ -1095,15 +1159,15 @@ dependencies = [ [[package]] name = "serde_json" -version = "1.0.145" +version = "1.0.149" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "402a6f66d8c709116cf22f558eab210f5a50187f702eb4d7e5ef38d9a7f1c79c" +checksum = "83fc039473c5595ace860d8c4fafa220ff474b3fc6bfdb4293327f1a37e94d86" dependencies = [ "itoa", "memchr", - "ryu", "serde", "serde_core", + "zmij", ] [[package]] @@ -1114,9 +1178,9 @@ checksum = "0fda2ff0d084019ba4d7c6f371c95d8fd75ce3524c3cb8fb653a3023f6323e64" [[package]] name = "simd-adler32" -version = "0.3.7" +version = "0.3.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d66dc143e6b11c1eddc06d5c423cfc97062865baf299914ab64caa38182078fe" +checksum = "e320a6c5ad31d271ad523dcf3ad13e2767ad8b1cb8f047f75a8aeaf8da139da2" [[package]] name = "simdutf8" @@ -1138,9 +1202,9 @@ checksum = "13c2bddecc57b384dee18652358fb23172facb8a2c51ccc10d74c157bdea3292" [[package]] name = "syn" -version = "2.0.106" +version = "2.0.114" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ede7c438028d4436d71104916910f5bb611972c5cfd7f89b8300a8186e6fada6" +checksum = "d4d107df263a3013ef9b1879b0df87d706ff80f65a86ea879bd9c31f9b307c2a" dependencies = [ "proc-macro2", "quote", @@ -1149,9 +1213,9 @@ dependencies = [ [[package]] name = "target-lexicon" -version = "0.13.2" +version = "0.13.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e502f78cdbb8ba4718f566c418c52bc729126ffd16baee5baa718cf25dd5a69a" +checksum = "b1dd07eb858a2067e2f3c7155d54e929265c264e6f37efe3ee7a8d1b5a1dd0ba" [[package]] name = "thrift" @@ -1187,9 +1251,9 @@ checksum = "562d481066bde0658276a35467c4af00bdc6ee726305698a55b86e61d7ad82bb" [[package]] name = "unicode-ident" -version = "1.0.18" +version = "1.0.22" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5a5f39404a5da50712a4c1eecf25e90dd62b613502b7e925fd4e4d19b5c96512" +checksum = "9312f7c4f6ff9069b165498234ce8be658059c6728633667c526e27dc2cf1df5" [[package]] name = "unindent" @@ -1336,18 +1400,18 @@ checksum = "fdd20c5420375476fbd4394763288da7eb0cc0b8c11deed431a91562af7335d3" [[package]] name = "zerocopy" -version = "0.8.31" +version = "0.8.33" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fd74ec98b9250adb3ca554bdde269adf631549f51d8a8f8f0a10b50f1cb298c3" +checksum = "668f5168d10b9ee831de31933dc111a459c97ec93225beb307aed970d1372dfd" dependencies = [ "zerocopy-derive", ] [[package]] name = "zerocopy-derive" -version = "0.8.31" +version = "0.8.33" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d8a8d209fdf45cf5138cbb5a506f6b52522a25afccc534d1475dad8e31105c6a" +checksum = "2c7962b26b0a8685668b671ee4b54d007a67d4eaf05fda79ac0ecf41e32270f1" dependencies = [ "proc-macro2", "quote", @@ -1356,9 +1420,15 @@ dependencies = [ [[package]] name = "zlib-rs" -version = "0.5.2" +version = "0.5.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "40990edd51aae2c2b6907af74ffb635029d5788228222c4bb811e9351c0caad3" + +[[package]] +name = "zmij" +version = "1.0.12" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2f06ae92f42f5e5c42443fd094f245eb656abf56dd7cce9b8b263236565e00f2" +checksum = "2fc5a66a20078bf1251bde995aa2fdcc4b800c70b5d92dd2c62abc5c60f679f8" [[package]] name = "zstd" diff --git a/lib/dupekit/Cargo.toml b/lib/dupekit/Cargo.toml index a0491d393f..ca33497de4 100644 --- a/lib/dupekit/Cargo.toml +++ b/lib/dupekit/Cargo.toml @@ -18,4 +18,7 @@ pyo3 = { version = "0.26", features = [ "extension-module", "abi3-py311", ] } # stable ABI with minimum Python version 3.11 +rand = "0.8" +rand_pcg = "0.3" +regex = "1.10" xxhash-rust = { version = "0.8", features = ["xxh3"] } diff --git a/lib/dupekit/README.md b/lib/dupekit/README.md index 069d9a38f6..b4398849ff 100644 --- a/lib/dupekit/README.md +++ b/lib/dupekit/README.md @@ -37,6 +37,7 @@ uv run pytest lib/dupekit/tests/bench/test_marshaling.py --run-benchmark uv run pytest lib/dupekit/tests/bench/test_batch_tuning.py --run-benchmark uv run pytest lib/dupekit/tests/bench/test_io.py --run-benchmark uv run pytest lib/dupekit/tests/bench/test_hashing.py --run-benchmark +uv run pytest lib/dupekit/tests/bench/test_minhash.py --run-benchmark ``` Note: Run separated by type of benchmark (otherwise results are mixed within one table) diff --git a/lib/dupekit/dupekit.pyi b/lib/dupekit/dupekit.pyi index 62df9b634d..9906b33e7d 100644 --- a/lib/dupekit/dupekit.pyi +++ b/lib/dupekit/dupekit.pyi @@ -141,6 +141,35 @@ class Transformation: """Projects the batch to keep only the specified columns.""" ... + @staticmethod + def CleanText(input_col: str, output_col: str) -> "Transformation": + """Normalizes text (lowercase, remove punctuation, normalize whitespace).""" + ... + + @staticmethod + def MinHash(input_col: str, output_col: str, num_perms: int, ngram_size: int, seed: int) -> "Transformation": + """Computes MinHash signatures for the input text using fused shingling/hashing. + + Args: + input_col: Column containing text. + output_col: Column to store signature (List[uint64]). + num_perms: Number of permutation functions (length of signature). + ngram_size: Size of char-ngrams for shingling. + seed: Random seed for permutation coefficients. + """ + ... + + @staticmethod + def MinHashLSH(input_col: str, output_col: str, num_bands: int) -> "Transformation": + """Computes LSH buckets from a MinHash signature. + + Args: + input_col: Column containing signatures (List[uint64]). + output_col: Column to store buckets (List[uint64]). + num_bands: Number of LSH bands. + """ + ... + def transform(batch: pa.RecordBatch, steps: list[Transformation]) -> pa.RecordBatch: ... def mark_paragraph_duplicates( batch: pa.RecordBatch, diff --git a/lib/dupekit/src/lib.rs b/lib/dupekit/src/lib.rs index 21020587b0..a2e632fc15 100644 --- a/lib/dupekit/src/lib.rs +++ b/lib/dupekit/src/lib.rs @@ -4,6 +4,7 @@ mod bloom; mod dedupe; mod hashing; mod marshaling; +mod minhash_ops; mod ops; mod pipeline; diff --git a/lib/dupekit/src/minhash_ops.rs b/lib/dupekit/src/minhash_ops.rs new file mode 100644 index 0000000000..f9e05a9fc6 --- /dev/null +++ b/lib/dupekit/src/minhash_ops.rs @@ -0,0 +1,159 @@ +use arrow::array::{Array, ListArray, ListBuilder, StringArray, StringBuilder, UInt64Builder}; +use pyo3::exceptions::PyValueError; +use pyo3::prelude::*; +use rand::{Rng, SeedableRng}; +use rand_pcg::Pcg64; +use regex::Regex; +use std::sync::Arc; +use xxhash_rust::xxh3; + +/// Clean text using the SlimPajama text cleaning process. +/// 1. Lowercase +/// 2. Remove punctuation +/// 3. Replace multiple whitespace with single space +/// 4. Trim +pub fn clean_text(arr: &StringArray) -> PyResult> { + let mut builder = StringBuilder::with_capacity(arr.len(), arr.len() * 50); + let whitespace_re = Regex::new(r"\s+").map_err(|e| PyValueError::new_err(e.to_string()))?; + let punctuation: &[char] = &[ + '!', '"', '#', '$', '%', '&', '\'', '(', ')', '*', '+', ',', '-', '.', '/', ':', ';', '<', + '=', '>', '?', '@', '[', '\\', ']', '^', '_', '`', '{', '|', '}', '~', + ]; + + for i in 0..arr.len() { + if arr.is_null(i) { + builder.append_null(); + continue; + } + + let text = arr.value(i); + let lower = text.to_lowercase(); + let no_punct: String = lower.chars().filter(|c| !punctuation.contains(c)).collect(); + let normalized = whitespace_re.replace_all(&no_punct, " "); + builder.append_value(normalized.trim()); + } + + Ok(Arc::new(builder.finish())) +} + +/// Fused operation: Shingling -> Hashing -> Permutation -> Min extraction +pub fn compute_minhash( + arr: &StringArray, + num_perms: usize, + ngram_size: usize, + seed: u64, +) -> PyResult> { + // Generate permutations using Duplodocus strategy (Single u128 coefficient) + let mut rng = Pcg64::seed_from_u64(seed); + let mut coeffs = Vec::with_capacity(num_perms); + + for _ in 0..num_perms { + // Duplodocus ensures coefficients are odd to preserve properties of the permutation group + let mut c = rng.gen::(); + if c % 2 == 0 { + c = c.wrapping_add(1); + } + coeffs.push(c); + } + + let values_builder = UInt64Builder::with_capacity(arr.len() * num_perms); + let mut list_builder = ListBuilder::new(values_builder); + + for i in 0..arr.len() { + if arr.is_null(i) { + list_builder.append_null(); + continue; + } + + let text = arr.value(i); + let chars: Vec = text.chars().collect(); + let mut signature = vec![u64::MAX; num_perms]; + + if chars.len() < ngram_size { + let hash = xxh3::xxh3_64(text.as_bytes()) as u128; + update_signature(&mut signature, hash, &coeffs); + } else { + for window in chars.windows(ngram_size) { + let s: String = window.iter().collect(); + let hash = xxh3::xxh3_64(s.as_bytes()) as u128; + update_signature(&mut signature, hash, &coeffs); + } + } + list_builder.values().append_slice(&signature); + list_builder.append(true); + } + Ok(Arc::new(list_builder.finish())) +} + +#[inline(always)] +fn update_signature(signature: &mut [u64], hash: u128, coeffs: &[u128]) { + // Logic: (hash * coeff) >> 64. Similar to Duplodocus + for (sig_val, &coeff) in signature.iter_mut().zip(coeffs) { + let permuted_hash = (hash.wrapping_mul(coeff) >> 64) as u64; + if permuted_hash < *sig_val { + *sig_val = permuted_hash; + } + } +} + +pub fn compute_lsh(input_col: &dyn Array, num_bands: usize) -> PyResult> { + let list_arr = input_col + .as_any() + .downcast_ref::() + .ok_or_else(|| { + PyValueError::new_err("Input to MinHashLSH must be a ListArray of UInt64") + })?; + + let values_arr = list_arr + .values() + .as_any() + .downcast_ref::() + .ok_or_else(|| PyValueError::new_err("Inner array must be UInt64"))?; + + let out_values_builder = UInt64Builder::with_capacity(list_arr.len() * num_bands); + let mut out_list_builder = ListBuilder::new(out_values_builder); + + for i in 0..list_arr.len() { + if list_arr.is_null(i) { + out_list_builder.append_null(); + continue; + } + + let start = list_arr.value_offsets()[i] as usize; + let end = list_arr.value_offsets()[i + 1] as usize; + let sig_len = end - start; + + if sig_len == 0 { + // Empty signature + out_list_builder.append(true); + continue; + } + + if sig_len % num_bands != 0 { + return Err(PyValueError::new_err(format!( + "Signature length {} is not divisible by num_bands {}", + sig_len, num_bands + ))); + } + + let rows_per_band = sig_len / num_bands; + let slice = &values_arr.values()[start..end]; + + for band_idx in 0..num_bands { + let band_start = band_idx * rows_per_band; + let band_end = band_start + rows_per_band; + let band_data = &slice[band_start..band_end]; + let band_bytes: &[u8] = unsafe { + std::slice::from_raw_parts( + band_data.as_ptr() as *const u8, + band_data.len() * std::mem::size_of::(), + ) + }; + let bucket_hash = xxh3::xxh3_64(band_bytes); + out_list_builder.values().append_value(bucket_hash); + } + out_list_builder.append(true); + } + + Ok(Arc::new(out_list_builder.finish())) +} diff --git a/lib/dupekit/src/pipeline.rs b/lib/dupekit/src/pipeline.rs index f34e0c1748..4608bd0746 100644 --- a/lib/dupekit/src/pipeline.rs +++ b/lib/dupekit/src/pipeline.rs @@ -1,4 +1,5 @@ use crate::hashing::{HashAlgorithm, DEFAULT_HASH_ALGO}; +use crate::minhash_ops; use crate::ops; use arrow::array::{Array, StringBuilder}; use arrow::datatypes::{Field, Schema}; @@ -29,6 +30,23 @@ pub enum Transformation { SelectColumns { columns: Vec, }, + // MinHash Pipeline Ops + CleanText { + input_col: String, + output_col: String, + }, + MinHash { + input_col: String, + output_col: String, + num_perms: usize, + ngram_size: usize, + seed: u64, + }, + MinHashLSH { + input_col: String, + output_col: String, + num_bands: usize, + }, } #[pymethods] @@ -64,6 +82,47 @@ impl Transformation { fn select_columns(columns: Vec) -> Transformation { Self::SelectColumns { columns } } + + #[staticmethod] + #[pyo3(name = "CleanText")] + fn clean_text(input_col: String, output_col: String) -> Transformation { + Self::CleanText { + input_col, + output_col, + } + } + + #[staticmethod] + #[pyo3(name = "MinHash")] + fn min_hash( + input_col: String, + output_col: String, + num_perms: usize, + ngram_size: usize, + seed: u64, + ) -> Transformation { + Self::MinHash { + input_col, + output_col, + num_perms, + ngram_size, + seed, + } + } + + #[staticmethod] + #[pyo3(name = "MinHashLSH")] + fn min_hash_lsh( + input_col: String, + output_col: String, + num_bands: usize, + ) -> Transformation { + Self::MinHashLSH { + input_col, + output_col, + num_bands, + } + } } fn apply_transformation(batch: RecordBatch, step: &Transformation) -> PyResult { @@ -122,6 +181,40 @@ fn apply_transformation(batch: RecordBatch, step: &Transformation) -> PyResult ops::select_columns(&batch, columns), + + Transformation::CleanText { + input_col, + output_col, + } => { + let input_arr = ops::get_string_array(&batch, input_col)?; + let clean_arr = minhash_ops::clean_text(&input_arr)?; + ops::add_column(&batch, output_col, clean_arr) + } + + Transformation::MinHash { + input_col, + output_col, + num_perms, + ngram_size, + seed, + } => { + let input_arr = ops::get_string_array(&batch, input_col)?; + let signature_arr = + minhash_ops::compute_minhash(&input_arr, *num_perms, *ngram_size, *seed)?; + ops::add_column(&batch, output_col, signature_arr.into()) + } + + Transformation::MinHashLSH { + input_col, + output_col, + num_bands, + } => { + let input_arr = batch.column_by_name(input_col).ok_or_else(|| { + PyRuntimeError::new_err(format!("Column '{}' missing", input_col)) + })?; + let buckets_arr = minhash_ops::compute_lsh(input_arr.as_ref(), *num_bands)?; + ops::add_column(&batch, output_col, buckets_arr.into()) + } } } diff --git a/lib/dupekit/tests/bench/conftest.py b/lib/dupekit/tests/bench/conftest.py index 0f68a0448e..196c2c8845 100644 --- a/lib/dupekit/tests/bench/conftest.py +++ b/lib/dupekit/tests/bench/conftest.py @@ -13,6 +13,8 @@ # limitations under the License. import pytest +import pyarrow as pa +import pyarrow.parquet as pq from typing import Any from huggingface_hub import hf_hub_download @@ -35,6 +37,47 @@ def parquet_file() -> str: return file_path +@pytest.fixture(scope="session") +def small_parquet_path(tmp_path_factory: pytest.TempPathFactory, parquet_file: str) -> str: + """ + Creates a smaller slice (250k rows) of the main parquet file for faster benchmarking + and I/O tests. + """ + fn = tmp_path_factory.mktemp("data_io") / "subset.parquet" + pf = pq.ParquetFile(parquet_file) + # 250k rows is substantial enough for I/O throughput tests + first_batch = next(pf.iter_batches(batch_size=250_000)) + table = pa.Table.from_batches([first_batch]) + pq.write_table(table, fn) + path_str = str(fn) + + # Warm up OS cache for this new file + with open(path_str, "rb") as f: + while f.read(1024**2): + pass + + return path_str + + +@pytest.fixture(scope="session") +def in_memory_table(small_parquet_path: str) -> pa.Table: + """ + Loads 250k rows into memory once. Used for marshaling and batch size tuning benchmarks. + """ + return pq.read_table(small_parquet_path) + + +@pytest.fixture(scope="session") +def sample_batch(parquet_file: str) -> pa.RecordBatch: + """ + Loads a single batch (10k rows) for algorithm benchmarks (hashing, dedupe logic). + Columns are restricted to ensure we have 'text' and 'id'. + """ + pf = pq.ParquetFile(parquet_file) + # Ensure we get necessary columns if they exist, though 'iter_batches' defaults to all. + return next(pf.iter_batches(batch_size=10_000)) + + def pytest_addoption(parser: Any) -> None: parser.addoption("--run-benchmark", action="store_true", default=False, help="run benchmark tests") diff --git a/lib/dupekit/tests/bench/test_batch_tuning.py b/lib/dupekit/tests/bench/test_batch_tuning.py index d8afa93a15..d53d7738e5 100644 --- a/lib/dupekit/tests/bench/test_batch_tuning.py +++ b/lib/dupekit/tests/bench/test_batch_tuning.py @@ -13,22 +13,11 @@ # limitations under the License. import pytest -import pyarrow.parquet as pq import pyarrow as pa from typing import Any import dupekit -@pytest.fixture(scope="module") -def in_memory_table(parquet_file: str) -> pa.Table: - """ - Loads 100k rows into memory. - """ - pf = pq.ParquetFile(parquet_file) - first_batch = next(pf.iter_batches(batch_size=100_000)) - return pa.Table.from_batches([first_batch]) - - @pytest.mark.parametrize("batch_size", [1, 128, 512, 1024, 2048, 4096, 8192, 16384, 32768, 65536, 131072]) def test_arrow_batch_sizes(benchmark: Any, in_memory_table: pa.Table, batch_size: int) -> None: """ diff --git a/lib/dupekit/tests/bench/test_dedupe.py b/lib/dupekit/tests/bench/test_dedupe.py index 7bf44f6d70..d38841ee6b 100644 --- a/lib/dupekit/tests/bench/test_dedupe.py +++ b/lib/dupekit/tests/bench/test_dedupe.py @@ -15,18 +15,10 @@ import hashlib import pytest import pyarrow as pa -import pyarrow.parquet as pq from typing import Any import dupekit -@pytest.fixture(scope="module") -def sample_data(parquet_file: str) -> pa.RecordBatch: - """Loads a single large batch (10k rows) for benchmarking.""" - pf = pq.ParquetFile(parquet_file) - return next(pf.iter_batches(batch_size=10_000)) - - def build_map(hashes: list[str], ids: list[str]) -> dict[str, Any]: """ Builds a duplicate map where ~50% of items are marked as canonical and 50% as duplicates. @@ -157,23 +149,23 @@ def rust_mark_document_duplicates( @pytest.mark.parametrize("granularity", ["paragraphs", "documents"]) @pytest.mark.parametrize("backend", ["python", "rust"]) -def test_hashing(benchmark: Any, sample_data: pa.RecordBatch, granularity: str, backend: str) -> None: +def test_hashing(benchmark: Any, sample_batch: pa.RecordBatch, granularity: str, backend: str) -> None: """Benchmark the hash generation step.""" func = PROCESS_FUNCS[(granularity, backend)] benchmark.group = f"{granularity.title()}: Hash Generation" - benchmark(func, sample_data, "text", "id") + benchmark(func, sample_batch, "text", "id") @pytest.mark.parametrize("granularity", ["paragraphs", "documents"]) @pytest.mark.parametrize("backend", ["python", "rust"]) -def test_deduplication(benchmark: Any, sample_data: pa.RecordBatch, granularity: str, backend: str) -> None: +def test_deduplication(benchmark: Any, sample_batch: pa.RecordBatch, granularity: str, backend: str) -> None: """Benchmark the duplicate marking step (requires pre-calculated map).""" process_func = PROCESS_FUNCS[(granularity, "rust")] # Always use fast Rust version to build map - processed = process_func(sample_data, "text", "id") + processed = process_func(sample_batch, "text", "id") hashes, ids = extract_results(processed) dup_map = build_map(hashes, ids) mark_func = MARK_FUNCS[(granularity, backend)] benchmark.group = f"{granularity.title()}: Exact Deduplication" - benchmark(mark_func, sample_data, "text", "id", dup_map, "dups") + benchmark(mark_func, sample_batch, "text", "id", dup_map, "dups") diff --git a/lib/dupekit/tests/bench/test_hashing.py b/lib/dupekit/tests/bench/test_hashing.py index 8adaa059d0..7adeda48a4 100644 --- a/lib/dupekit/tests/bench/test_hashing.py +++ b/lib/dupekit/tests/bench/test_hashing.py @@ -16,18 +16,11 @@ from collections.abc import Callable import pytest -import pyarrow.parquet as pq +import pyarrow as pa import hashlib import dupekit -@pytest.fixture(scope="module") -def text_samples(parquet_file: str) -> list[bytes]: - table = pq.read_table(parquet_file) - texts = table["text"][:10000].to_pylist() - return [t.encode("utf-8") for t in texts] - - def _py_blake2b(text: bytes) -> bytes: return hashlib.blake2b(text).digest() @@ -43,7 +36,9 @@ def _py_blake2b(text: bytes) -> bytes: pytest.param(dupekit.hash_xxh3_64_batch, "batch", id="rust_xxh3_64_batch"), ], ) -def test_hashing_throughput(benchmark: Any, text_samples: list[bytes], func: Callable, mode: str) -> None: +def test_hashing_throughput(benchmark: Any, sample_batch: pa.RecordBatch, func: Callable, mode: str) -> None: + # Use the sample_batch fixture (10k rows) and convert to bytes + text_samples = [t.as_py().encode("utf-8") for t in sample_batch["text"]] def _run() -> list[Any]: if mode == "batch": diff --git a/lib/dupekit/tests/bench/test_io.py b/lib/dupekit/tests/bench/test_io.py index 03bace80d4..43bb50e1fd 100644 --- a/lib/dupekit/tests/bench/test_io.py +++ b/lib/dupekit/tests/bench/test_io.py @@ -19,38 +19,17 @@ import pytest import pyarrow.parquet as pq -import pyarrow as pa from typing import Any import dupekit -@pytest.fixture(scope="module") -def small_parquet_file(tmp_path_factory: pytest.TempPathFactory, parquet_file: str) -> str: - """ - Creates a smaller slice (250k rows) of the main parquet file for faster benchmarking. - """ - fn = tmp_path_factory.mktemp("data_io") / "subset.parquet" - pf = pq.ParquetFile(parquet_file) - first_batch = next(pf.iter_batches(batch_size=250_000)) - table = pa.Table.from_batches([first_batch]) - pq.write_table(table, fn) - path_str = str(fn) - - # Warm up OS cache - with open(path_str, "rb") as f: - while f.read(1024**2): - pass - - return path_str - - -def test_rust_native(benchmark: Any, small_parquet_file: str) -> None: +def test_rust_native(benchmark: Any, small_parquet_path: str) -> None: """ Baseline: Rust reads file from disk, parses Parquet, transforms, returns RecordBatch. """ def _run() -> int: - return len(dupekit.process_native(small_parquet_file)) + return len(dupekit.process_native(small_parquet_path)) assert benchmark(_run) > 0 @@ -62,27 +41,27 @@ def _run() -> int: pytest.param(1024, id="small"), ], ) -def test_arrow_io_pipeline(benchmark: Any, small_parquet_file: str, batch_size: int) -> None: +def test_arrow_io_pipeline(benchmark: Any, small_parquet_path: str, batch_size: int) -> None: """ Python End-to-End: Python reads file -> Stream of RecordBatches -> Rust (called per batch). Includes Parquet parsing overhead and Python loop overhead. """ def _pipeline() -> int: - batches = pq.ParquetFile(small_parquet_file).iter_batches(batch_size=batch_size) + batches = pq.ParquetFile(small_parquet_path).iter_batches(batch_size=batch_size) return sum(len(dupekit.process_arrow_batch(b)) for b in batches) assert benchmark(_pipeline) > 0 -def test_dicts_loop_io(benchmark: Any, small_parquet_file: str) -> None: +def test_dicts_loop_io(benchmark: Any, small_parquet_path: str) -> None: """ Python End-to-End: Read File -> List[dict] -> Loop calling Rust per item -> List[dict]. Slowest Python approach (Baseline for worst case). """ def _pipeline() -> int: - docs = pq.read_table(small_parquet_file).to_pylist() + docs = pq.read_table(small_parquet_path).to_pylist() return len([dupekit.process_dicts_loop(doc) for doc in docs]) assert benchmark(_pipeline) > 0 diff --git a/lib/dupekit/tests/bench/test_marshaling.py b/lib/dupekit/tests/bench/test_marshaling.py index 7fa8d4d1ce..e881ded4eb 100644 --- a/lib/dupekit/tests/bench/test_marshaling.py +++ b/lib/dupekit/tests/bench/test_marshaling.py @@ -18,22 +18,11 @@ """ import pytest -import pyarrow.parquet as pq import pyarrow as pa from typing import Any import dupekit -@pytest.fixture(scope="module") -def in_memory_table(tmp_path_factory: pytest.TempPathFactory, parquet_file: str) -> pa.Table: - """ - Loads 250k rows into memory once. - """ - pf = pq.ParquetFile(parquet_file) - first_batch = next(pf.iter_batches(batch_size=250_000)) - return pa.Table.from_batches([first_batch]) - - @pytest.mark.parametrize( "batch_size", [ diff --git a/lib/dupekit/tests/bench/test_minhash.py b/lib/dupekit/tests/bench/test_minhash.py new file mode 100644 index 0000000000..b8ffa65d12 --- /dev/null +++ b/lib/dupekit/tests/bench/test_minhash.py @@ -0,0 +1,36 @@ +# Copyright 2025 The Marin Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pyarrow as pa +from typing import Any +import dupekit +from dupekit import Transformation + +# Python is slow, can't use too many rows +BENCHMARK_ROWS = 1000 + + +def rust_minhash_pipeline(batch: pa.RecordBatch) -> int: + pipeline = [ + Transformation.CleanText(input_col="text", output_col="clean"), + Transformation.MinHash(input_col="clean", output_col="sig", num_perms=286, ngram_size=5, seed=42), + Transformation.MinHashLSH(input_col="sig", output_col="buckets", num_bands=26), + ] + res = dupekit.transform(batch, pipeline) + return len(res) + + +def test_bench_rust_minhash(benchmark: Any, sample_batch: pa.RecordBatch) -> None: + batch = sample_batch.slice(length=BENCHMARK_ROWS) + benchmark(rust_minhash_pipeline, batch) diff --git a/lib/dupekit/tests/test_minhash.py b/lib/dupekit/tests/test_minhash.py new file mode 100644 index 0000000000..183ec66183 --- /dev/null +++ b/lib/dupekit/tests/test_minhash.py @@ -0,0 +1,68 @@ +# Copyright 2025 The Marin Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pyarrow as pa +from dupekit import Transformation, transform + + +def test_clean_text(): + """Test text cleaning (lowercase, punct removal, whitespace norm).""" + text = "Hello, World! This is a test." + expected = "hello world this is a test" + batch = pa.RecordBatch.from_pydict({"text": [text, None, " "]}) + pipeline = [Transformation.CleanText(input_col="text", output_col="clean")] + clean = transform(batch, pipeline)["clean"] + assert clean[0].as_py() == expected + assert clean[1].as_py() is None + assert clean[2].as_py() == "" + + +def test_minhash_dimensions(): + """Test that MinHash output has correct dimensions.""" + texts = ["doc one", "doc two"] + num_perms = 128 + batch = pa.RecordBatch.from_pydict({"text": texts}) + pipeline = [Transformation.MinHash(input_col="text", output_col="sig", num_perms=num_perms, ngram_size=3, seed=42)] + sigs = transform(batch, pipeline)["sig"] + for sig in sigs: + assert len(sig.as_py()) == num_perms + assert all(isinstance(x, int) for x in sig.as_py()) + + +def test_minhash_lsh_dimensions(): + """Test LSH banding logic.""" + num_bands = 26 + sig = list(range(286)) + batch = pa.RecordBatch.from_pydict({"sig": [sig]}, schema=pa.schema([("sig", pa.list_(pa.uint64()))])) + pipeline = [Transformation.MinHashLSH(input_col="sig", output_col="buckets", num_bands=num_bands)] + res = transform(batch, pipeline) + buckets = res["buckets"][0].as_py() + assert len(buckets) == num_bands + res2 = transform(batch, pipeline) + assert buckets == res2["buckets"][0].as_py() + + +def test_full_pipeline_determinism(): + """Test that the full MinHash pipeline produces deterministic results.""" + text = "The quick brown fox jumps over the lazy dog." + batch = pa.RecordBatch.from_pydict({"text": [text, text]}) + pipeline = [ + Transformation.CleanText(input_col="text", output_col="clean"), + Transformation.MinHash(input_col="clean", output_col="sig", num_perms=20, ngram_size=5, seed=1), + Transformation.MinHashLSH(input_col="sig", output_col="buckets", num_bands=4), + ] + res = transform(batch, pipeline) + b0 = res["buckets"][0].as_py() + b1 = res["buckets"][1].as_py() + assert b0 == b1 diff --git a/lib/marin/src/marin/processing/classification/deduplication/pipeline.py b/lib/marin/src/marin/processing/classification/deduplication/pipeline.py index 3c99011ef8..ffd5159e4d 100644 --- a/lib/marin/src/marin/processing/classification/deduplication/pipeline.py +++ b/lib/marin/src/marin/processing/classification/deduplication/pipeline.py @@ -31,7 +31,6 @@ from marin.execution.executor import THIS_OUTPUT_PATH from marin.processing.classification.deduplication.connected_components import connected_components -from marin.processing.classification.deduplication.minhash_lsh import minhash_lsh from marin.utilities.time_logger import log_time import pyarrow as pa import pyarrow.json as pa_json @@ -436,10 +435,43 @@ def compute_document_hashes(batch: pa.RecordBatch) -> pa.RecordBatch: exact_cnts = _compute_dedup_stats(duplicate_key_shards, method="exact", level="document") logger.info(str(exact_cnts)) - doc_minhash_lsh = minhash_lsh( + def compute_minhash_lsh_batches(batch: pa.RecordBatch) -> Iterator[dict]: + """ + Runs the Rust-optimized MinHash LSH pipeline on a RecordBatch. + Yields {bucket: str, id: Any} for each bucket hit. + """ + pipeline = [ + Transformation.ResolveIds(text_col=config.text_field, id_col="id", output_col="resolved_id"), + Transformation.CleanText(input_col=config.text_field, output_col="clean_text"), + Transformation.MinHash( + input_col="clean_text", + output_col="signature", + num_perms=286, # 26 bands * 11 rows + ngram_size=5, + seed=42, + ), + Transformation.MinHashLSH(input_col="signature", output_col="buckets", num_bands=26), + Transformation.SelectColumns(columns=["resolved_id", "buckets"]), + ] + + result_batch = dupekit.transform(batch, pipeline) + + ids = result_batch["resolved_id"] + buckets = result_batch["buckets"] + + for doc_id, doc_buckets in zip(ids, buckets, strict=False): + if not doc_buckets.is_valid: + continue + + doc_id_val = doc_id.as_py() + for b in doc_buckets.as_py(): + yield {"bucket": str(b), "id": doc_id_val} + + doc_minhash_lsh = ( Dataset.from_list(input_files) - .flat_map(load_file) + .flat_map(lambda f: _load_batches(f, columns=[config.text_field, "id"])) .reshard(num_shards=config.processes if len(input_files) < 42 else None) + .flat_map(compute_minhash_lsh_batches) ) converged, cc_files = connected_components(doc_minhash_lsh, ctx=ctx, output_dir=f"{config.output_path}/metadata/cc") if not converged: diff --git a/tests/processing/classification/deduplication/__init__.py b/tests/processing/classification/deduplication/__init__.py new file mode 100644 index 0000000000..731b4c72e7 --- /dev/null +++ b/tests/processing/classification/deduplication/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2025 The Marin Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License.