diff --git a/frontend/src/__tests__/cypress/cypress/support/commands/odh.ts b/frontend/src/__tests__/cypress/cypress/support/commands/odh.ts index d16745d4d5..ea9d8d8cdd 100644 --- a/frontend/src/__tests__/cypress/cypress/support/commands/odh.ts +++ b/frontend/src/__tests__/cypress/cypress/support/commands/odh.ts @@ -318,6 +318,11 @@ declare global { options: { path: { serviceName: string; apiVersion: string; registeredModelId: number } }, response: OdhResponse, ) => Cypress.Chainable) & + (( + type: 'GET /api/service/modelregistry/:serviceName/api/model_registry/:apiVersion/model_versions', + options: { path: { serviceName: string; apiVersion: string } }, + response: OdhResponse, + ) => Cypress.Chainable) & (( type: 'POST /api/service/modelregistry/:serviceName/api/model_registry/:apiVersion/registered_models/:registeredModelId/versions', options: { path: { serviceName: string; apiVersion: string; registeredModelId: number } }, diff --git a/frontend/src/__tests__/cypress/cypress/tests/mocked/modelRegistry/modelRegistry.cy.ts b/frontend/src/__tests__/cypress/cypress/tests/mocked/modelRegistry/modelRegistry.cy.ts index b9e0075edc..f8aa8eb012 100644 --- a/frontend/src/__tests__/cypress/cypress/tests/mocked/modelRegistry/modelRegistry.cy.ts +++ b/frontend/src/__tests__/cypress/cypress/tests/mocked/modelRegistry/modelRegistry.cy.ts @@ -87,6 +87,16 @@ const initIntercepts = ({ disableModelRegistry: disableModelRegistryFeature, }), ); + cy.interceptOdh( + `GET /api/service/modelregistry/:serviceName/api/model_registry/:apiVersion/model_versions`, + { + path: { + serviceName: 'modelregistry-sample', + apiVersion: MODEL_REGISTRY_API_VERSION, + }, + }, + mockModelVersionList({ items: modelVersions }), + ); cy.interceptOdh('GET /api/components', { query: { installed: 'true' } }, mockComponents()); cy.interceptK8s('POST', SelfSubjectRulesReviewModel, mockSelfSubjectRulesReview()); diff --git a/frontend/src/__tests__/cypress/cypress/tests/mocked/modelRegistry/modelVersions.cy.ts b/frontend/src/__tests__/cypress/cypress/tests/mocked/modelRegistry/modelVersions.cy.ts index bce9454e3c..2013b320b3 100644 --- a/frontend/src/__tests__/cypress/cypress/tests/mocked/modelRegistry/modelVersions.cy.ts +++ b/frontend/src/__tests__/cypress/cypress/tests/mocked/modelRegistry/modelVersions.cy.ts @@ -62,6 +62,16 @@ const initIntercepts = ({ disableModelRegistry: disableModelRegistryFeature, }), ); + cy.interceptOdh( + `GET /api/service/modelregistry/:serviceName/api/model_registry/:apiVersion/model_versions`, + { + path: { + serviceName: 'modelregistry-sample', + apiVersion: MODEL_REGISTRY_API_VERSION, + }, + }, + mockModelVersionList({ items: modelVersions }), + ); cy.interceptK8sList(ServiceModel, mockK8sResourceList(modelRegistries)); diff --git a/frontend/src/__tests__/cypress/cypress/tests/mocked/modelRegistry/registeredModelArchive.cy.ts b/frontend/src/__tests__/cypress/cypress/tests/mocked/modelRegistry/registeredModelArchive.cy.ts index 97b7b7616d..01d7329a92 100644 --- a/frontend/src/__tests__/cypress/cypress/tests/mocked/modelRegistry/registeredModelArchive.cy.ts +++ b/frontend/src/__tests__/cypress/cypress/tests/mocked/modelRegistry/registeredModelArchive.cy.ts @@ -66,7 +66,16 @@ const initIntercepts = ({ }, }), ); - + cy.interceptOdh( + `GET /api/service/modelregistry/:serviceName/api/model_registry/:apiVersion/model_versions`, + { + path: { + serviceName: 'modelregistry-sample', + apiVersion: MODEL_REGISTRY_API_VERSION, + }, + }, + mockModelVersionList({ items: modelVersions }), + ); cy.interceptK8sList( ServiceModel, mockK8sResourceList([ diff --git a/frontend/src/concepts/modelRegistry/apiHooks/__tests__/useModelVersions.spec.ts b/frontend/src/concepts/modelRegistry/apiHooks/__tests__/useModelVersions.spec.ts new file mode 100644 index 0000000000..56bdabdfc6 --- /dev/null +++ b/frontend/src/concepts/modelRegistry/apiHooks/__tests__/useModelVersions.spec.ts @@ -0,0 +1,87 @@ +import { renderHook, act } from '@testing-library/react'; +import { useModelRegistryAPI } from '~/concepts/modelRegistry/context/ModelRegistryContext'; +import useModelVersions from '~/concepts/modelRegistry/apiHooks/useModelVersions'; + +jest.mock('~/concepts/modelRegistry/context/ModelRegistryContext'); + +describe('useModelVersions', () => { + const mockListModelVersions = jest.fn(); + const mockUseModelRegistryAPI = useModelRegistryAPI as jest.Mock; + + beforeEach(() => { + jest.resetAllMocks(); + }); + + it('should return initial state when API is not available', async () => { + mockUseModelRegistryAPI.mockReturnValue({ + api: { listModelVersions: mockListModelVersions }, + apiAvailable: false, + }); + + const { result } = renderHook(() => useModelVersions()); + + expect(result.current[0]).toEqual({ items: [], size: 0, pageSize: 0, nextPageToken: '' }); + expect(result.current[1]).toBe(false); + expect(result.current[2]).toBeUndefined(); + + await act(async () => { + await result.current[3](); + }); + + expect(mockListModelVersions).not.toHaveBeenCalled(); + }); + + it('should fetch model versions when API is available', async () => { + const mockModelVersions = { + items: [ + { id: '1', name: 'Model 1' }, + { id: '2', name: 'Model 2' }, + ], + size: 2, + pageSize: 10, + nextPageToken: 'next-token', + }; + + mockUseModelRegistryAPI.mockReturnValue({ + api: { listModelVersions: mockListModelVersions }, + apiAvailable: true, + }); + + mockListModelVersions.mockResolvedValue(mockModelVersions); + + const { result } = renderHook(() => useModelVersions()); + + expect(result.current[1]).toBe(false); + + await act(async () => { + await result.current[3](); + }); + + expect(result.current[0]).toEqual(mockModelVersions); + expect(result.current[1]).toBe(true); + expect(result.current[2]).toBeUndefined(); + expect(mockListModelVersions).toHaveBeenCalledTimes(2); + }); + + it('should handle errors when fetching model versions', async () => { + const mockError = new Error('Failed to fetch model versions'); + + mockUseModelRegistryAPI.mockReturnValue({ + api: { listModelVersions: mockListModelVersions }, + apiAvailable: true, + }); + + mockListModelVersions.mockRejectedValue(mockError); + + const { result } = renderHook(() => useModelVersions()); + + await act(async () => { + await result.current[3](); + }); + + expect(result.current[0]).toEqual({ items: [], size: 0, pageSize: 0, nextPageToken: '' }); + expect(result.current[1]).toBe(false); + expect(result.current[2]).toBe(mockError); + expect(mockListModelVersions).toHaveBeenCalledTimes(2); + }); +}); diff --git a/frontend/src/concepts/modelRegistry/apiHooks/useModelVersions.ts b/frontend/src/concepts/modelRegistry/apiHooks/useModelVersions.ts new file mode 100644 index 0000000000..3d27087841 --- /dev/null +++ b/frontend/src/concepts/modelRegistry/apiHooks/useModelVersions.ts @@ -0,0 +1,28 @@ +import * as React from 'react'; +import useFetchState, { + FetchState, + FetchStateCallbackPromise, + NotReadyError, +} from '~/utilities/useFetchState'; +import { ModelVersionList } from '~/concepts/modelRegistry/types'; +import { useModelRegistryAPI } from '~/concepts/modelRegistry/context/ModelRegistryContext'; + +const useModelVersions = (): FetchState => { + const { api, apiAvailable } = useModelRegistryAPI(); + const callback = React.useCallback>( + (opts) => { + if (!apiAvailable) { + return Promise.reject(new NotReadyError('API not yet available')); + } + return api.listModelVersions(opts).then((r) => r); + }, + [api, apiAvailable], + ); + return useFetchState( + callback, + { items: [], size: 0, pageSize: 0, nextPageToken: '' }, + { initialPromisePurity: true }, + ); +}; + +export default useModelVersions; diff --git a/frontend/src/pages/modelRegistry/screens/ModelRegistry.tsx b/frontend/src/pages/modelRegistry/screens/ModelRegistry.tsx index c7bcff2df2..096c955c80 100644 --- a/frontend/src/pages/modelRegistry/screens/ModelRegistry.tsx +++ b/frontend/src/pages/modelRegistry/screens/ModelRegistry.tsx @@ -1,6 +1,7 @@ import React from 'react'; import ApplicationsPage from '~/pages/ApplicationsPage'; import useRegisteredModels from '~/concepts/modelRegistry/apiHooks/useRegisteredModels'; +import useModelVersions from '~/concepts/modelRegistry/apiHooks/useModelVersions'; import TitleWithIcon from '~/concepts/design/TitleWithIcon'; import { ProjectObjectType } from '~/concepts/design/utils'; import RegisteredModelListView from './RegisteredModels/RegisteredModelListView'; @@ -19,7 +20,16 @@ type ModelRegistryProps = Omit< >; const ModelRegistry: React.FC = ({ ...pageProps }) => { - const [registeredModels, loaded, loadError, refresh] = useRegisteredModels(); + const [registeredModels, modelsLoaded, modelsLoadError, refreshModels] = useRegisteredModels(); + const [modelVersions, versionsLoaded, versionsLoadError, refreshVersions] = useModelVersions(); + + const loaded = modelsLoaded && versionsLoaded; + const loadError = modelsLoadError || versionsLoadError; + + const refresh = React.useCallback(() => { + refreshModels(); + refreshVersions(); + }, [refreshModels, refreshVersions]); return ( = ({ ...pageProps }) => { provideChildrenPadding removeChildrenTopPadding > - + ); }; diff --git a/frontend/src/pages/modelRegistry/screens/RegisteredModels/RegisteredModelListView.tsx b/frontend/src/pages/modelRegistry/screens/RegisteredModels/RegisteredModelListView.tsx index 20e7df7e4f..2adcf25727 100644 --- a/frontend/src/pages/modelRegistry/screens/RegisteredModels/RegisteredModelListView.tsx +++ b/frontend/src/pages/modelRegistry/screens/RegisteredModels/RegisteredModelListView.tsx @@ -3,7 +3,7 @@ import { SearchInput, ToolbarFilter, ToolbarGroup, ToolbarItem } from '@patternf import { FilterIcon } from '@patternfly/react-icons'; import { useNavigate } from 'react-router'; import { SearchType } from '~/concepts/dashboard/DashboardSearchField'; -import { RegisteredModel } from '~/concepts/modelRegistry/types'; +import { ModelVersion, RegisteredModel } from '~/concepts/modelRegistry/types'; import SimpleSelect from '~/components/SimpleSelect'; import { filterRegisteredModels } from '~/pages/modelRegistry/screens/utils'; import { ModelRegistrySelectorContext } from '~/concepts/modelRegistry/context/ModelRegistrySelectorContext'; @@ -20,11 +20,13 @@ import RegisteredModelsTableToolbar from './RegisteredModelsTableToolbar'; type RegisteredModelListViewProps = { registeredModels: RegisteredModel[]; + modelVersions: ModelVersion[]; refresh: () => void; }; const RegisteredModelListView: React.FC = ({ registeredModels, + modelVersions, refresh, }) => { const navigate = useNavigate(); @@ -63,6 +65,7 @@ const RegisteredModelListView: React.FC = ({ const filteredRegisteredModels = filterRegisteredModels( unfilteredRegisteredModels, + modelVersions, search, searchType, ); diff --git a/frontend/src/pages/modelRegistry/screens/RegisteredModelsArchive/RegisteredModelsArchive.tsx b/frontend/src/pages/modelRegistry/screens/RegisteredModelsArchive/RegisteredModelsArchive.tsx index 9b7cc1e566..475bf3f8a7 100644 --- a/frontend/src/pages/modelRegistry/screens/RegisteredModelsArchive/RegisteredModelsArchive.tsx +++ b/frontend/src/pages/modelRegistry/screens/RegisteredModelsArchive/RegisteredModelsArchive.tsx @@ -5,6 +5,7 @@ import ApplicationsPage from '~/pages/ApplicationsPage'; import { ModelRegistrySelectorContext } from '~/concepts/modelRegistry/context/ModelRegistrySelectorContext'; import { filterArchiveModels } from '~/concepts/modelRegistry/utils'; import useRegisteredModels from '~/concepts/modelRegistry/apiHooks/useRegisteredModels'; +import useModelVersions from '~/concepts/modelRegistry/apiHooks/useModelVersions'; import RegisteredModelsArchiveListView from './RegisteredModelsArchiveListView'; type RegisteredModelsArchiveProps = Omit< @@ -14,7 +15,16 @@ type RegisteredModelsArchiveProps = Omit< const RegisteredModelsArchive: React.FC = ({ ...pageProps }) => { const { preferredModelRegistry } = React.useContext(ModelRegistrySelectorContext); - const [registeredModels, loaded, loadError, refresh] = useRegisteredModels(); + const [registeredModels, modelsLoaded, modelsLoadError, refreshModels] = useRegisteredModels(); + const [modelVersions, versionsLoaded, versionsLoadError, refreshVersions] = useModelVersions(); + + const loaded = modelsLoaded && versionsLoaded; + const loadError = modelsLoadError || versionsLoadError; + + const refresh = React.useCallback(() => { + refreshModels(); + refreshVersions(); + }, [refreshModels, refreshVersions]); return ( = ({ ...pa > diff --git a/frontend/src/pages/modelRegistry/screens/RegisteredModelsArchive/RegisteredModelsArchiveListView.tsx b/frontend/src/pages/modelRegistry/screens/RegisteredModelsArchive/RegisteredModelsArchiveListView.tsx index 3d3d07e22b..2c3dbeaa46 100644 --- a/frontend/src/pages/modelRegistry/screens/RegisteredModelsArchive/RegisteredModelsArchiveListView.tsx +++ b/frontend/src/pages/modelRegistry/screens/RegisteredModelsArchive/RegisteredModelsArchiveListView.tsx @@ -9,7 +9,7 @@ import { } from '@patternfly/react-core'; import { FilterIcon, SearchIcon } from '@patternfly/react-icons'; import { SearchType } from '~/concepts/dashboard/DashboardSearchField'; -import { RegisteredModel } from '~/concepts/modelRegistry/types'; +import { ModelVersion, RegisteredModel } from '~/concepts/modelRegistry/types'; import SimpleSelect from '~/components/SimpleSelect'; import { filterRegisteredModels } from '~/pages/modelRegistry/screens/utils'; import EmptyModelRegistryState from '~/pages/modelRegistry/screens/components/EmptyModelRegistryState'; @@ -18,20 +18,22 @@ import RegisteredModelsArchiveTable from './RegisteredModelsArchiveTable'; type RegisteredModelsArchiveListViewProps = { registeredModels: RegisteredModel[]; + modelVersions: ModelVersion[]; refresh: () => void; }; const RegisteredModelsArchiveListView: React.FC = ({ registeredModels: unfilteredRegisteredModels, + modelVersions, refresh, }) => { const [searchType, setSearchType] = React.useState(SearchType.KEYWORD); const [search, setSearch] = React.useState(''); - const searchTypes = [SearchType.KEYWORD, SearchType.AUTHOR]; - + const searchTypes = [SearchType.KEYWORD, SearchType.OWNER]; const filteredRegisteredModels = filterRegisteredModels( unfilteredRegisteredModels, + modelVersions, search, searchType, ); diff --git a/frontend/src/pages/modelRegistry/screens/__tests__/utils.spec.ts b/frontend/src/pages/modelRegistry/screens/__tests__/utils.spec.ts index 8cb12503f3..37c83b79a2 100644 --- a/frontend/src/pages/modelRegistry/screens/__tests__/utils.spec.ts +++ b/frontend/src/pages/modelRegistry/screens/__tests__/utils.spec.ts @@ -309,22 +309,27 @@ describe('filterRegisteredModels', () => { ]; test('filters by name', () => { - const filtered = filterRegisteredModels(registeredModels, 'Test 1', SearchType.KEYWORD); + const filtered = filterRegisteredModels(registeredModels, [], 'Test 1', SearchType.KEYWORD); expect(filtered).toEqual([registeredModels[0]]); }); test('filters by description', () => { - const filtered = filterRegisteredModels(registeredModels, 'Description2', SearchType.KEYWORD); + const filtered = filterRegisteredModels( + registeredModels, + [], + 'Description2', + SearchType.KEYWORD, + ); expect(filtered).toEqual([registeredModels[1]]); }); test('filters by owner', () => { - const filtered = filterRegisteredModels(registeredModels, 'Alice', SearchType.OWNER); + const filtered = filterRegisteredModels(registeredModels, [], 'Alice', SearchType.OWNER); expect(filtered).toEqual([registeredModels[0], registeredModels[3]]); }); test('does not filter when search is empty', () => { - const filtered = filterRegisteredModels(registeredModels, '', SearchType.KEYWORD); + const filtered = filterRegisteredModels(registeredModels, [], '', SearchType.KEYWORD); expect(filtered).toEqual(registeredModels); }); }); diff --git a/frontend/src/pages/modelRegistry/screens/utils.ts b/frontend/src/pages/modelRegistry/screens/utils.ts index 7c9a934f78..3c675c174d 100644 --- a/frontend/src/pages/modelRegistry/screens/utils.ts +++ b/frontend/src/pages/modelRegistry/screens/utils.ts @@ -114,6 +114,7 @@ export const sortModelVersionsByCreateTime = (registeredModels: ModelVersion[]): export const filterRegisteredModels = ( unfilteredRegisteredModels: RegisteredModel[], + unfilteredModelVersions: ModelVersion[], search: string, searchType: SearchType, ): RegisteredModel[] => { @@ -123,16 +124,26 @@ export const filterRegisteredModels = ( if (!search) { return true; } + const modelVersions = unfilteredModelVersions.filter((mv) => mv.registeredModelId === rm.id); switch (searchType) { case SearchType.KEYWORD: { - return ( + const matchesModel = rm.name.toLowerCase().includes(searchLower) || (rm.description && rm.description.toLowerCase().includes(searchLower)) || - getLabels(rm.customProperties).some((label) => label.toLowerCase().includes(searchLower)) + getLabels(rm.customProperties).some((label) => label.toLowerCase().includes(searchLower)); + + const matchesVersion = modelVersions.some( + (mv: ModelVersion) => + mv.name.toLowerCase().includes(searchLower) || + (mv.description && mv.description.toLowerCase().includes(searchLower)) || + getLabels(mv.customProperties).some((label) => + label.toLowerCase().includes(searchLower), + ), ); - } + return matchesModel || matchesVersion; + } case SearchType.OWNER: { return rm.owner && rm.owner.toLowerCase().includes(searchLower); }