Skip to content

Commit

Permalink
Add CopyFrom / CopyTo to fsys.Files (#151)
Browse files Browse the repository at this point in the history
Signed-off-by: Kimmo Lehto <[email protected]>
  • Loading branch information
kke authored Jan 8, 2024
1 parent 10f783d commit 7642b01
Show file tree
Hide file tree
Showing 7 changed files with 167 additions and 14 deletions.
18 changes: 18 additions & 0 deletions pkg/rigfs/bytecounter.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
package rigfs

// ByteCounter is a simple io.Writer that counts the number of bytes written to it, to be used in
// conjunction with io.MultiWriter / io.TeeReader
type ByteCounter struct {
count int64
}

// Write implements io.Writer
func (bc *ByteCounter) Write(p []byte) (int, error) {
bc.count += int64(len(p))
return len(p), nil
}

// Count returns the number of bytes written to the ByteCounter
func (bc *ByteCounter) Count() int64 {
return bc.count
}
44 changes: 36 additions & 8 deletions pkg/rigfs/posixfsys.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ import (

var (
_ fs.File = (*PosixFile)(nil)
_ File = (*PosixFile)(nil)
_ fs.ReadDirFile = (*PosixDir)(nil)
_ fs.FS = (*PosixFsys)(nil)
errInvalid = errors.New("invalid")
Expand Down Expand Up @@ -199,26 +200,53 @@ func (f *PosixFile) Write(p []byte) (int, error) {
return written, nil
}

// Copy copies the remote file at src to the local file at dst
func (f *PosixFile) Copy(dst io.Writer) (int64, error) {
// CopyTo copies the remote file to the writer dst
func (f *PosixFile) CopyTo(dst io.Writer) (int64, error) {
if f.isEOF {
return 0, io.EOF
}
if !f.isReadable() {
return 0, f.pathErr("copy", fmt.Errorf("%w: file %s is not open for reading", fs.ErrClosed, f.path))
return 0, f.pathErr(OpCopyTo, fmt.Errorf("%w: file %s is not open for reading", fs.ErrClosed, f.path))
}
bs, skip, count := f.ddParams(f.pos, int(f.size-f.pos))
errbuf := bytes.NewBuffer(nil)
cmd, err := f.fsys.conn.ExecStreams(fmt.Sprintf("dd if=%s bs=%d skip=%d count=%d", shellescape.Quote(f.path), bs, skip, count), nil, dst, errbuf, f.fsys.opts...)
counter := &ByteCounter{}
writer := io.MultiWriter(dst, counter)
cmd, err := f.fsys.conn.ExecStreams(fmt.Sprintf("dd if=%s bs=%d skip=%d count=%d", shellescape.Quote(f.path), bs, skip, count), nil, writer, errbuf, f.fsys.opts...)
if err != nil {
return 0, f.pathErr("copy", fmt.Errorf("failed to execute dd: %w (%s)", err, errbuf.String()))
return 0, f.pathErr(OpCopyTo, fmt.Errorf("failed to execute dd: %w (%s)", err, errbuf.String()))
}
if err := cmd.Wait(); err != nil {
return 0, f.pathErr("copy", fmt.Errorf("dd: %w (%s)", err, errbuf.String()))
return 0, f.pathErr(OpCopyTo, fmt.Errorf("dd: %w (%s)", err, errbuf.String()))
}
f.pos = f.size
f.pos += counter.Count()
f.isEOF = true
return f.size - f.pos, nil
return counter.Count(), nil
}

// CopyFrom copies the local reader src to the remote file
func (f *PosixFile) CopyFrom(src io.Reader) (int64, error) {
if !f.isWritable() {
return 0, f.pathErr(OpCopyFrom, fmt.Errorf("%w: file %s is not open for writing", fs.ErrClosed, f.path))
}
if err := f.fsys.Truncate(f.Name(), f.pos); err != nil {
return 0, f.pathErr(OpCopyFrom, fmt.Errorf("truncate: %w", err))
}
counter := &ByteCounter{}
tee := io.NopCloser(io.TeeReader(src, counter))
errbuf := bytes.NewBuffer(nil)

cmd, err := f.fsys.conn.ExecStreams(fmt.Sprintf("dd if=/dev/stdin of=%s bs=%d seek=%d conv=notrunc", shellescape.Quote(f.path), f.fsBlockSize(), f.pos), tee, io.Discard, errbuf, f.fsys.opts...)
if err != nil {
return 0, f.pathErr(OpCopyFrom, fmt.Errorf("exec dd: %w", err))
}
if err := cmd.Wait(); err != nil {
return 0, f.pathErr(OpCopyFrom, fmt.Errorf("dd: %w: %s", err, errbuf.String()))
}

f.pos += counter.Count()
f.size = f.pos
return counter.Count(), nil
}

// Close closes the file, rendering it unusable for I/O. It returns an error, if any.
Expand Down
7 changes: 7 additions & 0 deletions pkg/rigfs/types.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,19 @@ type connection interface {
ExecStreams(cmd string, stdin io.ReadCloser, stdout io.Writer, stderr io.Writer, opts ...exec.Option) (exec.Waiter, error)
}

// Copier is a file-like struct that can copy data to and from io.Reader and io.Writer
type Copier interface {
CopyFrom(src io.Reader) (int64, error)
CopyTo(dst io.Writer) (int64, error)
}

// File is a file in the remote filesystem
type File interface {
fs.File
io.Seeker
io.ReadCloser
io.Writer
Copier
}

// Fsys is a filesystem on the remote host
Expand Down
9 changes: 9 additions & 0 deletions pkg/rigfs/windir.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package rigfs
import (
"encoding/json"
"fmt"
"io"
"io/fs"
"os"

Expand Down Expand Up @@ -32,6 +33,14 @@ func (f *winDir) Write(_ []byte) (int, error) {
return 0, f.pathErr("write", fmt.Errorf("%w: is a directory", fs.ErrInvalid))
}

func (f *winDir) CopyTo(_ io.Writer) (int64, error) {
return 0, f.pathErr("write", fmt.Errorf("%w: is a directory", fs.ErrInvalid))
}

func (f *winDir) CopyFrom(_ io.Reader) (int64, error) {
return 0, f.pathErr("write", fmt.Errorf("%w: is a directory", fs.ErrInvalid))
}

func (f *winDir) Close() error {
if f.closed {
return f.pathErr("close", fs.ErrClosed)
Expand Down
32 changes: 32 additions & 0 deletions pkg/rigfs/winfile.go
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,38 @@ func (f *winFile) Read(p []byte) (int, error) {
return total, nil
}

// CopyTo copies the remote file to the provided io.Writer.
func (f *winFile) CopyTo(dst io.Writer) (int64, error) {
if f.closed {
return 0, f.pathErr(OpCopyTo, fs.ErrClosed)
}
resp, err := f.command("r -1")
if err != nil {
return 0, f.pathErr(OpCopyTo, fmt.Errorf("read: %w", err))
}
if resp.N == 0 {
return 0, f.pathErr(OpCopyTo, io.EOF)
}
total := int64(0)
for total < resp.N {
n, err := io.CopyN(dst, f.stdout, resp.N-total)
total += n
if err != nil {
return total, f.pathErr(OpCopyTo, fmt.Errorf("copy: %w", err))
}
}
return total, nil
}

// CopyFrom copies the provided io.Reader to the remote file.
func (f *winFile) CopyFrom(src io.Reader) (int64, error) {
n, err := io.Copy(f, src)
if err != nil {
return n, f.pathErr(OpCopyFrom, fmt.Errorf("io.copy: %w", err))
}
return n, nil
}

func fAccess(flags int) string {
switch {
case flags&(os.O_WRONLY|os.O_TRUNC|os.O_APPEND) != 0:
Expand Down
14 changes: 8 additions & 6 deletions pkg/rigfs/withname.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,14 @@ package rigfs
import "io/fs"

const (
OpClose = "close" // OpClose Close operation
OpOpen = "open" // OpOpen Open operation
OpRead = "read" // OpRead Read operation
OpSeek = "seek" // OpSeek Seek operation
OpStat = "stat" // OpStat Stat operation
OpWrite = "write" // OpWrite Write operation
OpClose = "close" // OpClose Close operation
OpOpen = "open" // OpOpen Open operation
OpRead = "read" // OpRead Read operation
OpSeek = "seek" // OpSeek Seek operation
OpStat = "stat" // OpStat Stat operation
OpWrite = "write" // OpWrite Write operation
OpCopyTo = "copy-to" // OpCopyTo CopyTo operation
OpCopyFrom = "copy-from" // OpCopyFrom CopyFrom operation
)

type withPath struct {
Expand Down
57 changes: 57 additions & 0 deletions test/rig_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -546,6 +546,63 @@ func (s *FsysSuite) TestReadWriteFile() {
}
}

func (s *FsysSuite) TestReadWriteFileCopy() {
for _, testFileSize := range []int64{
int64(500), // less than one block on most filesystems
int64(1 << (10 * 2)), // exactly 1MB
int64(4096), // exactly one block on most filesystems
int64(4097), // plus 1
} {
s.Run(fmt.Sprintf("File size %d", testFileSize), func() {
fn := s.TempPath()

origin := io.LimitReader(rand.Reader, testFileSize)
shasum := sha256.New()
reader := io.TeeReader(origin, shasum)

defer func() {
_ = s.fsys.Remove(fn)
}()
s.Run("Write file", func() {
f, err := s.fsys.OpenFile(fn, os.O_CREATE|os.O_WRONLY, 0644)
s.Require().NoError(err)
n, err := f.CopyFrom(reader)
s.Require().NoError(err)
s.Equal(testFileSize, n)
s.Require().NoError(f.Close())
})

s.Run("Verify file size", func() {
stat, err := s.fsys.Stat(fn)
s.Require().NoError(err)
s.Equal(testFileSize, stat.Size())
})

s.Run("Verify file sha256", func() {
sum, err := s.fsys.Sha256(fn)
s.Require().NoError(err)
s.Equal(hex.EncodeToString(shasum.Sum(nil)), sum)
})

readSha := sha256.New()
s.Run("Read file", func() {
fsf, err := s.fsys.Open(fn)
s.Require().NoError(err)
f, ok := fsf.(rigfs.File)
s.Require().True(ok)
n, err := f.CopyTo(readSha)
s.Require().NoError(err)
s.Equal(testFileSize, n)
s.Require().NoError(f.Close())
})

s.Run("Verify read file sha256", func() {
s.Equal(shasum.Sum(nil), readSha.Sum(nil))
})
})
}
}

type RepeatReader struct {
data []byte
}
Expand Down

0 comments on commit 7642b01

Please sign in to comment.