Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Create a FileOutput reader if the agent produce file output #391

Merged
merged 8 commits into from
Sep 19, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 15 additions & 7 deletions go/tasks/plugins/webapi/agent/integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,12 +48,15 @@ func (m *MockClient) CreateTask(_ context.Context, createTaskRequest *admin.Crea
return &admin.CreateTaskResponse{ResourceMeta: []byte{1, 2, 3, 4}}, nil
}

func (m *MockClient) GetTask(_ context.Context, _ *admin.GetTaskRequest, _ ...grpc.CallOption) (*admin.GetTaskResponse, error) {
return &admin.GetTaskResponse{Resource: &admin.Resource{State: admin.State_SUCCEEDED, Outputs: &flyteIdlCore.LiteralMap{
Literals: map[string]*flyteIdlCore.Literal{
"arr": coreutils.MustMakeLiteral([]interface{}{[]interface{}{"a", "b"}, []interface{}{1, 2}}),
},
}}}, nil
func (m *MockClient) GetTask(_ context.Context, req *admin.GetTaskRequest, _ ...grpc.CallOption) (*admin.GetTaskResponse, error) {
if req.GetTaskType() == "bigquery_query_job_task" {
return &admin.GetTaskResponse{Resource: &admin.Resource{State: admin.State_SUCCEEDED, Outputs: &flyteIdlCore.LiteralMap{
Literals: map[string]*flyteIdlCore.Literal{
"arr": coreutils.MustMakeLiteral([]interface{}{[]interface{}{"a", "b"}, []interface{}{1, 2}}),
},
}}}, nil
}
return &admin.GetTaskResponse{Resource: &admin.Resource{State: admin.State_SUCCEEDED}}, nil
}

func (m *MockClient) DeleteTask(_ context.Context, _ *admin.DeleteTaskRequest, _ ...grpc.CallOption) (*admin.DeleteTaskResponse, error) {
Expand Down Expand Up @@ -113,6 +116,11 @@ func TestEndToEnd(t *testing.T) {

phase := tests.RunPluginEndToEndTest(t, plugin, &template, inputs, nil, nil, iter)
assert.Equal(t, true, phase.Phase().IsSuccess())

template.Type = "spark_job"
phase = tests.RunPluginEndToEndTest(t, plugin, &template, inputs, nil, nil, iter)
assert.Equal(t, true, phase.Phase().IsSuccess())

})

t.Run("failed to create a job", func(t *testing.T) {
Expand Down Expand Up @@ -251,7 +259,7 @@ func getTaskContext(t *testing.T) *pluginCoreMocks.TaskExecutionContext {
func newMockAgentPlugin() webapi.PluginEntry {
return webapi.PluginEntry{
ID: "agent-service",
SupportedTaskTypes: []core.TaskType{"bigquery_query_job_task"},
SupportedTaskTypes: []core.TaskType{"bigquery_query_job_task", "spark_job"},
PluginLoader: func(ctx context.Context, iCtx webapi.PluginSetupContext) (webapi.AsyncPlugin, error) {
return &MockPlugin{
Plugin{
Expand Down
35 changes: 29 additions & 6 deletions go/tasks/plugins/webapi/agent/plugin.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
"fmt"

"github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/admin"
"github.com/flyteorg/flytestdlib/config"
"google.golang.org/grpc/credentials"
"google.golang.org/grpc/credentials/insecure"

Expand All @@ -19,8 +18,11 @@
"github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery"
"github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/core"
"github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/core/template"
"github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/io"
"github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/ioutils"
"github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/webapi"
"github.com/flyteorg/flytestdlib/config"
"github.com/flyteorg/flytestdlib/logger"
"github.com/flyteorg/flytestdlib/promutils"
"google.golang.org/grpc"
)
Expand Down Expand Up @@ -169,17 +171,38 @@
case admin.State_RETRYABLE_FAILURE:
return core.PhaseInfoRetryableFailure(pluginErrors.TaskFailedWithError, "failed to run the job", taskInfo), nil
case admin.State_SUCCEEDED:
if resource.Outputs != nil {
err := taskCtx.OutputWriter().Put(ctx, ioutils.NewInMemoryOutputReader(resource.Outputs, nil, nil))
if err != nil {
return core.PhaseInfoUndefined, err
}
err = writeOutput(ctx, taskCtx, resource)
if err != nil {
logger.Errorf(ctx, "Failed to write output with err %s", err.Error())
return core.PhaseInfoUndefined, err

Check warning on line 177 in go/tasks/plugins/webapi/agent/plugin.go

View check run for this annotation

Codecov / codecov/patch

go/tasks/plugins/webapi/agent/plugin.go#L176-L177

Added lines #L176 - L177 were not covered by tests
}
return core.PhaseInfoSuccess(taskInfo), nil
}
return core.PhaseInfoUndefined, pluginErrors.Errorf(core.SystemErrorCode, "unknown execution phase [%v].", resource.State)
}

func writeOutput(ctx context.Context, taskCtx webapi.StatusContext, resource *ResourceWrapper) error {
taskTemplate, err := taskCtx.TaskReader().Read(ctx)

Check warning on line 185 in go/tasks/plugins/webapi/agent/plugin.go

View check run for this annotation

Codecov / codecov/patch

go/tasks/plugins/webapi/agent/plugin.go#L184-L185

Added lines #L184 - L185 were not covered by tests
if err != nil {
return err
}

Check warning on line 188 in go/tasks/plugins/webapi/agent/plugin.go

View check run for this annotation

Codecov / codecov/patch

go/tasks/plugins/webapi/agent/plugin.go#L188

Added line #L188 was not covered by tests

if taskTemplate.Interface == nil || taskTemplate.Interface.Outputs == nil || taskTemplate.Interface.Outputs.Variables == nil {
logger.Debugf(ctx, "The task declares no outputs. Skipping writing the outputs.")
return nil
}

var opReader io.OutputReader

Check warning on line 195 in go/tasks/plugins/webapi/agent/plugin.go

View check run for this annotation

Codecov / codecov/patch

go/tasks/plugins/webapi/agent/plugin.go#L194-L195

Added lines #L194 - L195 were not covered by tests
if resource.Outputs != nil {
logger.Debugf(ctx, "Agent returned an output")
opReader = ioutils.NewInMemoryOutputReader(resource.Outputs, nil, nil)
} else {
logger.Debugf(ctx, "Agent didn't return any output, assuming file based outputs.")
opReader = ioutils.NewRemoteFileOutputReader(ctx, taskCtx.DataStore(), taskCtx.OutputWriter(), taskCtx.MaxDatasetSizeBytes())
}
return taskCtx.OutputWriter().Put(ctx, opReader)
hamersaw marked this conversation as resolved.
Show resolved Hide resolved
}

Check warning on line 205 in go/tasks/plugins/webapi/agent/plugin.go

View check run for this annotation

Codecov / codecov/patch

go/tasks/plugins/webapi/agent/plugin.go#L202-L205

Added lines #L202 - L205 were not covered by tests
func getFinalAgent(taskType string, cfg *Config) (*Agent, error) {
if id, exists := cfg.AgentForTaskTypes[taskType]; exists {
if agent, exists := cfg.Agents[id]; exists {
Expand Down
2 changes: 1 addition & 1 deletion go/tasks/plugins/webapi/agent/plugin_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ func TestPlugin(t *testing.T) {
assert.Equal(t, plugin.cfg.ResourceConstraints, constraints)
})

t.Run("tet newAgentPlugin", func(t *testing.T) {
t.Run("test newAgentPlugin", func(t *testing.T) {
p := newAgentPlugin()
assert.NotNil(t, p)
assert.Equal(t, "agent-service", p.ID)
Expand Down
1 change: 1 addition & 0 deletions tests/end_to_end.go
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@ func RunPluginEndToEndTest(t *testing.T, executor pluginCore.Plugin, template *i
outputWriter.OnGetOutputPath().Return(basePrefix + "/outputs.pb")
outputWriter.OnGetCheckpointPrefix().Return("/checkpoint")
outputWriter.OnGetPreviousCheckpointsPrefix().Return("/prev")
outputWriter.OnPutMatch(mock.Anything, mock.Anything).Return(nil)

outputWriter.OnPutMatch(mock.Anything, mock.Anything).Return(nil).Run(func(args mock.Arguments) {
or := args.Get(1).(io.OutputReader)
Expand Down
Loading