diff --git a/csi/README.md b/csi/README.md index 73f5ae72..1dc783d6 100644 --- a/csi/README.md +++ b/csi/README.md @@ -28,7 +28,7 @@ sequenceDiagram U->>+MR: Register ML Model MR-->>-U: Indexed Model U->>U: Create InferenceService CR - Note right of U: The InferenceService should
point to the model registry
indexed model, e.g.,:
model-registry:/// + Note right of U: The InferenceService should
point to the model registry
indexed model, e.g.,:
model-registry://// KC->>KC: React to InferenceService creation KC->>+MD: Create Model Deployment MD->>+MRSI: Initialization (Download Model) @@ -66,14 +66,17 @@ Which wil create the executable under `bin/mr-storage-initializer`. You can run `main.go` (without building the executable) by running: ```bash -./bin/mr-storage-initializer "model-registry://model/version" "./" +./bin/mr-storage-initializer "model-registry://model-registry-url/model/version" "./" ``` or directly running the `main.go` skipping the previous step: ```bash -make SOURCE_URI=model-registry://model/version DEST_PATH=./ run +make SOURCE_URI=model-registry://model-registry-url/model/version DEST_PATH=./ run ``` +> [!NOTE] +> `model-registry-url` is optional, if not provided the value of `MODEL_REGISTRY_BASE_URL` env variable will be used. + > [!NOTE] > A Model Registry service should be up and running at `localhost:8080`. diff --git a/csi/main.go b/csi/main.go index b51ece22..2a2daaf1 100644 --- a/csi/main.go +++ b/csi/main.go @@ -4,6 +4,7 @@ import ( "log" "os" + "github.com/kubeflow/model-registry/csi/pkg/modelregistry" "github.com/kubeflow/model-registry/csi/pkg/storage" "github.com/kubeflow/model-registry/pkg/openapi" ) @@ -38,7 +39,10 @@ func main() { cfg := openapi.NewConfiguration() cfg.Host = baseUrl cfg.Scheme = scheme - provider, err := storage.NewModelRegistryProvider(cfg) + + apiClient := modelregistry.NewAPIClient(cfg, sourceUri) + + provider, err := storage.NewModelRegistryProvider(apiClient) if err != nil { log.Fatalf("Error initiliazing model registry provider: %v", err) } diff --git a/csi/pkg/constants/constants.go b/csi/pkg/constants/constants.go new file mode 100644 index 00000000..02f01d9a --- /dev/null +++ b/csi/pkg/constants/constants.go @@ -0,0 +1,5 @@ +package constants + +import kserve "github.com/kserve/kserve/pkg/agent/storage" + +const MR kserve.Protocol = "model-registry://" diff --git a/csi/pkg/modelregistry/api_client.go b/csi/pkg/modelregistry/api_client.go new file mode 100644 index 00000000..5f6720d5 --- /dev/null +++ b/csi/pkg/modelregistry/api_client.go @@ -0,0 +1,41 @@ +package modelregistry + +import ( + "context" + "log" + "strings" + + "github.com/kubeflow/model-registry/csi/pkg/constants" + "github.com/kubeflow/model-registry/pkg/openapi" +) + +func NewAPIClient(cfg *openapi.Configuration, storageUri string) *openapi.APIClient { + client := openapi.NewAPIClient(cfg) + + // Parse the URI to retrieve the needed information to query model registry (modelArtifact) + mrUri := strings.TrimPrefix(storageUri, string(constants.MR)) + + tokens := strings.SplitN(mrUri, "/", 3) + + if len(tokens) < 2 { + return client + } + + newCfg := openapi.NewConfiguration() + newCfg.Host = tokens[0] + newCfg.Scheme = cfg.Scheme + + newClient := openapi.NewAPIClient(newCfg) + + if len(tokens) == 2 { + // Check if the model registry service is available + _, _, err := newClient.ModelRegistryServiceAPI.GetRegisteredModels(context.Background()).Execute() + if err != nil { + log.Printf("Falling back to base url %s for model registry service", cfg.Host) + + return client + } + } + + return newClient +} diff --git a/csi/pkg/storage/modelregistry_provider.go b/csi/pkg/storage/modelregistry_provider.go index 01049f69..0bdb4f61 100644 --- a/csi/pkg/storage/modelregistry_provider.go +++ b/csi/pkg/storage/modelregistry_provider.go @@ -2,50 +2,64 @@ package storage import ( "context" + "errors" "fmt" "log" "regexp" "strings" kserve "github.com/kserve/kserve/pkg/agent/storage" + "github.com/kubeflow/model-registry/csi/pkg/constants" "github.com/kubeflow/model-registry/pkg/openapi" ) -const MR kserve.Protocol = "model-registry://" +var ( + _ kserve.Provider = (*ModelRegistryProvider)(nil) + ErrInvalidMRURI = errors.New("invalid model registry URI, use like model-registry://{dnsName}/{registeredModelName}/{versionName}") + ErrNoVersionAssociated = errors.New("no versions associated to registered model") + ErrNoArtifactAssociated = errors.New("no artifacts associated to model version") + ErrNoModelArtifact = errors.New("no model artifact found for model version") + ErrModelArtifactEmptyURI = errors.New("model artifact has empty URI") + ErrNoStorageURI = errors.New("there is no storageUri supplied") + ErrNoProtocolInSTorageURI = errors.New("there is no protocol specified for the storageUri") + ErrProtocolNotSupported = errors.New("protocol not supported for storageUri") + ErrFetchingModelVersion = errors.New("error fetching model version") + ErrFetchingModelVersions = errors.New("error fetching model versions") +) type ModelRegistryProvider struct { Client *openapi.APIClient Providers map[kserve.Protocol]kserve.Provider } -func NewModelRegistryProvider(cfg *openapi.Configuration) (*ModelRegistryProvider, error) { - client := openapi.NewAPIClient(cfg) - +func NewModelRegistryProvider(client *openapi.APIClient) (*ModelRegistryProvider, error) { return &ModelRegistryProvider{ Client: client, Providers: map[kserve.Protocol]kserve.Provider{}, }, nil } -var _ kserve.Provider = (*ModelRegistryProvider)(nil) - -// storageUri formatted like model-registry://{registeredModelName}/{versionName} +// storageUri formatted like model-registry://{modelRegistryUrl}/{registeredModelName}/{versionName} func (p *ModelRegistryProvider) DownloadModel(modelDir string, modelName string, storageUri string) error { - log.Printf("Download model indexed in model registry: modelName=%s, storageUri=%s, modelDir=%s", modelName, storageUri, modelDir) - - // Parse the URI to retrieve the needed information to query model registry (modelArtifact) - mrUri := strings.TrimPrefix(storageUri, string(MR)) - tokens := strings.SplitN(mrUri, "/", 2) + log.Printf("Download model indexed in model registry: modelName=%s, storageUri=%s, modelDir=%s", + modelName, + storageUri, + modelDir, + ) - if len(tokens) == 0 || len(tokens) > 2 { - return fmt.Errorf("invalid model registry URI, use like model-registry://{registeredModelName}/{versionName}") + registeredModelName, versionName, err := p.parseModelVersion(storageUri) + if err != nil { + return err } - registeredModelName := tokens[0] - var versionName *string - if len(tokens) == 2 { - versionName = &tokens[1] - } + log.Printf("Parsed storageUri=%s as: modelRegistryUrl=%s, registeredModelName=%s, versionName=%v", + storageUri, + p.Client.GetConfig().Host, + registeredModelName, + versionName, + ) + + log.Printf("Fetching model: registeredModelName=%s, versionName=%v", registeredModelName, versionName) // Fetch the registered model model, _, err := p.Client.ModelRegistryServiceAPI.FindRegisteredModel(context.Background()).Name(registeredModelName).Execute() @@ -53,28 +67,16 @@ func (p *ModelRegistryProvider) DownloadModel(modelDir string, modelName string, return err } - // Fetch model version by name or latest if not specified - var version *openapi.ModelVersion - if versionName != nil { - version, _, err = p.Client.ModelRegistryServiceAPI.FindModelVersion(context.Background()).Name(*versionName).ParentResourceId(*model.Id).Execute() - if err != nil { - return err - } - } else { - versions, _, err := p.Client.ModelRegistryServiceAPI.GetRegisteredModelVersions(context.Background(), *model.Id). - OrderBy(openapi.ORDERBYFIELD_CREATE_TIME). - SortOrder(openapi.SORTORDER_DESC). - Execute() - if err != nil { - return err - } + log.Printf("Fetching model version: model=%v", model) - if versions.Size == 0 { - return fmt.Errorf("no versions associated to registered model %s", registeredModelName) - } - version = &versions.Items[0] + // Fetch model version by name or latest if not specified + version, err := p.fetchModelVersion(versionName, registeredModelName, model) + if err != nil { + return err } + log.Printf("Fetching model artifacts: version=%v", version) + artifacts, _, err := p.Client.ModelRegistryServiceAPI.GetModelVersionArtifacts(context.Background(), *version.Id). OrderBy(openapi.ORDERBYFIELD_CREATE_TIME). SortOrder(openapi.SORTORDER_DESC). @@ -84,20 +86,20 @@ func (p *ModelRegistryProvider) DownloadModel(modelDir string, modelName string, } if artifacts.Size == 0 { - return fmt.Errorf("no artifacts associated to model version %s", *version.Id) + return fmt.Errorf("%w %s", ErrNoArtifactAssociated, *version.Id) } modelArtifact := artifacts.Items[0].ModelArtifact if modelArtifact == nil { - return fmt.Errorf("no model artifact found for model version %s", *version.Id) + return fmt.Errorf("%w %s", ErrNoModelArtifact, *version.Id) } // Call appropriate kserve provider based on the indexed model artifact URI if modelArtifact.Uri == nil { - return fmt.Errorf("model artifact %s has empty URI", *modelArtifact.Id) + return fmt.Errorf("%w %s", ErrModelArtifactEmptyURI, *modelArtifact.Id) } - protocol, err := extractProtocol(*modelArtifact.Uri) + protocol, err := p.extractProtocol(*modelArtifact.Uri) if err != nil { return err } @@ -110,13 +112,77 @@ func (p *ModelRegistryProvider) DownloadModel(modelDir string, modelName string, return provider.DownloadModel(modelDir, "", *modelArtifact.Uri) } -func extractProtocol(storageURI string) (kserve.Protocol, error) { +// Possible URIs: +// (1) model-registry://{modelName} +// (2) model-registry://{modelName}/{modelVersion} +// (3) model-registry://{modelRegistryUrl}/{modelName} +// (4) model-registry://{modelRegistryUrl}/{modelName}/{modelVersion} +func (p *ModelRegistryProvider) parseModelVersion(storageUri string) (string, *string, error) { + var versionName *string + + // Parse the URI to retrieve the needed information to query model registry (modelArtifact) + mrUri := strings.TrimPrefix(storageUri, string(constants.MR)) + + tokens := strings.SplitN(mrUri, "/", 3) + + if len(tokens) == 0 || len(tokens) > 3 { + return "", nil, ErrInvalidMRURI + } + + // Check if the first token is the host and remove it so that we reduce cases (3) and (4) to (1) and (2) + if len(tokens) >= 2 && p.Client.GetConfig().Host == tokens[0] { + tokens = tokens[1:] + } + + registeredModelName := tokens[0] + + if len(tokens) == 2 { + versionName = &tokens[1] + } + + return registeredModelName, versionName, nil +} + +func (p *ModelRegistryProvider) fetchModelVersion( + versionName *string, + registeredModelName string, + model *openapi.RegisteredModel, +) (*openapi.ModelVersion, error) { + if versionName != nil { + version, _, err := p.Client.ModelRegistryServiceAPI. + FindModelVersion(context.Background()). + Name(*versionName). + ParentResourceId(*model.Id). + Execute() + if err != nil { + return nil, fmt.Errorf("%w: %w", ErrFetchingModelVersion, err) + } + + return version, nil + } + + versions, _, err := p.Client.ModelRegistryServiceAPI.GetRegisteredModelVersions(context.Background(), *model.Id). + // OrderBy(openapi.ORDERBYFIELD_CREATE_TIME). not supported + SortOrder(openapi.SORTORDER_DESC). + Execute() + if err != nil { + return nil, fmt.Errorf("%w: %w", ErrFetchingModelVersions, err) + } + + if versions.Size == 0 { + return nil, fmt.Errorf("%w %s", ErrNoVersionAssociated, registeredModelName) + } + + return &versions.Items[0], nil +} + +func (*ModelRegistryProvider) extractProtocol(storageURI string) (kserve.Protocol, error) { if storageURI == "" { - return "", fmt.Errorf("there is no storageUri supplied") + return "", ErrNoStorageURI } - if !regexp.MustCompile("\\w+?://").MatchString(storageURI) { - return "", fmt.Errorf("there is no protocol specified for the storageUri") + if !regexp.MustCompile(`\w+?://`).MatchString(storageURI) { + return "", ErrNoProtocolInSTorageURI } for _, prefix := range kserve.SupportedProtocols { @@ -124,5 +190,6 @@ func extractProtocol(storageURI string) (kserve.Protocol, error) { return prefix, nil } } - return "", fmt.Errorf("protocol not supported for storageUri") + + return "", ErrProtocolNotSupported } diff --git a/csi/scripts/install_modelregistry.sh b/csi/scripts/install_modelregistry.sh index f2ec7317..3276592d 100755 --- a/csi/scripts/install_modelregistry.sh +++ b/csi/scripts/install_modelregistry.sh @@ -40,9 +40,11 @@ if ! kubectl get namespace "$namespace" &> /dev/null; then fi # Apply model-registry kustomize manifests echo Using model registry image: $image -cd $MR_ROOT/manifests/kustomize/base && kustomize edit set image kubeflow/model-registry:latest=${image} && cd - +cd $MR_ROOT/manifests/kustomize/base && kustomize edit set image kubeflow/model-registry:latest=${image} && \ +kustomize edit set namespace $namespace && cd - +cd $MR_ROOT/manifests/kustomize/overlays/db && kustomize edit set namespace $namespace && cd - kubectl -n $namespace apply -k "$MR_ROOT/manifests/kustomize/overlays/db" # Wait for model registry deployment -modelregistry=$(kubectl get pod -n kubeflow --selector="component=model-registry-server" --output jsonpath='{.items[0].metadata.name}') -kubectl wait --for=condition=Ready pod/$modelregistry -n $namespace --timeout=6m \ No newline at end of file +modelregistry=$(kubectl get pod -n $namespace --selector="component=model-registry-server" --output jsonpath='{.items[0].metadata.name}') +kubectl wait --for=condition=Ready pod/$modelregistry -n $namespace --timeout=6m diff --git a/csi/test/e2e_test.sh b/csi/test/e2e_test.sh index 87ec06a4..ebf6a6c4 100755 --- a/csi/test/e2e_test.sh +++ b/csi/test/e2e_test.sh @@ -42,17 +42,79 @@ fi # Apply the port forward to access the model registry NAMESPACE=${NAMESPACE:-"kubeflow"} +TESTNAMESPACE=${TESTNAMESPACE:-"test"} MR_HOSTNAME=localhost:8080 +MR_TEST_HOSTNAME=localhost:8082 MODEL_REGISTRY_SERVICE=model-registry-service - MODEL_REGISTRY_REST_PORT=$(kubectl get svc/$MODEL_REGISTRY_SERVICE -n $NAMESPACE --output jsonpath='{.spec.ports[0].targetPort}') +INGRESS_HOST="localhost:8081" +KSERVE_TEST_NAMESPACE=kserve-test + +echo "======== Preparing test environment ========" + +echo "Applying Model Registry custom storage initializer ..." + +kubectl apply -f - < /dev/null; then + kubectl create namespace $KSERVE_TEST_NAMESPACE +fi + +echo "Creating dummy input data for testing ..." + +cat < "/tmp/iris-input.json" +{ + "instances": [ + [6.8, 2.8, 4.8, 1.4], + [6.0, 3.4, 4.5, 1.6] + ] +} +EOF + +echo "======== Finished preparing test environment ========" + +echo "======== Scenario 1 - Testing with default model registry service ========" kubectl port-forward -n $NAMESPACE svc/$MODEL_REGISTRY_SERVICE "8080:$MODEL_REGISTRY_REST_PORT" & pf_pid=$! wait_for_port 8080 -echo "Initializing data into Model Registry ..." +echo "Initializing data into Model Registry in ${NAMESPACE} namespace..." curl --silent -X 'POST' \ "$MR_HOSTNAME/api/model_registry/v1alpha3/registered_models" \ @@ -87,93 +149,226 @@ curl --silent -X 'POST' \ "artifactType": "model-artifact" }' -echo "======== Model Registry populated ========" - -echo "Applying Model Registry custom storage initializer ..." +echo "Starting test ..." -kubectl apply -f - < /dev/null; then - kubectl create namespace $KSERVE_TEST_NAMESPACE +kubectl wait --for=jsonpath='{.status.url}' inferenceservice/sklearn-iris-scenario-one -n $KSERVE_TEST_NAMESPACE --timeout=5m +sleep 5 + +SERVICE_HOSTNAME=$(kubectl get inferenceservice sklearn-iris-scenario-one -n $KSERVE_TEST_NAMESPACE -o jsonpath='{.status.url}' | cut -d "/" -f 3) +res_one=$(curl -s -H "Host: ${SERVICE_HOSTNAME}" -H "Content-Type: application/json" "http://${INGRESS_HOST}/v1/models/sklearn-iris-scenario-one:predict" -d @/tmp/iris-input.json) +echo "Received: $res_one" + +if [ ! "$res_one" = "{\"predictions\":[1,1]}" ]; then + echo "Prediction does not match expectation!" + echo "Printing some logs for debugging.." + kubectl logs pod/$predictor_one -n $KSERVE_TEST_NAMESPACE -c storage-initializer + kubectl logs pod/$predictor_one -n $KSERVE_TEST_NAMESPACE -c kserve-container + exit 1 +else + echo "Scenario 1 - Test succeeded!" fi +echo "Cleaning up inferenceservice sklearn-iris-scenario-one ..." + +kubectl delete inferenceservice sklearn-iris-scenario-one -n $KSERVE_TEST_NAMESPACE + +echo "======== Finished Scenario 1 ========" + +echo "======== Scenario 2 - Testing with default model registry service without model version ========" + +echo "Starting test ..." + kubectl apply -n $KSERVE_TEST_NAMESPACE -f - < "/tmp/iris-input.json" -{ - "instances": [ - [6.8, 2.8, 4.8, 1.4], - [6.0, 3.4, 4.5, 1.6] - ] -} +curl --silent -X 'POST' \ + "$MR_TEST_HOSTNAME/api/model_registry/v1alpha3/registered_models" \ + -H 'accept: application/json' \ + -H 'Content-Type: application/json' \ + -d '{ + "description": "Iris scikit-learn model", + "name": "iris-test" +}' + +curl --silent -X 'POST' \ + "$MR_TEST_HOSTNAME/api/model_registry/v1alpha3/model_versions" \ + -H 'accept: application/json' \ + -H 'Content-Type: application/json' \ + -d '{ + "description": "Iris model version v1", + "name": "v1-test", + "registeredModelID": "1" +}' + +curl --silent -X 'POST' \ + "$MR_TEST_HOSTNAME/api/model_registry/v1alpha3/model_versions/2/artifacts" \ + -H 'accept: application/json' \ + -H 'Content-Type: application/json' \ + -d '{ + "description": "Model artifact for Iris v1", + "uri": "gs://kfserving-examples/models/sklearn/1.0/model", + "state": "UNKNOWN", + "name": "sklearn-iris-test-v1", + "modelFormatName": "sklearn", + "modelFormatVersion": "1", + "artifactType": "model-artifact" +}' + +echo "Starting test ..." + +kubectl apply -n $KSERVE_TEST_NAMESPACE -f - <