Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Calibration Option for Burden Scores #111

Open
wants to merge 39 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 26 commits
Commits
Show all changes
39 commits
Select commit Hold shift + click to select a range
653cff2
add center & scaling feature for burden scores
meyerkm Jul 15, 2024
71c3488
fixup snakemake rule spacing
meyerkm Jul 17, 2024
950468c
Add in center_scale_burden option to config generation
meyerkm Jul 17, 2024
9626e01
fixup! Format Python code with psf/black pull_request
Jul 17, 2024
542e190
remove unneeded function argument
meyerkm Jul 18, 2024
d933813
Merge branch 'score-calibration' of https://github.com/PMBio/deeprvat…
meyerkm Jul 18, 2024
8afef13
fixup! Format Python code with psf/black pull_request
Jul 18, 2024
583f54b
reduce max computation and set to 1.0
meyerkm Jul 22, 2024
5baa52e
Merge branch 'score-calibration' of https://github.com/PMBio/deeprvat…
meyerkm Jul 22, 2024
7e85215
fixup! Format Python code with psf/black pull_request
Jul 22, 2024
18b5ba3
move max computation to compute_burdens
meyerkm Aug 20, 2024
9a0934e
fixup! Format Python code with psf/black pull_request
Aug 20, 2024
4ddbb5a
update default to perform score calibration of burdens
meyerkm Aug 21, 2024
9105273
Merge remote-tracking branch 'origin/main' into score-calibration
meyerkm Aug 21, 2024
036a0de
Merge remote-tracking branch 'origin/main' into score-calibration
meyerkm Oct 7, 2024
9c8901f
remove lsf dir
meyerkm Oct 7, 2024
4879d9b
fixup! Format Python code with psf/black pull_request
Oct 7, 2024
c39cf8d
bugfix compute_burdens rule
meyerkm Oct 7, 2024
49e06f6
bugfix
meyerkm Oct 7, 2024
ee88758
remove skip_burdens from get_burden
meyerkm Oct 7, 2024
9a112dd
bugfix max computation from main branch merge
meyerkm Oct 7, 2024
df40898
fixup! Format Python code with psf/black pull_request
Oct 7, 2024
9cf31f8
Merge remote-tracking branch 'origin/main' into score-calibration
meyerkm Oct 7, 2024
ae8450a
github action test
meyerkm Oct 8, 2024
ccc5d56
specify upload-artifact version
meyerkm Oct 8, 2024
e28f873
Reset default of center_scale_burdens to True
meyerkm Oct 9, 2024
042aa42
remove unused module
meyerkm Oct 14, 2024
9587f90
remove extra space
meyerkm Oct 14, 2024
68aabb3
remove extra space
meyerkm Oct 14, 2024
08e01dd
add release notes
bfclarke Oct 14, 2024
97bd07c
Merge branch 'score-calibration' of github.com:PMBio/deeprvat into sc…
bfclarke Oct 14, 2024
1b626fc
update scaling function
meyerkm Oct 17, 2024
a260b8b
fixup! Format Python code with psf/black pull_request
Oct 17, 2024
7c56720
specify center_scale_burdens option in example configs
meyerkm Oct 17, 2024
e4d31d6
remove hard-coded annotation length
meyerkm Oct 22, 2024
c86757b
fixup! Format Python code with psf/black pull_request
Oct 22, 2024
d69451f
Merge remote-tracking branch 'origin/main' into score-calibration
meyerkm Oct 22, 2024
b77a98a
add in center_scale_burdens option from config
meyerkm Nov 4, 2024
08b5c93
Merge remote-tracking branch 'origin/main' into score-calibration
meyerkm Nov 4, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions .github/workflows/run-pipeline.yml
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ jobs:
- name: Upload Training Outputs
id: uploaded_training_outputs
if: inputs.upload_training_outputs
uses: actions/upload-artifact@v4
uses: actions/upload-artifact@v4.4.0
with:
name: completed_training_outputs
path: |
Expand All @@ -148,7 +148,7 @@ jobs:
- name: Upload Pretrained Outputs
id: uploaded_pretrained_outputs
if: inputs.upload_pretrained_outputs
uses: actions/upload-artifact@v4
uses: actions/upload-artifact@v4.4.0
with:
name: completed_pretrained_outputs
path: |
Expand All @@ -163,7 +163,7 @@ jobs:
- name: Upload Regenie Outputs
id: uploaded_regenie_outputs
if: inputs.upload_regenie_outputs
uses: actions/upload-artifact@v4
uses: actions/upload-artifact@v4.4.0
with:
name: completed_regenie_outputs
path: |
Expand Down
78 changes: 73 additions & 5 deletions deeprvat/deeprvat/associate.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from tqdm import tqdm, trange
import zarr
import re
import dask.array as da
meyerkm marked this conversation as resolved.
Show resolved Hide resolved

import deeprvat.deeprvat.models as deeprvat_models
from deeprvat.data import DenseGTDataset
Expand Down Expand Up @@ -63,10 +64,8 @@ def get_burden(
:type agg_models: Dict[str, List[nn.Module]]
:param device: Device to perform computations on, defaults to "cpu".
:type device: torch.device
:param skip_burdens: Flag to skip burden computation, defaults to False.
:type skip_burdens: bool
:return: Tuple containing burden scores, target y phenotype values, x phenotypes and sample ids.
:rtype: Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]
:rtype: Tuple[torch.Tensor, torch.Tensor]

.. note::
Checkpoint models all corresponding to the same repeat are averaged for that repeat.
Expand Down Expand Up @@ -923,6 +922,14 @@ def compute_burdens_(
if bottleneck and i > 20:
break

# Calculate Max for this chunk and store for later
max_df = pd.DataFrame(columns=["max"])
for r in range(len(agg_models)):
chunk_max = np.max(chunk_burden[:, :, r]) # samples x genes x repeats
max_df.loc[r, "max"] = chunk_max
print(f"Saving Burden Max Scores")
max_df.to_csv(f"{Path(cache_dir)}/chunk{chunk}_max.csv", index=False)

burdens[:] = chunk_burden[:]
sample_ids[:] = chunk_sampleid[:]

Expand All @@ -941,6 +948,7 @@ def compute_burdens_(
@click.option("--n-chunks", type=int)
@click.option("--chunk", type=int)
@click.option("--dataset-file", type=click.Path(exists=True))
@click.option("--center-scale-burdens", is_flag=True)
@click.argument("data-config-file", type=click.Path(exists=True))
@click.argument("model-config-file", type=click.Path(exists=True))
@click.argument("checkpoint-files", type=click.Path(exists=True), nargs=-1)
Expand All @@ -952,6 +960,7 @@ def compute_burdens(
n_chunks: Optional[int],
chunk: Optional[int],
dataset_file: Optional[str],
center_scale_burdens: bool,
data_config_file: str,
model_config_file: str,
checkpoint_files: Tuple[str],
Expand All @@ -972,6 +981,8 @@ def compute_burdens(
:type chunk: Optional[int]
:param dataset_file: Path to the dataset file, i.e., association_dataset.pkl.
:type dataset_file: Optional[str]
:param center_scale_burdens: Flag to enable calculation of center and scaling parameters for centering and scaling burden results.
:type center_scale_burdens: bool
:param data_config_file: Path to the data configuration file.
:type data_config_file: str
:param model_config_file: Path to the model configuration file.
Expand Down Expand Up @@ -1024,6 +1035,29 @@ def compute_burdens(
bottleneck=bottleneck,
)

if center_scale_burdens:
if (chunk == 0) or not chunk:
# Calculate Mode
empty_batch = {
"rare_variant_annotations": torch.zeros(1, 1, 34, 1),
"y": None,
"x_phenotypes": None,
"sample": None,
}
this_mode, _ = get_burden(
empty_batch,
agg_models,
device=device,
)
this_mode = this_mode.flatten()
center_scale_df = pd.DataFrame(columns=["mode"])
for r in range(len(agg_models)):
center_scale_df.loc[r, "mode"] = this_mode[r]
pprint(f"Calculated Zero-Effect Burden Score :\n {this_mode}")
center_scale_df.to_csv(
f"{Path(out_dir)}/computed_burdens_stats.csv", index=False
)

logger.info("Saving computed burdens, corresponding genes, and targets")
np.save(Path(out_dir) / "genes.npy", genes)

Expand Down Expand Up @@ -1145,7 +1179,6 @@ def regress_on_gene_scoretest(
:rtype: Tuple[List[str], List[float], List[float]]
"""
burdens = burdens.reshape(burdens.shape[0], -1)
assert np.all(burdens != 0) # because DeepRVAT burdens are corrently all non-zero
logger.info(f"Burdens shape: {burdens.shape}")

if np.all(np.abs(burdens) < 1e-6):
Expand Down Expand Up @@ -1504,13 +1537,15 @@ def combine_regression_results(


@cli.command()
@click.option("--center-scale-burdens", is_flag=True)
@click.option("--n-chunks", type=int)
@click.option("--chunk", type=int)
@click.option("-r", "--repeats", multiple=True, type=int)
@click.option("--agg-fct", type=str, default="mean")
@click.argument("burden-file", type=click.Path(exists=True))
@click.argument("burden-out-file", type=click.Path())
def average_burdens(
center_scale_burdens: bool,
repeats: Tuple,
burden_file: str,
burden_out_file: str,
Expand All @@ -1523,6 +1558,20 @@ def average_burdens(
logger.info(f"Reading burdens to aggregate from {burden_file}")
burdens = zarr.open(burden_file)
n_total_samples = burdens.shape[0]

if center_scale_burdens:
center_scale_params_file = (
Path(os.path.split(burden_out_file)[0]) / "computed_burdens_stats.csv"
)
center_scale_df = pd.read_csv(center_scale_params_file)

max_dfs = pd.DataFrame()
max_files_path = Path(os.path.split(burden_out_file)[0]).glob("chunk*_max.csv")
for i, filename in enumerate(max_files_path):
max_dfs[f"Max_Chunk{i}"] = pd.read_csv(filename)["max"]
# compute max across all chunks
max_dfs["max"] = max_dfs.max(axis=1)

if chunk is not None:
if n_chunks is None:
raise ValueError("n_chunks must be specified if chunk is not None")
Expand All @@ -1541,7 +1590,7 @@ def average_burdens(
f"Computing result for chunk {chunk} out of {n_chunks} in range {chunk_start}, {chunk_end}"
)

batch_size = 100
batch_size = 1000
logger.info(f"Batch size: {batch_size}")
n_batches = n_samples // batch_size + (n_samples % batch_size != 0)

Expand Down Expand Up @@ -1570,6 +1619,25 @@ def average_burdens(
end_idx = min(start_idx + batch_size, chunk_end)
print(start_idx, end_idx)
this_burdens = np.take(burdens[start_idx:end_idx, :, :], repeats, axis=2)

# Double-check zarr creation - no computed burdens should equal zero
assert np.all(this_burdens != 0)

if center_scale_burdens:
print("Centering and Scaling Burdens before aggregating")
for r in range(len(repeats)):
zero_effect_val = center_scale_df.loc[r, "mode"]
repeat_max = max_dfs.loc[r, "max"]
# Subtract off zero effect burden value (mode)
this_burdens[:, :, r] -= zero_effect_val
adjusted_max = repeat_max - zero_effect_val
min_val = this_burdens[:, :, r].min()
# Scale values between -1 and 1
this_burdens[:, :, r] = (
meyerkm marked this conversation as resolved.
Show resolved Hide resolved
2 * ((this_burdens[:, :, r] - min_val) / (adjusted_max - min_val))
- 1
)

this_burdens = AGG_FCT[agg_fct](this_burdens, axis=2)

burdens_new[start_idx:end_idx, :, 0] = this_burdens
Expand Down
4 changes: 4 additions & 0 deletions deeprvat/deeprvat/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -382,6 +382,10 @@ def create_main_config(
"correction_method": input_config["evaluation"]["correction_method"],
"alpha": input_config["evaluation"]["alpha"],
}
if "center_scale_burdens" in input_config["evaluation"]:
full_config["center_scale_burdens"] = input_config["evaluation"][
"center_scale_burdens"
]

if pretrained_setup:
full_config.update(
Expand Down
1 change: 1 addition & 0 deletions example/config/deeprvat_input_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,7 @@ y_transformation: quantile_transform
evaluation:
correction_method: Bonferroni
alpha: 0.05
center_scale_burdens: False

# Subsetting samples for training or association testing
#sample_files:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ y_transformation: quantile_transform
evaluation:
correction_method: Bonferroni
alpha: 0.05
center_scale_burdens: False

# Subsetting samples for association testing
#sample_files:
Expand Down
131 changes: 0 additions & 131 deletions lsf/lsf.yaml

This file was deleted.

4 changes: 2 additions & 2 deletions pipelines/association_testing/burdens.snakefile
Original file line number Diff line number Diff line change
Expand Up @@ -77,16 +77,16 @@ rule compute_burdens:
' '.join([
'deeprvat_associate compute-burdens '
+ debug +
' --n-chunks ' + str(n_burden_chunks) + ' '
' --n-chunks '+ str(n_burden_chunks) + ' '
'--chunk {wildcards.chunk} '
'--dataset-file {input.dataset} '
+ center_scale_burdens +
'{input.data_config} '
'{input.model_config} '
'{input.checkpoints} '
'{params.prefix}/burdens'],
)


rule reverse_models:
input:
checkpoints = expand(model_path / 'repeat_{repeat}/best/bag_{bag}.ckpt',
Expand Down
1 change: 1 addition & 0 deletions pipelines/association_testing/regress_eval.snakefile
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,7 @@ rule average_burdens:
shell:
' && '.join([
('deeprvat_associate average-burdens '
+ center_scale_burdens +
'--n-chunks ' + str(n_avg_chunks) + ' '
'--chunk {wildcards.chunk} '
'{params.repeats} '
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -293,6 +293,7 @@ rule average_burdens:
shell:
' && '.join([
('deeprvat_associate average-burdens '
+ center_scale_burdens +
'--n-chunks ' + str(n_avg_chunks) + ' '
'--chunk {wildcards.chunk} '
'{params.repeats} '
Expand Down
1 change: 1 addition & 0 deletions pipelines/association_testing_pretrained.snakefile
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ training_phenotypes = []
n_burden_chunks = config.get('n_burden_chunks', 1) if not debug_flag else 2
n_regression_chunks = config.get('n_regression_chunks', 40) if not debug_flag else 2
n_avg_chunks = config.get('n_avg_chunks', 1)
center_scale_burdens = '--center-scale-burdens ' if config.get('center_scale_burdens', True) else ''
n_trials = config['hyperparameter_optimization']['n_trials']
n_bags = config['training']['n_bags'] if not debug_flag else 3
n_repeats = config['n_repeats']
Expand Down
Loading
Loading