From 457d01d6e5345f039b660e011fba27d82b25391b Mon Sep 17 00:00:00 2001 From: Mike Turley Date: Wed, 7 Aug 2024 06:09:36 -0400 Subject: [PATCH] API integration for Register Model form (#3065) * Implement registerModel form submission (WIP) Signed-off-by: Mike Turley * Construct URI for object storage mode Signed-off-by: Mike Turley * Fix page layout and microcopy, and add sourceModelFormatVersion field Signed-off-by: Mike Turley * Move uri utils and add unit tests for them Signed-off-by: Mike Turley * Cypress test for register button Signed-off-by: Mike Turley * Stub for register model page cypress tests Signed-off-by: Mike Turley * Add full cypress tests for submitting register form Signed-off-by: Mike Turley * Fix TODO Signed-off-by: Mike Turley * Fix imports Signed-off-by: Mike Turley * Fix tests Signed-off-by: Mike Turley * Sanitize inputs in uri utils, fix breadcrumb bar Signed-off-by: Mike Turley * Fix tests Signed-off-by: Mike Turley * Fix another test Signed-off-by: Mike Turley --------- Signed-off-by: Mike Turley --- .../cypress/cypress/pages/modelRegistry.ts | 4 + .../pages/modelRegistry/registerModelPage.ts | 41 ++ .../cypress/cypress/support/commands/odh.ts | 16 + .../mocked/modelRegistry/modelRegistry.cy.ts | 122 +++-- .../mocked/modelRegistry/registerModel.cy.ts | 235 +++++++++ .../modelRegistry/__tests__/custom.spec.ts | 82 +++ frontend/src/api/modelRegistry/custom.ts | 33 ++ .../modelRegistry/__tests__/utils.spec.ts | 122 +++++ .../context/useModelRegistryAPIState.tsx | 4 + frontend/src/concepts/modelRegistry/types.ts | 16 +- frontend/src/concepts/modelRegistry/utils.ts | 39 ++ .../screens/RegisterModel/RegisterModel.tsx | 469 +++++++++--------- .../RegisterModel/useRegisterModelData.ts | 52 +- .../screens/RegisterModel/utils.ts | 89 ++++ 14 files changed, 1030 insertions(+), 294 deletions(-) create mode 100644 frontend/src/__tests__/cypress/cypress/pages/modelRegistry/registerModelPage.ts create mode 100644 frontend/src/__tests__/cypress/cypress/tests/mocked/modelRegistry/registerModel.cy.ts create mode 100644 frontend/src/concepts/modelRegistry/__tests__/utils.spec.ts create mode 100644 frontend/src/concepts/modelRegistry/utils.ts create mode 100644 frontend/src/pages/modelRegistry/screens/RegisterModel/utils.ts diff --git a/frontend/src/__tests__/cypress/cypress/pages/modelRegistry.ts b/frontend/src/__tests__/cypress/cypress/pages/modelRegistry.ts index 4c96700e9e..88688563d9 100644 --- a/frontend/src/__tests__/cypress/cypress/pages/modelRegistry.ts +++ b/frontend/src/__tests__/cypress/cypress/pages/modelRegistry.ts @@ -167,6 +167,10 @@ class ModelRegistry { findModelVersionsTableFilter() { return cy.findByTestId('model-versions-table-filter'); } + + findRegisterModelButton() { + return cy.findByRole('button', { name: 'Register model' }); + } } export const modelRegistry = new ModelRegistry(); diff --git a/frontend/src/__tests__/cypress/cypress/pages/modelRegistry/registerModelPage.ts b/frontend/src/__tests__/cypress/cypress/pages/modelRegistry/registerModelPage.ts new file mode 100644 index 0000000000..d420cdff09 --- /dev/null +++ b/frontend/src/__tests__/cypress/cypress/pages/modelRegistry/registerModelPage.ts @@ -0,0 +1,41 @@ +export enum FormFieldSelector { + MODEL_NAME = '#model-name', + MODEL_DESCRIPTION = '#model-description', + VERSION_NAME = '#version-name', + VERSION_DESCRIPTION = '#version-description', + SOURCE_MODEL_FORMAT = '#source-model-format', + SOURCE_MODEL_FORMAT_VERSION = '#source-model-format-version', + LOCATION_TYPE_OBJECT_STORAGE = '#location-type-object-storage', + LOCATION_ENDPOINT = '#location-endpoint', + LOCATION_BUCKET = '#location-bucket', + LOCATION_REGION = '#location-region', + LOCATION_PATH = '#location-path', + LOCATION_TYPE_URI = '#location-type-uri', + LOCATION_URI = '#location-uri', +} + +class RegisterModelPage { + visit() { + const preferredModelRegistry = 'modelregistry-sample'; + cy.visitWithLogin(`/modelRegistry/${preferredModelRegistry}/registerModel`); + this.wait(); + } + + private wait() { + const preferredModelRegistry = 'modelregistry-sample'; + cy.findByTestId('app-page-title').should('exist'); + cy.findByTestId('app-page-title').contains('Register model'); + cy.findByText(`Model registry - ${preferredModelRegistry}`).should('exist'); + cy.testA11y(); + } + + findFormField(selector: FormFieldSelector) { + return cy.get(selector); + } + + findSubmitButton() { + return cy.findByTestId('create-button'); + } +} + +export const registerModelPage = new RegisterModelPage(); diff --git a/frontend/src/__tests__/cypress/cypress/support/commands/odh.ts b/frontend/src/__tests__/cypress/cypress/support/commands/odh.ts index c02ff019ba..51838eb9a3 100644 --- a/frontend/src/__tests__/cypress/cypress/support/commands/odh.ts +++ b/frontend/src/__tests__/cypress/cypress/support/commands/odh.ts @@ -2,6 +2,7 @@ import type { K8sResourceListResult } from '@openshift/dynamic-plugin-sdk-utils' import type { GenericStaticResponse, RouteHandlerController } from 'cypress/types/net-stubbing'; import type { BaseMetricCreationResponse, BaseMetricListResponse } from '~/api'; import type { + ModelArtifact, ModelArtifactList, ModelVersion, ModelVersionList, @@ -294,11 +295,21 @@ declare global { options: { path: { serviceName: string; apiVersion: string } }, response: OdhResponse, ) => Cypress.Chainable) & + (( + type: 'POST /api/service/modelregistry/:serviceName/api/model_registry/:apiVersion/registered_models', + options: { path: { serviceName: string; apiVersion: string } }, + response: OdhResponse, + ) => Cypress.Chainable) & (( type: 'GET /api/service/modelregistry/:serviceName/api/model_registry/:apiVersion/registered_models/:registeredModelId/versions', options: { path: { serviceName: string; apiVersion: string; registeredModelId: number } }, response: OdhResponse, ) => Cypress.Chainable) & + (( + type: 'POST /api/service/modelregistry/:serviceName/api/model_registry/:apiVersion/registered_models/:registeredModelId/versions', + options: { path: { serviceName: string; apiVersion: string; registeredModelId: number } }, + response: OdhResponse, + ) => Cypress.Chainable) & (( type: 'GET /api/service/modelregistry/:serviceName/api/model_registry/:apiVersion/registered_models/:registeredModelId', options: { path: { serviceName: string; apiVersion: string; registeredModelId: number } }, @@ -321,6 +332,11 @@ declare global { options: { path: { serviceName: string; apiVersion: string; modelVersionId: number } }, response: OdhResponse, ) => Cypress.Chainable) & + (( + type: 'POST /api/service/modelregistry/:serviceName/api/model_registry/:apiVersion/model_versions/:modelVersionId/artifacts', + options: { path: { serviceName: string; apiVersion: string; modelVersionId: number } }, + response: OdhResponse, + ) => Cypress.Chainable) & (( type: 'PATCH /api/service/modelregistry/:serviceName/api/model_registry/:apiVersion/model_versions/:modelVersionId', options: { path: { serviceName: string; apiVersion: string; modelVersionId: number } }, diff --git a/frontend/src/__tests__/cypress/cypress/tests/mocked/modelRegistry/modelRegistry.cy.ts b/frontend/src/__tests__/cypress/cypress/tests/mocked/modelRegistry/modelRegistry.cy.ts index 1e110dff25..24c8dd0a70 100644 --- a/frontend/src/__tests__/cypress/cypress/tests/mocked/modelRegistry/modelRegistry.cy.ts +++ b/frontend/src/__tests__/cypress/cypress/tests/mocked/modelRegistry/modelRegistry.cy.ts @@ -101,7 +101,7 @@ const initIntercepts = ({ ); }; -describe('Model Registry', () => { +describe('Model Registry core', () => { it('Model Registry Disabled in the cluster', () => { initIntercepts({ disableModelRegistryFeature: true, @@ -134,57 +134,81 @@ describe('Model Registry', () => { modelRegistry.shouldregisteredModelsEmpty(); }); - it('Registered model table', () => { - initIntercepts({ - disableModelRegistryFeature: false, + describe('Registered model table', () => { + beforeEach(() => { + initIntercepts({ disableModelRegistryFeature: false }); + modelRegistry.visit(); }); + it('Renders row contents', () => { + const registeredModelRow = modelRegistry.getRow('Fraud detection model'); + registeredModelRow.findName().contains('Fraud detection model'); + registeredModelRow + .findDescription() + .contains( + 'A machine learning model trained to detect fraudulent transactions in financial data', + ); + registeredModelRow.findOwner().contains('Author 1'); + + // Label popover + registeredModelRow.findLabelPopoverText().contains('2 more'); + registeredModelRow.findLabelPopoverText().click(); + registeredModelRow.shouldContainsPopoverLabels([ + 'Machine learning', + 'Next data to be overflow', + ]); + }); + + it('Renders labels in modal', () => { + const registeredModelRow2 = modelRegistry.getRow('Label modal'); + registeredModelRow2.findLabelModalText().contains('6 more'); + registeredModelRow2.findLabelModalText().click(); + labelModal.shouldContainsModalLabels([ + 'Testing label', + 'Financial', + 'Financial data', + 'Fraud detection', + 'Machine learning', + 'Next data to be overflow', + 'Label x', + 'Label y', + 'Label z', + ]); + labelModal.findModalSearchInput().type('Financial'); + labelModal.shouldContainsModalLabels(['Financial', 'Financial data']); + labelModal.findCloseModal().click(); + }); + + it('Sorts by model name', () => { + modelRegistry.findRegisteredModelTableHeaderButton('Model name').should(be.sortAscending); + modelRegistry.findRegisteredModelTableHeaderButton('Model name').click(); + modelRegistry.findRegisteredModelTableHeaderButton('Model name').should(be.sortDescending); + }); + + it('Filters by keyword', () => { + modelRegistry.findTableSearch().type('Fraud detection model'); + modelRegistry.findTableRows().should('have.length', 1); + modelRegistry.findTableRows().contains('Fraud detection model'); + }); + }); +}); + +describe('Register Model button', () => { + it('Navigates to register page from empty state', () => { + initIntercepts({ disableModelRegistryFeature: false, registeredModels: [] }); modelRegistry.visit(); + modelRegistry.findRegisterModelButton().click(); + cy.findByTestId('app-page-title').should('exist'); + cy.findByTestId('app-page-title').contains('Register model'); + cy.findByText('Model registry - modelregistry-sample').should('exist'); + }); - const registeredModelRow = modelRegistry.getRow('Fraud detection model'); - registeredModelRow.findName().contains('Fraud detection model'); - registeredModelRow - .findDescription() - .contains( - 'A machine learning model trained to detect fraudulent transactions in financial data', - ); - registeredModelRow.findOwner().contains('Author 1'); - - // Label popover - registeredModelRow.findLabelPopoverText().contains('2 more'); - registeredModelRow.findLabelPopoverText().click(); - registeredModelRow.shouldContainsPopoverLabels([ - 'Machine learning', - 'Next data to be overflow', - ]); - - // Label modal - const registeredModelRow2 = modelRegistry.getRow('Label modal'); - registeredModelRow2.findLabelModalText().contains('6 more'); - registeredModelRow2.findLabelModalText().click(); - labelModal.shouldContainsModalLabels([ - 'Testing label', - 'Financial', - 'Financial data', - 'Fraud detection', - 'Machine learning', - 'Next data to be overflow', - 'Label x', - 'Label y', - 'Label z', - ]); - labelModal.findModalSearchInput().type('Financial'); - labelModal.shouldContainsModalLabels(['Financial', 'Financial data']); - labelModal.findCloseModal().click(); - - // sort by modelName - modelRegistry.findRegisteredModelTableHeaderButton('Model name').should(be.sortAscending); - modelRegistry.findRegisteredModelTableHeaderButton('Model name').click(); - modelRegistry.findRegisteredModelTableHeaderButton('Model name').should(be.sortDescending); - - // filtering by keyword - modelRegistry.findTableSearch().type('Fraud detection model'); - modelRegistry.findTableRows().should('have.length', 1); - modelRegistry.findTableRows().contains('Fraud detection model'); + it('Navigates to register page from table toolbar', () => { + initIntercepts({ disableModelRegistryFeature: false }); + modelRegistry.visit(); + modelRegistry.findRegisterModelButton().click(); + cy.findByTestId('app-page-title').should('exist'); + cy.findByTestId('app-page-title').contains('Register model'); + cy.findByText('Model registry - modelregistry-sample').should('exist'); }); }); diff --git a/frontend/src/__tests__/cypress/cypress/tests/mocked/modelRegistry/registerModel.cy.ts b/frontend/src/__tests__/cypress/cypress/tests/mocked/modelRegistry/registerModel.cy.ts new file mode 100644 index 0000000000..e108df61f0 --- /dev/null +++ b/frontend/src/__tests__/cypress/cypress/tests/mocked/modelRegistry/registerModel.cy.ts @@ -0,0 +1,235 @@ +import { mockDashboardConfig, mockDscStatus, mockK8sResourceList } from '~/__mocks__'; +import { mockDsciStatus } from '~/__mocks__/mockDsciStatus'; +import { StackCapability, StackComponent } from '~/concepts/areas/types'; +import { ModelRegistryModel } from '~/__tests__/cypress/cypress/utils/models'; +import { + FormFieldSelector, + registerModelPage, +} from '~/__tests__/cypress/cypress/pages/modelRegistry/registerModelPage'; +import { mockModelRegistry } from '~/__mocks__/mockModelRegistry'; +import { mockRegisteredModel } from '~/__mocks__/mockRegisteredModel'; +import { mockModelVersion } from '~/__mocks__/mockModelVersion'; +import { mockModelArtifact } from '~/__mocks__/mockModelArtifact'; +import { + ModelArtifactState, + ModelState, + type RegisteredModel, + type ModelVersion, + type ModelArtifact, +} from '~/concepts/modelRegistry/types'; + +const MODEL_REGISTRY_API_VERSION = 'v1alpha3'; + +const initIntercepts = () => { + cy.interceptOdh( + 'GET /api/config', + mockDashboardConfig({ + disableModelRegistry: false, + }), + ); + cy.interceptOdh( + 'GET /api/dsc/status', + mockDscStatus({ + installedComponents: { + [StackComponent.MODEL_REGISTRY]: true, + [StackComponent.MODEL_MESH]: true, + }, + }), + ); + cy.interceptOdh( + 'GET /api/dsci/status', + mockDsciStatus({ + requiredCapabilities: [StackCapability.SERVICE_MESH, StackCapability.SERVICE_MESH_AUTHZ], + }), + ); + + // TODO replace these with a mock list of services when https://github.com/opendatahub-io/odh-dashboard/pull/3034 is merged + cy.interceptK8sList( + ModelRegistryModel, + mockK8sResourceList([mockModelRegistry({ name: 'modelregistry-sample' })]), + ); + cy.interceptK8s(ModelRegistryModel, mockModelRegistry({ name: 'modelregistry-sample' })); + + cy.interceptOdh( + 'POST /api/service/modelregistry/:serviceName/api/model_registry/:apiVersion/registered_models', + { + path: { + serviceName: 'modelregistry-sample', + apiVersion: MODEL_REGISTRY_API_VERSION, + }, + }, + mockRegisteredModel({ id: '1', name: 'Test model name' }), + ).as('createRegisteredModel'); + + cy.interceptOdh( + 'POST /api/service/modelregistry/:serviceName/api/model_registry/:apiVersion/registered_models/:registeredModelId/versions', + { + path: { + serviceName: 'modelregistry-sample', + apiVersion: MODEL_REGISTRY_API_VERSION, + registeredModelId: 1, + }, + }, + mockModelVersion({ id: '2', name: 'Test version name' }), + ).as('createModelVersion'); + + cy.interceptOdh( + 'POST /api/service/modelregistry/:serviceName/api/model_registry/:apiVersion/model_versions/:modelVersionId/artifacts', + { + path: { + serviceName: 'modelregistry-sample', + apiVersion: MODEL_REGISTRY_API_VERSION, + modelVersionId: 2, + }, + }, + mockModelArtifact(), + ).as('createModelArtifact'); +}; + +describe('Register model page', () => { + beforeEach(() => { + initIntercepts(); + registerModelPage.visit(); + }); + + it('Disables submit until required fields are filled in object storage mode', () => { + registerModelPage.findSubmitButton().should('be.disabled'); + registerModelPage.findFormField(FormFieldSelector.MODEL_NAME).type('Test model name'); + registerModelPage.findFormField(FormFieldSelector.VERSION_NAME).type('Test version name'); + registerModelPage.findFormField(FormFieldSelector.LOCATION_TYPE_OBJECT_STORAGE).click(); + registerModelPage + .findFormField(FormFieldSelector.LOCATION_ENDPOINT) + .type('http://s3.amazonaws.com/'); + registerModelPage.findFormField(FormFieldSelector.LOCATION_BUCKET).type('test-bucket'); + registerModelPage + .findFormField(FormFieldSelector.LOCATION_PATH) + .type('demo-models/flan-t5-small-caikit'); + registerModelPage.findSubmitButton().should('be.enabled'); + }); + + it('Creates expected resources on submit in object storage mode', () => { + registerModelPage.findFormField(FormFieldSelector.MODEL_NAME).type('Test model name'); + registerModelPage + .findFormField(FormFieldSelector.MODEL_DESCRIPTION) + .type('Test model description'); + registerModelPage.findFormField(FormFieldSelector.VERSION_NAME).type('Test version name'); + registerModelPage + .findFormField(FormFieldSelector.VERSION_DESCRIPTION) + .type('Test version description'); + registerModelPage.findFormField(FormFieldSelector.SOURCE_MODEL_FORMAT).type('caikit'); + registerModelPage.findFormField(FormFieldSelector.SOURCE_MODEL_FORMAT_VERSION).type('1'); + registerModelPage.findFormField(FormFieldSelector.LOCATION_TYPE_OBJECT_STORAGE).click(); + registerModelPage + .findFormField(FormFieldSelector.LOCATION_ENDPOINT) + .type('http://s3.amazonaws.com/'); + registerModelPage.findFormField(FormFieldSelector.LOCATION_BUCKET).type('test-bucket'); + registerModelPage.findFormField(FormFieldSelector.LOCATION_REGION).type('us-east-1'); + registerModelPage + .findFormField(FormFieldSelector.LOCATION_PATH) + .type('demo-models/flan-t5-small-caikit'); + + registerModelPage.findSubmitButton().click(); + + cy.wait('@createRegisteredModel').then((interception) => { + expect(interception.request.body).to.containSubset({ + name: 'Test model name', + description: 'Test model description', + customProperties: {}, + state: ModelState.LIVE, + } satisfies Partial); + }); + cy.wait('@createModelVersion').then((interception) => { + expect(interception.request.body).to.containSubset({ + name: 'Test version name', + description: 'Test version description', + customProperties: {}, + state: ModelState.LIVE, + author: 'test-user', + registeredModelId: '1', + } satisfies Partial); + }); + cy.wait('@createModelArtifact').then((interception) => { + expect(interception.request.body).to.containSubset({ + name: 'Test model name-Test version name-artifact', + description: 'Test version description', + customProperties: {}, + state: ModelArtifactState.LIVE, + author: 'test-user', + modelFormatName: 'caikit', + modelFormatVersion: '1', + uri: 's3://test-bucket/demo-models/flan-t5-small-caikit?endpoint=http%3A%2F%2Fs3.amazonaws.com%2F&defaultRegion=us-east-1', + artifactType: 'model-artifact', + } satisfies Partial); + }); + + cy.url().should('include', '/modelRegistry/modelregistry-sample/registeredModels/1'); + }); + + it('Disables submit until required fields are filled in URI mode', () => { + registerModelPage.findSubmitButton().should('be.disabled'); + registerModelPage.findFormField(FormFieldSelector.MODEL_NAME).type('Test model name'); + registerModelPage.findFormField(FormFieldSelector.VERSION_NAME).type('Test version name'); + registerModelPage.findFormField(FormFieldSelector.LOCATION_TYPE_URI).click(); + registerModelPage + .findFormField(FormFieldSelector.LOCATION_URI) + .type( + 's3://test-bucket/demo-models/flan-t5-small-caikit?endpoint=http%3A%2F%2Fs3.amazonaws.com%2F&defaultRegion=us-east-1', + ); + registerModelPage.findSubmitButton().should('be.enabled'); + }); + + it('Creates expected resources on submit in URI mode', () => { + registerModelPage.findFormField(FormFieldSelector.MODEL_NAME).type('Test model name'); + registerModelPage + .findFormField(FormFieldSelector.MODEL_DESCRIPTION) + .type('Test model description'); + registerModelPage.findFormField(FormFieldSelector.VERSION_NAME).type('Test version name'); + registerModelPage + .findFormField(FormFieldSelector.VERSION_DESCRIPTION) + .type('Test version description'); + registerModelPage.findFormField(FormFieldSelector.SOURCE_MODEL_FORMAT).type('caikit'); + registerModelPage.findFormField(FormFieldSelector.SOURCE_MODEL_FORMAT_VERSION).type('1'); + registerModelPage.findFormField(FormFieldSelector.LOCATION_TYPE_URI).click(); + registerModelPage + .findFormField(FormFieldSelector.LOCATION_URI) + .type( + 's3://test-bucket/demo-models/flan-t5-small-caikit?endpoint=http%3A%2F%2Fs3.amazonaws.com%2F&defaultRegion=us-east-1', + ); + + registerModelPage.findSubmitButton().click(); + + cy.wait('@createRegisteredModel').then((interception) => { + expect(interception.request.body).to.containSubset({ + name: 'Test model name', + description: 'Test model description', + customProperties: {}, + state: ModelState.LIVE, + } satisfies Partial); + }); + cy.wait('@createModelVersion').then((interception) => { + expect(interception.request.body).to.containSubset({ + name: 'Test version name', + description: 'Test version description', + customProperties: {}, + state: ModelState.LIVE, + author: 'test-user', + registeredModelId: '1', + } satisfies Partial); + }); + cy.wait('@createModelArtifact').then((interception) => { + expect(interception.request.body).to.containSubset({ + name: 'Test model name-Test version name-artifact', + description: 'Test version description', + customProperties: {}, + state: ModelArtifactState.LIVE, + author: 'test-user', + modelFormatName: 'caikit', + modelFormatVersion: '1', + uri: 's3://test-bucket/demo-models/flan-t5-small-caikit?endpoint=http%3A%2F%2Fs3.amazonaws.com%2F&defaultRegion=us-east-1', + artifactType: 'model-artifact', + } satisfies Partial); + }); + + cy.url().should('include', '/modelRegistry/modelregistry-sample/registeredModels/1'); + }); +}); diff --git a/frontend/src/api/modelRegistry/__tests__/custom.spec.ts b/frontend/src/api/modelRegistry/__tests__/custom.spec.ts index 5efcfcce41..200a6f1ba5 100644 --- a/frontend/src/api/modelRegistry/__tests__/custom.spec.ts +++ b/frontend/src/api/modelRegistry/__tests__/custom.spec.ts @@ -16,6 +16,8 @@ import { patchModelVersion, patchRegisteredModel, getModelArtifactsByModelVersion, + createModelVersionForRegisteredModel, + createModelArtifactForModelVersion, } from '~/api/modelRegistry/custom'; import { MODEL_REGISTRY_API_VERSION } from '~/concepts/modelRegistry/const'; @@ -104,6 +106,40 @@ describe('createModelVersion', () => { }); }); +describe('createModelVersionForRegisteredModel', () => { + it('should call proxyCREATE and handleModelRegistryFailures to create model version for a model', () => { + expect( + createModelVersionForRegisteredModel('hostPath')(K8sAPIOptionsMock, '1', { + description: 'test', + externalID: '1', + author: 'test author', + registeredModelId: '1', + name: 'test new model version', + state: ModelState.LIVE, + customProperties: {}, + }), + ).toBe(mockResultPromise); + expect(proxyCREATEMock).toHaveBeenCalledTimes(1); + expect(proxyCREATEMock).toHaveBeenCalledWith( + 'hostPath', + `/api/model_registry/${MODEL_REGISTRY_API_VERSION}/registered_models/1/versions`, + { + description: 'test', + externalID: '1', + author: 'test author', + registeredModelId: '1', + name: 'test new model version', + state: ModelState.LIVE, + customProperties: {}, + }, + {}, + K8sAPIOptionsMock, + ); + expect(handleModelRegistryFailuresMock).toHaveBeenCalledTimes(1); + expect(handleModelRegistryFailuresMock).toHaveBeenCalledWith(mockProxyPromise); + }); +}); + describe('createModelArtifact', () => { it('should call proxyCREATE and handleModelRegistryFailures to create model artifact', () => { expect( @@ -119,6 +155,7 @@ describe('createModelArtifact', () => { modelFormatVersion: 'testmodelFormatVersion', serviceAccountName: 'testserviceAccountname', customProperties: {}, + artifactType: 'model-artifact', }), ).toBe(mockResultPromise); expect(proxyCREATEMock).toHaveBeenCalledTimes(1); @@ -137,6 +174,51 @@ describe('createModelArtifact', () => { modelFormatVersion: 'testmodelFormatVersion', serviceAccountName: 'testserviceAccountname', customProperties: {}, + artifactType: 'model-artifact', + }, + {}, + K8sAPIOptionsMock, + ); + expect(handleModelRegistryFailuresMock).toHaveBeenCalledTimes(1); + expect(handleModelRegistryFailuresMock).toHaveBeenCalledWith(mockProxyPromise); + }); +}); + +describe('createModelArtifactForModelVersion', () => { + it('should call proxyCREATE and handleModelRegistryFailures to create model artifact for version', () => { + expect( + createModelArtifactForModelVersion('hostPath')(K8sAPIOptionsMock, '2', { + description: 'test', + externalID: 'test', + uri: 'test-uri', + state: ModelArtifactState.LIVE, + name: 'test-name', + modelFormatName: 'test-modelformatname', + storageKey: 'teststoragekey', + storagePath: 'teststoragePath', + modelFormatVersion: 'testmodelFormatVersion', + serviceAccountName: 'testserviceAccountname', + customProperties: {}, + artifactType: 'model-artifact', + }), + ).toBe(mockResultPromise); + expect(proxyCREATEMock).toHaveBeenCalledTimes(1); + expect(proxyCREATEMock).toHaveBeenCalledWith( + 'hostPath', + `/api/model_registry/${MODEL_REGISTRY_API_VERSION}/model_versions/2/artifacts`, + { + description: 'test', + externalID: 'test', + uri: 'test-uri', + state: ModelArtifactState.LIVE, + name: 'test-name', + modelFormatName: 'test-modelformatname', + storageKey: 'teststoragekey', + storagePath: 'teststoragePath', + modelFormatVersion: 'testmodelFormatVersion', + serviceAccountName: 'testserviceAccountname', + customProperties: {}, + artifactType: 'model-artifact', }, {}, K8sAPIOptionsMock, diff --git a/frontend/src/api/modelRegistry/custom.ts b/frontend/src/api/modelRegistry/custom.ts index 8282cc51ff..3604ed6357 100644 --- a/frontend/src/api/modelRegistry/custom.ts +++ b/frontend/src/api/modelRegistry/custom.ts @@ -39,6 +39,22 @@ export const createModelVersion = opts, ), ); +export const createModelVersionForRegisteredModel = + (hostpath: string) => + ( + opts: K8sAPIOptions, + registeredModelId: string, + data: CreateModelVersionData, + ): Promise => + handleModelRegistryFailures( + proxyCREATE( + hostpath, + `/api/model_registry/${MODEL_REGISTRY_API_VERSION}/registered_models/${registeredModelId}/versions`, + data, + {}, + opts, + ), + ); export const createModelArtifact = (hostPath: string) => @@ -53,6 +69,23 @@ export const createModelArtifact = ), ); +export const createModelArtifactForModelVersion = + (hostPath: string) => + ( + opts: K8sAPIOptions, + modelVersionId: string, + data: CreateModelArtifactData, + ): Promise => + handleModelRegistryFailures( + proxyCREATE( + hostPath, + `/api/model_registry/${MODEL_REGISTRY_API_VERSION}/model_versions/${modelVersionId}/artifacts`, + data, + {}, + opts, + ), + ); + export const getRegisteredModel = (hostPath: string) => (opts: K8sAPIOptions, registeredModelId: string): Promise => diff --git a/frontend/src/concepts/modelRegistry/__tests__/utils.spec.ts b/frontend/src/concepts/modelRegistry/__tests__/utils.spec.ts new file mode 100644 index 0000000000..314277b795 --- /dev/null +++ b/frontend/src/concepts/modelRegistry/__tests__/utils.spec.ts @@ -0,0 +1,122 @@ +import { + ObjectStorageFields, + objectStorageFieldsToUri, + uriToObjectStorageFields, +} from '~/concepts/modelRegistry/utils'; + +describe('objectStorageFieldsToUri', () => { + it('converts fields to URI with all fields present', () => { + const uri = objectStorageFieldsToUri({ + endpoint: 'http://s3.amazonaws.com/', + bucket: 'test-bucket', + region: 'us-east-1', + path: 'demo-models/flan-t5-small-caikit', + }); + expect(uri).toEqual( + 's3://test-bucket/demo-models/flan-t5-small-caikit?endpoint=http%3A%2F%2Fs3.amazonaws.com%2F&defaultRegion=us-east-1', + ); + }); + + it('converts fields to URI with region missing', () => { + const uri = objectStorageFieldsToUri({ + endpoint: 'http://s3.amazonaws.com/', + bucket: 'test-bucket', + path: 'demo-models/flan-t5-small-caikit', + }); + expect(uri).toEqual( + 's3://test-bucket/demo-models/flan-t5-small-caikit?endpoint=http%3A%2F%2Fs3.amazonaws.com%2F', + ); + }); + + it('converts fields to URI with region empty', () => { + const uri = objectStorageFieldsToUri({ + endpoint: 'http://s3.amazonaws.com/', + bucket: 'test-bucket', + region: '', + path: 'demo-models/flan-t5-small-caikit', + }); + expect(uri).toEqual( + 's3://test-bucket/demo-models/flan-t5-small-caikit?endpoint=http%3A%2F%2Fs3.amazonaws.com%2F', + ); + }); + + it('falls back to null if endpoint is empty', () => { + const uri = objectStorageFieldsToUri({ + endpoint: '', + bucket: 'test-bucket', + region: 'us-east-1', + path: 'demo-models/flan-t5-small-caikit', + }); + expect(uri).toEqual(null); + }); + + it('falls back to null if bucket is empty', () => { + const uri = objectStorageFieldsToUri({ + endpoint: 'http://s3.amazonaws.com/', + bucket: '', + region: 'us-east-1', + path: 'demo-models/flan-t5-small-caikit', + }); + expect(uri).toEqual(null); + }); + + it('falls back to null if path is empty', () => { + const uri = objectStorageFieldsToUri({ + endpoint: 'http://s3.amazonaws.com/', + bucket: 'test-bucket', + region: 'us-east-1', + path: '', + }); + expect(uri).toEqual(null); + }); +}); + +describe('uriToObjectStorageFields', () => { + it('converts URI to fields with all params present', () => { + const fields = uriToObjectStorageFields( + 's3://test-bucket/demo-models/flan-t5-small-caikit?endpoint=http%3A%2F%2Fs3.amazonaws.com%2F&defaultRegion=us-east-1', + ); + expect(fields).toEqual({ + endpoint: 'http://s3.amazonaws.com/', + bucket: 'test-bucket', + region: 'us-east-1', + path: 'demo-models/flan-t5-small-caikit', + } satisfies ObjectStorageFields); + }); + + it('converts URI to fields with region missing', () => { + const fields = uriToObjectStorageFields( + 's3://test-bucket/demo-models/flan-t5-small-caikit?endpoint=http%3A%2F%2Fs3.amazonaws.com%2F', + ); + expect(fields).toEqual({ + endpoint: 'http://s3.amazonaws.com/', + bucket: 'test-bucket', + path: 'demo-models/flan-t5-small-caikit', + region: undefined, + } satisfies ObjectStorageFields); + }); + + it('falls back to null if endpoint is missing', () => { + const fields = uriToObjectStorageFields('s3://test-bucket/demo-models/flan-t5-small-caikit'); + expect(fields).toBeNull(); + }); + + it('falls back to null if path is missing', () => { + const fields = uriToObjectStorageFields( + 's3://test-bucket/?endpoint=http%3A%2F%2Fs3.amazonaws.com%2F&defaultRegion=us-east-1', + ); + expect(fields).toBeNull(); + }); + + it('falls back to null if bucket is missing', () => { + const fields = uriToObjectStorageFields( + 's3://?endpoint=http%3A%2F%2Fs3.amazonaws.com%2F&defaultRegion=us-east-1', + ); + expect(fields).toBeNull(); + }); + + it('falls back to null if the URI is malformed', () => { + const fields = uriToObjectStorageFields('test-bucket/demo-models/flan-t5-small-caikit'); + expect(fields).toBeNull(); + }); +}); diff --git a/frontend/src/concepts/modelRegistry/context/useModelRegistryAPIState.tsx b/frontend/src/concepts/modelRegistry/context/useModelRegistryAPIState.tsx index ee5e5627ac..64c09c29ad 100644 --- a/frontend/src/concepts/modelRegistry/context/useModelRegistryAPIState.tsx +++ b/frontend/src/concepts/modelRegistry/context/useModelRegistryAPIState.tsx @@ -3,7 +3,9 @@ import { APIState } from '~/concepts/proxy/types'; import { ModelRegistryAPIs } from '~/concepts/modelRegistry/types'; import { createModelArtifact, + createModelArtifactForModelVersion, createModelVersion, + createModelVersionForRegisteredModel, createRegisteredModel, getListModelArtifacts, getListModelVersions, @@ -28,7 +30,9 @@ const useModelRegistryAPIState = ( (path: string) => ({ createRegisteredModel: createRegisteredModel(path), createModelVersion: createModelVersion(path), + createModelVersionForRegisteredModel: createModelVersionForRegisteredModel(path), createModelArtifact: createModelArtifact(path), + createModelArtifactForModelVersion: createModelArtifactForModelVersion(path), getRegisteredModel: getRegisteredModel(path), getModelVersion: getModelVersion(path), getModelArtifact: getModelArtifact(path), diff --git a/frontend/src/concepts/modelRegistry/types.ts b/frontend/src/concepts/modelRegistry/types.ts index e93a9d8f72..d601a97409 100644 --- a/frontend/src/concepts/modelRegistry/types.ts +++ b/frontend/src/concepts/modelRegistry/types.ts @@ -145,7 +145,7 @@ export type CreateModelVersionData = Omit< export type CreateModelArtifactData = Omit< ModelArtifact, - 'lastUpdateTimeSinceEpoch' | 'createTimeSinceEpoch' | 'id' | 'artifactType' + 'lastUpdateTimeSinceEpoch' | 'createTimeSinceEpoch' | 'id' >; export type ModelRegistryListParams = { @@ -170,11 +170,23 @@ export type CreateModelVersion = ( data: CreateModelVersionData, ) => Promise; +export type CreateModelVersionForRegisteredModel = ( + opts: K8sAPIOptions, + registeredModelId: string, + data: CreateModelVersionData, +) => Promise; + export type CreateModelArtifact = ( opts: K8sAPIOptions, data: CreateModelArtifactData, ) => Promise; +export type CreateModelArtifactForModelVersion = ( + opts: K8sAPIOptions, + modelVersionId: string, + data: CreateModelArtifactData, +) => Promise; + export type GetRegisteredModel = ( opts: K8sAPIOptions, registeredModelId: string, @@ -227,7 +239,9 @@ export type PatchModelArtifact = ( export type ModelRegistryAPIs = { createRegisteredModel: CreateRegisteredModel; createModelVersion: CreateModelVersion; + createModelVersionForRegisteredModel: CreateModelVersionForRegisteredModel; createModelArtifact: CreateModelArtifact; + createModelArtifactForModelVersion: CreateModelArtifactForModelVersion; getRegisteredModel: GetRegisteredModel; getModelVersion: GetModelVersion; getModelArtifact: GetModelArtifact; diff --git a/frontend/src/concepts/modelRegistry/utils.ts b/frontend/src/concepts/modelRegistry/utils.ts new file mode 100644 index 0000000000..b9d39754d4 --- /dev/null +++ b/frontend/src/concepts/modelRegistry/utils.ts @@ -0,0 +1,39 @@ +export type ObjectStorageFields = { + endpoint: string; + bucket: string; + region?: string; + path: string; +}; + +export const objectStorageFieldsToUri = (fields: ObjectStorageFields): string | null => { + const { endpoint, bucket, region, path } = fields; + if (!endpoint || !bucket || !path) { + return null; + } + const searchParams = new URLSearchParams(); + searchParams.set('endpoint', endpoint); + if (region) { + searchParams.set('defaultRegion', region); + } + return `s3://${bucket}/${path}?${searchParams.toString()}`; +}; + +export const uriToObjectStorageFields = (uri: string): ObjectStorageFields | null => { + try { + const urlObj = new URL(uri); + // Some environments include the first token after the protocol (our bucket) in the pathname and some have it as the hostname + const [bucket, ...pathSplit] = `${urlObj.hostname}/${urlObj.pathname}` + .split('/') + .filter(Boolean); + const path = pathSplit.join('/'); + const searchParams = new URLSearchParams(urlObj.search); + const endpoint = searchParams.get('endpoint'); + const region = searchParams.get('defaultRegion'); + if (endpoint && bucket && path) { + return { endpoint, bucket, region: region || undefined, path }; + } + return null; + } catch { + return null; + } +}; diff --git a/frontend/src/pages/modelRegistry/screens/RegisterModel/RegisterModel.tsx b/frontend/src/pages/modelRegistry/screens/RegisterModel/RegisterModel.tsx index f3724e6707..bfb563fb95 100644 --- a/frontend/src/pages/modelRegistry/screens/RegisterModel/RegisterModel.tsx +++ b/frontend/src/pages/modelRegistry/screens/RegisterModel/RegisterModel.tsx @@ -13,6 +13,7 @@ import { HelperTextItem, InputGroupItem, InputGroupText, + PageSection, Radio, Split, SplitItem, @@ -23,41 +24,40 @@ import { TextInput, Title, } from '@patternfly/react-core'; -import { useParams } from 'react-router'; +import spacing from '@patternfly/react-styles/css/utilities/Spacing/spacing'; +import { useParams, useNavigate } from 'react-router'; import { Link } from 'react-router-dom'; import ApplicationsPage from '~/pages/ApplicationsPage'; -import useRegisterModelData from './useRegisterModelData'; +import { ModelRegistryContext } from '~/concepts/modelRegistry/context/ModelRegistryContext'; +import { useAppSelector } from '~/redux/hooks'; +import { useRegisterModelData, ModelLocationType } from './useRegisterModelData'; +import { registerModel } from './utils'; const RegisterModel: React.FC = () => { const { modelRegistry: mrName } = useParams(); - const [ - { - modelRegistryName, - modelName, - modelDescription, - versionName, - versionDescription, - sourceModelFormat, - modelLocationType, - modelLocationEndpoint, - modelLocationBucket, - modelLocationRegion, - modelLocationPath, - modelLocationURI, - }, - setData, - resetData, - ] = useRegisterModelData(mrName); + const navigate = useNavigate(); + const [formData, setData] = useRegisterModelData(); + const { + modelName, + modelDescription, + versionName, + versionDescription, + sourceModelFormat, + sourceModelFormatVersion, + modelLocationType, + modelLocationEndpoint, + modelLocationBucket, + modelLocationRegion, + modelLocationPath, + modelLocationURI, + } = formData; const [loading, setIsLoading] = React.useState(false); const [error, setError] = React.useState(undefined); - enum ModelLocationType { - ObjectStorage = 'Object storage', - URI = 'URI', - } + const { apiState } = React.useContext(ModelRegistryContext); + const author = useAppSelector((state) => state.user || ''); const isSubmitDisabled = - !modelRegistryName || !modelName || !versionName || loading || @@ -68,23 +68,15 @@ const RegisterModel: React.FC = () => { const handleSubmit = () => { setIsLoading(true); setError(undefined); - //TODO: implement submit calls/logic. remove console log and alert - alert('This functionality is not yet implemented'); - /* eslint-disable-next-line no-console */ - console.log({ - modelRegistryName, - modelName, - modelDescription, - versionName, - versionDescription, - sourceModelFormat, - modelLocationType, - modelLocationEndpoint, - modelLocationBucket, - modelLocationRegion, - modelLocationPath, - modelLocationURI, - }); + + registerModel(apiState, formData, author) + .then(({ registeredModel }) => { + navigate(`/modelRegistry/${mrName}/registeredModels/${registeredModel.id}`); + }) + .catch((e: Error) => { + setIsLoading(false); + setError(e); + }); }; return ( @@ -93,206 +85,233 @@ const RegisterModel: React.FC = () => { description="Create a new model and register a first version of the new model." breadcrumb={ - Model registry} /> ( - {modelRegistryName} - )} + render={() => Model registry - {mrName}} /> Register model } loaded empty={false} - provideChildrenPadding > -
- - - - + + + + - - - - - Model info - - Configure the model info you want to create. - - - } - > - - setData('modelName', value)} - /> - - -