Skip to content

Commit

Permalink
Add deployment modal for model registry
Browse files Browse the repository at this point in the history
  • Loading branch information
DaoDaoNoCode committed Aug 7, 2024
1 parent 457d01d commit f523086
Show file tree
Hide file tree
Showing 20 changed files with 793 additions and 159 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,20 @@ import * as React from 'react';
import { Alert, Bullseye } from '@patternfly/react-core';
import { SupportedArea, conditionalArea } from '~/concepts/areas';
import { MODEL_REGISTRY_DEFAULT_NAMESPACE } from '~/concepts/modelRegistry/const';
import useModelRegistryAPIState, { ModelRegistryAPIState } from './useModelRegistryAPIState';
import { useTemplates } from '~/api';
import { useDashboardNamespace } from '~/redux/selectors';
import { useContextResourceData } from '~/utilities/useContextResourceData';
import useTemplateOrder from '~/pages/modelServing/customServingRuntimes/useTemplateOrder';
import { ContextResourceData, CustomWatchK8sResult } from '~/types';
import { TemplateKind } from '~/k8sTypes';
import { DEFAULT_CONTEXT_DATA, DEFAULT_LIST_WATCH_RESULT } from '~/utilities/const';
import useTemplateDisablement from '~/pages/modelServing/customServingRuntimes/useTemplateDisablement';
import {
hasServerTimedOut,
isModelRegistryAvailable,
useModelRegistryNamespaceCR,
} from './useModelRegistryNamespaceCR';
import useModelRegistryAPIState, { ModelRegistryAPIState } from './useModelRegistryAPIState';

export type ModelRegistryContextType = {
hasCR: boolean;
Expand All @@ -17,6 +25,9 @@ export type ModelRegistryContextType = {
ignoreTimedOut: () => void;
refreshState: () => Promise<undefined>;
refreshAPIState: () => void;
servingRuntimeTemplates: CustomWatchK8sResult<TemplateKind[]>;
servingRuntimeTemplateOrder: ContextResourceData<string>;
servingRuntimeTemplateDisablement: ContextResourceData<string>;
};

type ModelRegistryContextProviderProps = {
Expand All @@ -33,12 +44,16 @@ export const ModelRegistryContext = React.createContext<ModelRegistryContextType
ignoreTimedOut: () => undefined,
refreshState: async () => undefined,
refreshAPIState: () => undefined,
servingRuntimeTemplates: DEFAULT_LIST_WATCH_RESULT,
servingRuntimeTemplateOrder: DEFAULT_CONTEXT_DATA,
servingRuntimeTemplateDisablement: DEFAULT_CONTEXT_DATA,
});

export const ModelRegistryContextProvider = conditionalArea<ModelRegistryContextProviderProps>(
SupportedArea.MODEL_REGISTRY,
true,
)(({ children, modelRegistryName }) => {
const { dashboardNamespace } = useDashboardNamespace();
const state = useModelRegistryNamespaceCR(MODEL_REGISTRY_DEFAULT_NAMESPACE, modelRegistryName);
const [modelRegistryCR, crLoaded, crLoadError, refreshCR] = state;
const isCRReady = isModelRegistryAvailable(state);
Expand All @@ -49,6 +64,14 @@ export const ModelRegistryContextProvider = conditionalArea<ModelRegistryContext
setDisableTimeout(true);
}, []);

const servingRuntimeTemplates = useTemplates(dashboardNamespace);
const servingRuntimeTemplateOrder = useContextResourceData<string>(
useTemplateOrder(dashboardNamespace),
);
const servingRuntimeTemplateDisablement = useContextResourceData<string>(
useTemplateDisablement(dashboardNamespace),
);

const hostPath = modelRegistryName ? `/api/service/modelregistry/${modelRegistryName}` : null;

const [apiState, refreshAPIState] = useModelRegistryAPIState(hostPath);
Expand Down Expand Up @@ -78,6 +101,9 @@ export const ModelRegistryContextProvider = conditionalArea<ModelRegistryContext
ignoreTimedOut,
refreshState,
refreshAPIState,
servingRuntimeTemplates,
servingRuntimeTemplateOrder,
servingRuntimeTemplateDisablement,
}}
>
{children}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import { ModelVersion, ModelState } from '~/concepts/modelRegistry/types';
import { getPatchBodyForModelVersion } from '~/pages/modelRegistry/screens/utils';
import { ModelRegistrySelectorContext } from '~/concepts/modelRegistry/context/ModelRegistrySelectorContext';
import { modelVersionArchiveDetailsUrl } from '~/pages/modelRegistry/screens/routeUtils';
import DeployRegisteredModelModal from '~/pages/modelRegistry/screens/components/DeployRegisteredModelModal';

interface ModelVersionsDetailsHeaderActionsProps {
mv: ModelVersion;
Expand All @@ -21,6 +22,7 @@ const ModelVersionsDetailsHeaderActions: React.FC<ModelVersionsDetailsHeaderActi
const navigate = useNavigate();
const [isOpenActionDropdown, setOpenActionDropdown] = React.useState(false);
const [isArchiveModalOpen, setIsArchiveModalOpen] = React.useState(false);
const [isDeployModalOpen, setIsDeployModalOpen] = React.useState(false);
const tooltipRef = React.useRef<HTMLButtonElement>(null);

return (
Expand Down Expand Up @@ -48,9 +50,8 @@ const ModelVersionsDetailsHeaderActions: React.FC<ModelVersionsDetailsHeaderActi
id="deploy-button"
aria-label="Deploy version"
key="deploy-button"
onClick={() => undefined}
onClick={() => setIsDeployModalOpen(true)}
ref={tooltipRef}
isDisabled // TODO This feature is currently disabled but will be enabled in a future PR post-summit release.
>
Deploy
</DropdownItem>
Expand All @@ -65,6 +66,11 @@ const ModelVersionsDetailsHeaderActions: React.FC<ModelVersionsDetailsHeaderActi
</DropdownItem>
</DropdownList>
</Dropdown>
<DeployRegisteredModelModal
onCancel={() => setIsDeployModalOpen(false)}
isOpen={isDeployModalOpen}
modelVersion={mv}
/>
<ArchiveModelVersionModal
onCancel={() => setIsArchiveModalOpen(false)}
onSubmit={() =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ import { ArchiveModelVersionModal } from '~/pages/modelRegistry/screens/componen
import { ModelRegistryContext } from '~/concepts/modelRegistry/context/ModelRegistryContext';
import { getPatchBodyForModelVersion } from '~/pages/modelRegistry/screens/utils';
import { RestoreModelVersionModal } from '~/pages/modelRegistry/screens/components/RestoreModelVersionModal';
import DeployRegisteredModelModal from '~/pages/modelRegistry/screens/components/DeployRegisteredModelModal';

type ModelVersionsTableRowProps = {
modelVersion: ModelVersion;
Expand All @@ -30,6 +31,7 @@ const ModelVersionsTableRow: React.FC<ModelVersionsTableRowProps> = ({
const { preferredModelRegistry } = React.useContext(ModelRegistrySelectorContext);
const [isArchiveModalOpen, setIsArchiveModalOpen] = React.useState(false);
const [isRestoreModalOpen, setIsRestoreModalOpen] = React.useState(false);
const [isDeployModalOpen, setIsDeployModalOpen] = React.useState(false);
const { apiState } = React.useContext(ModelRegistryContext);

const actions = isArchiveRow
Expand All @@ -42,8 +44,7 @@ const ModelVersionsTableRow: React.FC<ModelVersionsTableRowProps> = ({
: [
{
title: 'Deploy',
// TODO: Implement functionality for onClick. This will be added in another PR
onClick: () => undefined,
onClick: () => setIsDeployModalOpen(true),
},
{
title: 'Archive version',
Expand Down Expand Up @@ -105,6 +106,11 @@ const ModelVersionsTableRow: React.FC<ModelVersionsTableRowProps> = ({
isOpen={isArchiveModalOpen}
modelVersionName={mv.name}
/>
<DeployRegisteredModelModal
onCancel={() => setIsDeployModalOpen(false)}
isOpen={isDeployModalOpen}
modelVersion={mv}
/>
<RestoreModelVersionModal
onCancel={() => setIsRestoreModalOpen(false)}
onSubmit={() =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,12 +42,6 @@ const RegisteredModelTableRow: React.FC<RegisteredModelTableRowProps> = ({
},
]
: [
{
title: 'Deploy',
isDisabled: true,
// TODO: Implement functionality for onClick. This will be added in another PR
onClick: () => undefined,
},
{
title: 'Archive model',
onClick: () => setIsArchiveModalOpen(true),
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
import React from 'react';
import { uriToObjectStorageFields } from '~/concepts/modelRegistry/utils';
import { LabeledDataConnection } from '~/pages/modelServing/screens/types';
import { AwsKeys } from '~/pages/projects/dataConnections/const';
import { convertAWSSecretData } from '~/pages/projects/screens/detail/data-connections/utils';
import { DataConnection } from '~/pages/projects/types';

const useLabeledDataConnections = (
modelArtifactUri: string | undefined,
dataConnections: DataConnection[] = [],
): {
dataConnections: LabeledDataConnection[];
path: string;
hasParseError?: boolean;
} =>
React.useMemo(() => {
if (!modelArtifactUri) {
return {
dataConnections: dataConnections.map((dataConnection) => ({ dataConnection })),
path: '',
};
}
const storageFields = uriToObjectStorageFields(modelArtifactUri);
if (!storageFields) {
return {
dataConnections: dataConnections.map((dataConnection) => ({ dataConnection })),
path: '',
hasParseError: true,
};
}
const labeledDataConnections = dataConnections.map((dataConnection) => {
const awsData = convertAWSSecretData(dataConnection);
const bucket = awsData.find((data) => data.key === AwsKeys.AWS_S3_BUCKET)?.value;
const endpoint = awsData.find((data) => data.key === AwsKeys.S3_ENDPOINT)?.value;
const region = awsData.find((data) => data.key === AwsKeys.DEFAULT_REGION)?.value;
if (
bucket === storageFields.bucket &&
endpoint === storageFields.endpoint &&
region === storageFields.region
) {
return { dataConnection, isRecommended: true };
}
return { dataConnection };
});
return { dataConnections: labeledDataConnections, path: storageFields.path };
}, [dataConnections, modelArtifactUri]);

export default useLabeledDataConnections;
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
import useServingRuntimes from '~/pages/modelServing/useServingRuntimes';
import { ServingRuntimePlatform } from '~/types';

const useProjectErrorForRegisteredModel = (
projectName?: string,
platform?: ServingRuntimePlatform,
): Error | undefined => {
const [servingRuntimes, loaded, loadError] = useServingRuntimes(projectName);

// If project is not selected, there is no error
if (!projectName) {
return undefined;
}

// If the platform is not selected
if (!platform) {
return new Error('Cannot deploy the model until you select a model serving platform');
}

// If the platform is MULTI but it doesn't have a server
if (platform === ServingRuntimePlatform.MULTI) {
if (loadError) {
return loadError;
}
if (loaded && servingRuntimes.length === 0) {
return new Error('Cannot deploy the model until you configure a model server');
}
}

return undefined;
};

export default useProjectErrorForRegisteredModel;
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
import React from 'react';
import useModelArtifactsByVersionId from '~/concepts/modelRegistry/apiHooks/useModelArtifactsByVersionId';
import useRegisteredModelById from '~/concepts/modelRegistry/apiHooks/useRegisteredModelById';
import { ModelVersion } from '~/concepts/modelRegistry/types';

export type RegisteredModelContext = {
modelName: string;
modelFormat?: string;
modelArtifactUri?: string;
modelArtifactStorageKey?: string;
};

const useRegisteredModelContext = (modelVersion: ModelVersion): RegisteredModelContext => {
const [registeredModel] = useRegisteredModelById(modelVersion.registeredModelId);
const [modelArtifactList] = useModelArtifactsByVersionId(modelVersion.id);

const registeredModelContext = React.useMemo(() => {
if (modelArtifactList.size === 0) {
return {
modelName: `${registeredModel?.name} - ${modelVersion.name}`,
};
}
const modelArtifact = modelArtifactList.items[0];
return {
modelName: `${registeredModel?.name} - ${modelVersion.name}`,
modelFormat: modelArtifact.modelFormatName
? `${modelArtifact.modelFormatName} - ${modelArtifact.modelFormatVersion}`
: undefined,
modelArtifactUri: modelArtifact.uri,
modelArtifactStorageKey: modelArtifact.storageKey,
};
}, [modelArtifactList.items, modelArtifactList.size, modelVersion.name, registeredModel?.name]);

return registeredModelContext;
};

export default useRegisteredModelContext;
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
import React from 'react';
import { ProjectKind } from '~/k8sTypes';
import useLabeledDataConnections from '~/pages/modelRegistry/screens/RegisteredModels/useLabeledDataConnections';
import { RegisteredModelContext } from '~/pages/modelRegistry/screens/RegisteredModels/useRegisteredModelContext';
import {
CreatingInferenceServiceObject,
InferenceServiceStorageType,
LabeledDataConnection,
} from '~/pages/modelServing/screens/types';
import { AwsKeys, EMPTY_AWS_SECRET_DATA } from '~/pages/projects/dataConnections/const';
import useDataConnections from '~/pages/projects/screens/detail/data-connections/useDataConnections';
import { DataConnection, UpdateObjectAtPropAndValue } from '~/pages/projects/types';

const useRegisteredModelContextForModal = (
projectContext: { currentProject: ProjectKind; dataConnections: DataConnection[] } | undefined,
createData: CreatingInferenceServiceObject,
setCreateData: UpdateObjectAtPropAndValue<CreatingInferenceServiceObject>,
registeredModelContext?: RegisteredModelContext,
): [LabeledDataConnection[], boolean, Error | undefined] => {
const [fetchedDataConnections, dataConnectionsLoaded, dataConnectionsLoadError] =
useDataConnections(projectContext ? undefined : createData.project);
const allDataConnections = projectContext?.dataConnections || fetchedDataConnections;
const { dataConnections, path, hasParseError } = useLabeledDataConnections(
registeredModelContext?.modelArtifactUri,
allDataConnections,
);

React.useEffect(() => {
if (registeredModelContext) {
setCreateData('name', registeredModelContext.modelName);
const recommendedDataConnections = dataConnections.filter(
(dataConnection) => dataConnection.isRecommended,
);

if (!registeredModelContext.modelArtifactUri || hasParseError) {
setCreateData('storage', {
awsData: EMPTY_AWS_SECRET_DATA,
dataConnection: '',
path,
type: InferenceServiceStorageType.EXISTING_STORAGE,
});
} else if (recommendedDataConnections.length === 0) {
setCreateData('storage', {
awsData: [
...EMPTY_AWS_SECRET_DATA,
{ key: AwsKeys.NAME, value: registeredModelContext.modelArtifactStorageKey || '' },
],
dataConnection: '',
path,
type: InferenceServiceStorageType.NEW_STORAGE,
});
} else if (recommendedDataConnections.length === 1) {
setCreateData('storage', {
awsData: EMPTY_AWS_SECRET_DATA,
dataConnection: recommendedDataConnections[0].dataConnection.data.metadata.name,
path,
type: InferenceServiceStorageType.EXISTING_STORAGE,
});
} else {
setCreateData('storage', {
awsData: EMPTY_AWS_SECRET_DATA,
dataConnection: '',
path,
type: InferenceServiceStorageType.EXISTING_STORAGE,
});
}
}
}, [dataConnections, hasParseError, path, registeredModelContext, setCreateData]);

return [dataConnections, dataConnectionsLoaded, dataConnectionsLoadError];
};

export default useRegisteredModelContextForModal;
Loading

0 comments on commit f523086

Please sign in to comment.