Skip to content

Commit

Permalink
Refactor imports and remove unused code in ArtifactDetails.spec.tsx a…
Browse files Browse the repository at this point in the history
…nd ArtifactsTable.spec.tsx

Add TestArtifacts component to GlobalPipelineExperimentsRoutes

Refactor ConfusionMatrix component to improve readability and styling

Add ConfusionMatrix component and buildConfusionMatrixConfig function

Add ConfusionMatrix and ROCCurve components

Remove TestArtifacts component and update related routes

Refactor ROC curve table components, update imports, and remove unnecessary code

Refactor ConfusionMatrixSelect component and add skeleton loading

Refactor ROC curve table components and update imports

Refactor ROC curve table components and add related constants

Refactor ROC curve table components and add related constants

Refactor CompareRunsMetricsSection component, update imports, and remove console.log statements

Refactor CompareRunsMetricsSection component, update imports, and add ROC curve tab

Refactor Dockerfile to use linux/amd64 platform

Refactor CompareRunsMetricsSection component, update imports, and add confusion matrix tab

Refactor CompareRunsMetricsSection component, add MetricSectionTabLabels enum, and update related types and imports

Refactor CompareRunsMetricsSection component, add MetricSectionTabLabels enum, and update related types and imports

Refactor CompareRunsMetricsSection component and add MetricSectionTabLabels enum

Refactor and remove unused code in usePipelinesUiRoute.ts and ArtifactUriLink.tsx

Refactor MLMD API hooks and remove unused code

Refactor components to improve readability and add loading spinners

Refactor components to improve readability and add loading spinners

clean up

Refactor imports and remove unused code in ArtifactDetails.spec.tsx and ArtifactsTable.spec.tsx
  • Loading branch information
Gkrumbach07 committed May 14, 2024
1 parent 424378e commit 2fd73b2
Show file tree
Hide file tree
Showing 41 changed files with 1,773 additions and 201 deletions.
22 changes: 17 additions & 5 deletions frontend/src/concepts/pipelines/apiHooks/mlmd/types.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import { Artifact, Context, ContextType, Event } from '~/third_party/mlmd';
import { Artifact, Context, ContextType, Event, Execution } from '~/third_party/mlmd';
import { PipelineRunKFv2 } from '~/concepts/pipelines/kfTypes';

export type MlmdContext = Context;

Expand All @@ -8,9 +9,20 @@ export enum MlmdContextTypes {
RUN = 'system.PipelineRun',
}

// An artifact which has associated event.
// You can retrieve artifact name from event.path.steps[0].key
export interface LinkedArtifact {
// each artifact is linked to an event
export type LinkedArtifact = {
event: Event;
artifact: Artifact;
}
};

// each execution can have multiple output artifacts
export type ExecutionArtifact = {
execution: Execution;
linkedArtifacts: LinkedArtifact[];
};

// each run has multiple executions, each execution can have multiple artifacts
export type RunArtifact = {
run: PipelineRunKFv2;
executionArtifacts: ExecutionArtifact[];
};
Original file line number Diff line number Diff line change
@@ -1,22 +1,18 @@
import React from 'react';
import { usePipelinesAPI } from '~/concepts/pipelines/context';
import { GetArtifactTypesRequest } from '~/third_party/mlmd';
import { ArtifactType, GetArtifactTypesRequest } from '~/third_party/mlmd';
import useFetchState, { FetchState, FetchStateCallbackPromise } from '~/utilities/useFetchState';

export const useGetArtifactTypeMap = (): FetchState<Record<number, string>> => {
export const useGetArtifactTypes = (): FetchState<ArtifactType[]> => {
const { metadataStoreServiceClient } = usePipelinesAPI();

const call = React.useCallback<FetchStateCallbackPromise<Record<number, string>>>(async () => {
const call = React.useCallback<FetchStateCallbackPromise<ArtifactType[]>>(async () => {
const request = new GetArtifactTypesRequest();

const res = await metadataStoreServiceClient.getArtifactTypes(request);

const artifactTypeMap: Record<number, string> = {};
res.getArtifactTypesList().forEach((artifactType) => {
artifactTypeMap[artifactType.getId()] = artifactType.getName();
});
return artifactTypeMap;
return res.getArtifactTypesList();
}, [metadataStoreServiceClient]);

return useFetchState(call, {});
return useFetchState(call, []);
};
23 changes: 17 additions & 6 deletions frontend/src/concepts/pipelines/apiHooks/mlmd/useMlmdContext.ts
Original file line number Diff line number Diff line change
@@ -1,13 +1,28 @@
import React from 'react';
import { MlmdContext, MlmdContextTypes } from '~/concepts/pipelines/apiHooks/mlmd/types';
import { usePipelinesAPI } from '~/concepts/pipelines/context';
import { GetContextByTypeAndNameRequest } from '~/third_party/mlmd';
import {
GetContextByTypeAndNameRequest,
MetadataStoreServicePromiseClient,
} from '~/third_party/mlmd';
import useFetchState, {
FetchState,
FetchStateCallbackPromise,
NotReadyError,
} from '~/utilities/useFetchState';

export const getMlmdContext = async (
client: MetadataStoreServicePromiseClient,
name: string,
type: MlmdContextTypes,
): Promise<MlmdContext | undefined> => {
const request = new GetContextByTypeAndNameRequest();
request.setTypeName(type);
request.setContextName(name);
const res = await client.getContextByTypeAndName(request);
return res.getContext();
};

/**
* A hook used to use the MLMD service and fetch the MLMD context
* If being used without name/type, this hook will throw an error
Expand All @@ -28,11 +43,7 @@ export const useMlmdContext = (
return Promise.reject(new NotReadyError('No context name'));
}

const request = new GetContextByTypeAndNameRequest();
request.setTypeName(type);
request.setContextName(name);
const res = await metadataStoreServiceClient.getContextByTypeAndName(request);
const context = res.getContext();
const context = await getMlmdContext(metadataStoreServiceClient, name, type);
if (!context) {
return Promise.reject(new Error('Cannot find specified context'));
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
import React from 'react';
import { Text } from '@patternfly/react-core';
import './confusionMatrix/ConfusionMatrix.scss';

export type ConfusionMatrixInput = {
annotationSpecs: {
displayName: string;
}[];
rows: { row: number[] }[];
};

export interface ConfusionMatrixConfig {
data: number[][];
labels: string[];
}

export function buildConfusionMatrixConfig(
confusionMatrix: ConfusionMatrixInput,
): ConfusionMatrixConfig {
return {
labels: confusionMatrix.annotationSpecs.map((annotation) => annotation.displayName),
data: confusionMatrix.rows.map((x) => x.row),
};
}

type ConfusionMatrixProps = {
config: ConfusionMatrixConfig;
size?: number;
};

const ConfusionMatrix: React.FC<ConfusionMatrixProps> = ({
config: { data, labels },
size = 100,
}) => {
const max = Math.max(...data.flat());

// Function to get color based on the cell value
const getColor = (value: number) => {
const opacity = value / max; // Normalize the value to get opacity
return `rgba(41, 121, 255, ${opacity})`; // Use blue color with calculated opacity
};

// Determine the size for all cells, including labels
const cellSize = `${size}px`;

// Generate the gradient for the legend
const gradientLegend = `linear-gradient(to bottom, rgba(41, 121, 255, 1) 0%, rgba(41, 121, 255, 0) 100%)`;

return (
<div className="ConfusionMatrix">
<table className="ConfusionMatrix-table">
<tbody>
{data.map((row, rowIndex) => (
<tr key={labels[rowIndex]}>
<td
className="ConfusionMatrix-labelCell"
style={{
lineHeight: cellSize,
minWidth: cellSize,
}}
>
<Text>{labels[rowIndex]}</Text>
</td>
{row.map((cell, cellIndex) => (
<td
key={labels[cellIndex] + labels[rowIndex]}
className="ConfusionMatrix-cell"
style={{
backgroundColor: getColor(cell),
color: cell / max < 0.6 ? 'black' : 'white',
height: cellSize,
minHeight: cellSize,
minWidth: cellSize,
width: cellSize,
}}
>
{cell}
</td>
))}
</tr>
))}
<tr>
<th
style={{
width: cellSize,
}}
/>
{labels.map((label, i) => (
<th key={i}>
<div
className="ConfusionMatrix-verticalMarker"
style={{
width: cellSize,
}}
>
<Text style={{ transform: `translateX(${size / 4}px) rotate(315deg)` }}>
{label}
</Text>
</div>
</th>
))}
</tr>
</tbody>
</table>
<div className="ConfusionMatrix-gradientLegendOuter">
<div
className="ConfusionMatrix-gradientLegend"
style={{
height: 0.75 * data.length * size,
background: gradientLegend,
}}
>
<div className="ConfusionMatrix-gradientLegendMaxOuter">
<span className="ConfusionMatrix-gradientLegendMaxLabel">{max}</span>
</div>
{new Array(5).fill(0).map((_, i) => (
<div
key={i}
className="ConfusionMatrix-markerLabel"
style={{
top: `${((5 - i) / 5) * 100}%`,
}}
>
<span className="ConfusionMatrix-gradientLegendMaxLabel">
{Math.floor((i / 5) * max)}
</span>
</div>
))}
</div>
<div className="ConfusionMatrix-trueLabel">True label</div>
</div>
</div>
);
};

export default ConfusionMatrix;
121 changes: 121 additions & 0 deletions frontend/src/concepts/pipelines/content/artifacts/charts/ROCCurve.tsx
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
import React from 'react';
import {
Chart,
ChartAxis,
ChartGroup,
ChartLine,
ChartVoronoiContainer,
} from '@patternfly/react-charts';
import {
chart_color_blue_100 as chartColorBlue100,
chart_color_blue_200 as chartColorBlue200,
chart_color_blue_300 as chartColorBlue300,
chart_color_blue_400 as chartColorBlue400,
chart_color_blue_500 as chartColorBlue500,
chart_color_cyan_100 as chartColorCyan100,
chart_color_cyan_200 as chartColorCyan200,
chart_color_cyan_300 as chartColorCyan300,
chart_color_cyan_400 as chartColorCyan400,
chart_color_cyan_500 as chartColorCyan500,
chart_color_black_100 as chartColorBlack100,
} from '@patternfly/react-tokens';

export type ROCCurveConfig = {
index: number;
data: {
name: string;
x: number;
y: number;
index: number;
}[];
};

export const RocCurveChartColorScale = [
chartColorBlue100.value,
chartColorBlue200.value,
chartColorBlue300.value,
chartColorBlue400.value,
chartColorBlue500.value,
chartColorCyan100.value,
chartColorCyan200.value,
chartColorCyan300.value,
chartColorCyan400.value,
chartColorCyan500.value,
];

type ROCCurveProps = {
configs: ROCCurveConfig[];
maxDimension?: number;
};

const ROCCurve: React.FC<ROCCurveProps> = ({ configs, maxDimension }) => {
const width = maxDimension || 800;
const height = width;
const baseLineData = Array.from(Array(100).keys()).map((x) => ({ x: x / 100, y: x / 100 }));

return (
<div style={{ height, width }}>
<Chart
ariaDesc="ROC Curve"
ariaTitle="ROC Curve"
containerComponent={
<ChartVoronoiContainer
constrainToVisibleArea
voronoiBlacklist={['baseline']}
labels={({ datum }) => `threshold (Series #${datum.index + 1}): ${datum.name}`}
/>
}
height={height}
width={width}
padding={{ bottom: 150, left: 100, right: 50, top: 50 }}
legendAllowWrap
legendPosition="bottom-left"
legendData={configs.map((config) => ({
name: `Series #${config.index + 1}`,
symbol: {
fill: RocCurveChartColorScale[config.index % RocCurveChartColorScale.length],
type: 'square',
},
}))}
>
<ChartAxis
showGrid
dependentAxis
label="True positive rate"
tickValues={Array.from(Array(11).keys()).map((x) => x / 10)}
/>
<ChartAxis
showGrid
label="False positive rate"
tickValues={Array.from(Array(21).keys()).map((x) => x / 20)}
/>
<ChartGroup>
<ChartLine
name="baseline"
data={baseLineData}
style={{
data: {
strokeDasharray: '3,3',
stroke: chartColorBlack100.value,
},
}}
/>
{configs.map((config, idx) => (
<ChartLine
key={idx}
data={config.data}
interpolation="basis"
style={{
data: {
stroke: RocCurveChartColorScale[config.index % RocCurveChartColorScale.length],
},
}}
/>
))}
</ChartGroup>
</Chart>
</div>
);
};

export default ROCCurve;
Loading

0 comments on commit 2fd73b2

Please sign in to comment.