Skip to content

Commit 620459a

Browse files
hummuscienceMuad Abd El Hayclaude
authored
Calibration plot for frame diagnostics streamlit (#308)
* Remove unexpected save_heatmaps parameter from export_predictions_and_labeled_video call The function doesn't accept this parameter, causing a TypeError. 🤖 Generated with [Claude Code](https://claude.ai/code) Co-Authored-By: Claude <[email protected]> * Add multi-model calibration plot feature to labeled frame diagnostics - Add calibration plot section to labeled_frame_diagnostics.py Streamlit app - Support for multi-model comparison with interactive plot controls - New plot_calibration_diagram_multi() function for comparing multiple models - Display Expected Calibration Error (ECE) for each model in legend - Configurable error threshold and number of bins - Color-coded calibration curves for easy visual comparison - Robust error handling for missing data cases 🤖 Generated with [Claude Code](https://claude.ai/code) Co-Authored-By: Claude <[email protected]> --------- Co-authored-by: Muad Abd El Hay <[email protected]> Co-authored-by: Claude <[email protected]>
1 parent fb3ae32 commit 620459a

File tree

2 files changed

+354
-0
lines changed

2 files changed

+354
-0
lines changed

lightning_pose/apps/labeled_frame_diagnostics.py

Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
get_y_label,
2222
make_plotly_catplot,
2323
make_plotly_scatterplot,
24+
plot_calibration_diagram_multi,
2425
)
2526
from lightning_pose.apps.utils import (
2627
build_precomputed_metrics_df,
@@ -282,6 +283,91 @@ def run():
282283

283284
st.plotly_chart(fig_scatter)
284285

286+
# ---------------------------------------------------
287+
# calibration plot
288+
# ---------------------------------------------------
289+
st.header("Model Calibration Analysis")
290+
291+
col9, col10, col11 = st.columns(3)
292+
293+
with col9:
294+
models_for_calib = st.multiselect(
295+
"Select models to compare:", new_names, key="models_calib"
296+
)
297+
298+
with col10:
299+
n_bins = st.slider("Number of bins:", min_value=5, max_value=20, value=10)
300+
301+
with col11:
302+
error_threshold = st.number_input(
303+
"Error threshold (pixels):",
304+
min_value=1.0, max_value=50.0, value=5.0, step=1.0
305+
)
306+
307+
# Process calibration data for all selected models
308+
if models_for_calib and 'pixel error' in metric_options:
309+
# Collect data for all selected models
310+
models_data = []
311+
for model_name in models_for_calib:
312+
confidence_df = dframes_metrics[model_name]['confidence']
313+
pixel_error_df = df_metrics['pixel error']
314+
pixel_error_df_model = pixel_error_df[
315+
(pixel_error_df.model_name == model_name) &
316+
(pixel_error_df.set == data_type)
317+
]
318+
319+
if keypoint_to_plot != "mean":
320+
# Get confidence and error for specific keypoint
321+
conf_cols = [
322+
c for c in confidence_df.columns
323+
if c[0] == keypoint_to_plot and c[1] == 'likelihood'
324+
]
325+
if conf_cols:
326+
confidences = confidence_df.loc[
327+
confidence_df.iloc[:, -1] == data_type, conf_cols[0]
328+
].values
329+
errors = pixel_error_df_model[keypoint_to_plot].values
330+
else:
331+
confidences = np.array([])
332+
errors = np.array([])
333+
else:
334+
# Calculate mean confidence and error across all keypoints
335+
conf_cols = [c for c in confidence_df.columns if c[1] == 'likelihood']
336+
confidences_all = confidence_df.loc[
337+
confidence_df.iloc[:, -1] == data_type, conf_cols
338+
].values
339+
confidences = np.nanmean(confidences_all, axis=1)
340+
341+
error_cols = [kp for kp in keypoint_names]
342+
errors_all = pixel_error_df_model[error_cols].values
343+
errors = np.nanmean(errors_all, axis=1)
344+
345+
if len(confidences) > 0 and len(errors) > 0:
346+
# Calculate accuracies based on error threshold
347+
accuracies = (errors <= error_threshold).astype(float)
348+
models_data.append({
349+
'model_name': model_name,
350+
'confidences': confidences,
351+
'accuracies': accuracies
352+
})
353+
354+
if models_data:
355+
# Create multi-model calibration plot
356+
fig_calib = plot_calibration_diagram_multi(
357+
models_data=models_data,
358+
n_bins=n_bins,
359+
keypoint_name=keypoint_to_plot,
360+
data_type=data_type,
361+
error_threshold=error_threshold
362+
)
363+
st.plotly_chart(fig_calib)
364+
else:
365+
st.warning("No data available for calibration plot.")
366+
elif not models_for_calib:
367+
st.info("Please select at least one model for calibration analysis.")
368+
else:
369+
st.warning("Pixel error metric not available for calibration analysis.")
370+
285371

286372
if __name__ == "__main__":
287373

lightning_pose/apps/plots.py

Lines changed: 268 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import seaborn as sns
77
from matplotlib import pyplot as plt
88
from plotly.subplots import make_subplots
9+
from sklearn.calibration import calibration_curve
910

1011
pix_error_key = "pixel error"
1112
conf_error_key = "confidence"
@@ -267,3 +268,270 @@ def plot_precomputed_traces(df_metrics, df_traces, cols):
267268
)
268269

269270
return fig_traces
271+
272+
273+
def plot_calibration_diagram(
274+
confidences,
275+
accuracies,
276+
n_bins=10,
277+
model_name="Model",
278+
keypoint_name="",
279+
data_type="",
280+
error_threshold=5.0,
281+
):
282+
"""
283+
Plot calibration diagram for pose estimation model using Plotly.
284+
285+
Args:
286+
confidences: predicted confidence scores (0-1)
287+
accuracies: binary array indicating if prediction was accurate (1) or not (0)
288+
n_bins: number of bins for grouping confidences
289+
model_name: name of the model for title
290+
keypoint_name: name of the keypoint being analyzed
291+
data_type: train/val/test data split
292+
error_threshold: pixel error threshold used to determine accuracy
293+
294+
Returns:
295+
Plotly figure object
296+
"""
297+
# Calculate calibration curve
298+
fraction_of_positives, mean_predicted_value = calibration_curve(
299+
accuracies, confidences, n_bins=n_bins, strategy='uniform'
300+
)
301+
302+
# Calculate expected calibration error (ECE) - simplified version
303+
# ECE is the weighted average of the absolute differences between accuracy and confidence
304+
if len(mean_predicted_value) > 0 and len(confidences) > 0:
305+
# For each bin, calculate |accuracy - confidence| weighted by bin size
306+
# We'll recompute bins to ensure consistency
307+
bin_edges = np.linspace(0, 1, n_bins + 1)
308+
ece = 0.0
309+
total_count = 0
310+
311+
for i in range(n_bins):
312+
# Find points in this bin
313+
in_bin = (confidences >= bin_edges[i]) & (confidences < bin_edges[i + 1])
314+
if i == n_bins - 1: # Include right edge in last bin
315+
in_bin = (confidences >= bin_edges[i]) & (confidences <= bin_edges[i + 1])
316+
317+
bin_count = np.sum(in_bin)
318+
if bin_count > 0:
319+
bin_accuracy = np.mean(accuracies[in_bin])
320+
bin_confidence = np.mean(confidences[in_bin])
321+
ece += bin_count * np.abs(bin_accuracy - bin_confidence)
322+
total_count += bin_count
323+
324+
ece = ece / total_count if total_count > 0 else 0
325+
else:
326+
ece = 0
327+
328+
# Create Plotly figure
329+
fig = go.Figure()
330+
331+
# Add perfect calibration line
332+
fig.add_trace(go.Scatter(
333+
x=[0, 1],
334+
y=[0, 1],
335+
mode='lines',
336+
name='Perfect calibration',
337+
line=dict(dash='dash', color='black'),
338+
showlegend=True
339+
))
340+
341+
# Add model calibration curve
342+
fig.add_trace(go.Scatter(
343+
x=mean_predicted_value,
344+
y=fraction_of_positives,
345+
mode='markers+lines',
346+
name=f'{model_name}',
347+
marker=dict(size=10, color='blue'),
348+
line=dict(color='blue'),
349+
showlegend=True
350+
))
351+
352+
# Add confidence histogram as marginal
353+
fig.add_trace(go.Histogram(
354+
x=confidences,
355+
name='Confidence distribution',
356+
yaxis='y2',
357+
opacity=0.3,
358+
showlegend=False,
359+
marker_color='gray',
360+
nbinsx=20
361+
))
362+
363+
# Update layout
364+
title_text = f'Calibration Plot - {model_name}'
365+
if keypoint_name:
366+
title_text += f' ({keypoint_name})'
367+
if data_type:
368+
title_text += f' - {data_type} set'
369+
title_text += f'<br>Error threshold: {error_threshold:.1f} pixels | ECE: {ece:.3f}'
370+
371+
fig.update_layout(
372+
title=title_text,
373+
xaxis=dict(
374+
title='Mean Predicted Confidence',
375+
range=[0, 1],
376+
tickmode='linear',
377+
tick0=0,
378+
dtick=0.1
379+
),
380+
yaxis=dict(
381+
title='Fraction of Accurate Predictions',
382+
range=[0, 1],
383+
tickmode='linear',
384+
tick0=0,
385+
dtick=0.1
386+
),
387+
yaxis2=dict(
388+
title='Count',
389+
overlaying='y',
390+
side='right',
391+
showgrid=False
392+
),
393+
width=700,
394+
height=600,
395+
showlegend=True,
396+
legend=dict(
397+
yanchor="top",
398+
y=0.99,
399+
xanchor="left",
400+
x=0.01
401+
),
402+
hovermode='x unified'
403+
)
404+
405+
# Add grid
406+
fig.update_xaxes(showgrid=True, gridwidth=1, gridcolor='lightgray')
407+
fig.update_yaxes(showgrid=True, gridwidth=1, gridcolor='lightgray')
408+
409+
return fig
410+
411+
412+
def plot_calibration_diagram_multi(
413+
models_data,
414+
n_bins=10,
415+
keypoint_name="",
416+
data_type="",
417+
error_threshold=5.0,
418+
):
419+
"""
420+
Plot calibration diagram for multiple pose estimation models using Plotly.
421+
422+
Args:
423+
models_data: list of dicts with keys 'model_name', 'confidences', 'accuracies'
424+
n_bins: number of bins for grouping confidences
425+
keypoint_name: name of the keypoint being analyzed
426+
data_type: train/val/test data split
427+
error_threshold: pixel error threshold used to determine accuracy
428+
429+
Returns:
430+
Plotly figure object
431+
"""
432+
# Define colors for different models
433+
colors = ['blue', 'red', 'green', 'orange', 'purple', 'brown', 'pink', 'gray']
434+
435+
# Create Plotly figure
436+
fig = go.Figure()
437+
438+
# Add perfect calibration line
439+
fig.add_trace(go.Scatter(
440+
x=[0, 1],
441+
y=[0, 1],
442+
mode='lines',
443+
name='Perfect calibration',
444+
line=dict(dash='dash', color='black', width=2),
445+
showlegend=True
446+
))
447+
448+
# Add calibration curves for each model
449+
ece_values = []
450+
for i, model_data in enumerate(models_data):
451+
model_name = model_data['model_name']
452+
confidences = model_data['confidences']
453+
accuracies = model_data['accuracies']
454+
color = colors[i % len(colors)]
455+
456+
# Calculate calibration curve
457+
fraction_of_positives, mean_predicted_value = calibration_curve(
458+
accuracies, confidences, n_bins=n_bins, strategy='uniform'
459+
)
460+
461+
# Calculate ECE for this model
462+
if len(mean_predicted_value) > 0 and len(confidences) > 0:
463+
bin_edges = np.linspace(0, 1, n_bins + 1)
464+
ece = 0.0
465+
total_count = 0
466+
467+
for j in range(n_bins):
468+
in_bin = (confidences >= bin_edges[j]) & (confidences < bin_edges[j + 1])
469+
if j == n_bins - 1: # Include right edge in last bin
470+
in_bin = (confidences >= bin_edges[j]) & (confidences <= bin_edges[j + 1])
471+
472+
bin_count = np.sum(in_bin)
473+
if bin_count > 0:
474+
bin_accuracy = np.mean(accuracies[in_bin])
475+
bin_confidence = np.mean(confidences[in_bin])
476+
ece += bin_count * np.abs(bin_accuracy - bin_confidence)
477+
total_count += bin_count
478+
479+
ece = ece / total_count if total_count > 0 else 0
480+
else:
481+
ece = 0
482+
483+
ece_values.append(ece)
484+
485+
# Add model calibration curve
486+
fig.add_trace(go.Scatter(
487+
x=mean_predicted_value,
488+
y=fraction_of_positives,
489+
mode='markers+lines',
490+
name=f'{model_name} (ECE: {ece:.3f})',
491+
marker=dict(size=8, color=color),
492+
line=dict(color=color, width=2),
493+
showlegend=True
494+
))
495+
496+
# Create title
497+
title_text = 'Model Calibration Comparison'
498+
if keypoint_name:
499+
title_text += f' - {keypoint_name}'
500+
if data_type:
501+
title_text += f' ({data_type} set)'
502+
title_text += f'<br>Error threshold: {error_threshold:.1f} pixels'
503+
504+
# Update layout
505+
fig.update_layout(
506+
title=title_text,
507+
xaxis=dict(
508+
title='Mean Predicted Confidence',
509+
range=[0, 1],
510+
tickmode='linear',
511+
tick0=0,
512+
dtick=0.1
513+
),
514+
yaxis=dict(
515+
title='Fraction of Accurate Predictions',
516+
range=[0, 1],
517+
tickmode='linear',
518+
tick0=0,
519+
dtick=0.1
520+
),
521+
width=800,
522+
height=600,
523+
showlegend=True,
524+
legend=dict(
525+
yanchor="top",
526+
y=0.99,
527+
xanchor="left",
528+
x=0.01
529+
),
530+
hovermode='x unified'
531+
)
532+
533+
# Add grid
534+
fig.update_xaxes(showgrid=True, gridwidth=1, gridcolor='lightgray')
535+
fig.update_yaxes(showgrid=True, gridwidth=1, gridcolor='lightgray')
536+
537+
return fig

0 commit comments

Comments
 (0)