Skip to content

Commit

Permalink
4330 replace loss chart with vega chart (#4447)
Browse files Browse the repository at this point in the history
  • Loading branch information
jryu01 authored Aug 15, 2024
1 parent 303877e commit 46a232d
Show file tree
Hide file tree
Showing 9 changed files with 101 additions and 70 deletions.
15 changes: 10 additions & 5 deletions packages/client/hmi-client/src/components/widgets/VegaChart.vue
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
{{ renderErrorMessage }}
</p>
<div ref="vegaContainer"></div>
<footer>
<footer v-if="$slots.footer">
<slot name="footer" />
</footer>
</div>
Expand All @@ -15,7 +15,7 @@ import embed, { Result, VisualizationSpec } from 'vega-embed';
import { Config as VgConfig } from 'vega';
import { Config as VlConfig } from 'vega-lite';
import { ref, watch, toRaw, isRef, isReactive, isProxy } from 'vue';
import { ref, watch, toRaw, isRef, isReactive, isProxy, computed } from 'vue';
export type Config = VgConfig | VlConfig;
Expand All @@ -41,6 +41,8 @@ const props = withDefaults(
);
const vegaContainer = ref<HTMLElement>();
const vegaVisualization = ref<Result>();
const view = computed(() => vegaVisualization.value?.view);
const renderErrorMessage = ref<String>();
const emit = defineEmits<{
Expand Down Expand Up @@ -97,17 +99,16 @@ async function updateVegaVisualization(container: HTMLElement, visualizationSpec
actions: props.areEmbedActionsVisible === false ? false : undefined
}
);
const { view } = vegaVisualization.value;
props.intervalSelectionSignalNames.forEach((signalName) => {
view.addSignalListener(signalName, (name, valueRange: { [fieldName: string]: [number, number] }) => {
view.value!.addSignalListener(signalName, (name, valueRange: { [fieldName: string]: [number, number] }) => {
if (valueRange === undefined || Object.keys(valueRange).length === 0) {
emit('update-interval-selection', name, null);
return;
}
emit('update-interval-selection', name, valueRange);
});
});
view.addEventListener('click', (_event, item) => {
view.value!.addEventListener('click', (_event, item) => {
emit('chart-click', item?.datum ?? null);
});
} catch (e) {
Expand All @@ -123,6 +124,10 @@ watch([vegaContainer, () => props.visualizationSpec], () => {
const spec = deepToRaw(props.visualizationSpec);
updateVegaVisualization(vegaContainer.value, spec);
});
defineExpose({
view
});
</script>
<style scoped>
.vega-chart-container {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -149,11 +149,18 @@
<!-- Loss chart -->
<h5>Loss</h5>
<div ref="drilldownLossPlot" class="loss-chart" />
<div ref="lossChartContainer">
<vega-chart
v-if="lossValues.length > 0 || showSpinner"
ref="lossChartRef"
:are-embed-actions-visible="true"
:visualization-spec="lossChartSpec"
/>
</div>
<!-- Variable charts -->
<div v-if="!showSpinner" class="form-section">
<section v-if="modelConfig && csvAsset" ref="outputPanel">
<section ref="outputPanel" v-if="modelConfig && csvAsset">
<h5>Parameters</h5>
<tera-chart-control
:chart-config="{ selectedRun: 'fixme', selectedVariable: selectedParameters }"
Expand Down Expand Up @@ -223,14 +230,15 @@
<script setup lang="ts">
import _ from 'lodash';
import * as vega from 'vega';
import { csvParse, autoType, mean, variance } from 'd3';
import { computed, onMounted, ref, shallowRef, watch } from 'vue';
import Button from 'primevue/button';
import DataTable from 'primevue/datatable';
import Dropdown from 'primevue/dropdown';
import Column from 'primevue/column';
import TeraInputNumber from '@/components/widgets/tera-input-number.vue';
import { CalibrateMap, renderLossGraph, setupDatasetInput, setupModelInput } from '@/services/calibrate-workflow';
import { CalibrateMap, setupDatasetInput, setupModelInput } from '@/services/calibrate-workflow';
import TeraDrilldown from '@/components/drilldown/tera-drilldown.vue';
import TeraDrilldownSection from '@/components/drilldown/tera-drilldown-section.vue';
import TeraDrilldownPreview from '@/components/drilldown/tera-drilldown-preview.vue';
Expand Down Expand Up @@ -334,16 +342,12 @@ const cancelRunId = computed(
);
const currentDatasetFileName = ref<string>();
const drilldownLossPlot = ref<HTMLElement>();
const runResult = ref<DataArray>([]);
const runResultPre = ref<DataArray>([]);
const runResultSummary = ref<DataArray>([]);
const runResultSummaryPre = ref<DataArray>([]);
const previewChartWidth = ref(120);
const showSpinner = ref(false);
let lossValues: { [key: string]: number }[] = [];
const mapping = ref<CalibrateMap[]>(props.node.state.mapping);
Expand Down Expand Up @@ -399,6 +403,8 @@ const disableRunButton = computed(
);
const selectedOutputId = ref<string>();
const lossChartContainer = ref(null);
const lossChartSize = computed(() => drilldownChartSize(lossChartContainer.value));
const outputPanel = ref(null);
const chartSize = computed(() => drilldownChartSize(outputPanel.value));
Expand Down Expand Up @@ -460,18 +466,18 @@ const preparedCharts = computed(() => {
}
charts[variable] = createForecastChart(
{
dataset: result,
data: result,
variables: [`${pyciemssMap[variable]}:pre`, pyciemssMap[variable]],
timeField: 'timepoint_id',
groupField: 'sample_id'
},
{
dataset: resultSummary,
data: resultSummary,
variables: [`${pyciemssMap[variable]}_mean:pre`, `${pyciemssMap[variable]}_mean`],
timeField: 'timepoint_id'
},
{
dataset: groundTruth,
data: groundTruth,
variables: datasetVariables,
timeField: datasetTimeField as string,
groupField: 'sample_id'
Expand Down Expand Up @@ -526,6 +532,29 @@ const preparedDistributionCharts = computed(() => {
return charts;
});
const LOSS_CHART_DATA_SOURCE = 'lossData'; // Name of the streaming data source
const lossChartRef = ref<InstanceType<typeof VegaChart>>();
const lossChartSpec = ref();
const lossValues = ref<{ [key: string]: number }[]>([]);
const updateLossChartSpec = (data: string | Record<string, any>[]) => {
lossChartSpec.value = createForecastChart(
null,
{
data: Array.isArray(data) ? data : { name: data },
variables: ['loss'],
timeField: 'iter'
},
null,
{
title: '',
width: lossChartSize.value.width,
height: 100,
xAxisTitle: 'Solver iterations',
yAxisTitle: 'Loss'
}
);
};
const runCalibrate = async () => {
if (!modelConfigId.value || !datasetId.value || !currentDatasetFileName.value) return;
Expand All @@ -538,7 +567,7 @@ const runCalibrate = async () => {
}
// Reset loss buffer
lossValues = [];
lossValues.value = [];
const state = _.cloneDeep(props.node.state);
Expand Down Expand Up @@ -577,14 +606,10 @@ const runCalibrate = async () => {
};
const messageHandler = (event: ClientEvent<any>) => {
lossValues.push({ iter: lossValues.length, loss: event.data.loss });
if (drilldownLossPlot.value) {
renderLossGraph(drilldownLossPlot.value, lossValues, {
width: previewChartWidth.value,
height: 120
});
}
if (!lossChartRef.value?.view) return;
const data = { iter: lossValues.value.length, loss: event.data.loss };
lossChartRef.value.view.change(LOSS_CHART_DATA_SOURCE, vega.changeset().insert(data)).resize().run();
lossValues.value.push(data);
};
const onSelection = (id: string) => {
Expand Down Expand Up @@ -645,11 +670,6 @@ async function getAutoMapping() {
}
onMounted(async () => {
// Get sizing
if (drilldownLossPlot.value) {
previewChartWidth.value = drilldownLossPlot.value.offsetWidth;
}
// Model configuration input
const { modelConfiguration, modelOptions, modelPartUnits, modelPartTypes } = await setupModelInput(
modelConfigId.value
Expand Down Expand Up @@ -683,9 +703,11 @@ watch(
(id) => {
if (id === '') {
showSpinner.value = false;
updateLossChartSpec(lossValues.value);
unsubscribeToUpdateMessages([id], ClientEventType.SimulationPyciemss, messageHandler);
} else {
showSpinner.value = true;
updateLossChartSpec(LOSS_CHART_DATA_SOURCE);
subscribeToUpdateMessages([id], ClientEventType.SimulationPyciemss, messageHandler);
}
},
Expand All @@ -702,16 +724,11 @@ watch(
// Fetch saved intermediate state
const simulationObj = await getSimulation(props.node.state.calibrationId);
if (simulationObj?.updates) {
lossValues = simulationObj?.updates.map((d, i) => ({
lossValues.value = simulationObj?.updates.map((d, i) => ({
iter: i,
loss: d.data.loss
}));
if (drilldownLossPlot.value) {
renderLossGraph(drilldownLossPlot.value, lossValues, {
width: previewChartWidth.value,
height: 120
});
}
updateLossChartSpec(lossValues.value);
}
const state = props.node.state;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
:visualization-spec="preparedCharts[index]"
/>
</template>
<div v-else ref="drilldownLossPlot" class="loss-chart" />
<vega-chart v-else-if="lossChartSpec" :are-embed-actions-visible="false" :visualization-spec="lossChartSpec" />
<tera-progress-spinner v-if="inProgressCalibrationId" :font-size="2" is-centered style="height: 100%">
<div>{{ props.node.state.currentProgress }}%</div>
Expand Down Expand Up @@ -42,7 +42,7 @@ import {
} from '@/services/models/simulation-service';
import { getModelConfigurationById, createModelConfiguration } from '@/services/model-configurations';
import { getModelByModelConfigurationId, getUnitsFromModelParts } from '@/services/model';
import { renderLossGraph, setupDatasetInput } from '@/services/calibrate-workflow';
import { setupDatasetInput } from '@/services/calibrate-workflow';
import { nodeMetadata, nodeOutputLabel } from '@/components/workflow/util';
import { logger } from '@/utils/logger';
import { Poller, PollerState } from '@/api/api';
Expand Down Expand Up @@ -80,23 +80,34 @@ const runResult = ref<DataArray>([]);
const runResultPre = ref<DataArray>([]);
const runResultSummary = ref<DataArray>([]);
const runResultSummaryPre = ref<DataArray>([]);
const drilldownLossPlot = ref<HTMLElement>();
const csvAsset = shallowRef<CsvAsset | undefined>(undefined);
const areInputsFilled = computed(() => props.node.inputs[0].value && props.node.inputs[1].value);
const inProgressCalibrationId = computed(() => props.node.state.inProgressCalibrationId);
const chartSize = { width: 180, height: 120 };
let lossValues: { [key: string]: number }[] = [];
function drawLossGraph() {
if (drilldownLossPlot.value) {
renderLossGraph(drilldownLossPlot.value, lossValues, {
width: 200,
height: 120
});
}
}
const lossChartSpec = ref();
const updateLossChartSpec = (data: Record<string, any>[]) => {
lossChartSpec.value = createForecastChart(
null,
{
data,
variables: ['loss'],
timeField: 'iter'
},
null,
{
title: '',
xAxisTitle: 'Solver iterations',
yAxisTitle: 'Loss',
...chartSize
}
);
};
async function updateLossChartWithSimulation() {
if (props.node.active) {
Expand All @@ -106,7 +117,7 @@ async function updateLossChartWithSimulation() {
iter: i,
loss: d.data.loss
}));
drawLossGraph();
updateLossChartSpec(lossValues);
}
}
}
Expand Down Expand Up @@ -156,30 +167,29 @@ const preparedCharts = computed(() => {
return createForecastChart(
{
dataset: result,
data: result,
variables: [`${pyciemssMap[variable]}:pre`, pyciemssMap[variable]],
timeField: 'timepoint_id',
groupField: 'sample_id'
},
{
dataset: resultSummary,
data: resultSummary,
variables: [`${pyciemssMap[variable]}_mean:pre`, `${pyciemssMap[variable]}_mean`],
timeField: 'timepoint_id'
},
{
dataset: groundTruth,
data: groundTruth,
variables: datasetVariables,
timeField: datasetTimeField as string
},
{
title: '',
width: 180,
height: 120,
legend: true,
translationMap: reverseMap,
xAxisTitle: modelVarUnits.value._time || 'Time',
yAxisTitle: modelVarUnits.value[variable] || '',
colorscheme: ['#AAB3C6', '#1B8073']
colorscheme: ['#AAB3C6', '#1B8073'],
...chartSize
}
);
});
Expand All @@ -197,7 +207,7 @@ const pollResult = async (runId: string) => {
iter: i,
loss: d.data.loss
}));
drawLossGraph();
updateLossChartSpec(lossValues);
}
if (runId === props.node.state.inProgressCalibrationId && data.updates.length > 0) {
const checkpoint = _.first(data.updates);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -958,13 +958,13 @@ const preparedForecastCharts = computed(() => {
const forecastChart = createForecastChart(
{
dataset: result,
data: result,
variables: [`${pyciemssMap[variable]}:pre`, pyciemssMap[variable]],
timeField: 'timepoint_id',
groupField: 'sample_id'
},
{
dataset: resultSummary,
data: resultSummary,
variables: [`${pyciemssMap[variable]}_mean:pre`, `${pyciemssMap[variable]}_mean`],
timeField: 'timepoint_id'
},
Expand All @@ -984,13 +984,13 @@ const preparedForecastCharts = computed(() => {
return createForecastChart(
{
dataset: result,
data: result,
variables: [`${pyciemssMap[variable]}:pre`, pyciemssMap[variable]],
timeField: 'timepoint_id',
groupField: 'sample_id'
},
{
dataset: resultSummary,
data: resultSummary,
variables: [`${pyciemssMap[variable]}_mean:pre`, `${pyciemssMap[variable]}_mean`],
timeField: 'timepoint_id'
},
Expand Down
Loading

0 comments on commit 46a232d

Please sign in to comment.