diff --git a/.gitignore b/.gitignore index b53cac6..11fae36 100644 --- a/.gitignore +++ b/.gitignore @@ -20,3 +20,4 @@ Cargo.lock #VS Code .vscode +.idea diff --git a/Cargo.toml b/Cargo.toml index 18685ba..919cc3a 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -2,13 +2,12 @@ name = "generalized_suffix_tree" license = "MIT" description = "Implementation of Generalized Suffix Tree using Ukkonen's algorithm in Rust" -version = "1.2.1" +version = "1.2.2" authors = ["Xun Li "] edition = "2018" repository = "https://github.com/lxfind/rust-generalized-suffix-tree" [dependencies] -mediumvec = "1.0.4" [dev-dependencies] -rand = "0.7.0" +rand = "0.7" diff --git a/src/lib.rs b/src/lib.rs index df320a7..4a8c871 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -3,37 +3,30 @@ mod disjoint_set; use std::collections::HashMap; -use mediumvec::{Vec32, vec32}; - -type NodeID = u32; -type StrID = u32; -type IndexType = u32; -type CharType = u8; - // Special nodes. -const ROOT: NodeID = 0; -const SINK: NodeID = 1; -const INVALID: NodeID = NodeID::max_value(); +const ROOT: u32 = 0; +const SINK: u32 = 1; +const INVALID: u32 = u32::max_value(); /// This structure represents a slice to a string. #[derive(Debug, Clone)] struct MappedSubstring { /// Unique ID of the string it's slicing, which can be used to locate the string from the tree's string storage. - str_id: StrID, + str_id: u32, /// Index of the first character of the slice. - start: IndexType, + start: u32, /// One past the index of the last character of the slice. /// e.g. when `end` is equal to `start`, this is an empty slice. /// Note that `end` here always represents a meaningful index, unlike in the original algorithm where a slice could potentially be open-ended. /// Such open-endedness allows for online construction of the tree. Here I chose to not support online construction for convenience. It's possible - /// to support it by changing `end`'s type to `Option`. - end: IndexType, + /// to support it by changing `end`'s type to `Option`. + end: u32, } impl MappedSubstring { - const fn new(str_id: StrID, start: IndexType, end: IndexType) -> Self { + const fn new(str_id: u32, start: u32, end: u32) -> Self { Self { str_id, start, end } } @@ -41,7 +34,7 @@ impl MappedSubstring { self.start == self.end } - const fn len(&self) -> IndexType { + const fn len(&self) -> u32 { self.end - self.start } } @@ -56,16 +49,16 @@ impl MappedSubstring { /// node represents. By doing so we avoid having an explicit edge data type. #[derive(Debug)] struct Node { - transitions: HashMap, + transitions: HashMap, - suffix_link: NodeID, + suffix_link: u32, /// The slice of the string this node represents. substr: MappedSubstring, } impl Node { - fn new(str_id: StrID, start: IndexType, end: IndexType) -> Self { + fn new(str_id: u32, start: u32, end: u32) -> Self { Self { transitions: HashMap::new(), suffix_link: INVALID, @@ -73,7 +66,7 @@ impl Node { } } - fn get_suffix_link(&self) -> NodeID { + fn get_suffix_link(&self) -> u32 { assert!(self.suffix_link != INVALID, "Invalid suffix link"); self.suffix_link } @@ -82,17 +75,17 @@ impl Node { /// A data structure used to store the current state during the Ukkonen's algorithm. struct ReferencePoint { /// The active node. - node: NodeID, + node: u32, /// The current string we are processing. - str_id: StrID, + str_id: u32, /// The active point. - index: IndexType, + index: u32, } impl ReferencePoint { - const fn new(node: NodeID, str_id: StrID, index: IndexType) -> Self { + const fn new(node: u32, str_id: u32, index: u32) -> Self { Self { node, str_id, @@ -117,8 +110,8 @@ impl ReferencePoint { /// ``` #[derive(Debug)] pub struct GeneralizedSuffixTree { - node_storage: Vec32, - str_storage: Vec32, + node_storage: Vec, + str_storage: Vec, } impl Default for GeneralizedSuffixTree { @@ -131,10 +124,10 @@ impl Default for GeneralizedSuffixTree { root.suffix_link = SINK; sink.suffix_link = ROOT; - let node_storage: Vec32 = vec32![root, sink]; + let node_storage: Vec = vec![root, sink]; Self { node_storage, - str_storage: Vec32::new(), + str_storage: Vec::new(), } } } @@ -149,7 +142,7 @@ impl GeneralizedSuffixTree { pub fn add_string(&mut self, mut s: String, term: char) { self.validate_string(&s, term); - let str_id = self.str_storage.len() as StrID; + let str_id = self.str_storage.len() as u32; // Add a unique terminator character to the end of the string. s.push(term); @@ -159,7 +152,7 @@ impl GeneralizedSuffixTree { } fn validate_string(&self, s: &str, term: char) { - assert!(s.len() <= IndexType::max_value() as usize); + assert!(s.len() <= u32::max_value() as usize); assert!(term.is_ascii(), "Only accept ASCII terminator"); assert!( !s.contains(term), @@ -175,7 +168,7 @@ impl GeneralizedSuffixTree { /// Find the longest common substring among all strings in the suffix. /// This function can be used when you already have a suffix tree built, - /// and would need to know the longest commmon substring. + /// and would need to know the longest common substring. /// It can be trivially extended to support longest common substring among /// `K` strings. #[must_use] @@ -184,13 +177,13 @@ impl GeneralizedSuffixTree { // prev_node stores the most recent occurance of a leaf that belongs to each string. // We use the terminator character (which uniquely represents a string) as the key. - let mut prev_node: HashMap = HashMap::new(); + let mut prev_node: HashMap = HashMap::new(); // lca_cnt[v] means the total number of times that the lca of two nodes is node v. - let mut lca_cnt: Vec32 = vec32![0; self.node_storage.len()]; + let mut lca_cnt: Vec = vec![0; self.node_storage.len()]; - let mut longest_str: (Vec32<&MappedSubstring>, IndexType) = (Vec32::new(), 0); - let mut cur_str: (Vec32<&MappedSubstring>, IndexType) = (Vec32::new(), 0); + let mut longest_str: (Vec<&MappedSubstring>, u32) = (Vec::new(), 0); + let mut cur_str: (Vec<&MappedSubstring>, u32) = (Vec::new(), 0); self.longest_common_substring_all_rec( &mut disjoint_set, &mut prev_node, @@ -202,7 +195,7 @@ impl GeneralizedSuffixTree { let mut result = String::new(); for s in longest_str.0 { - result.push_str(self.get_string_slice_short(s)); + result.push_str(&self.get_string_slice_short(s)); } result } @@ -225,11 +218,11 @@ impl GeneralizedSuffixTree { fn longest_common_substring_all_rec<'a>( &'a self, disjoint_set: &mut disjoint_set::DisjointSet, - prev_node: &mut HashMap, - lca_cnt: &mut Vec32, - node: NodeID, - longest_str: &mut (Vec32<&'a MappedSubstring>, IndexType), - cur_str: &mut (Vec32<&'a MappedSubstring>, IndexType), + prev_node: &mut HashMap, + lca_cnt: &mut Vec, + node: u32, + longest_str: &mut (Vec<&'a MappedSubstring>, u32), + cur_str: &mut (Vec<&'a MappedSubstring>, u32), ) -> (usize, usize) { let mut total_leaf = 0; let mut total_correction = 0; @@ -238,7 +231,7 @@ impl GeneralizedSuffixTree { continue; } let slice = &self.get_node(*target_node).substr; - if slice.end as usize == self.get_string(slice.str_id).len() { + if slice.end as usize == self.get_string(slice.str_id).chars().count() { // target_node is a leaf node. total_leaf += 1; let last_ch = self.get_char(slice.str_id, slice.end - 1); @@ -270,7 +263,7 @@ impl GeneralizedSuffixTree { total_correction += lca_cnt[node as usize]; let unique_str_cnt = total_leaf - total_correction; if unique_str_cnt == self.str_storage.len() { - // This node represnets a substring that is common among all strings. + // This node represents a substring that is common among all strings. if cur_str.1 > longest_str.1 { *longest_str = cur_str.clone(); } @@ -282,13 +275,13 @@ impl GeneralizedSuffixTree { /// This function allows us compute this without adding `s` to the suffix. #[must_use] pub fn longest_common_substring_with<'a>(&self, s: &'a str) -> &'a str { - let mut longest_start: IndexType = 0; - let mut longest_len: IndexType = 0; - let mut cur_start: IndexType = 0; - let mut cur_len: IndexType = 0; - let mut node: NodeID = ROOT; + let mut longest_start: u32 = 0; + let mut longest_len: u32 = 0; + let mut cur_start: u32 = 0; + let mut cur_len: u32 = 0; + let mut node: u32 = ROOT; - let chars = s.as_bytes(); + let chars = s.chars().collect::>(); let mut index = 0; let mut active_length = 0; while index < chars.len() { @@ -370,14 +363,15 @@ impl GeneralizedSuffixTree { #[must_use] fn is_suffix_or_substr(&self, s: &str, check_substr: bool) -> bool { for existing_str in &self.str_storage { + let ch = existing_str.chars().last().unwrap(); assert!( - !s.contains(existing_str.chars().last().unwrap()), + !s.contains(ch), "Queried string cannot contain terminator char" ); } let mut node = ROOT; let mut index = 0; - let chars = s.as_bytes(); + let chars = s.chars().collect::>(); while index < s.len() { let target_node = self.transition(node, chars[index]); if target_node == INVALID { @@ -402,7 +396,7 @@ impl GeneralizedSuffixTree { // to look up in the current transitions to determine if we have // reached the end of any string. If needed, we are also able to // return which string the queried string is a suffix of. - if self.transition(node, *s.as_bytes().last().unwrap()) != INVALID { + if self.transition(node, s.chars().last().unwrap()) != INVALID { is_suffix = true; break; } @@ -415,7 +409,7 @@ impl GeneralizedSuffixTree { self.print_recursive(ROOT, 0); } - fn print_recursive(&self, node: NodeID, space_count: u32) { + fn print_recursive(&self, node: u32, space_count: u32) { for target_node in self.get_node(node).transitions.values() { if *target_node == INVALID { continue; @@ -432,18 +426,17 @@ impl GeneralizedSuffixTree { } } - fn process_suffixes(&mut self, str_id: StrID) { + fn process_suffixes(&mut self, str_id: u32) { let mut active_point = ReferencePoint::new(ROOT, str_id, 0); - for i in 0..self.get_string(str_id).len() { - let mut cur_str = - MappedSubstring::new(str_id, active_point.index, (i + 1) as IndexType); + for i in 0..self.get_string(str_id).chars().count() { + let mut cur_str = MappedSubstring::new(str_id, active_point.index, (i + 1) as u32); active_point = self.update(active_point.node, &cur_str); cur_str.start = active_point.index; active_point = self.canonize(active_point.node, &cur_str); } } - fn update(&mut self, node: NodeID, cur_str: &MappedSubstring) -> ReferencePoint { + fn update(&mut self, node: u32, cur_str: &MappedSubstring) -> ReferencePoint { assert!(!cur_str.is_empty()); let mut cur_str = cur_str.clone(); @@ -461,7 +454,7 @@ impl GeneralizedSuffixTree { let mut is_endpoint = self.test_and_split(node, &split_str, last_ch, &mut r); while !is_endpoint { - let str_len = self.get_string(active_point.str_id).len() as IndexType; + let str_len = self.get_string(active_point.str_id).chars().count() as u32; let leaf_node = self.create_node_with_slice(active_point.str_id, cur_str.end - 1, str_len); self.set_transition(r, last_ch, leaf_node); @@ -483,10 +476,10 @@ impl GeneralizedSuffixTree { fn test_and_split( &mut self, - node: NodeID, + node: u32, split_str: &MappedSubstring, - ch: CharType, - r: &mut NodeID, + ch: char, + r: &mut u32, ) -> bool { if split_str.is_empty() { *r = node; @@ -513,7 +506,7 @@ impl GeneralizedSuffixTree { false } - fn canonize(&mut self, mut node: NodeID, cur_str: &MappedSubstring) -> ReferencePoint { + fn canonize(&mut self, mut node: u32, cur_str: &MappedSubstring) -> ReferencePoint { let mut cur_str = cur_str.clone(); loop { if cur_str.is_empty() { @@ -536,41 +529,39 @@ impl GeneralizedSuffixTree { ReferencePoint::new(node, cur_str.str_id, cur_str.start) } - fn create_node_with_slice( - &mut self, - str_id: StrID, - start: IndexType, - end: IndexType, - ) -> NodeID { + fn create_node_with_slice(&mut self, str_id: u32, start: u32, end: u32) -> u32 { let node = Node::new(str_id, start, end); self.node_storage.push(node); - (self.node_storage.len() - 1) as NodeID + (self.node_storage.len() - 1) as u32 } - fn get_node(&self, node_id: NodeID) -> &Node { + fn get_node(&self, node_id: u32) -> &Node { &self.node_storage[node_id as usize] } - fn get_node_mut(&mut self, node_id: NodeID) -> &mut Node { + fn get_node_mut(&mut self, node_id: u32) -> &mut Node { &mut self.node_storage[node_id as usize] } - fn get_string(&self, str_id: StrID) -> &str { + fn get_string(&self, str_id: u32) -> &str { &self.str_storage[str_id as usize] } - fn get_string_slice(&self, str_id: StrID, start: IndexType, end: IndexType) -> &str { - &self.get_string(str_id)[start as usize..end as usize] + fn get_string_slice(&self, str_id: u32, start: u32, end: u32) -> String { + self.get_string(str_id) + .chars() + .skip(start as usize) + .take((end - start) as usize) + .collect::() } - fn get_string_slice_short(&self, slice: &MappedSubstring) -> &str { + fn get_string_slice_short(&self, slice: &MappedSubstring) -> String { self.get_string_slice(slice.str_id, slice.start, slice.end) } - fn transition(&self, node: NodeID, ch: CharType) -> NodeID { + fn transition(&self, node: u32, ch: char) -> u32 { if node == SINK { - // SINK always transition to ROOT. return ROOT; } match self.get_node(node).transitions.get(&ch) { @@ -579,12 +570,16 @@ impl GeneralizedSuffixTree { } } - fn set_transition(&mut self, node: NodeID, ch: CharType, target_node: NodeID) { + fn set_transition(&mut self, node: u32, ch: char, target_node: u32) { self.get_node_mut(node).transitions.insert(ch, target_node); } - fn get_char(&self, str_id: StrID, index: IndexType) -> u8 { + fn get_char(&self, str_id: u32, index: u32) -> char { assert!((index as usize) < self.get_string(str_id).len()); - self.get_string(str_id).as_bytes()[index as usize] + if let Some(ch) = self.get_string(str_id).chars().nth(index as usize) { + ch + } else { + panic!("{}, {}", self.get_string(str_id), index) + } } } diff --git a/tests/integration_test.rs b/tests/integration_test.rs index 9bfc09d..9974864 100644 --- a/tests/integration_test.rs +++ b/tests/integration_test.rs @@ -143,4 +143,15 @@ mod tests { assert_eq!(result1.len(), result3.len()); } } + + #[test] + fn test_longest_common_substring_all_unicode() { + { + let mut tree = generalized_suffix_tree::GeneralizedSuffixTree::new(); + tree.add_string(String::from("我们爱在大自然"), '$'); + tree.add_string(String::from("爱大自然里撒欢"), '#'); + // tree.pretty_print(); + assert_eq!(tree.longest_common_substring_all(), "大自然"); + } + } }