From 98f4e45de3ae878d694ef417977a395b4880fcfb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?M=C3=A1ximo=20Cuadros?= Date: Thu, 23 Apr 2020 21:16:20 +0200 Subject: [PATCH] make RegExp thread-safe by default and code cleanup --- regex.go | 242 +++++++++++++++++++++++++++++++++++++------------------ 1 file changed, 164 insertions(+), 78 deletions(-) diff --git a/regex.go b/regex.go index cbb647c..00f0c0e 100644 --- a/regex.go +++ b/regex.go @@ -14,7 +14,6 @@ import ( "errors" "fmt" "io" - "log" "runtime" "strconv" "sync" @@ -45,24 +44,40 @@ type Regexp struct { errorBuf *C.char matchData *MatchData namedGroupInfo NamedGroupInfo + mutex *sync.Mutex } // NewRegexp creates and initializes a new Regexp with the given pattern and option. -func NewRegexp(pattern string, option int) (re *Regexp, err error) { +func NewRegexp(pattern string, option int) (*Regexp, error) { + re, err := initRegexp(&Regexp{pattern: pattern, encoding: C.ONIG_ENCODING_UTF8}, option) + if err != nil { + return nil, err + } + + re.mutex = new(sync.Mutex) + return re, nil +} + +// NewRegexpNonThreadsafe creates and initializes a new Regexp with the given +// pattern and option. The resulting regexp is not thread-safe. +func NewRegexpNonThreadsafe(pattern string, option int) (*Regexp, error) { return initRegexp(&Regexp{pattern: pattern, encoding: C.ONIG_ENCODING_UTF8}, option) } // NewRegexpASCII is equivalent to NewRegexp, but with the encoding restricted to ASCII. -func NewRegexpASCII(pattern string, option int) (re *Regexp, err error) { +func NewRegexpASCII(pattern string, option int) (*Regexp, error) { return initRegexp(&Regexp{pattern: pattern, encoding: C.ONIG_ENCODING_ASCII}, option) } func initRegexp(re *Regexp, option int) (*Regexp, error) { var err error + patternCharPtr := C.CString(re.pattern) defer C.free(unsafe.Pointer(patternCharPtr)) + mutex.Lock() defer mutex.Unlock() + errorCode := C.NewOnigRegex(patternCharPtr, C.int(len(re.pattern)), C.int(option), &re.regex, &re.region, &re.encoding, &re.errorInfo, &re.errorBuf) if errorCode != C.ONIG_NORMAL { err = errors.New(C.GoString(re.errorBuf)) @@ -77,6 +92,7 @@ func initRegexp(re *Regexp, option int) (*Regexp, error) { re.namedGroupInfo = re.getNamedGroupInfo() runtime.SetFinalizer(re, (*Regexp).Free) } + return re, err } @@ -113,7 +129,22 @@ func MustCompileASCII(str string) *Regexp { return regexp } +func (re *Regexp) lock() { + if re.mutex != nil { + re.mutex.Lock() + } +} + +func (re *Regexp) unlock() { + if re.mutex != nil { + re.mutex.Unlock() + } +} + func (re *Regexp) Free() { + re.lock() + defer re.unlock() + mutex.Lock() if re.regex != nil { C.onig_free(re.regex) @@ -134,47 +165,56 @@ func (re *Regexp) Free() { } } -func (re *Regexp) getNamedGroupInfo() (namedGroupInfo NamedGroupInfo) { +func (re *Regexp) getNamedGroupInfo() NamedGroupInfo { numNamedGroups := int(C.onig_number_of_names(re.regex)) //when any named capture exisits, there is no numbered capture even if there are unnamed captures - if numNamedGroups > 0 { - namedGroupInfo = make(map[string]int) - //try to get the names - bufferSize := len(re.pattern) * 2 - nameBuffer := make([]byte, bufferSize) - groupNumbers := make([]int32, numNamedGroups) - bufferPtr := unsafe.Pointer(&nameBuffer[0]) - numbersPtr := unsafe.Pointer(&groupNumbers[0]) - length := int(C.GetCaptureNames(re.regex, bufferPtr, (C.int)(bufferSize), (*C.int)(numbersPtr))) - if length > 0 { - namesAsBytes := bytes.Split(nameBuffer[:length], ([]byte)(";")) - if len(namesAsBytes) != numNamedGroups { - log.Fatalf("the number of named groups (%d) does not match the number names found (%d)\n", numNamedGroups, len(namesAsBytes)) - } - for i, nameAsBytes := range namesAsBytes { - name := string(nameAsBytes) - namedGroupInfo[name] = int(groupNumbers[i]) - } - } else { - log.Fatalf("could not get the capture group names from %q", re.String()) - } + if numNamedGroups == 0 { + return nil } - return + + namedGroupInfo := make(map[string]int) + + //try to get the names + bufferSize := len(re.pattern) * 2 + nameBuffer := make([]byte, bufferSize) + groupNumbers := make([]int32, numNamedGroups) + bufferPtr := unsafe.Pointer(&nameBuffer[0]) + numbersPtr := unsafe.Pointer(&groupNumbers[0]) + + length := int(C.GetCaptureNames(re.regex, bufferPtr, (C.int)(bufferSize), (*C.int)(numbersPtr))) + if length == 0 { + panic(fmt.Errorf("could not get the capture group names from %q", re.String())) + } + + namesAsBytes := bytes.Split(nameBuffer[:length], ([]byte)(";")) + if len(namesAsBytes) != numNamedGroups { + panic(fmt.Errorf( + "the number of named groups (%d) does not match the number names found (%d)", + numNamedGroups, len(namesAsBytes), + )) + } + + for i, nameAsBytes := range namesAsBytes { + name := string(nameAsBytes) + namedGroupInfo[name] = int(groupNumbers[i]) + } + + return namedGroupInfo } -func (re *Regexp) groupNameToId(name string) (id int) { +func (re *Regexp) groupNameToId(name string) int { if re.namedGroupInfo == nil { - id = ONIGERR_UNDEFINED_NAME_REFERENCE - } else { - id = re.namedGroupInfo[name] + return ONIGERR_UNDEFINED_NAME_REFERENCE } - return + + return re.namedGroupInfo[name] } -func (re *Regexp) processMatch(numCaptures int) (match []int32) { +func (re *Regexp) processMatch(numCaptures int) []int32 { if numCaptures <= 0 { panic("cannot have 0 captures when processing a match") } + matchData := re.matchData return matchData.indexes[matchData.count][:numCaptures*2] } @@ -184,10 +224,16 @@ func (re *Regexp) ClearMatchData() { matchData.count = 0 } -func (re *Regexp) find(b []byte, n int, offset int) (match []int) { +func (re *Regexp) find(b []byte, n int, offset int) []int { + re.lock() + defer re.unlock() + + var match []int + if n == 0 { b = []byte{0} } + ptr := unsafe.Pointer(&b[0]) matchData := re.matchData capturesPtr := unsafe.Pointer(&(matchData.indexes[matchData.count][0])) @@ -198,17 +244,33 @@ func (re *Regexp) find(b []byte, n int, offset int) (match []int) { if numCaptures <= 0 { panic("cannot have 0 captures when processing a match") } + match2 := matchData.indexes[matchData.count][:numCaptures*2] match = make([]int, len(match2)) for i := range match2 { match[i] = int(match2[i]) } + numCapturesInPattern := int32(C.onig_number_of_captures(re.regex)) + 1 if numCapturesInPattern != numCaptures { - log.Fatalf("expected %d captures but got %d\n", numCapturesInPattern, numCaptures) + panic(fmt.Errorf("expected %d captures but got %d", numCapturesInPattern, numCaptures)) } } - return + + return re.copySlice(match) +} + +func (re *Regexp) copySlice(indices []int) (result []int) { + if re.mutex == nil { + return indices + } + + if indices != nil { + result = make([]int, len(indices)) + copy(result, indices) + } + + return result } func getCapture(b []byte, beg int, end int) []byte { @@ -219,21 +281,27 @@ func getCapture(b []byte, beg int, end int) []byte { } func (re *Regexp) match(b []byte, n int, offset int) bool { + re.lock() + defer re.unlock() + re.ClearMatchData() if n == 0 { b = []byte{0} } + ptr := unsafe.Pointer(&b[0]) pos := int(C.SearchOnigRegex((ptr), C.int(n), C.int(offset), C.int(ONIG_OPTION_DEFAULT), re.regex, re.region, re.errorInfo, (*C.char)(nil), (*C.int)(nil), (*C.int)(nil))) return pos >= 0 } -func (re *Regexp) findAll(b []byte, n int) (matches [][]int) { +func (re *Regexp) findAll(b []byte, n int) [][]int { + var matches [][]int re.ClearMatchData() if n < 0 { n = len(b) } + matchData := re.matchData offset := 0 for offset <= n { @@ -241,25 +309,28 @@ func (re *Regexp) findAll(b []byte, n int) (matches [][]int) { length := len(matchData.indexes[0]) matchData.indexes = append(matchData.indexes, make([]int32, length)) } - if match := re.find(b, n, offset); len(match) > 0 { - matchData.count += 1 - //move offset to the ending index of the current match and prepare to find the next non-overlapping match - offset = match[1] - //if match[0] == match[1], it means the current match does not advance the search. we need to exit the loop to avoid getting stuck here. - if match[0] == match[1] { - if offset < n && offset >= 0 { - //there are more bytes, so move offset by a word - _, width := utf8.DecodeRune(b[offset:]) - offset += width - } else { - //search is over, exit loop - break - } - } - } else { + + match := re.find(b, n, offset) + if len(match) == 0 { break } + + matchData.count++ + //move offset to the ending index of the current match and prepare to find the next non-overlapping match + offset = match[1] + //if match[0] == match[1], it means the current match does not advance the search. we need to exit the loop to avoid getting stuck here. + if match[0] == match[1] { + if offset < n && offset >= 0 { + //there are more bytes, so move offset by a word + _, width := utf8.DecodeRune(b[offset:]) + offset += width + } else { + //search is over, exit loop + break + } + } } + matches2 := matchData.indexes[:matchData.count] matches = make([][]int, len(matches2)) for i, v := range matches2 { @@ -268,7 +339,8 @@ func (re *Regexp) findAll(b []byte, n int) (matches [][]int) { matches[i][j] = int(v2) } } - return + + return matches } func (re *Regexp) FindIndex(b []byte) []int { @@ -277,6 +349,7 @@ func (re *Regexp) FindIndex(b []byte) []int { if len(match) == 0 { return nil } + return match[:2] } @@ -285,21 +358,21 @@ func (re *Regexp) Find(b []byte) []byte { if loc == nil { return nil } + return getCapture(b, loc[0], loc[1]) } func (re *Regexp) FindString(s string) string { - b := []byte(s) - mb := re.Find(b) + mb := re.Find([]byte(s)) if mb == nil { return "" } + return string(mb) } func (re *Regexp) FindStringIndex(s string) []int { - b := []byte(s) - return re.FindIndex(b) + return re.FindIndex([]byte(s)) } func (re *Regexp) FindAllIndex(b []byte, n int) [][]int { @@ -307,6 +380,7 @@ func (re *Regexp) FindAllIndex(b []byte, n int) [][]int { if len(matches) == 0 { return nil } + return matches } @@ -328,6 +402,7 @@ func (re *Regexp) FindAllString(s string, n int) []string { if matches == nil { return nil } + matchStrings := make([]string, 0, len(matches)) for _, match := range matches { m := getCapture(b, match[0], match[1]) @@ -342,46 +417,45 @@ func (re *Regexp) FindAllString(s string, n int) []string { } func (re *Regexp) FindAllStringIndex(s string, n int) [][]int { - b := []byte(s) - return re.FindAllIndex(b, n) -} - -func (re *Regexp) findSubmatchIndex(b []byte) (match []int) { - re.ClearMatchData() - match = re.find(b, len(b), 0) - return + return re.FindAllIndex([]byte(s), n) } func (re *Regexp) FindSubmatchIndex(b []byte) []int { - match := re.findSubmatchIndex(b) + re.ClearMatchData() + match := re.find(b, len(b), 0) if len(match) == 0 { return nil } + return match } func (re *Regexp) FindSubmatch(b []byte) [][]byte { - match := re.findSubmatchIndex(b) + match := re.FindSubmatchIndex(b) if match == nil { return nil } + length := len(match) / 2 if length == 0 { return nil } + results := make([][]byte, 0, length) for i := 0; i < length; i++ { results = append(results, getCapture(b, match[2*i], match[2*i+1])) } + return results } func (re *Regexp) FindStringSubmatch(s string) []string { b := []byte(s) - match := re.findSubmatchIndex(b) + match := re.FindSubmatchIndex(b) if match == nil { return nil } + length := len(match) / 2 if length == 0 { return nil @@ -396,12 +470,12 @@ func (re *Regexp) FindStringSubmatch(s string) []string { results = append(results, string(cap)) } } + return results } func (re *Regexp) FindStringSubmatchIndex(s string) []int { - b := []byte(s) - return re.FindSubmatchIndex(b) + return re.FindSubmatchIndex([]byte(s)) } func (re *Regexp) FindAllSubmatchIndex(b []byte, n int) [][]int { @@ -409,6 +483,7 @@ func (re *Regexp) FindAllSubmatchIndex(b []byte, n int) [][]int { if len(matches) == 0 { return nil } + return matches } @@ -417,6 +492,7 @@ func (re *Regexp) FindAllSubmatch(b []byte, n int) [][][]byte { if len(matches) == 0 { return nil } + allCapturedBytes := make([][][]byte, 0, len(matches)) for _, match := range matches { length := len(match) / 2 @@ -424,6 +500,7 @@ func (re *Regexp) FindAllSubmatch(b []byte, n int) [][][]byte { for i := 0; i < length; i++ { capturedBytes = append(capturedBytes, getCapture(b, match[2*i], match[2*i+1])) } + allCapturedBytes = append(allCapturedBytes, capturedBytes) } @@ -432,10 +509,12 @@ func (re *Regexp) FindAllSubmatch(b []byte, n int) [][][]byte { func (re *Regexp) FindAllStringSubmatch(s string, n int) [][]string { b := []byte(s) + matches := re.findAll(b, n) if len(matches) == 0 { return nil } + allCapturedStrings := make([][]string, 0, len(matches)) for _, match := range matches { length := len(match) / 2 @@ -448,14 +527,15 @@ func (re *Regexp) FindAllStringSubmatch(s string, n int) [][]string { capturedStrings = append(capturedStrings, string(cap)) } } + allCapturedStrings = append(allCapturedStrings, capturedStrings) } + return allCapturedStrings } func (re *Regexp) FindAllStringSubmatchIndex(s string, n int) [][]int { - b := []byte(s) - return re.FindAllSubmatchIndex(b, n) + return re.FindAllSubmatchIndex([]byte(s), n) } func (re *Regexp) Match(b []byte) bool { @@ -463,11 +543,13 @@ func (re *Regexp) Match(b []byte) bool { } func (re *Regexp) MatchString(s string) bool { - b := []byte(s) - return re.Match(b) + return re.Match([]byte(s)) } func (re *Regexp) NumSubexp() int { + re.lock() + defer re.unlock() + return (int)(C.onig_number_of_captures(re.regex)) } @@ -532,6 +614,7 @@ func (re *Regexp) replaceAll(src, repl []byte, replFunc func([]byte, []byte, map if len(matches) == 0 { return src } + dest := make([]byte, 0, srcLen) for i, match := range matches { length := len(match) / 2 @@ -579,14 +662,15 @@ func (re *Regexp) ReplaceAllString(src, repl string) string { } func (re *Regexp) ReplaceAllStringFunc(src string, repl func(string) string) string { - srcB := []byte(src) - destB := re.replaceAll(srcB, []byte(""), func(_ []byte, matchBytes []byte, _ map[string][]byte) []byte { + return string(re.replaceAll([]byte(src), []byte(""), func(_ []byte, matchBytes []byte, _ map[string][]byte) []byte { return []byte(repl(string(matchBytes))) - }) - return string(destB) + })) } func (re *Regexp) String() string { + re.lock() + defer re.unlock() + return re.pattern } @@ -651,6 +735,7 @@ func (re *Regexp) Gsub(src, repl string) string { srcBytes := ([]byte)(src) replBytes := ([]byte)(repl) replaced := re.replaceAll(srcBytes, replBytes, fillCapturedValues) + return string(replaced) } @@ -664,5 +749,6 @@ func (re *Regexp) GsubFunc(src string, replFunc func(string, map[string]string) matchString := string(matchBytes) return ([]byte)(replFunc(matchString, capturedStrings)) }) + return string(replaced) }