Skip to content

Commit

Permalink
fix(backend): let components add default values (#9591)
Browse files Browse the repository at this point in the history
* let components add default values

* address comments and add unit tests

* address comments

* backend test

* backend test 2

* backend test 3

* backend test 4

* backend test 5

* backend test 6

* do not use python component in unit tests

* shell command

* shell command

* shell command

* does not delete tmp folder

* change folder permission

* update launcher image
  • Loading branch information
Linchin authored Jun 13, 2023
1 parent a9ac0b9 commit dbebbde
Show file tree
Hide file tree
Showing 5 changed files with 300 additions and 29 deletions.
2 changes: 1 addition & 1 deletion backend/src/v2/compiler/argocompiler/argo.go
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ func Compile(jobArg *pipelinespec.PipelineJob, kubernetesSpecArg *pipelinespec.S
templates: make(map[string]*wfapi.Template),
// TODO(chensun): release process and update the images.
driverImage: "gcr.io/ml-pipeline/kfp-driver@sha256:0ce9bf20ac9cbb21e84ff0762d5ae508d21e9c85fde2b14b51363bd1b8cd7528",
launcherImage: "gcr.io/ml-pipeline/kfp-launcher@sha256:2b844d5509a2f8713f677045695e5622b7aab57b8880159e7872c60b57fae0d9",
launcherImage: "gcr.io/ml-pipeline/kfp-launcher@sha256:80cf120abd125db84fa547640fd6386c4b2a26936e0c2b04a7d3634991a850a4",
job: job,
spec: spec,
executors: deploy.GetExecutors(),
Expand Down
119 changes: 92 additions & 27 deletions backend/src/v2/component/launcher_v2.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ import (
"time"

"github.com/golang/glog"
"github.com/golang/protobuf/proto"
"github.com/golang/protobuf/ptypes/timestamp"
api "github.com/kubeflow/pipelines/backend/api/v1beta1/go_client"
"github.com/kubeflow/pipelines/backend/src/v2/cacheutils"
Expand Down Expand Up @@ -62,11 +63,16 @@ type LauncherV2 struct {
options LauncherV2Options

// clients
metadataClient *metadata.Client
k8sClient *kubernetes.Clientset
metadataClient metadata.ClientInterface
k8sClient kubernetes.Interface
cacheClient *cacheutils.Client
}

// Client is the struct to hold the Kubernetes Clientset
type kubernetesClient struct {
Clientset kubernetes.Interface
}

func NewLauncherV2(ctx context.Context, executionID int64, executorInputJSON, componentSpecJSON string, cmdArgs []string, opts *LauncherV2Options) (l *LauncherV2, err error) {
defer func() {
if err != nil {
Expand Down Expand Up @@ -260,7 +266,43 @@ func (l *LauncherV2) publish(
return l.metadataClient.PublishExecution(ctx, execution, outputParameters, outputArtifacts, status)
}

func executeV2(ctx context.Context, executorInput *pipelinespec.ExecutorInput, component *pipelinespec.ComponentSpec, cmd string, args []string, bucket *blob.Bucket, bucketConfig *objectstore.Config, metadataClient *metadata.Client, namespace string, k8sClient *kubernetes.Clientset) (*pipelinespec.ExecutorOutput, []*metadata.OutputArtifact, error) {
func executeV2(
ctx context.Context,
executorInput *pipelinespec.ExecutorInput,
component *pipelinespec.ComponentSpec,
cmd string,
args []string,
bucket *blob.Bucket,
bucketConfig *objectstore.Config,
metadataClient metadata.ClientInterface,
namespace string,
k8sClient kubernetes.Interface,
) (*pipelinespec.ExecutorOutput, []*metadata.OutputArtifact, error) {

// Add parameter default values to executorInput, if there is not already a user input.
// This process is done in the launcher because we let the component resolve default values internally.
// Variable executorInputWithDefault is a copy so we don't alter the original data.
executorInputWithDefault, err := addDefaultParams(executorInput, component)
if err != nil {
return nil, nil, err
}

// Fill in placeholders with runtime values.
placeholders, err := getPlaceholders(executorInputWithDefault)
if err != nil {
return nil, nil, err
}
for placeholder, replacement := range placeholders {
cmd = strings.ReplaceAll(cmd, placeholder, replacement)
}
for i := range args {
arg := args[i]
for placeholder, replacement := range placeholders {
arg = strings.ReplaceAll(arg, placeholder, replacement)
}
args[i] = arg
}

executorOutput, err := execute(ctx, executorInput, cmd, args, bucket, bucketConfig, namespace, k8sClient)
if err != nil {
return nil, nil, err
Expand Down Expand Up @@ -332,30 +374,23 @@ func prettyPrint(jsonStr string) string {

const OutputMetadataFilepath = "/tmp/kfp_outputs/output_metadata.json"

func execute(ctx context.Context, executorInput *pipelinespec.ExecutorInput, cmd string, args []string, bucket *blob.Bucket, bucketConfig *objectstore.Config, namespace string, k8sClient *kubernetes.Clientset) (*pipelinespec.ExecutorOutput, error) {
func execute(
ctx context.Context,
executorInput *pipelinespec.ExecutorInput,
cmd string,
args []string,
bucket *blob.Bucket,
bucketConfig *objectstore.Config,
namespace string,
k8sClient kubernetes.Interface,
) (*pipelinespec.ExecutorOutput, error) {
if err := downloadArtifacts(ctx, executorInput, bucket, bucketConfig, namespace, k8sClient); err != nil {
return nil, err
}
if err := prepareOutputFolders(executorInput); err != nil {
return nil, err
}

// Fill in placeholders with runtime values.
placeholders, err := getPlaceholders(executorInput)
if err != nil {
return nil, err
}
for placeholder, replacement := range placeholders {
cmd = strings.ReplaceAll(cmd, placeholder, replacement)
}
for i := range args {
arg := args[i]
for placeholder, replacement := range placeholders {
arg = strings.ReplaceAll(arg, placeholder, replacement)
}
args[i] = arg
}

// Run user program.
executor := exec.Command(cmd, args...)
executor.Stdin = os.Stdin
Expand All @@ -373,7 +408,7 @@ func execute(ctx context.Context, executorInput *pipelinespec.ExecutorInput, cmd
type uploadOutputArtifactsOptions struct {
bucketConfig *objectstore.Config
bucket *blob.Bucket
metadataClient *metadata.Client
metadataClient metadata.ClientInterface
}

func uploadOutputArtifacts(ctx context.Context, executorInput *pipelinespec.ExecutorInput, executorOutput *pipelinespec.ExecutorOutput, opts uploadOutputArtifactsOptions) ([]*metadata.OutputArtifact, error) {
Expand Down Expand Up @@ -428,9 +463,9 @@ func uploadOutputArtifacts(ctx context.Context, executorInput *pipelinespec.Exec
return outputArtifacts, nil
}

func downloadArtifacts(ctx context.Context, executorInput *pipelinespec.ExecutorInput, defaultBucket *blob.Bucket, defaultBucketConfig *objectstore.Config, namespace string, k8sClient *kubernetes.Clientset) error {
func downloadArtifacts(ctx context.Context, executorInput *pipelinespec.ExecutorInput, defaultBucket *blob.Bucket, defaultBucketConfig *objectstore.Config, namespace string, k8sClient kubernetes.Interface) error {
// Read input artifact metadata.
nonDefaultBuckets, err := fetchNonDefaultBuckets(ctx, executorInput.Inputs.Artifacts, defaultBucketConfig, namespace, k8sClient)
nonDefaultBuckets, err := fetchNonDefaultBuckets(ctx, executorInput.GetInputs().GetArtifacts(), defaultBucketConfig, namespace, k8sClient)
closeNonDefaultBuckets := func(buckets map[string]*blob.Bucket) {
for name, bucket := range nonDefaultBuckets {
if closeBucketErr := bucket.Close(); closeBucketErr != nil {
Expand All @@ -442,7 +477,7 @@ func downloadArtifacts(ctx context.Context, executorInput *pipelinespec.Executor
if err != nil {
return fmt.Errorf("failed to fetch non default buckets: %w", err)
}
for name, artifactList := range executorInput.Inputs.Artifacts {
for name, artifactList := range executorInput.GetInputs().GetArtifacts() {
// TODO(neuromage): Support concat-based placholders for arguments.
if len(artifactList.Artifacts) == 0 {
continue
Expand Down Expand Up @@ -485,7 +520,13 @@ func downloadArtifacts(ctx context.Context, executorInput *pipelinespec.Executor
return nil
}

func fetchNonDefaultBuckets(ctx context.Context, artifacts map[string]*pipelinespec.ArtifactList, defaultBucketConfig *objectstore.Config, namespace string, k8sClient *kubernetes.Clientset) (buckets map[string]*blob.Bucket, err error) {
func fetchNonDefaultBuckets(
ctx context.Context,
artifacts map[string]*pipelinespec.ArtifactList,
defaultBucketConfig *objectstore.Config,
namespace string,
k8sClient kubernetes.Interface,
) (buckets map[string]*blob.Bucket, err error) {
nonDefaultBuckets := make(map[string]*blob.Bucket)
for name, artifactList := range artifacts {
if len(artifactList.Artifacts) == 0 {
Expand Down Expand Up @@ -525,7 +566,7 @@ func getPlaceholders(executorInput *pipelinespec.ExecutorInput) (placeholders ma
placeholders["{{$}}"] = string(executorInputJSON)

// Read input artifact metadata.
for name, artifactList := range executorInput.Inputs.Artifacts {
for name, artifactList := range executorInput.GetInputs().GetArtifacts() {
if len(artifactList.Artifacts) == 0 {
continue
}
Expand Down Expand Up @@ -562,7 +603,7 @@ func getPlaceholders(executorInput *pipelinespec.ExecutorInput) (placeholders ma
}

// Prepare input parameter placeholders.
for name, parameter := range executorInput.Inputs.ParameterValues {
for name, parameter := range executorInput.GetInputs().GetParameterValues() {
key := fmt.Sprintf(`{{$.inputs.parameters['%s']}}`, name)
switch t := parameter.Kind.(type) {
case *structpb.Value_StringValue:
Expand Down Expand Up @@ -696,3 +737,27 @@ func prepareOutputFolders(executorInput *pipelinespec.ExecutorInput) error {

return nil
}

// Adds default parameter values if there is no user provided value
func addDefaultParams(
executorInput *pipelinespec.ExecutorInput,
component *pipelinespec.ComponentSpec,
) (*pipelinespec.ExecutorInput, error) {
// Make a deep copy so we don't alter the original data
executorInputWithDefaultMsg := proto.Clone(executorInput)
executorInputWithDefault, ok := executorInputWithDefaultMsg.(*pipelinespec.ExecutorInput)
if !ok {
return nil, fmt.Errorf("bug: cloned executor input message does not have expected type")
}

if executorInputWithDefault.GetInputs().GetParameterValues() == nil {
executorInputWithDefault.Inputs.ParameterValues = make(map[string]*structpb.Value)
}
for name, value := range component.GetInputDefinitions().GetParameters() {
_, hasInput := executorInputWithDefault.GetInputs().GetParameterValues()[name]
if value.GetDefaultValue() != nil && !hasInput {
executorInputWithDefault.GetInputs().GetParameterValues()[name] = value.GetDefaultValue()
}
}
return executorInputWithDefault, nil
}
92 changes: 92 additions & 0 deletions backend/src/v2/component/launcher_v2_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
// Copyright 2023 The Kubeflow Authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package component

import (
"context"
"testing"

"github.com/kubeflow/pipelines/api/v2alpha1/go/pipelinespec"
"github.com/kubeflow/pipelines/backend/src/v2/metadata"
"github.com/kubeflow/pipelines/backend/src/v2/objectstore"
"github.com/stretchr/testify/assert"
"gocloud.dev/blob"
"google.golang.org/protobuf/types/known/structpb"
"k8s.io/client-go/kubernetes/fake"
)

var addNumbersComponent = &pipelinespec.ComponentSpec{
Implementation: &pipelinespec.ComponentSpec_ExecutorLabel{ExecutorLabel: "add"},
InputDefinitions: &pipelinespec.ComponentInputsSpec{
Parameters: map[string]*pipelinespec.ComponentInputsSpec_ParameterSpec{
"a": {ParameterType: pipelinespec.ParameterType_NUMBER_INTEGER, DefaultValue: structpb.NewNumberValue(5)},
"b": {ParameterType: pipelinespec.ParameterType_NUMBER_INTEGER},
},
},
OutputDefinitions: &pipelinespec.ComponentOutputsSpec{
Parameters: map[string]*pipelinespec.ComponentOutputsSpec_ParameterSpec{
"Output": {ParameterType: pipelinespec.ParameterType_NUMBER_INTEGER},
},
},
}

// Tests that launcher correctly executes the user component and successfully writes output parameters to file.
func Test_executeV2_Parameters(t *testing.T) {
tests := []struct {
name string
executorInput *pipelinespec.ExecutorInput
executorArgs []string
wantErr bool
}{
{
"happy pass",
&pipelinespec.ExecutorInput{
Inputs: &pipelinespec.ExecutorInput_Inputs{
ParameterValues: map[string]*structpb.Value{"a": structpb.NewNumberValue(1), "b": structpb.NewNumberValue(2)},
},
},
[]string{"-c", "test {{$.inputs.parameters['a']}} -eq 1 || exit 1\ntest {{$.inputs.parameters['b']}} -eq 2 || exit 1"},
false,
},
{
"use default value",
&pipelinespec.ExecutorInput{
Inputs: &pipelinespec.ExecutorInput_Inputs{
ParameterValues: map[string]*structpb.Value{"b": structpb.NewNumberValue(2)},
},
},
[]string{"-c", "test {{$.inputs.parameters['a']}} -eq 5 || exit 1\ntest {{$.inputs.parameters['b']}} -eq 2 || exit 1"},
false,
},
}

for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
fakeKubernetesClientset := &fake.Clientset{}
fakeMetadataClient := metadata.NewFakeClient()
bucket, err := blob.OpenBucket(context.Background(), "gs://test-bucket")
assert.Nil(t, err)
bucketConfig, err := objectstore.ParseBucketConfig("gs://test-bucket/pipeline-root/")
assert.Nil(t, err)
_, _, err = executeV2(context.Background(), test.executorInput, addNumbersComponent, "sh", test.executorArgs, bucket, bucketConfig, fakeMetadataClient, "namespace", fakeKubernetesClientset)

if test.wantErr {
assert.NotNil(t, err)
} else {
assert.Nil(t, err)

}
})
}
}
21 changes: 20 additions & 1 deletion backend/src/v2/metadata/client.go
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Copyright 2021 The Kubeflow Authors
// Copyright 2021-2023 The Kubeflow Authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -76,6 +76,25 @@ var (
}
)

type ClientInterface interface {
GetPipeline(ctx context.Context, pipelineName, runID, namespace, runResource, pipelineRoot string) (*Pipeline, error)
GetDAG(ctx context.Context, executionID int64) (*DAG, error)
PublishExecution(ctx context.Context, execution *Execution, outputParameters map[string]*structpb.Value, outputArtifacts []*OutputArtifact, state pb.Execution_State) error
CreateExecution(ctx context.Context, pipeline *Pipeline, config *ExecutionConfig) (*Execution, error)
PrePublishExecution(ctx context.Context, execution *Execution, config *ExecutionConfig) (*Execution, error)
GetExecutions(ctx context.Context, ids []int64) ([]*pb.Execution, error)
GetExecution(ctx context.Context, id int64) (*Execution, error)
GetPipelineFromExecution(ctx context.Context, id int64) (*Pipeline, error)
GetExecutionsInDAG(ctx context.Context, dag *DAG, pipeline *Pipeline) (executionsMap map[string]*Execution, err error)
GetEventsByArtifactIDs(ctx context.Context, artifactIds []int64) ([]*pb.Event, error)
GetArtifactName(ctx context.Context, artifactId int64) (string, error)
GetArtifacts(ctx context.Context, ids []int64) ([]*pb.Artifact, error)
GetOutputArtifactsByExecutionId(ctx context.Context, executionId int64) (map[string]*OutputArtifact, error)
RecordArtifact(ctx context.Context, outputName, schema string, runtimeArtifact *pipelinespec.RuntimeArtifact, state pb.Artifact_State) (*OutputArtifact, error)
GetOrInsertArtifactType(ctx context.Context, schema string) (typeID int64, err error)
FindMatchedArtifact(ctx context.Context, artifactToMatch *pb.Artifact, pipelineContextId int64) (matchedArtifact *pb.Artifact, err error)
}

// Client is an MLMD service client.
type Client struct {
svc pb.MetadataStoreServiceClient
Expand Down
Loading

0 comments on commit dbebbde

Please sign in to comment.