Skip to content

Commit

Permalink
Merge branch 'render_images_in_md' into main
Browse files Browse the repository at this point in the history
  • Loading branch information
avivajpeyi committed Aug 4, 2023
2 parents dfded31 + 71e0358 commit d4e1166
Show file tree
Hide file tree
Showing 3 changed files with 74 additions and 18 deletions.
74 changes: 58 additions & 16 deletions src/tess_atlas/data/inference_data_tools.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,18 @@
import logging
import os
from typing import List, Optional
from typing import List, Optional, Union

import arviz as az
import numpy as np
import pandas as pd
import pymc3_ext as pmx

from ..file_management import INFERENCE_DATA_FNAME, SAMPLES_FNAME, get_filesize
from ..file_management import (
INFERENCE_DATA_FNAME,
INFERENCE_SUMMARY_FNAME,
SAMPLES_FNAME,
get_filesize,
)
from ..logger import LOGGER_NAME

logger = logging.getLogger(LOGGER_NAME)
Expand Down Expand Up @@ -37,11 +42,24 @@ def save_inference_data(inference_data: az.InferenceData, outdir: str):
filename=fname, groups=["posterior", "log_likelihood", "sample_stats"]
)
save_samples(inference_data, outdir)
save_inference_summary(inference_data, outdir)
logger.info(f"Saved inference data [{get_filesize(fname)} MB]")


def get_max_rhat(inference_data: az.InferenceData) -> float:
rhat = summary(inference_data, print_warnings=False)["r_hat"]
return rhat.max()


def save_inference_summary(inference_data: az.InferenceData, outdir: str):
fname = os.path.join(outdir, INFERENCE_SUMMARY_FNAME)
summary(inference_data).to_csv(fname, index=False)


def summary(
inference_data: az.InferenceData, just_planet_params=False
inference_data: az.InferenceData,
just_planet_params=False,
print_warnings=True,
) -> pd.DataFrame:
"""Returns a dataframe with the mean+sd of each candidate's p, b, r"""
df = az.summary(
Expand All @@ -57,25 +75,49 @@ def summary(
df["parameter"] = df.index
df.set_index(["parameter"], inplace=True, append=False, drop=True)

RHAT_THRESHOLD = 1.1
if print_warnings:
rhat_check(df)
grazing_check(df)

return df


def rhat_check(summary_df, rhat_threshold=1.05, print_warnings=True):
bogus_params = []
for param, row in df.iterrows():
if row["r_hat"] >= RHAT_THRESHOLD:
check_passed = True
for param, row in summary_df.iterrows():
if row["r_hat"] >= rhat_threshold:
bogus_params.append(param)

if len(bogus_params) > 0:
logger.warning(
f"Sampler may not have converged! r-hat > {RHAT_THRESHOLD} for {bogus_params}"
)
check_passed = False
if print_warnings:
logger.warning(
f"Sampler may not have converged! r-hat > {rhat_threshold} for {bogus_params}"
)
return check_passed

# TODO
"""
for each b_param in all_params:
if median(b_param) > 0.8:
logger.warning("b[i] > 0.8 --> may be a grazing system!")

def grazing_check(
summary_df=None, inference_data=None, b_threshold=0.8, print_warnings=True
):
"""Check for grazing systems
If summary_df contains rows with "b[]", then this function will check if any of the rows
have a median value > b_threshold. If so, a warning will be logged.
"""
if inference_data is not None:
summary_df = summary(
inference_data, just_planet_params=True, print_warnings=False
)

return df
check_passed = True
b_parms = [p for p in summary_df.index if "b[" in p]
for b_param in b_parms:
if summary_df.loc[b_param, "mean"] > b_threshold:
check_passed = False
if print_warnings:
logger.warning(f"{b_param} > 0.8 --> may be a grazing system!")
return check_passed


def get_samples_dataframe(inference_data: az.InferenceData) -> pd.DataFrame:
Expand Down
1 change: 1 addition & 0 deletions src/tess_atlas/file_management.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

# CONSTANT FILENAMES
SAMPLES_FNAME = "samples.csv"
INFERENCE_SUMMARY_FNAME = "inference_summary.csv"
INFERENCE_DATA_FNAME = "inference_data.netcdf"
TOI_DIR = "toi_{toi}_files"
TIC_CSV = "tic_data.csv"
Expand Down
17 changes: 15 additions & 2 deletions src/tess_atlas/plotting/diagnostic_plotter.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@

logger = logging.getLogger(LOGGER_NAME)

from tess_atlas.data.inference_data_tools import get_max_rhat, grazing_check

from ..data.data_utils import residual_rms
from .labels import (
DIAGNOSTIC_LC_PLOT,
Expand Down Expand Up @@ -197,7 +199,7 @@ def plot_lightcurve_gp_and_residuals(

if total_outliers > 100:
logger.warning(
"Large number of outliers in residuals after fitting model."
f"Large number of outliers in residuals after fitting model: {total_outliers}"
)

fig.subplots_adjust(hspace=0, wspace=0)
Expand All @@ -211,6 +213,8 @@ def plot_lightcurve_gp_and_residuals(
else:
return fig

return fig, total_outliers


def plot_inference_trace(tic_entry, save=True):
with az.style.context("default", after_reset=True):
Expand All @@ -231,7 +235,9 @@ def plot_inference_trace(tic_entry, save=True):


def plot_diagnostics(tic_entry, model, init_params, save=True):
plot_lightcurve_gp_and_residuals(tic_entry, model, save=save)
_, total_number_outliers = plot_lightcurve_gp_and_residuals(
tic_entry, model, save=save
)
plot_thumbnail(
tic_entry,
model,
Expand All @@ -242,3 +248,10 @@ def plot_diagnostics(tic_entry, model, init_params, save=True):
# would be nice to plot the maximum posterior params on the phase plot
# would be nice to plot the corner with the initial params / maximum posterior params
# would be nice to plot the median posterior params + uncertainties on the EXOFOP Radius Ratio vs Period plot

# print some metadata
logger.info(f"Total number of outliers: {total_number_outliers}")
logger.info(f"Max rhat: {get_max_rhat(tic_entry.inference_data)}")
logger.info(
f"Grazing check passed: {grazing_check(inference_data=tic_entry.inference_data)}"
)

0 comments on commit d4e1166

Please sign in to comment.