diff --git a/trie.go b/trie.go index c59bb40..46e8bf2 100644 --- a/trie.go +++ b/trie.go @@ -3,6 +3,7 @@ package trie import ( "bufio" "bytes" + "encoding/base64" "encoding/gob" "errors" "fmt" @@ -142,6 +143,60 @@ func (t *Trie) PrintDump() { t.Root.PrintDump() } +/* +ToBase64String returns a string representing the Base64 encoded Trie +*/ +func (t *Trie) ToBase64String() (encoded string, err error) { + + t.Root.Lock() + entries := t.Members() + t.Root.Unlock() + + encoded = "" + + buf := new(bytes.Buffer) + enc := gob.NewEncoder(buf) + if err = enc.Encode(entries); err != nil { + err = errors.New(fmt.Sprintf("Could encode Trie entries for base 64 encoding: %v", err)) + return + } + + encoded = base64.StdEncoding.EncodeToString(buf.Bytes()) + + return +} + +/* +FromBase64String returns a new Trie from a Base64 encoded version +*/ +func FromBase64String(encoded string) (tr *Trie, err error) { + + tr = NewTrie() + decoded, err := base64.StdEncoding.DecodeString(encoded) + + entries := new([]*MemberInfo) + + buf := bytes.NewReader(decoded) + dec := gob.NewDecoder(buf) + if err = dec.Decode(entries); err != nil { + if err == io.EOF && entries == nil { + log.Println("Nothing to decode. Seems the file is empty.") + err = nil + } else { + err = errors.New(fmt.Sprintf("Decoding error: %v", err)) + return + } + } + + for _, mi := range *entries { + b := tr.Add(mi.Value) + b.Count = mi.Count + } + + return + +} + /* DumpToFile dumps all values into a slice of strings and writes that to a file using encoding/gob. diff --git a/trie_test.go b/trie_test.go index 72af81a..1c7a133 100644 --- a/trie_test.go +++ b/trie_test.go @@ -936,6 +936,71 @@ func TestTrieDumpToFileLoadFromFile(t *testing.T) { } } + +func TestTrieEncodeB64DecodeB64(t *testing.T) { + tr := NewTrie() + var prefix = "prefix" + var words []string + var str []byte + var insert string + var n = 0 + for n < 100 { + i := 0 + str = []byte{} + for i < 10 { + rn := 0 + for rn < 97 { + rn = rand.Intn(123) + } + str = append(str, byte(rn)) + i++ + } + if rand.Intn(2) == 1 { + insert = prefix + string(str) + } else { + insert = string(str) + } + words = append(words, insert) + tr.Add(insert) + if rand.Intn(2) == 1 { + tr.Add(insert) + } + n++ + } + encoded, err := tr.ToBase64String() + + loadedTrie, err := FromBase64String(encoded) + if err != nil { + t.Errorf("Failed to load Trie from encoded string") + } + for _, w := range words { + // t.Logf("Checking for %s", w) + if !loadedTrie.Has(w) { + t.Errorf("Expected to find %s", w) + } + } + + trMembers := set.NewStringSet(tr.MembersList()...) + loadedTrieMembers := set.NewStringSet(loadedTrie.MembersList()...) + + t.Log("trMembers.IsEqual(loadedTrieMembers):", trMembers.IsEqual(loadedTrieMembers)) + + diff := trMembers.Difference(loadedTrieMembers) + if diff.Len() > 0 { + t.Error("Dump() of the original and the LoadFromFile() version of the Trie are different.") + } + + // check counts + for _, mi := range tr.Members() { + _, count := loadedTrie.HasCount(mi.Value) + if count != mi.Count { + t.Errorf("Count for member %s differs: orig was %v, restored trie has %v", mi.Value, mi.Count, count) + } + } + +} + + func TestTrieLoadFromFileEmpty(t *testing.T) { loadedTrie, err := LoadFromFile("testfiles/empty") if err != nil {