diff --git a/book/src/dmr_scoring_details.md b/book/src/dmr_scoring_details.md index a836f85..37ee8e6 100644 --- a/book/src/dmr_scoring_details.md +++ b/book/src/dmr_scoring_details.md @@ -90,7 +90,7 @@ Starting with a prior of \\(\text{Beta}(0.5, 0.5)\\), we can calculate the poste ![posterior_distributions](./images/beta_distributions.png) What we need to calculate is the probability distribution of the _difference_ (the effect size) between the two conditions (high-modification and low-modification). -This can be done using a piecewise solution described by [Pham-Gia, Turkkan, and Eng in 1993](https://www.tandfonline.com/doi/abs/10.1080/03610929308831114), shown below: +This distribution can be done using a piecewise solution described by [Pham-Gia, Turkkan, and Eng in 1993](https://www.tandfonline.com/doi/abs/10.1080/03610929308831114), the distribution is shown below: ![beta_diff](./images/estimated_map_pvalue2.png) @@ -108,3 +108,52 @@ d = \hat{p}_1 - \hat{p}_2 \\ \\[ \hat{p} = \frac{ N_{\text{mod}} }{ N_{\text{canonical}} } \\ \\] + +## DMR segmentation hidden Markov model + +When performing "single-site" analysis with `modkit dmr pair` (by omitting the `--regions` option) you can optionally run the "segmentation" model at the same time by passing the `--segment` option with a filepath to write the segments to. +The model is a simple 2-state hidden Markov model, shown below, where the two hidden states, "Different" and "Same" indicate that the position is either differentially methylated or not. +
+ +![hmm](./images/hmm2.png "2-state segmenting HMM") + +
+The model is run over the intersection of the modified positions in a [pileup](https://nanoporetech.github.io/modkit/intro_bedmethyl.html#description-of-bedmethyl-output) for which there is enough coverage, from one or more samples. + +## Transition parameters +There are two transition probability parameters, \\(p\\) and \\(d\\). +The \\(p\\) parameter is the probability of transitioning to the "Different" state, and can be roughly though of as the probability of a given site being differentially modified without any information about the site. +The \\(d\\) parameter is the maximum probability of remaining in the "Different" state, it is a maximum because the value of \\(d\\) will change dynamically depending on the proximity of the next modified site. +The model proposes that modified bases in close proximity will share modification characteristics. +Specifically, when a site is differentially modified the probability of the next site also being differentially modified depends on how close the next site happens to be. +For example, if a CpG dinucleotide is differentially modified and is immediately followed by another CpG (sequence is `CGCG`) we have the maximum expectation that the following site is also differentially modified. +However, as the next site becomes farther away (say the next site is one thousand base pairs away, `CG[N x 1000]CG`) these sites are not likely correlated and the probability of the next site being differentially modified decreases towards \\(p\\). +The chart below shows how the probability of staying in the "Different" state, \\(d\\), decreases as the distance to the next modified base increases. + +
+ +![hmm](./images/dynamic_probs.png "dynamic transition probabilities") + +
+ +In this case, the maximum value of \\(d\\) is 0.9, \\(p\\) is 0.1, and the `decay_distance` is 500 base pairs (these also happen to be the defaults). +This can be seen as the maximum value of both curves is 0.9, and the minimum value, reached at 500 base pairs, is 0.1. +These parameters can be set with the `--diff-stay`, `--dmr-prior`, and `--decay-distance`, parameters, respectively. +The two curves represent two different ways of interpolating the decay between the minimum (1) and the `decay_distance`, `linear` and `logistic`. +The `--log-transition-decay` flag will use the orange curve whereas the default is to use the blue curve. + +In general, these settings don't need to be adjusted. +However, if you want very fine-grained segmentation, use the `--fine-grained` option which will produce smaller regions but also decrease the rate at which sites are classified as "Different" when they are in fact not different. + +## Emission parameters +The emissions of the model are derived from the [likelihood ratio score](https://nanoporetech.github.io/modkit/dmr_scoring_details.html#likelihood-ratio-scoring-details). +One advantage to using this score is that differences in methylation type (i.e. changes from 5hmC to 5mC) will be modeled and detected. +The score is transformed into a probability by \\( p = e^{\text{-score}} \\). +The full description of the emission probabilities for the two states is: + +\\[ + p_{\text{Same}} = e^{\text{-score}} +\\] +\\[ + p_{\text{Different}} = 1 - p_{\text{same}} +\\] diff --git a/book/src/images/dynamic_probs.png b/book/src/images/dynamic_probs.png new file mode 100644 index 0000000..af323fe Binary files /dev/null and b/book/src/images/dynamic_probs.png differ diff --git a/book/src/images/hmm2.png b/book/src/images/hmm2.png new file mode 100644 index 0000000..155f864 Binary files /dev/null and b/book/src/images/hmm2.png differ diff --git a/book/src/intro_dmr.md b/book/src/intro_dmr.md index 091e722..5af68a8 100644 --- a/book/src/intro_dmr.md +++ b/book/src/intro_dmr.md @@ -93,7 +93,7 @@ chr20 10034962 10035266 CpG: 35 1.294227443419004 C:7 1513 C:14 The full schema is described [below](#differential-methylation-output-format). -### 2. Perform differential methylation detection on all pairs of samples over regions from the genome. +## 2. Perform differential methylation detection on all pairs of samples over regions from the genome. The `modkit dmr multi` command runs all pairwise comparisons for more than two samples for all regions provided in the regions BED file. The preparation of the data is identical to that for the [previous section](#preparing-the-input-data) (for each sample, of course). An example command could be: @@ -233,3 +233,45 @@ modkit dmr pair \ ``` these columns will not be present. + +## Segmenting on differential methylation + +When running `modkit dmr` without `--regions` (i.e. [single-site analysis](#3-detecting-differential-modification-at-single-base-positions)) you can generate regions of differential methylation on-the-fly using the segmenting [hidden Markov model](./dmr_scoring_details.html#dmr-segmentation-hidden-markov-model) (HMM). +To run segmenting on the fly, add the `--segments $segments_bed_fp` option to the command such as: + + +```bash +dmr_result=single_base_haplotype_dmr.bed +dmr_segments=single_base_segements.bed + +modkit dmr pair \ + -a ${hp1_pileup}.gz \ + -b ${hp2_pileup}.gz \ + -o ${dmr_result} \ + --segments ${dmr_segments} \ # indicates to run segmentation + --ref ${ref} \ + --base C \ + --threads ${threads} \ + --log-filepath dmr.log +``` + +The default settings for the HMM are to run in "coarse-grained" mode which will more eagerly join neighboring sites, potentially at the cost of including sites that are not differentially modified within "Different" blocks. +To activate "fine-grained" mode, pass the `--fine-grained` flag. +The output schema for the segments is: + +| column | name | description | type | +|--------|------------------------------|-------------------------------------------------------------------------------------------|-------| +| 1 | chrom | name of reference sequence from bedMethyl input samples | str | +| 2 | start position | 0-based start position, from `--regions` argument | int | +| 3 | end position | 0-based exclusive end position, from `--regions` argument | int | +| 4 | state-name | "different" when sites are differentially modified, "same" otherwise | str | +| 5 | score | Difference score, more positive values have increased difference | float | +| 6 | N_sites<\sub> | Number of sites (bedmethyl records) in the segment | float | +| 7 | samplea counts | Counts of each base modification in the region, comma-separated, for sample A | str | +| 8 | samplea total | Total number of base modification calls in the region, including unmodified, for sample A | str | +| 9 | sampleb counts | Counts of each base modification in the region, comma-separated, for sample B | str | +| 10 | sampleb total | Total number of base modification calls in the region, including unmodified, for sample B | str | +| 11 | samplea fractions | Fraction of calls for each base modification in the region, comma-separated, for sample A | str | +| 12 | sampleb fractions | Fraction of calls for each base modification in the region, comma-separated, for sample B | str | +| 13 | effect size | Percent modified in sample A (col 12) minus percent modified in sample B (col 13) | float | + diff --git a/src/dmr/single_site.rs b/src/dmr/single_site.rs index 0bb511c..4213fd1 100644 --- a/src/dmr/single_site.rs +++ b/src/dmr/single_site.rs @@ -1,11 +1,14 @@ use std::cmp::Ordering; -use std::collections::{HashMap, VecDeque}; +use std::collections::{BTreeMap, HashMap, VecDeque}; +use std::fs::File; use std::io::Write; use std::ops::Range; +use std::path::PathBuf; use std::sync::Arc; use anyhow::{bail, Context}; -use indicatif::MultiProgress; +use derive_new::new; +use indicatif::{MultiProgress, ProgressBar}; use itertools::Itertools; use log::{debug, error, info}; use rayon::prelude::*; @@ -17,10 +20,12 @@ use crate::dmr::tabix::{ }; use crate::dmr::util::DmrBatchOfPositions; use crate::genome_positions::GenomePositions; +use crate::hmm::{HmmModel, States}; use crate::mod_base_code::ModCodeRepr; use crate::monoid::BorrowingMoniod; use crate::thresholds::percentile_linear_interp; -use crate::util::{get_subroutine_progress_bar, get_ticker}; +use crate::util::{get_subroutine_progress_bar, get_ticker, Region}; +use crate::writers::TsvWriter; pub(super) struct SingleSiteDmrAnalysis { sample_index: Arc, @@ -29,6 +34,7 @@ pub(super) struct SingleSiteDmrAnalysis { batch_size: usize, interval_size: u64, header: bool, + segmentation_fp: Option, } impl SingleSiteDmrAnalysis { @@ -45,6 +51,7 @@ impl SingleSiteDmrAnalysis { rope: f64, sample_n: usize, header: bool, + segmentation_fp: Option<&PathBuf>, progress: &MultiProgress, ) -> anyhow::Result { let sample_index = @@ -109,6 +116,7 @@ impl SingleSiteDmrAnalysis { batch_size, interval_size, header, + segmentation_fp: segmentation_fp.cloned(), }) } @@ -116,6 +124,12 @@ impl SingleSiteDmrAnalysis { &self, multi_progress_bar: MultiProgress, pool: rayon::ThreadPool, + max_gap_size: u64, + dmr_prior: f64, + diff_stay: f64, + significance_factor: f64, + decay_distance: u32, + linear_transitions: bool, mut writer: Box, ) -> anyhow::Result<()> { let matched_samples = self.sample_index.matched_replicate_samples(); @@ -133,14 +147,26 @@ impl SingleSiteDmrAnalysis { )?; } - let batch_iter = SingleSiteBatches::new( - self.sample_index.clone(), - self.genome_positions.clone(), - self.batch_size, - self.interval_size, - )?; + let mut segmenter: Box = + if let Some(segmentation_fp) = &self.segmentation_fp { + Box::new(HmmDmrSegmenter::new( + segmentation_fp, + max_gap_size, + dmr_prior, + diff_stay, + 0.3f64, + -0.1f64, + significance_factor, + linear_transitions, + decay_distance, + &multi_progress_bar, + )?) + } else { + Box::new(DummySegmenter::new()) + }; - let (snd, rcv) = crossbeam::channel::bounded(1000); + let (scores_snd, scores_rcv) = crossbeam::channel::bounded(1000); + // let (segment_snd, segment_rcv) = crossbeam::channel::bounded(1000); let processed_batches = multi_progress_bar.add(get_ticker()); let failure_counter = multi_progress_bar.add(get_ticker()); let success_counter = multi_progress_bar.add(get_ticker()); @@ -149,6 +175,13 @@ impl SingleSiteDmrAnalysis { failure_counter.set_message("sites failed"); success_counter.set_message("sites processed successfully"); + let batch_iter = SingleSiteBatches::new( + self.sample_index.clone(), + self.genome_positions.clone(), + self.batch_size, + self.interval_size, + )?; + let sample_index = self.sample_index.clone(); let pmap_estimator = self.pmap_estimator.clone(); pool.spawn(move || { @@ -184,7 +217,7 @@ impl SingleSiteDmrAnalysis { || { results.into_iter().for_each( |chrom_to_scores: Vec| { - match snd.send(chrom_to_scores) { + match scores_snd.send(chrom_to_scores) { Ok(_) => processed_batches.inc(1), Err(e) => { error!( @@ -197,19 +230,22 @@ impl SingleSiteDmrAnalysis { }, ); results = super_batch_results; - results.into_iter().for_each(|chrom_to_scores| { - match snd.send(chrom_to_scores) { + results.into_iter().for_each( + |chrom_to_scores| match scores_snd.send(chrom_to_scores) { Ok(_) => processed_batches.inc(1), Err(e) => { error!("failed to send on channel, {e}"); } - } - }); + }, + ); } }); let mut success_count = 0usize; - for batch_result in rcv { + for batch_result in scores_rcv { + if let Err(e) = segmenter.add(&batch_result) { + debug!("segmentation error, {e}"); + } for (chrom, results) in batch_result { for result in results { match result { @@ -235,8 +271,12 @@ impl SingleSiteDmrAnalysis { } } + if let Err(e) = segmenter.run_current_chunk() { + debug!("segmentation error, {e}") + } success_counter.finish_and_clear(); failure_counter.finish_and_clear(); + segmenter.clean_up()?; info!( "finished, processed {} sites successfully, {} failed", @@ -271,6 +311,7 @@ impl SingleSiteBatches { .contig_sizes() .filter(|(name, _)| sample_index.has_contig(name)) .map(|(name, length)| (name.to_owned(), (0u64..(length as u64)))) + .sorted_by(|(a, _), (b, _)| a.cmp(b)) .collect::)>>(); if let Some((curr_contig, curr_contig_range)) = @@ -398,6 +439,7 @@ struct SingleSiteDmrScore { effect_size: f64, balanced_map_pval: f64, balanced_effect_size: f64, + balanced_score: f64, replicate_map_pval: Vec, replicate_effect_sizes: Vec, pct_a_samples: usize, @@ -480,6 +522,8 @@ impl SingleSiteDmrScore { let balanced_counts_b = collapse_counts(counts_b, true); let epmap_balanced = estimator.predict(&balanced_counts_a, &balanced_counts_b)?; + let balanced_llr_score = + llk_ratio(&balanced_counts_a, &balanced_counts_b)?; let collapsed_a = collapse_counts(counts_a, false); let collapsed_b = collapse_counts(counts_b, false); let epmap = estimator.predict(&collapsed_a, &collapsed_b)?; @@ -493,6 +537,7 @@ impl SingleSiteDmrScore { effect_size: epmap.effect_size, balanced_map_pval: epmap_balanced.e_pmap, balanced_effect_size: epmap_balanced.effect_size, + balanced_score: balanced_llr_score, replicate_map_pval: replicate_epmap, replicate_effect_sizes, pct_a_samples, @@ -790,3 +835,342 @@ fn calculate_max_coverages( info!("calculated max coverage for a: {a_max_cov} and b: {b_max_cov}"); Ok([a_max_cov, b_max_cov]) } + +trait DmrSegmenter { + fn add(&mut self, dmr_scores: &[ChromToSingleScores]) + -> anyhow::Result<()>; + fn run_current_chunk(&mut self) -> anyhow::Result<()>; + fn clean_up(&mut self) -> anyhow::Result<()>; +} + +#[derive(new)] +struct DummySegmenter {} + +impl DmrSegmenter for DummySegmenter { + fn add( + &mut self, + _dmr_scores: &[ChromToSingleScores], + ) -> anyhow::Result<()> { + Ok(()) + } + + fn run_current_chunk(&mut self) -> anyhow::Result<()> { + Ok(()) + } + + fn clean_up(&mut self) -> anyhow::Result<()> { + Ok(()) + } +} + +struct HmmDmrSegmenter { + writer: TsvWriter, + hmm: HmmModel, + curr_region_scores: Vec, + curr_region_positions: Vec, + curr_counts_a: BTreeMap, + curr_counts_b: BTreeMap, + curr_chrom: Option, + curr_end: Option, + max_gap_size: u64, + size_gauge: ProgressBar, + segments_written: ProgressBar, +} + +impl DmrSegmenter for HmmDmrSegmenter { + fn add( + &mut self, + dmr_scores: &[ChromToSingleScores], + ) -> anyhow::Result<()> { + for (chrom, scores) in dmr_scores.iter() { + if let Some(curr_chrom) = self.curr_chrom.as_ref() { + if chrom == curr_chrom { + let min_pos = scores.iter().find_map(|r| match r { + Ok(score) => Some(score.position), + Err(_) => None, + }); + match (min_pos, self.curr_end) { + (Some(pos), Some(end)) => { + if pos + .checked_sub(end) + .map(|x| x < self.max_gap_size) + .unwrap_or(false) + { + // within limits, add to current + self.append_scores(&scores); + } else { + // next chunk is too far away, run current and + // reset + self.run_current_chunk()?; + self.append_scores(&scores); + } + } + (Some(_pos), None) => { + // maybe this never happens? + // don't have any data, append + self.append_scores(&scores); + } + (None, _) => { + // nothing to do + debug!("no valid results.."); + } + } + } else { + // finish current chunk and add this chunk to current + self.run_current_chunk()?; + // update chrom + self.curr_chrom = Some(chrom.to_string()); + // update scores + assert_eq!( + self.curr_chrom.as_ref(), + Some(chrom), + "chroms arent' the same?" + ); + self.append_scores(&scores); + } + } else { + self.curr_chrom = Some(chrom.to_string()); + assert_eq!( + self.curr_chrom.as_ref(), + Some(chrom), + "chroms arent' the same?" + ); + self.append_scores(&scores); + } + } + Ok(()) + } + + #[inline] + fn run_current_chunk(&mut self) -> anyhow::Result<()> { + if self.curr_region_scores.is_empty() { + debug!("no scores to run"); + assert!( + self.curr_region_positions.is_empty(), + "should not have positions and no scores" + ); + return Ok(()); + } + assert_eq!( + self.curr_region_positions.len(), + self.curr_region_scores.len(), + "scores and positions should be the same length" + ); + + // these expects and asserts are safe because this method is only called + // when self.curr_chrom is some + let region = + self.current_chunk_region().expect("region should not be None"); + assert!(self.curr_chrom.is_some()); + let start_time = std::time::Instant::now(); + let path = self.hmm.viterbi_path( + &self.curr_region_scores, + &self.curr_region_positions, + ); + let took = start_time.elapsed(); + debug!( + "segmenting {} ({} scores), took {took:?}", + region.to_string(), + self.curr_region_scores.len() + ); + let integrated_path = + path_to_region_labels(&path, &self.curr_region_positions); + for (start, end, state) in integrated_path.iter() { + let counts_a = self.get_counts_a(*start, *end); + let counts_b = self.get_counts_b(*start, *end); + let score = llk_ratio(&counts_a, &counts_b)?; + let frac_mod_a = counts_a.pct_modified(); + let frac_mod_b = counts_b.pct_modified(); + let effect_size = frac_mod_a - frac_mod_b; + let num_sites = self.curr_counts_a.range(*start..*end).count(); + + let sep = '\t'; + let row = format!( + "{}{sep}\ + {start}{sep}\ + {end}{sep}\ + {state}{sep}\ + {score}{sep}\ + {num_sites}{sep}\ + {}{sep}\ + {}{sep}\ + {}{sep}\ + {}{sep}\ + {frac_mod_a}{sep}\ + {frac_mod_b}{sep}\ + {effect_size}\n", + self.curr_chrom.as_ref().unwrap(), + counts_a.string_counts(), + counts_b.string_counts(), + counts_a.string_percentages(), + counts_b.string_percentages(), + ); + self.writer.write(row.as_bytes())?; + } + debug!("wrote {} segments", integrated_path.len()); + + // reset everything + self.curr_region_positions = Vec::new(); + self.curr_region_scores = Vec::new(); + self.curr_counts_a = BTreeMap::new(); + self.curr_counts_b = BTreeMap::new(); + self.curr_end = None; + self.segments_written.inc(integrated_path.len() as u64); + self.size_gauge.set_position(0u64); + Ok(()) + } + + fn clean_up(&mut self) -> anyhow::Result<()> { + self.size_gauge.finish_and_clear(); + self.segments_written.finish_and_clear(); + debug!( + "HMM segmenter finished, wrote {} segments", + self.segments_written.position() + ); + Ok(()) + } +} + +impl HmmDmrSegmenter { + fn new( + out_fp: &PathBuf, + max_gap_size: u64, + dmr_prior: f64, + diff_stay: f64, + same_state_factor: f64, + diff_state_factor: f64, + significance_factor: f64, + linear_transitions: bool, + decay_distance: u32, + multi_progress: &MultiProgress, + ) -> anyhow::Result { + let hmm = HmmModel::new( + dmr_prior, + diff_stay, + same_state_factor, + diff_state_factor, + significance_factor, + decay_distance, + linear_transitions, + )?; + let writer = TsvWriter::new_path(out_fp, true, None)?; + let size_gauge = multi_progress.add(get_ticker()); + let segments_written = multi_progress.add(get_ticker()); + size_gauge.set_message("[segmenter] current region size"); + segments_written.set_message("[segmenter] segments finished"); + + Ok(Self { + writer, + hmm, + max_gap_size, + curr_region_scores: Vec::new(), + curr_region_positions: Vec::new(), + curr_counts_a: BTreeMap::new(), + curr_counts_b: BTreeMap::new(), + curr_chrom: None, + curr_end: None, + size_gauge, + segments_written, + }) + } + + fn append_scores(&mut self, scores: &[anyhow::Result]) { + let mut rightmost = 0u64; + for score in scores.iter().filter_map(|r| r.as_ref().ok()) { + self.curr_region_scores.push(score.score); + self.curr_region_positions.push(score.position); + let check = self + .curr_counts_a + .insert(score.position, score.counts_a.clone()); + assert!(check.is_none()); + let check = self + .curr_counts_b + .insert(score.position, score.counts_b.clone()); + assert!(check.is_none()); + rightmost = std::cmp::max(rightmost, score.position); + } + // check, todo remove after testing + if let Some(end) = self.curr_end { + if rightmost > 0u64 { + assert!( + end < rightmost, + "results were not sorted? {end} {rightmost}", + ); + } + } + self.curr_end = Some(rightmost); + self.size_gauge.set_position(self.curr_region_positions.len() as u64); + } + + #[inline] + fn current_chunk_start(&self) -> Option<&u64> { + // todo can make this a .first() + self.curr_region_positions.iter().min() + } + + #[inline] + fn current_chunk_region(&self) -> Option { + match ( + self.curr_chrom.as_ref(), + self.current_chunk_start(), + self.curr_end, + ) { + (Some(chrom), Some(&start), Some(end)) => { + Some(Region::new(chrom.to_string(), start as u32, end as u32)) + } + _ => None, + } + } + + fn get_counts_a(&self, start: u64, stop: u64) -> AggregatedCounts { + Self::get_counts_range(start..stop, &self.curr_counts_a) + } + + fn get_counts_b(&self, start: u64, stop: u64) -> AggregatedCounts { + Self::get_counts_range(start..stop, &self.curr_counts_b) + } + + fn get_counts_range( + r: Range, + counts: &BTreeMap, + ) -> AggregatedCounts { + counts.range(r).map(|(_, counts)| counts).fold( + AggregatedCounts::zero(), + |mut agg, x| { + agg.op_mut(x); + agg + }, + ) + } +} + +fn path_to_region_labels( + path: &[States], + positions: &[u64], +) -> Vec<(u64, u64, States)> { + assert_eq!(path.len(), positions.len() - 1); + if path.is_empty() { + return Vec::new(); + } else { + let mut curr_state = *path.first().unwrap(); + let mut curr_position = *positions.first().unwrap(); + let mut last_position = curr_position + 1; + let mut agg = Vec::new(); + for (state, &pos) in path.iter().zip(positions).skip(1) { + let position = pos; + if state != &curr_state { + let bedline = (curr_position, last_position, curr_state); + agg.push(bedline); + curr_position = position; + last_position = position + 1; + curr_state = *state; + } else { + last_position = position + 1; + } + } + let final_bedline = (curr_position, last_position, curr_state); + agg.push(final_bedline); + + agg + } +} diff --git a/src/dmr/subcommands.rs b/src/dmr/subcommands.rs index 2564583..1750f1c 100644 --- a/src/dmr/subcommands.rs +++ b/src/dmr/subcommands.rs @@ -80,6 +80,69 @@ pub struct PairwiseDmr { /// Path to reference fasta for used in the pileup/alignment. #[arg(long = "ref")] reference_fasta: PathBuf, + /// Run segmentation, output segmented differentially methylated regions to + /// this file. + #[arg(long = "segment", conflicts_with = "regions_bed")] + segmentation_fp: Option, + + /// Maximum number of base pairs between modified bases for them to be + /// segmented together. + #[arg(long, requires = "segmentation_fp", default_value_t = 5000)] + max_gap_size: u64, + /// Prior probability of a differentially methylated position + #[arg( + long, + requires = "segmentation_fp", + default_value_t = 0.1, + hide_short_help = true + )] + dmr_prior: f64, + /// Maximum probability of continuing a differentially methylated block, + /// decay will be dynamic based on proximity to the next position. + #[arg( + long, + requires = "segmentation_fp", + default_value_t = 0.9, + hide_short_help = true + )] + diff_stay: f64, + /// Significance factor, effective p-value necessary to favor the + /// "Different" state. + #[arg( + long, + requires = "segmentation_fp", + default_value_t = 0.01, + hide_short_help = true + )] + significance_factor: f64, + /// Use logarithmic decay for "Different" stay probability + #[arg( + long, + requires = "segmentation_fp", + default_value_t = false, + hide_short_help = true + )] + log_transition_decay: bool, + /// After this many base pairs, the transition probability will become the + /// prior probability of encountering a differentially modified + /// position. + #[arg( + long, + requires = "segmentation_fp", + default_value_t = 500, + hide_short_help = true + )] + decay_distance: u32, + /// Preset HMM segmentation parameters for higher propensity to switch from + /// "Same" to "Different" state. Results will be shorter segments, but + /// potentially higher sensitivity. + #[arg( + long, + requires = "segmentation_fp", + conflicts_with_all=["log_transition_decay", "significance_factor", "diff_stay", "dmr_prior"], + default_value_t=false + )] + fine_grained: bool, /// Bases to use to calculate DMR, may be multiple. For example, to /// calculate differentially methylated regions using only cytosine /// modifications use --base C. @@ -294,6 +357,11 @@ impl PairwiseDmr { if self.is_single_site() { info!("running single-site analysis"); + let linear_transitions = if self.fine_grained { + false + } else { + !self.log_transition_decay + }; return SingleSiteDmrAnalysis::new( sample_index, genome_positions, @@ -307,9 +375,20 @@ impl PairwiseDmr { self.delta, self.n_sample_records, self.header, + self.segmentation_fp.as_ref(), &mpb, )? - .run(mpb, pool, writer); + .run( + mpb, + pool, + self.max_gap_size, + self.dmr_prior, + self.diff_stay, + self.significance_factor, + self.decay_distance, + linear_transitions, + writer, + ); } let sample_index = Arc::new(sample_index); diff --git a/src/dmr/tabix.rs b/src/dmr/tabix.rs index 8abd2e9..159b611 100644 --- a/src/dmr/tabix.rs +++ b/src/dmr/tabix.rs @@ -180,6 +180,7 @@ impl MultiSampleIndex { #[inline] fn read_bedmethyl_lines_from_batch( &self, + // todo needs documentation of what this data structure is chunks: &FxHashMap>, ) -> anyhow::Result>>> { diff --git a/src/hmm.rs b/src/hmm.rs new file mode 100644 index 0000000..eae4b71 --- /dev/null +++ b/src/hmm.rs @@ -0,0 +1,404 @@ +use anyhow::bail; +use std::fmt::{Display, Formatter}; +use std::ops::Range; + +const STATE_NUM: usize = 2usize; + +#[derive(Copy, Clone, Debug, Eq, PartialEq, Hash, Ord, PartialOrd)] +#[repr(usize)] +pub(crate) enum States { + Same = 0, + Different = 1, +} + +impl Into for States { + fn into(self) -> usize { + self as usize + } +} + +impl From for States { + fn from(value: usize) -> Self { + match value { + 0 => Self::Same, + 1 => Self::Different, + _ => unreachable!("invalid {value}"), + } + } +} + +impl Display for States { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + let label = match self { + States::Same => "same", + States::Different => "different", + }; + write!(f, "{label}") + } +} + +#[derive(Copy, Clone, Debug)] +struct DpCell { + inner: [f64; STATE_NUM], // todo make the state num const generic +} + +impl DpCell { + #[inline] + fn new_full(val: f64) -> Self { + Self { inner: [val; STATE_NUM] } + } + + fn new_empty() -> Self { + Self::new_full(f64::NEG_INFINITY) + } + + fn total_probability(&self) -> f64 { + rv::misc::logsumexp(&self.inner) + } + + fn get_value(&self, state: States) -> f64 { + self.inner[state as usize] + } + + fn get_value_mut(&mut self, state: States) -> &mut f64 { + &mut self.inner[state as usize] + } + + fn set_value(&mut self, state: States, value: f64) { + assert!( + value.is_finite() && !value.is_nan(), + "cannot set {value} to state {state:?}" + ); + self.inner[state as usize] = value; + } + + fn argmax(&self) -> States { + self.inner + .iter() + .enumerate() + .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap()) + .map(|(i, _)| States::from(i)) + .unwrap() + } +} + +#[derive(Debug)] +struct PointerCell { + inner: [Option; STATE_NUM], +} + +impl PointerCell { + fn empty() -> Self { + Self { inner: [None; STATE_NUM] } + } + + fn get_value(&self, state: States) -> Option { + self.inner[state as usize] + } + + fn set_value(&mut self, state: States, value: States) { + self.inner[state as usize] = Some(value); + } +} + +pub(crate) struct HmmModel { + same_to_same: f64, + // diff_to_diff: f64, + same_to_diff: f64, + // diff_to_same: f64, + dmr_prior: f64, + + same_state_factor: f64, + diff_state_factor: f64, + significance_factor: f64, + + linear_proj: bool, + projection: Projection, +} + +impl HmmModel { + fn prob_to_factor(fpr: f64) -> anyhow::Result { + if fpr < 0f64 { + bail!("fpr cannot be less than 0") + } else if fpr >= 1.0 { + bail!("fpr cannot be >= 1.0") + } else { + Ok((fpr / (1f64 - fpr)).ln()) + } + } + + pub(crate) fn new( + dmr_prior: f64, + diff_stay: f64, + same_state_factor: f64, + diff_state_factor: f64, + significance_factor: f64, + decay_distance: u32, + linear_proj: bool, + ) -> anyhow::Result { + let same_to_diff = dmr_prior.ln(); + let same_to_same = (1f64 - dmr_prior).ln(); + // let diff_to_diff = diff_stay.ln(); + // let diff_to_same = (1f64 - diff_stay).ln(); + + let projection = Projection::new(decay_distance, diff_stay, dmr_prior)?; + let significance_factor = Self::prob_to_factor(significance_factor)?; + + Ok(Self { + same_to_same, + same_to_diff, + same_state_factor, + dmr_prior, + diff_state_factor, + significance_factor, + linear_proj, + projection, + }) + } + + pub(crate) fn viterbi_path( + &self, + scores: &[f64], + positions: &[u64], + ) -> Vec { + // P_s = e^(-score) + let probs = scores + .iter() + .map(|&x| if x < 0f64 { 0f64 } else { x }) + .map(|x| (-1f64 * x).exp()) + .collect::>(); + + let transitions = + positions.windows(2).fold(vec![self.dmr_prior], |mut agg, wind| { + assert_eq!(wind.len(), 2); + assert!(wind[1] > wind[0]); + let gap = (wind[1] - wind[0]) as f64; + assert!(gap > 0f64, "gap should be greater than zero"); + let p_diff_to_diff = if self.linear_proj { + self.projection.linear_project_prob(gap) + } else { + self.projection.ln_project_prob(gap) + }; + agg.push(p_diff_to_diff); + agg + }); + assert_eq!(probs.len(), transitions.len()); + let (dp_matrix, pointers) = self.viterbi_forward(&probs, &transitions); + let path = self.viterbi_decode(&dp_matrix, &pointers); + assert_eq!(path.len(), scores.len() - 1); + path + } + + fn viterbi_decode( + &self, + dp_matrix: &[DpCell], + pointers: &[PointerCell], + ) -> Vec { + let final_state = dp_matrix.last().unwrap().argmax(); + // dbg!(final_state); + let mut path = vec![final_state]; + let mut curr_pointer = + pointers.last().unwrap().get_value(final_state).unwrap(); + for pointers in pointers.iter().rev().skip(1) { + let pointer = pointers.get_value(curr_pointer); + if let Some(pointer) = pointer { + path.push(pointer); + curr_pointer = pointer; + } else { + break; + } + } + + path.pop(); + path.reverse(); + path + } + + fn viterbi_forward( + &self, + scores: &[f64], + transitions: &[f64], + ) -> (Vec, Vec) { + let first_cell = { + let mut first_cell = DpCell::new_full(0f64); + self.initialize_start_end_cell(&mut first_cell); + first_cell + }; + let first_pointers = PointerCell::empty(); + let (mut dp_matrix, pointers, last_cell) = + scores.iter().zip(transitions).enumerate().fold( + (Vec::new(), vec![first_pointers], first_cell), + |(mut cells, mut pointers, prev_cell), (i, (x, t))| { + let mut next_cell = DpCell::new_empty(); + let mut pointer_cell = PointerCell::empty(); + self.forward( + &prev_cell, + &mut next_cell, + &mut pointer_cell, + *t, + *x, + i, + ); + cells.push(prev_cell); + pointers.push(pointer_cell); + (cells, pointers, next_cell) + }, + ); + dp_matrix.push(last_cell); + assert_eq!(dp_matrix.len(), pointers.len()); + assert_eq!(dp_matrix.len(), scores.len() + 1); + (dp_matrix, pointers) + } + + #[inline] + fn emission_probs(&self, p: f64, state: States) -> f64 { + assert!(p <= 1f64, "p {p} cannot be greater than 1"); + let (factor, p) = match state { + States::Same => (self.same_state_factor, p.ln()), + States::Different => { + (self.diff_state_factor, (1f64 - p + 1e-5).ln()) + } + }; + let p = p - self.significance_factor; + factor * p + } + + fn forward( + &self, + prev_cell: &DpCell, + current_cell: &mut DpCell, + pointers: &mut PointerCell, + p_diff2diff: f64, + score: f64, + _idx: usize, + ) { + // todo make the naming convention here less terrible! + // emission probs + let e_diff = self.emission_probs(score, States::Different); + let e_same = self.emission_probs(score, States::Same); + // "dynamic" transition probs + assert!(p_diff2diff > 0f64, "p_diff2diff should not be zero"); + assert!( + p_diff2diff < 1.0, + "p_diff2diff should be less than zero {p_diff2diff}" + ); + let lnp_diff2diff = p_diff2diff.ln(); + let lnp_diff_to_same = (1f64 - p_diff2diff).ln(); + // previous state + let p_same = prev_cell.get_value(States::Same); + let p_diff = prev_cell.get_value(States::Different); + + Self::check_emission_prob(e_diff, "e_d"); + Self::check_emission_prob(e_same, "e_s"); + Self::check_emission_prob(p_diff, "p_d"); + Self::check_emission_prob(p_same, "p_s"); + Self::check_emission_prob(p_diff2diff, "p_diff2diff"); + Self::check_emission_prob(lnp_diff2diff, "lnp_diff2diff"); + Self::check_emission_prob(lnp_diff_to_same, "lnp_diff_to_same"); + + // Same-state + let same2same = p_same + self.same_to_same; + let diff2same = p_diff + lnp_diff_to_same; + let (current_same, same_pointer) = + [(same2same, States::Same), (diff2same, States::Different)] + .into_iter() + .max_by(|(a, _), (b, _)| a.partial_cmp(b).unwrap()) + .unwrap(); + + // Diff-state + let diff2diff = p_diff + lnp_diff2diff; + let same2diff = p_same + self.same_to_diff; + + let (current_diff, diff_pointer) = + [(diff2diff, States::Different), (same2diff, States::Same)] + .into_iter() + .max_by(|(a, _), (b, _)| a.partial_cmp(b).unwrap()) + .unwrap(); + + Self::check_emission_prob(current_diff, "current_diff"); + Self::check_emission_prob(current_same, "current_same"); + + current_cell.set_value(States::Same, current_same + e_same); + current_cell.set_value(States::Different, current_diff + e_diff); + pointers.set_value(States::Same, same_pointer); + pointers.set_value(States::Different, diff_pointer); + } + + // todo make this a compile time no-op + #[inline(always)] + fn check_emission_prob(x: f64, which: &str) { + assert!(x.is_finite(), "{which} is not finite {x}"); + assert!(!x.is_nan(), "{which} is NaN {x}"); + } + + fn initialize_start_end_cell(&self, cell: &mut DpCell) { + *cell.get_value_mut(States::Same) = self.same_to_same; + *cell.get_value_mut(States::Different) = self.same_to_diff; + } +} + +struct Projection { + prob_range: Range, + distance_range: Range, + prob_span: f64, + ratio: f64, +} + +impl Projection { + fn new( + max_distance: u32, + max_diff_stay: f64, + dmr_prob: f64, + ) -> anyhow::Result { + if max_diff_stay <= dmr_prob { + bail!("max_diff_stay must be > switch_prob") + } + let low = 1f64 - max_diff_stay; + let high = 1f64 - dmr_prob; + + let prob_range = low..high; + let max_distance = max_distance as f64; + let distance_range = 2f64..max_distance; + let prob_span = prob_range.end - prob_range.start; + let ratio = prob_span / (distance_range.end - distance_range.start); + + Ok(Self { prob_range, distance_range, prob_span, ratio }) + } + + #[inline] + fn clamp_value(&self, x: f64) -> f64 { + if x > self.distance_range.end { + self.distance_range.end + } else { + x + } + } + + fn linear_project_prob(&self, x: f64) -> f64 { + let x = self.clamp_value(x); + let adjusted = ((x - self.distance_range.start) * self.ratio) + + self.prob_range.start; + + 1f64 - adjusted + } + + fn ln_project_prob(&self, x: f64) -> f64 { + if x == 1f64 { + return 1f64 - self.prob_range.start; + } + let x = self.clamp_value(x); + let ln_ratio = + self.distance_range.end.ln() - self.distance_range.start.ln(); + let adjusted = ((x.ln() - self.distance_range.start.ln()) / ln_ratio) + * (self.prob_span) + + self.prob_range.start; + let prob = 1f64 - adjusted; + if prob > 1.0 { + panic!( + "prob should not be >1 x: {x}, prob: {prob}, adjusted \ + {adjusted}" + ) + } + prob + } +} diff --git a/src/lib.rs b/src/lib.rs index 155a399..65193c4 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -22,6 +22,7 @@ pub(crate) mod command_utils; pub mod dmr; /// Contains functions for genome arithmatic/overlaps, etc. pub(crate) mod genome_positions; +mod hmm; pub(crate) mod parsing_utils; mod read_cache; mod read_ids_to_base_mod_probs; diff --git a/src/util.rs b/src/util.rs index f67d4db..5097826 100644 --- a/src/util.rs +++ b/src/util.rs @@ -337,7 +337,7 @@ impl ReferenceRecord { } } -#[derive(Debug, Eq, PartialEq)] +#[derive(new, Debug, Eq, PartialEq)] pub struct Region { pub name: String, pub start: u32,