Skip to content

Commit

Permalink
Add titles to generated compare models graphs (#5100)
Browse files Browse the repository at this point in the history
  • Loading branch information
asylves1 authored Oct 9, 2024
1 parent 3b28429 commit 3540582
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ export interface ModelComparisonOperationState extends BaseState {
notebookHistory: NotebookHistory[];
hasCodeRun: boolean;
comparisonImageIds: string[];
comparisonPairs: string[][];
}

export const ModelComparisonOperation: Operation = {
Expand All @@ -26,7 +27,8 @@ export const ModelComparisonOperation: Operation = {
const init: ModelComparisonOperationState = {
notebookHistory: [],
hasCodeRun: false,
comparisonImageIds: []
comparisonImageIds: [],
comparisonPairs: []
};
return init;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@
/>
<ul>
<li v-for="(image, index) in structuralComparisons" :key="index">
<label>Comparison {{ index + 1 }}</label>
<label>Comparison {{ index + 1 }}: {{ getTitle(index) }}</label>
<Image id="img" :src="image" :alt="`Structural comparison ${index + 1}`" preview />
</li>
</ul>
Expand Down Expand Up @@ -198,6 +198,7 @@ const code = ref(props.node.state.notebookHistory?.[0]?.code ?? '');
const llmThoughts = ref<any[]>([]);
const isKernelReady = ref(false);
const contextLanguage = ref<string>('python3');
const comparisonPairs = ref(props.node.state.comparisonPairs);
const initializeAceEditor = (editorInstance: any) => {
editor = editorInstance;
Expand All @@ -218,6 +219,18 @@ function updateImagesState(operationType: string, newImageId: string | null = nu
emit('update-state', state);
}
function updateComparisonTitlesState(operationType: string, pairs: string[][] | null = null) {
const state = cloneDeep(props.node.state);
if (operationType === 'add' && pairs !== null) state.comparisonPairs = pairs;
else if (operationType === 'clear') state.comparisonPairs = [];
emit('update-state', state);
}
function getTitle(index: number) {
if (!comparisonPairs.value[index]) return '';
return `${comparisonPairs.value[index][0].replaceAll('_', ' ')} VS ${comparisonPairs.value[index][1].replaceAll('_', ' ')}`;
}
function updateCodeState() {
const state = saveCodeToState(props.node, code.value, true);
emit('update-state', state);
Expand All @@ -226,6 +239,7 @@ function updateCodeState() {
function emptyImages() {
deleteImages(props.node.state.comparisonImageIds); // Delete images from S3
updateImagesState('clear'); // Then their ids can be removed from the state
updateComparisonTitlesState('clear');
structuralComparisons.value = [];
}
Expand All @@ -248,6 +262,17 @@ function runCode() {
emptyImages();
updateCodeState();
kernelManager.sendMessage('get_comparison_pairs_request', {}).register('any_get_comparison_pairs_reply', (data) => {
const pairs = data.msg.content?.return?.comparison_pairs;
const state = cloneDeep(props.node.state);
if (pairs.length) {
updateComparisonTitlesState('add', pairs);
comparisonPairs.value = pairs;
} else if (state.comparisonPairs.length) {
comparisonPairs.value = state.comparisonPairs;
}
});
kernelManager
.sendMessage('execute_request', messageContent)
.register('display_data', (data) => {
Expand Down

0 comments on commit 3540582

Please sign in to comment.