Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Add support upload collections #54

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions data/common.go
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
// This module contains Flyte CoPilot related code.
// Currently it only has 2 utilities - downloader and an uploader.
// Usage Downloader:
// downloader := NewDownloader(...)
// downloader.DownloadInputs(...) // will recursively download all inputs
//
// downloader := NewDownloader(...)
// downloader.DownloadInputs(...) // will recursively download all inputs
//
// Usage uploader:
// uploader := NewUploader(...)
Expand Down
39 changes: 39 additions & 0 deletions data/upload.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
"path"
"path/filepath"
"reflect"
"strings"

"github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/core"
"github.com/flyteorg/flytestdlib/futures"
Expand Down Expand Up @@ -56,6 +57,39 @@ func (u Uploader) handleSimpleType(_ context.Context, t core.SimpleType, filePat
return coreutils.MakeLiteralForSimpleType(t, string(b))
}

func (u Uploader) handleCollectionType(_ context.Context, t *core.LiteralType, filePath string) (*core.Literal, error) {
fpath, info, err := IsFileReadable(filePath, true)
if err != nil {
return nil, err
}
if info.IsDir() {
return nil, fmt.Errorf("expected file for type [%s], found dir at path [%s]", t.String(), filePath)
}
if info.Size() > maxPrimitiveSize {
return nil, fmt.Errorf("maximum allowed filesize is [%d], but found [%d]", maxPrimitiveSize, info.Size())
}
b, err := ioutil.ReadFile(fpath)
if err != nil {
return nil, err
}
literalString := strings.Split(strings.ReplaceAll(string(b), " ", ""), ",")
literals := make([]*core.Literal, 0, len(literalString))
for _, val := range literalString {
lv, err := coreutils.MakeLiteralForType(t.GetCollectionType(), val)
if err != nil {
return nil, err
}
literals = append(literals, lv)
}
res := &core.Literal{}
res.Value = &core.Literal_Collection{
Collection: &core.LiteralCollection{
Literals: literals,
},
}
return res, nil
}

func (u Uploader) handleBlobType(ctx context.Context, localPath string, toPath storage.DataReference) (*core.Literal, error) {
fpath, info, err := IsFileReadable(localPath, true)
if err != nil {
Expand Down Expand Up @@ -158,6 +192,10 @@ func (u Uploader) RecursiveUpload(ctx context.Context, vars *core.VariableMap, f
varFutures[varName] = futures.NewAsyncFuture(childCtx, func(ctx2 context.Context) (interface{}, error) {
return u.handleSimpleType(ctx2, varType.GetSimple(), varPath)
})
case *core.LiteralType_CollectionType:
varFutures[varName] = futures.NewAsyncFuture(childCtx, func(ctx2 context.Context) (interface{}, error) {
return u.handleCollectionType(ctx2, varType, varPath)
})
default:
return fmt.Errorf("currently CoPilot uploader does not support [%s], system error", varType)
}
Expand All @@ -178,6 +216,7 @@ func (u Uploader) RecursiveUpload(ctx context.Context, vars *core.VariableMap, f
return fmt.Errorf("IllegalState, expected core.Literal, received [%s]", reflect.TypeOf(v))
}
outputs.Literals[k] = l
logger.Infof(ctx, "llll [%s]", l)
logger.Infof(ctx, "Var [%s] completed", k)
}

Expand Down
37 changes: 37 additions & 0 deletions data/upload_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,4 +60,41 @@ func TestUploader_RecursiveUpload(t *testing.T) {
assert.NoError(t, err)
assert.Equal(t, string(data), string(b), "content dont match")
})

t.Run("upload-collection", func(t *testing.T) {
tmpDir, err := ioutil.TempDir(tmpFolderLocation, tmpPrefix)
assert.NoError(t, err)
defer func() {
assert.NoError(t, os.RemoveAll(tmpDir))
}()

lt := core.LiteralType{Type: &core.LiteralType_Simple{Simple: core.SimpleType_INTEGER}}
vmap := &core.VariableMap{
Variables: map[string]*core.Variable{
"y": {
Type: &core.LiteralType{Type: &core.LiteralType_CollectionType{CollectionType: &lt}},
},
},
}

data := []byte("1, 2, 3, 4")
assert.NoError(t, ioutil.WriteFile(path.Join(tmpDir, "y"), data, os.ModePerm))
fmt.Printf("Written to %s ", path.Join(tmpDir, "y"))

store, err := storage.NewDataStore(&storage.Config{Type: storage.TypeMemory}, promutils.NewTestScope())
assert.NoError(t, err)

outputRef := storage.DataReference("output")
rawRef := storage.DataReference("raw")
u := NewUploader(context.TODO(), store, core.DataLoadingConfig_JSON, core.IOStrategy_UPLOAD_ON_EXIT, "error")
assert.NoError(t, u.RecursiveUpload(context.TODO(), vmap, tmpDir, outputRef, rawRef))

outputs := &core.LiteralMap{}
assert.NoError(t, store.ReadProtobuf(context.TODO(), outputRef, outputs))
assert.Len(t, outputs.Literals, 1)
assert.NotNil(t, outputs.Literals["y"])
assert.NotNil(t, outputs.Literals["y"].GetCollection())
assert.NotNil(t, outputs.Literals["y"].GetCollection().GetLiterals())
assert.NotNil(t, outputs.Literals["y"].GetCollection().GetLiterals()[0].GetScalar())
})
}