Skip to content

Commit

Permalink
4276 parameter distribution charts (#4414)
Browse files Browse the repository at this point in the history
  • Loading branch information
jryu01 authored Aug 12, 2024
1 parent 2086dd0 commit eccd02a
Show file tree
Hide file tree
Showing 7 changed files with 188 additions and 59 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@
{{ renderErrorMessage }}
</p>
<div ref="vegaContainer"></div>
<footer>
<slot name="footer" />
</footer>
</div>
</template>

Expand Down Expand Up @@ -128,6 +131,9 @@ watch([vegaContainer, () => props.visualizationSpec], () => {
border: 1px solid var(--surface-border-light);
margin-bottom: var(--gap-4);
padding-top: var(--gap-2);
footer {
padding: var(--gap-3);
}
}
/* adjust style, position and rotation of action button */
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@ const DOCUMENTATION_URL = 'https://github.com/ciemss/pyciemss/blob/main/pyciemss

export interface CalibrationOperationStateCiemss extends BaseState {
method: string;
chartConfigs: string[][];
selectedParameters: string[];
selectedVariables: string[];
mapping: CalibrateMap[];
simulationsInProgress: string[];

Expand Down Expand Up @@ -51,7 +52,8 @@ export const CalibrationOperationCiemss: Operation = {
initState: () => {
const init: CalibrationOperationStateCiemss = {
method: 'dopri5',
chartConfigs: [],
selectedParameters: [],
selectedVariables: [],
mapping: [{ modelVariable: '', datasetVariable: '' }],
simulationsInProgress: [],
currentProgress: 0,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -153,20 +153,56 @@
<!-- Variable charts -->
<div v-if="!showSpinner" class="form-section">
<h5>Variables</h5>
<section v-if="modelConfig && csvAsset" ref="outputPanel">
<template v-for="(cfg, index) of node.state.chartConfigs" :key="index">
<tera-chart-control
:chart-config="{ selectedRun: 'fixme', selectedVariable: cfg }"
:multi-select="false"
:show-remove-button="true"
:variables="Object.keys(pyciemssMap)"
@configuration-change="chartProxy.configurationChange(index, $event)"
@remove="chartProxy.removeChart(index)"
/>
<vega-chart :are-embed-actions-visible="true" :visualization-spec="preparedCharts[index]" />
<h5>Parameters</h5>
<tera-chart-control
:chart-config="{ selectedRun: 'fixme', selectedVariable: selectedParameters }"
:multi-select="true"
:show-remove-button="false"
:variables="Object.keys(pyciemssMap).filter((c) => modelPartTypesMap[c] === 'parameter')"
@configuration-change="updateSelectedParameters"
/>
<template v-for="param of node.state.selectedParameters" :key="param">
<vega-chart
:are-embed-actions-visible="true"
:visualization-spec="preparedDistributionCharts[param].histogram"
>
<template v-slot:footer>
<table class="distribution-table">
<thead>
<tr>
<th scope="col"></th>
<th scope="col">{{ preparedDistributionCharts[param].stat.header[0] }}</th>
<th scope="col">{{ preparedDistributionCharts[param].stat.header[1] }}</th>
</tr>
</thead>
<tbody>
<tr>
<th scope="row">Mean</th>
<td>{{ preparedDistributionCharts[param].stat.mean[0] }}</td>
<td>{{ preparedDistributionCharts[param].stat.mean[1] }}</td>
</tr>
<tr>
<th scope="row">Variance</th>
<td>{{ preparedDistributionCharts[param].stat.variance[0] }}</td>
<td>{{ preparedDistributionCharts[param].stat.variance[1] }}</td>
</tr>
</tbody>
</table>
</template>
</vega-chart>
</template>
<h5>Variables</h5>
<tera-chart-control
:chart-config="{ selectedRun: 'fixme', selectedVariable: selectedVariables }"
:multi-select="true"
:show-remove-button="false"
:variables="Object.keys(pyciemssMap).filter((c) => modelPartTypesMap[c] !== 'parameter')"
@configuration-change="updateSelectedVariables"
/>
<template v-for="variable of node.state.selectedVariables" :key="variable">
<vega-chart :are-embed-actions-visible="true" :visualization-spec="preparedCharts[variable]" />
</template>
<Button size="small" text @click="chartProxy.addChart()" label="Add chart" icon="pi pi-plus" />
</section>
<section v-else-if="!modelConfig" class="emptyState">
<img src="@assets/svg/seed.svg" alt="" draggable="false" />
Expand All @@ -187,7 +223,7 @@
<script setup lang="ts">
import _ from 'lodash';
import { csvParse, autoType } from 'd3';
import { csvParse, autoType, mean, variance } from 'd3';
import { computed, onMounted, ref, shallowRef, watch } from 'vue';
import Button from 'primevue/button';
import DataTable from 'primevue/datatable';
Expand All @@ -210,7 +246,7 @@ import {
DatasetColumn,
ModelConfiguration
} from '@/types/Types';
import { getTimespan, chartActionsProxy, drilldownChartSize, nodeMetadata } from '@/components/workflow/util';
import { getTimespan, drilldownChartSize, nodeMetadata } from '@/components/workflow/util';
import { useToastService } from '@/services/toast';
import { autoCalibrationMapping } from '@/services/concept';
import {
Expand All @@ -225,11 +261,12 @@ import {
} from '@/services/models/simulation-service';
import type { WorkflowNode } from '@/types/workflow';
import { createForecastChart } from '@/services/charts';
import { createForecastChart, createHistogramChart } from '@/services/charts';
import VegaChart from '@/components/widgets/VegaChart.vue';
import TeraChartControl from '@/components/workflow/tera-chart-control.vue';
import { CiemssPresetTypes, DrilldownTabs } from '@/types/common';
import TeraInputText from '@/components/widgets/tera-input-text.vue';
import { displayNumber } from '@/utils/number';
import type { CalibrationOperationStateCiemss } from './calibrate-operation';
import { renameFnGenerator, mergeResults } from './calibrate-utils';
Expand Down Expand Up @@ -283,6 +320,7 @@ const csvAsset = shallowRef<CsvAsset | undefined>(undefined);
const modelConfig = ref<ModelConfiguration>();
const modelVarUnits = ref<{ [key: string]: string }>({});
const modelPartTypesMap = ref<{ [key: string]: string }>({});
const modelConfigId = computed<string | undefined>(() => props.node.inputs[0]?.value?.[0]);
const datasetId = computed<string | undefined>(() => props.node.inputs[1]?.value?.[0]);
Expand Down Expand Up @@ -361,12 +399,17 @@ const disableRunButton = computed(
);
const selectedOutputId = ref<string>();
const outputPanel = ref(null);
const chartSize = computed(() => drilldownChartSize(outputPanel.value));
const selectedParameters = ref<string[]>(props.node.state.selectedParameters);
const selectedVariables = ref<string[]>(props.node.state.selectedVariables);
let pyciemssMap: Record<string, string> = {};
const preparedCharts = computed(() => {
const preparedChartInputs = computed(() => {
const state = props.node.state;
if (!state.calibrationId) return [];
if (!state.calibrationId) return null;
// Merge before/after for chart
const { result, resultSummary } = mergeResults(
Expand All @@ -385,6 +428,17 @@ const preparedCharts = computed(() => {
state.mapping.forEach((mapObj) => {
reverseMap[mapObj.datasetVariable] = 'Observations';
});
return {
result,
resultSummary,
reverseMap
};
});
const preparedCharts = computed(() => {
if (!preparedChartInputs.value) return [];
const { result, resultSummary, reverseMap } = preparedChartInputs.value;
const state = props.node.state;
// FIXME: Hacky re-parse CSV with correct data types
let groundTruth: DataArray = [];
Expand All @@ -397,25 +451,23 @@ const preparedCharts = computed(() => {
// Need to get the dataset's time field
const datasetTimeField = state.mapping.find((d) => d.modelVariable === 'timestamp')?.datasetVariable;
return state.chartConfigs.map((config) => {
const charts = {};
state.selectedVariables.forEach((variable) => {
const datasetVariables: string[] = [];
config.forEach((variableName) => {
const mapObj = state.mapping.find((d) => d.modelVariable === variableName);
if (mapObj) {
datasetVariables.push(mapObj.datasetVariable);
}
});
return createForecastChart(
const mapObj = state.mapping.find((d) => d.modelVariable === variable);
if (mapObj) {
datasetVariables.push(mapObj.datasetVariable);
}
charts[variable] = createForecastChart(
{
dataset: result,
variables: [...config.map((d) => `${pyciemssMap[d]}:pre`), ...config.map((d) => pyciemssMap[d])],
variables: [`${pyciemssMap[variable]}:pre`, pyciemssMap[variable]],
timeField: 'timepoint_id',
groupField: 'sample_id'
},
{
dataset: resultSummary,
variables: [...config.map((d) => `${pyciemssMap[d]}_mean:pre`), ...config.map((d) => `${pyciemssMap[d]}_mean`)],
variables: [`${pyciemssMap[variable]}_mean:pre`, `${pyciemssMap[variable]}_mean`],
timeField: 'timepoint_id'
},
{
Expand All @@ -425,24 +477,53 @@ const preparedCharts = computed(() => {
groupField: 'sample_id'
},
{
title: '',
title: variable,
width: chartSize.value.width,
height: chartSize.value.height,
legend: true,
translationMap: reverseMap,
xAxisTitle: modelVarUnits.value._time || 'Time',
yAxisTitle: _.uniq(config.map((v) => modelVarUnits.value[v]).filter((v) => !!v)).join(',') || '',
yAxisTitle: modelVarUnits.value[variable] || '',
colorscheme: ['#AAB3C6', '#1B8073']
}
);
});
return charts;
});
const outputPanel = ref(null);
const chartSize = computed(() => drilldownChartSize(outputPanel.value));
const chartProxy = chartActionsProxy(props.node, (state: CalibrationOperationStateCiemss) => {
emit('update-state', state);
const preparedDistributionCharts = computed(() => {
if (!preparedChartInputs.value) return [];
const { result } = preparedChartInputs.value;
const state = props.node.state;
const labelBefore = 'Before calibration';
const labelAfter = 'After calibration';
const charts = {};
state.selectedParameters.forEach((param) => {
const fieldName = pyciemssMap[param];
const beforeFieldName = `${fieldName}:pre`;
const histogram = createHistogramChart(result, {
title: `${param}`,
width: chartSize.value.width,
height: chartSize.value.height,
xAxisTitle: `${param}`,
yAxisTitle: 'Count',
maxBins: 10,
variables: [
{ field: beforeFieldName, label: labelBefore, width: 54, color: '#AAB3C6' },
{ field: fieldName, label: labelAfter, width: 24, color: '#1B8073' }
]
});
const toDisplayNumber = (num?: number) => (num ? displayNumber(num.toString()) : '');
const stat = {
header: [labelBefore, labelAfter],
mean: [mean(result, (d) => d[beforeFieldName]), mean(result, (d) => d[fieldName])].map(toDisplayNumber),
variance: [variance(result, (d) => d[beforeFieldName]), variance(result, (d) => d[fieldName])].map(
toDisplayNumber
)
};
charts[param] = { histogram, stat };
});
return charts;
});
const runCalibrate = async () => {
Expand Down Expand Up @@ -510,6 +591,14 @@ const onSelection = (id: string) => {
emit('select-output', id);
};
function updateSelectedParameters(event) {
emit('update-state', { ...props.node.state, selectedParameters: event.selectedVariable });
}
function updateSelectedVariables(event) {
emit('update-state', { ...props.node.state, selectedVariables: event.selectedVariable });
}
// Used from button to add new entry to the mapping object
function addMapping() {
mapping.value.push({
Expand Down Expand Up @@ -562,10 +651,13 @@ onMounted(async () => {
}
// Model configuration input
const { modelConfiguration, modelOptions, modelVariableUnits } = await setupModelInput(modelConfigId.value);
const { modelConfiguration, modelOptions, modelPartUnits, modelPartTypes } = await setupModelInput(
modelConfigId.value
);
modelConfig.value = modelConfiguration;
modelStateOptions.value = modelOptions;
modelVarUnits.value = modelVariableUnits ?? {};
modelVarUnits.value = modelPartUnits ?? {};
modelPartTypesMap.value = modelPartTypes ?? {};
// dataset input
const { filename, csv, datasetOptions } = await setupDatasetInput(datasetId.value);
Expand Down Expand Up @@ -720,4 +812,22 @@ img {
border: 1px solid var(--surface-border-light);
border-radius: var(--border-radius-medium);
}
.distribution-table {
width: 100%;
border-collapse: collapse;
thead {
background-color: var(--surface-200);
}
tr {
height: 1.75rem;
}
tbody tr {
border-bottom: 1px solid var(--surface-border-light);
}
td,
th {
text-align: center;
}
}
</style>
Loading

0 comments on commit eccd02a

Please sign in to comment.