diff --git a/csi/README.md b/csi/README.md
index 73f5ae72..1dc783d6 100644
--- a/csi/README.md
+++ b/csi/README.md
@@ -28,7 +28,7 @@ sequenceDiagram
U->>+MR: Register ML Model
MR-->>-U: Indexed Model
U->>U: Create InferenceService CR
- Note right of U: The InferenceService should
point to the model registry
indexed model, e.g.,:
model-registry:///
+ Note right of U: The InferenceService should
point to the model registry
indexed model, e.g.,:
model-registry:////
KC->>KC: React to InferenceService creation
KC->>+MD: Create Model Deployment
MD->>+MRSI: Initialization (Download Model)
@@ -66,14 +66,17 @@ Which wil create the executable under `bin/mr-storage-initializer`.
You can run `main.go` (without building the executable) by running:
```bash
-./bin/mr-storage-initializer "model-registry://model/version" "./"
+./bin/mr-storage-initializer "model-registry://model-registry-url/model/version" "./"
```
or directly running the `main.go` skipping the previous step:
```bash
-make SOURCE_URI=model-registry://model/version DEST_PATH=./ run
+make SOURCE_URI=model-registry://model-registry-url/model/version DEST_PATH=./ run
```
+> [!NOTE]
+> `model-registry-url` is optional, if not provided the value of `MODEL_REGISTRY_BASE_URL` env variable will be used.
+
> [!NOTE]
> A Model Registry service should be up and running at `localhost:8080`.
diff --git a/csi/main.go b/csi/main.go
index b51ece22..2a2daaf1 100644
--- a/csi/main.go
+++ b/csi/main.go
@@ -4,6 +4,7 @@ import (
"log"
"os"
+ "github.com/kubeflow/model-registry/csi/pkg/modelregistry"
"github.com/kubeflow/model-registry/csi/pkg/storage"
"github.com/kubeflow/model-registry/pkg/openapi"
)
@@ -38,7 +39,10 @@ func main() {
cfg := openapi.NewConfiguration()
cfg.Host = baseUrl
cfg.Scheme = scheme
- provider, err := storage.NewModelRegistryProvider(cfg)
+
+ apiClient := modelregistry.NewAPIClient(cfg, sourceUri)
+
+ provider, err := storage.NewModelRegistryProvider(apiClient)
if err != nil {
log.Fatalf("Error initiliazing model registry provider: %v", err)
}
diff --git a/csi/pkg/constants/constants.go b/csi/pkg/constants/constants.go
new file mode 100644
index 00000000..02f01d9a
--- /dev/null
+++ b/csi/pkg/constants/constants.go
@@ -0,0 +1,5 @@
+package constants
+
+import kserve "github.com/kserve/kserve/pkg/agent/storage"
+
+const MR kserve.Protocol = "model-registry://"
diff --git a/csi/pkg/modelregistry/api_client.go b/csi/pkg/modelregistry/api_client.go
new file mode 100644
index 00000000..5f6720d5
--- /dev/null
+++ b/csi/pkg/modelregistry/api_client.go
@@ -0,0 +1,41 @@
+package modelregistry
+
+import (
+ "context"
+ "log"
+ "strings"
+
+ "github.com/kubeflow/model-registry/csi/pkg/constants"
+ "github.com/kubeflow/model-registry/pkg/openapi"
+)
+
+func NewAPIClient(cfg *openapi.Configuration, storageUri string) *openapi.APIClient {
+ client := openapi.NewAPIClient(cfg)
+
+ // Parse the URI to retrieve the needed information to query model registry (modelArtifact)
+ mrUri := strings.TrimPrefix(storageUri, string(constants.MR))
+
+ tokens := strings.SplitN(mrUri, "/", 3)
+
+ if len(tokens) < 2 {
+ return client
+ }
+
+ newCfg := openapi.NewConfiguration()
+ newCfg.Host = tokens[0]
+ newCfg.Scheme = cfg.Scheme
+
+ newClient := openapi.NewAPIClient(newCfg)
+
+ if len(tokens) == 2 {
+ // Check if the model registry service is available
+ _, _, err := newClient.ModelRegistryServiceAPI.GetRegisteredModels(context.Background()).Execute()
+ if err != nil {
+ log.Printf("Falling back to base url %s for model registry service", cfg.Host)
+
+ return client
+ }
+ }
+
+ return newClient
+}
diff --git a/csi/pkg/storage/modelregistry_provider.go b/csi/pkg/storage/modelregistry_provider.go
index 01049f69..0bdb4f61 100644
--- a/csi/pkg/storage/modelregistry_provider.go
+++ b/csi/pkg/storage/modelregistry_provider.go
@@ -2,50 +2,64 @@ package storage
import (
"context"
+ "errors"
"fmt"
"log"
"regexp"
"strings"
kserve "github.com/kserve/kserve/pkg/agent/storage"
+ "github.com/kubeflow/model-registry/csi/pkg/constants"
"github.com/kubeflow/model-registry/pkg/openapi"
)
-const MR kserve.Protocol = "model-registry://"
+var (
+ _ kserve.Provider = (*ModelRegistryProvider)(nil)
+ ErrInvalidMRURI = errors.New("invalid model registry URI, use like model-registry://{dnsName}/{registeredModelName}/{versionName}")
+ ErrNoVersionAssociated = errors.New("no versions associated to registered model")
+ ErrNoArtifactAssociated = errors.New("no artifacts associated to model version")
+ ErrNoModelArtifact = errors.New("no model artifact found for model version")
+ ErrModelArtifactEmptyURI = errors.New("model artifact has empty URI")
+ ErrNoStorageURI = errors.New("there is no storageUri supplied")
+ ErrNoProtocolInSTorageURI = errors.New("there is no protocol specified for the storageUri")
+ ErrProtocolNotSupported = errors.New("protocol not supported for storageUri")
+ ErrFetchingModelVersion = errors.New("error fetching model version")
+ ErrFetchingModelVersions = errors.New("error fetching model versions")
+)
type ModelRegistryProvider struct {
Client *openapi.APIClient
Providers map[kserve.Protocol]kserve.Provider
}
-func NewModelRegistryProvider(cfg *openapi.Configuration) (*ModelRegistryProvider, error) {
- client := openapi.NewAPIClient(cfg)
-
+func NewModelRegistryProvider(client *openapi.APIClient) (*ModelRegistryProvider, error) {
return &ModelRegistryProvider{
Client: client,
Providers: map[kserve.Protocol]kserve.Provider{},
}, nil
}
-var _ kserve.Provider = (*ModelRegistryProvider)(nil)
-
-// storageUri formatted like model-registry://{registeredModelName}/{versionName}
+// storageUri formatted like model-registry://{modelRegistryUrl}/{registeredModelName}/{versionName}
func (p *ModelRegistryProvider) DownloadModel(modelDir string, modelName string, storageUri string) error {
- log.Printf("Download model indexed in model registry: modelName=%s, storageUri=%s, modelDir=%s", modelName, storageUri, modelDir)
-
- // Parse the URI to retrieve the needed information to query model registry (modelArtifact)
- mrUri := strings.TrimPrefix(storageUri, string(MR))
- tokens := strings.SplitN(mrUri, "/", 2)
+ log.Printf("Download model indexed in model registry: modelName=%s, storageUri=%s, modelDir=%s",
+ modelName,
+ storageUri,
+ modelDir,
+ )
- if len(tokens) == 0 || len(tokens) > 2 {
- return fmt.Errorf("invalid model registry URI, use like model-registry://{registeredModelName}/{versionName}")
+ registeredModelName, versionName, err := p.parseModelVersion(storageUri)
+ if err != nil {
+ return err
}
- registeredModelName := tokens[0]
- var versionName *string
- if len(tokens) == 2 {
- versionName = &tokens[1]
- }
+ log.Printf("Parsed storageUri=%s as: modelRegistryUrl=%s, registeredModelName=%s, versionName=%v",
+ storageUri,
+ p.Client.GetConfig().Host,
+ registeredModelName,
+ versionName,
+ )
+
+ log.Printf("Fetching model: registeredModelName=%s, versionName=%v", registeredModelName, versionName)
// Fetch the registered model
model, _, err := p.Client.ModelRegistryServiceAPI.FindRegisteredModel(context.Background()).Name(registeredModelName).Execute()
@@ -53,28 +67,16 @@ func (p *ModelRegistryProvider) DownloadModel(modelDir string, modelName string,
return err
}
- // Fetch model version by name or latest if not specified
- var version *openapi.ModelVersion
- if versionName != nil {
- version, _, err = p.Client.ModelRegistryServiceAPI.FindModelVersion(context.Background()).Name(*versionName).ParentResourceId(*model.Id).Execute()
- if err != nil {
- return err
- }
- } else {
- versions, _, err := p.Client.ModelRegistryServiceAPI.GetRegisteredModelVersions(context.Background(), *model.Id).
- OrderBy(openapi.ORDERBYFIELD_CREATE_TIME).
- SortOrder(openapi.SORTORDER_DESC).
- Execute()
- if err != nil {
- return err
- }
+ log.Printf("Fetching model version: model=%v", model)
- if versions.Size == 0 {
- return fmt.Errorf("no versions associated to registered model %s", registeredModelName)
- }
- version = &versions.Items[0]
+ // Fetch model version by name or latest if not specified
+ version, err := p.fetchModelVersion(versionName, registeredModelName, model)
+ if err != nil {
+ return err
}
+ log.Printf("Fetching model artifacts: version=%v", version)
+
artifacts, _, err := p.Client.ModelRegistryServiceAPI.GetModelVersionArtifacts(context.Background(), *version.Id).
OrderBy(openapi.ORDERBYFIELD_CREATE_TIME).
SortOrder(openapi.SORTORDER_DESC).
@@ -84,20 +86,20 @@ func (p *ModelRegistryProvider) DownloadModel(modelDir string, modelName string,
}
if artifacts.Size == 0 {
- return fmt.Errorf("no artifacts associated to model version %s", *version.Id)
+ return fmt.Errorf("%w %s", ErrNoArtifactAssociated, *version.Id)
}
modelArtifact := artifacts.Items[0].ModelArtifact
if modelArtifact == nil {
- return fmt.Errorf("no model artifact found for model version %s", *version.Id)
+ return fmt.Errorf("%w %s", ErrNoModelArtifact, *version.Id)
}
// Call appropriate kserve provider based on the indexed model artifact URI
if modelArtifact.Uri == nil {
- return fmt.Errorf("model artifact %s has empty URI", *modelArtifact.Id)
+ return fmt.Errorf("%w %s", ErrModelArtifactEmptyURI, *modelArtifact.Id)
}
- protocol, err := extractProtocol(*modelArtifact.Uri)
+ protocol, err := p.extractProtocol(*modelArtifact.Uri)
if err != nil {
return err
}
@@ -110,13 +112,77 @@ func (p *ModelRegistryProvider) DownloadModel(modelDir string, modelName string,
return provider.DownloadModel(modelDir, "", *modelArtifact.Uri)
}
-func extractProtocol(storageURI string) (kserve.Protocol, error) {
+// Possible URIs:
+// (1) model-registry://{modelName}
+// (2) model-registry://{modelName}/{modelVersion}
+// (3) model-registry://{modelRegistryUrl}/{modelName}
+// (4) model-registry://{modelRegistryUrl}/{modelName}/{modelVersion}
+func (p *ModelRegistryProvider) parseModelVersion(storageUri string) (string, *string, error) {
+ var versionName *string
+
+ // Parse the URI to retrieve the needed information to query model registry (modelArtifact)
+ mrUri := strings.TrimPrefix(storageUri, string(constants.MR))
+
+ tokens := strings.SplitN(mrUri, "/", 3)
+
+ if len(tokens) == 0 || len(tokens) > 3 {
+ return "", nil, ErrInvalidMRURI
+ }
+
+ // Check if the first token is the host and remove it so that we reduce cases (3) and (4) to (1) and (2)
+ if len(tokens) >= 2 && p.Client.GetConfig().Host == tokens[0] {
+ tokens = tokens[1:]
+ }
+
+ registeredModelName := tokens[0]
+
+ if len(tokens) == 2 {
+ versionName = &tokens[1]
+ }
+
+ return registeredModelName, versionName, nil
+}
+
+func (p *ModelRegistryProvider) fetchModelVersion(
+ versionName *string,
+ registeredModelName string,
+ model *openapi.RegisteredModel,
+) (*openapi.ModelVersion, error) {
+ if versionName != nil {
+ version, _, err := p.Client.ModelRegistryServiceAPI.
+ FindModelVersion(context.Background()).
+ Name(*versionName).
+ ParentResourceId(*model.Id).
+ Execute()
+ if err != nil {
+ return nil, fmt.Errorf("%w: %w", ErrFetchingModelVersion, err)
+ }
+
+ return version, nil
+ }
+
+ versions, _, err := p.Client.ModelRegistryServiceAPI.GetRegisteredModelVersions(context.Background(), *model.Id).
+ // OrderBy(openapi.ORDERBYFIELD_CREATE_TIME). not supported
+ SortOrder(openapi.SORTORDER_DESC).
+ Execute()
+ if err != nil {
+ return nil, fmt.Errorf("%w: %w", ErrFetchingModelVersions, err)
+ }
+
+ if versions.Size == 0 {
+ return nil, fmt.Errorf("%w %s", ErrNoVersionAssociated, registeredModelName)
+ }
+
+ return &versions.Items[0], nil
+}
+
+func (*ModelRegistryProvider) extractProtocol(storageURI string) (kserve.Protocol, error) {
if storageURI == "" {
- return "", fmt.Errorf("there is no storageUri supplied")
+ return "", ErrNoStorageURI
}
- if !regexp.MustCompile("\\w+?://").MatchString(storageURI) {
- return "", fmt.Errorf("there is no protocol specified for the storageUri")
+ if !regexp.MustCompile(`\w+?://`).MatchString(storageURI) {
+ return "", ErrNoProtocolInSTorageURI
}
for _, prefix := range kserve.SupportedProtocols {
@@ -124,5 +190,6 @@ func extractProtocol(storageURI string) (kserve.Protocol, error) {
return prefix, nil
}
}
- return "", fmt.Errorf("protocol not supported for storageUri")
+
+ return "", ErrProtocolNotSupported
}
diff --git a/csi/scripts/install_modelregistry.sh b/csi/scripts/install_modelregistry.sh
index f2ec7317..3276592d 100755
--- a/csi/scripts/install_modelregistry.sh
+++ b/csi/scripts/install_modelregistry.sh
@@ -40,9 +40,11 @@ if ! kubectl get namespace "$namespace" &> /dev/null; then
fi
# Apply model-registry kustomize manifests
echo Using model registry image: $image
-cd $MR_ROOT/manifests/kustomize/base && kustomize edit set image kubeflow/model-registry:latest=${image} && cd -
+cd $MR_ROOT/manifests/kustomize/base && kustomize edit set image kubeflow/model-registry:latest=${image} && \
+kustomize edit set namespace $namespace && cd -
+cd $MR_ROOT/manifests/kustomize/overlays/db && kustomize edit set namespace $namespace && cd -
kubectl -n $namespace apply -k "$MR_ROOT/manifests/kustomize/overlays/db"
# Wait for model registry deployment
-modelregistry=$(kubectl get pod -n kubeflow --selector="component=model-registry-server" --output jsonpath='{.items[0].metadata.name}')
-kubectl wait --for=condition=Ready pod/$modelregistry -n $namespace --timeout=6m
\ No newline at end of file
+modelregistry=$(kubectl get pod -n $namespace --selector="component=model-registry-server" --output jsonpath='{.items[0].metadata.name}')
+kubectl wait --for=condition=Ready pod/$modelregistry -n $namespace --timeout=6m
diff --git a/csi/test/e2e_test.sh b/csi/test/e2e_test.sh
index 87ec06a4..ebf6a6c4 100755
--- a/csi/test/e2e_test.sh
+++ b/csi/test/e2e_test.sh
@@ -42,17 +42,79 @@ fi
# Apply the port forward to access the model registry
NAMESPACE=${NAMESPACE:-"kubeflow"}
+TESTNAMESPACE=${TESTNAMESPACE:-"test"}
MR_HOSTNAME=localhost:8080
+MR_TEST_HOSTNAME=localhost:8082
MODEL_REGISTRY_SERVICE=model-registry-service
-
MODEL_REGISTRY_REST_PORT=$(kubectl get svc/$MODEL_REGISTRY_SERVICE -n $NAMESPACE --output jsonpath='{.spec.ports[0].targetPort}')
+INGRESS_HOST="localhost:8081"
+KSERVE_TEST_NAMESPACE=kserve-test
+
+echo "======== Preparing test environment ========"
+
+echo "Applying Model Registry custom storage initializer ..."
+
+kubectl apply -f - < /dev/null; then
+ kubectl create namespace $KSERVE_TEST_NAMESPACE
+fi
+
+echo "Creating dummy input data for testing ..."
+
+cat < "/tmp/iris-input.json"
+{
+ "instances": [
+ [6.8, 2.8, 4.8, 1.4],
+ [6.0, 3.4, 4.5, 1.6]
+ ]
+}
+EOF
+
+echo "======== Finished preparing test environment ========"
+
+echo "======== Scenario 1 - Testing with default model registry service ========"
kubectl port-forward -n $NAMESPACE svc/$MODEL_REGISTRY_SERVICE "8080:$MODEL_REGISTRY_REST_PORT" &
pf_pid=$!
wait_for_port 8080
-echo "Initializing data into Model Registry ..."
+echo "Initializing data into Model Registry in ${NAMESPACE} namespace..."
curl --silent -X 'POST' \
"$MR_HOSTNAME/api/model_registry/v1alpha3/registered_models" \
@@ -87,93 +149,226 @@ curl --silent -X 'POST' \
"artifactType": "model-artifact"
}'
-echo "======== Model Registry populated ========"
-
-echo "Applying Model Registry custom storage initializer ..."
+echo "Starting test ..."
-kubectl apply -f - < /dev/null; then
- kubectl create namespace $KSERVE_TEST_NAMESPACE
+kubectl wait --for=jsonpath='{.status.url}' inferenceservice/sklearn-iris-scenario-one -n $KSERVE_TEST_NAMESPACE --timeout=5m
+sleep 5
+
+SERVICE_HOSTNAME=$(kubectl get inferenceservice sklearn-iris-scenario-one -n $KSERVE_TEST_NAMESPACE -o jsonpath='{.status.url}' | cut -d "/" -f 3)
+res_one=$(curl -s -H "Host: ${SERVICE_HOSTNAME}" -H "Content-Type: application/json" "http://${INGRESS_HOST}/v1/models/sklearn-iris-scenario-one:predict" -d @/tmp/iris-input.json)
+echo "Received: $res_one"
+
+if [ ! "$res_one" = "{\"predictions\":[1,1]}" ]; then
+ echo "Prediction does not match expectation!"
+ echo "Printing some logs for debugging.."
+ kubectl logs pod/$predictor_one -n $KSERVE_TEST_NAMESPACE -c storage-initializer
+ kubectl logs pod/$predictor_one -n $KSERVE_TEST_NAMESPACE -c kserve-container
+ exit 1
+else
+ echo "Scenario 1 - Test succeeded!"
fi
+echo "Cleaning up inferenceservice sklearn-iris-scenario-one ..."
+
+kubectl delete inferenceservice sklearn-iris-scenario-one -n $KSERVE_TEST_NAMESPACE
+
+echo "======== Finished Scenario 1 ========"
+
+echo "======== Scenario 2 - Testing with default model registry service without model version ========"
+
+echo "Starting test ..."
+
kubectl apply -n $KSERVE_TEST_NAMESPACE -f - < "/tmp/iris-input.json"
-{
- "instances": [
- [6.8, 2.8, 4.8, 1.4],
- [6.0, 3.4, 4.5, 1.6]
- ]
-}
+curl --silent -X 'POST' \
+ "$MR_TEST_HOSTNAME/api/model_registry/v1alpha3/registered_models" \
+ -H 'accept: application/json' \
+ -H 'Content-Type: application/json' \
+ -d '{
+ "description": "Iris scikit-learn model",
+ "name": "iris-test"
+}'
+
+curl --silent -X 'POST' \
+ "$MR_TEST_HOSTNAME/api/model_registry/v1alpha3/model_versions" \
+ -H 'accept: application/json' \
+ -H 'Content-Type: application/json' \
+ -d '{
+ "description": "Iris model version v1",
+ "name": "v1-test",
+ "registeredModelID": "1"
+}'
+
+curl --silent -X 'POST' \
+ "$MR_TEST_HOSTNAME/api/model_registry/v1alpha3/model_versions/2/artifacts" \
+ -H 'accept: application/json' \
+ -H 'Content-Type: application/json' \
+ -d '{
+ "description": "Model artifact for Iris v1",
+ "uri": "gs://kfserving-examples/models/sklearn/1.0/model",
+ "state": "UNKNOWN",
+ "name": "sklearn-iris-test-v1",
+ "modelFormatName": "sklearn",
+ "modelFormatVersion": "1",
+ "artifactType": "model-artifact"
+}'
+
+echo "Starting test ..."
+
+kubectl apply -n $KSERVE_TEST_NAMESPACE -f - <