From 3b5f55d235c065d3914d79831dbc76c0baaa1b17 Mon Sep 17 00:00:00 2001 From: Igor Shishkin Date: Sun, 30 Jun 2024 11:20:32 +0300 Subject: [PATCH] Optimize hash calculation Signed-off-by: Igor Shishkin --- database/legacy/database.go | 9 +- go.mod | 2 + go.sum | 4 + main.go | 149 ++++++++++++++------------------ operations.go | 65 ++++++-------- utils/concurrent/writer.go | 55 ++++++++++++ utils/concurrent/writer_test.go | 61 +++++++++++++ 7 files changed, 217 insertions(+), 128 deletions(-) create mode 100644 utils/concurrent/writer.go create mode 100644 utils/concurrent/writer_test.go diff --git a/database/legacy/database.go b/database/legacy/database.go index b91267c..397d306 100644 --- a/database/legacy/database.go +++ b/database/legacy/database.go @@ -9,6 +9,7 @@ import ( "time" "github.com/fatih/color" + "github.com/pkg/errors" ) // DataObject is a file object in JSON database @@ -52,25 +53,25 @@ func NewDatabase(path string) (*Database, error) { Data: make(map[string]*DataObject), }) if err != nil { - return nil, fmt.Errorf("Error marshaling initial JSON: %s", err) + return nil, errors.Errorf("Error marshaling initial JSON: %s", err) } err = ioutil.WriteFile(path, js, 0644) if err != nil { - return nil, fmt.Errorf("Error creating schema: %s", err) + return nil, errors.Errorf("Error creating schema: %s", err) } } fp, err := os.Open(path) if err != nil { - return nil, fmt.Errorf("Error opening file: %s", err) + return nil, errors.Errorf("Error opening file: %s", err) } defer fp.Close() decoder := json.NewDecoder(fp) err = decoder.Decode(&database.Schema) if err != nil { - return nil, fmt.Errorf("Error decoding JSON data: %s", err) + return nil, errors.Errorf("Error decoding JSON data: %s", err) } return &database, nil diff --git a/go.mod b/go.mod index bae2449..7d9a5c9 100644 --- a/go.mod +++ b/go.mod @@ -6,6 +6,8 @@ require ( github.com/cosiner/flag v0.5.2 github.com/fatih/color v1.17.0 github.com/mattn/go-runewidth v0.0.15 // indirect + github.com/pkg/errors v0.9.1 github.com/stretchr/testify v1.9.0 + golang.org/x/sync v0.7.0 gopkg.in/cheggaaa/pb.v1 v1.0.28 ) diff --git a/go.sum b/go.sum index 1ce5094..bfb253d 100644 --- a/go.sum +++ b/go.sum @@ -14,6 +14,8 @@ github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWE github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= github.com/mattn/go-runewidth v0.0.15 h1:UNAjwbU9l54TA3KzvqLGxwWjHmMgBUVhBiTjelZgg3U= github.com/mattn/go-runewidth v0.0.15/go.mod h1:Jdepj2loyihRzMpdS35Xk/zdY8IAYHsh153qUoGf23w= +github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= +github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/rivo/uniseg v0.2.0 h1:S1pD9weZBuJdFmowNwbpi7BJ8TNftyUImj/0WQi72jY= @@ -27,6 +29,8 @@ github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg= github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= +golang.org/x/sync v0.7.0 h1:YsImfSBoP9QPYL0xyKJPq0gcaJdG3rInoqxTWbfQu9M= +golang.org/x/sync v0.7.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.18.0 h1:DBdB3niSjOA/O0blCZBqDefyWNYveAYMNF1Wum0DYQ4= diff --git a/main.go b/main.go index 98238e5..59b92e7 100644 --- a/main.go +++ b/main.go @@ -1,7 +1,6 @@ package main import ( - "bytes" "flag" "fmt" "log" @@ -9,19 +8,17 @@ import ( "path/filepath" "regexp" "sort" - "sync" "sync/atomic" "time" "github.com/fatih/color" + "golang.org/x/sync/errgroup" "gopkg.in/cheggaaa/pb.v1" database "github.com/teran/checksum/database/legacy" ) var ( - wg sync.WaitGroup - appVersion = "No version specified(probably trunk build)" buildTimestamp = "0000-00-00T00:00:00Z" @@ -60,7 +57,6 @@ func main() { } if !cfg.GenerateChecksumOnly { - sem := make(chan bool, cfg.Concurrency) var bar *pb.ProgressBar if cfg.Progressbar { bar = pb.New(db.Count()) @@ -80,94 +76,83 @@ func main() { } sort.Strings(keys) - for _, key := range keys { - sem <- true - wg.Add(1) - go func(file string, obj *database.DataObject) { - if cfg.Progressbar { - defer func() { - bar.Increment() - }() - } - defer func() { - <-sem - }() - defer wg.Done() - - if _, err := os.Stat(file); os.IsNotExist(err) { - if !cfg.SkipMissed { - fmt.Printf("%s %s\n", color.RedString("[MISS]"), file) - } + wg := &errgroup.Group{} + wg.SetLimit(cfg.Concurrency) - if cfg.DeleteMissed { - fmt.Printf("%s DeleteMissed requested: deleting file `%s` from database\n", color.BlueString("[NOTE]"), file) - db.DeleteOne(file) - atomic.AddUint64(&cntDeleted, 1) + for _, key := range keys { + wg.Go(func(file string, obj *database.DataObject) func() error { + return func() error { + if cfg.Progressbar { + defer func() { + bar.Increment() + }() } - atomic.AddUint64(&cntMissed, 1) - return - } + if _, err := os.Stat(file); os.IsNotExist(err) { + if !cfg.SkipMissed { + fmt.Printf("%s %s\n", color.RedString("[MISS]"), file) + } - isChanged := false + if cfg.DeleteMissed { + fmt.Printf("%s DeleteMissed requested: deleting file `%s` from database\n", color.BlueString("[NOTE]"), file) + db.DeleteOne(file) + atomic.AddUint64(&cntDeleted, 1) + } - if obj.Length == 0 { - obj.Length = flength(file) - isChanged = true - } + atomic.AddUint64(&cntMissed, 1) + return nil + } - data, err := readFile(file) - if err != nil { - log.Fatalf("error reading data: %s", err) - } + isChanged := false - if obj.SHA1 == "" { - obj.SHA1, err = SHA1(bytes.NewReader(data)) - if err != nil { - log.Fatalf("error calculating SHA1: %s", err) + if obj.Length == 0 { + obj.Length = flength(file) + isChanged = true } - isChanged = true - } + if obj.SHA1 == "" || obj.SHA256 == "" { + sha1, sha256, err := generateActualChecksum(file) + if err != nil { + return err + } - if obj.SHA256 == "" { - obj.SHA256, err = SHA256(bytes.NewReader(data)) - if err != nil { - log.Fatalf("error calculating SHA256: %s", err) - } + obj.SHA1 = sha1 + obj.SHA256 = sha256 - isChanged = true - } + isChanged = true + } - res := verify(file, obj.Length, obj.SHA1, obj.SHA256) + res := verify(file, obj.Length, obj.SHA1, obj.SHA256) - if isChanged { - db.WriteOne(file, &database.DataObject{ - Length: obj.Length, - SHA1: obj.SHA1, - SHA256: obj.SHA256, - Modified: time.Now().UTC(), - }) - } + if isChanged { + db.WriteOne(file, &database.DataObject{ + Length: obj.Length, + SHA1: obj.SHA1, + SHA256: obj.SHA256, + Modified: time.Now().UTC(), + }) + } - if res { - if !cfg.SkipOk { - fmt.Printf("%s %s\n", color.GreenString("[ OK ]"), file) + if res { + if !cfg.SkipOk { + fmt.Printf("%s %s\n", color.GreenString("[ OK ]"), file) + } + atomic.AddUint64(&cntPassed, 1) + return nil } - atomic.AddUint64(&cntPassed, 1) - return - } - if !cfg.SkipFailed { - fmt.Printf("%s %s\n", color.RedString("[FAIL]"), file) + if !cfg.SkipFailed { + fmt.Printf("%s %s\n", color.RedString("[FAIL]"), file) + } + atomic.AddUint64(&cntFailed, 1) + return nil } - atomic.AddUint64(&cntFailed, 1) - }(key, objects[key]) + }(key, objects[key])) } - for i := 0; i < cap(sem); i++ { - sem <- true + err = wg.Wait() + if err != nil { + log.Fatalf("error handling threads") } - wg.Wait() if cfg.Progressbar { bar.Finish() @@ -179,24 +164,16 @@ func main() { if cfg.DataDir != "" { fmt.Printf("%s Checking for new files on %s\n", color.CyanString("[INFO]"), cfg.DataDir) + // TODO: check data dir for existence + err = filepath.Walk(cfg.DataDir, func(path string, info os.FileInfo, err error) error { if info.IsDir() { return nil } if isApplicable(path) { - data, err := readFile(path) - if err != nil { - log.Fatalf("error reading file: %s", err) - } - - sha1, err := SHA1(bytes.NewReader(data)) - if err != nil { - log.Fatalf("error calculating SHA1: %s", err) - } - - sha256, err := SHA256(bytes.NewReader(data)) + sha1, sha256, err := generateActualChecksum(path) if err != nil { - log.Fatalf("error calculating SHA256: %s", err) + return err } db.WriteOne(path, &database.DataObject{ diff --git a/operations.go b/operations.go index a5e1d5f..321a971 100644 --- a/operations.go +++ b/operations.go @@ -1,17 +1,19 @@ package main import ( - "bytes" + "context" "crypto/sha1" "crypto/sha256" + "encoding/hex" "fmt" "io" - "io/ioutil" "log" "os" "path/filepath" "runtime" "strings" + + "github.com/teran/checksum/utils/concurrent" ) func completeArgs(word string) { @@ -28,63 +30,50 @@ func completeArgs(word string) { }, " ")) } -func readFile(fn string) ([]byte, error) { - fp, err := os.Open(fn) +func flength(filename string) int64 { + stat, err := os.Stat(filename) if err != nil { - return nil, err + log.Fatal(err) } - defer fp.Close() - return ioutil.ReadAll(fp) + return stat.Size() } -// SHA256 ... -func SHA256(rd io.Reader) (string, error) { - h := sha256.New() - _, err := io.Copy(h, rd) +func generateActualChecksum(filename string) (sha1sum string, sha256sum string, err error) { + fi, err := os.Stat(filename) if err != nil { - return "", err + return "", "", err } - return fmt.Sprintf("%x", h.Sum(nil)), nil -} - -// SHA1 ... -func SHA1(rd io.Reader) (string, error) { - h := sha1.New() - _, err := io.Copy(h, rd) + fp, err := os.Open(filename) if err != nil { - return "", err + return "", "", err } + defer fp.Close() - return fmt.Sprintf("%x", h.Sum(nil)), nil -} + sha1hasher := sha1.New() + sha256hasher := sha256.New() -func flength(filename string) int64 { - stat, err := os.Stat(filename) + w, err := concurrent.NewConcurrentMultiWriter(context.TODO(), sha1hasher, sha256hasher) if err != nil { - log.Fatal(err) + return "", "", err } - return stat.Size() -} - -func verify(path string, length int64, sha1, sha256 string) bool { - data, err := readFile(path) + n, err := io.Copy(w, fp) if err != nil { - log.Printf("error reading file: %s", err) - return false + return "", "", err } - actSHA1, err := SHA1(bytes.NewReader(data)) - if err != nil { - log.Printf("error calculating SHA1: %s", err) - return false + if n != fi.Size() { + return "", "", io.ErrShortWrite } - actSHA256, err := SHA256(bytes.NewReader(data)) + return hex.EncodeToString(sha1hasher.Sum(nil)), hex.EncodeToString(sha256hasher.Sum(nil)), nil +} + +func verify(path string, length int64, sha1, sha256 string) bool { + actSHA1, actSHA256, err := generateActualChecksum(path) if err != nil { - log.Printf("error calculating SHA256: %s", err) return false } diff --git a/utils/concurrent/writer.go b/utils/concurrent/writer.go new file mode 100644 index 0000000..9ea5651 --- /dev/null +++ b/utils/concurrent/writer.go @@ -0,0 +1,55 @@ +package concurrent + +import ( + "context" + "io" + "runtime" + + "github.com/pkg/errors" + "golang.org/x/sync/errgroup" +) + +type writer struct { + ctx context.Context + writers []io.Writer +} + +func (cw *writer) Write(p []byte) (n int, err error) { + g, _ := errgroup.WithContext(cw.ctx) + g.SetLimit(runtime.NumCPU()) + + for idx, w := range cw.writers { + g.Go(func(idx int, w io.Writer) func() error { + return func() error { + n, err := w.Write(p) + if err != nil { + return errors.Wrapf(err, "error writing to channel #%d", idx) + } + + if n != len(p) { + return errors.Wrapf(io.ErrShortWrite, "error writing to channel #%d", idx) + } + return nil + } + }(idx, w)) + } + + return len(p), g.Wait() +} + +func NewConcurrentMultiWriter(ctx context.Context, writers ...io.Writer) (io.Writer, error) { + w := make([]io.Writer, len(writers)) + + n := copy(w, writers) + if n != len(writers) { + return nil, errors.Errorf( + "unexpected copy amount: expected %d copied %d. Looks like internal error or memory corruption", + len(writers), n, + ) + } + + return &writer{ + ctx: ctx, + writers: w, + }, nil +} diff --git a/utils/concurrent/writer_test.go b/utils/concurrent/writer_test.go new file mode 100644 index 0000000..19c5426 --- /dev/null +++ b/utils/concurrent/writer_test.go @@ -0,0 +1,61 @@ +package concurrent + +import ( + "bytes" + "context" + "fmt" + "io" + "testing" + + "github.com/stretchr/testify/require" + "github.com/stretchr/testify/suite" +) + +func (s *concurrentWriterTestSuite) TestWrite() { + sample := func() []byte { + out := make([]byte, 1024) + for i := range out { + out[i] = 0xaf + } + return out + }() + + writers := func() []*bytes.Buffer { + out := make([]*bytes.Buffer, 10) + for i := range out { + out[i] = &bytes.Buffer{} + } + return out + }() + + w, err := NewConcurrentMultiWriter(context.TODO(), func(buffers ...*bytes.Buffer) (out []io.Writer) { + for _, wr := range buffers { + out = append(out, wr) + } + return + }(writers...)...) + s.Require().NoError(err) + + n, err := w.Write(sample) + s.Require().NoError(err) + s.Require().Equal(len(sample), n) + + for i, buf := range writers { + s.T().Run(fmt.Sprintf("buffer #%d", i), func(t *testing.T) { + r := require.New(t) + + r.Equal(sample, buf.Bytes()) + }) + } +} + +// ======================================================================== +// Test suite setup +// ======================================================================== +type concurrentWriterTestSuite struct { + suite.Suite +} + +func TestConcurrentWriterTestSuite(t *testing.T) { + suite.Run(t, &concurrentWriterTestSuite{}) +}