Skip to content

Commit

Permalink
Merge pull request #14 from ihmeuw-msca/add_type_hints
Browse files Browse the repository at this point in the history
Added type hints where appropriate
  • Loading branch information
AHsu98 authored Feb 3, 2023
2 parents a173b0b + afbef6d commit 91e99c8
Show file tree
Hide file tree
Showing 3 changed files with 92 additions and 52 deletions.
31 changes: 18 additions & 13 deletions src/pydisagg/disaggregate.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,20 @@
import pandas as pd
from typing import Optional,List
from numpy.typing import NDArray
from pandas import DataFrame


import pandas as pd
from pydisagg.models import SplittingModel
from pydisagg.models import LMO_model


def split_datapoint(
measured_count,
bucket_populations,
rate_pattern,
measured_count_se=None,
model=LMO_model(1),
CI_method='delta-wald'
measured_count:float,
bucket_populations:NDArray,
rate_pattern:NDArray,
measured_count_se:Optional[float]=None,
model:Optional[SplittingModel]=LMO_model(1),
CI_method:Optional[str]='delta-wald'
):
'''
Disaggregates a datapoint using the model given as input.
Expand All @@ -29,12 +34,12 @@ def split_datapoint(


def split_dataframe(
groups_to_split_into,
observation_group_membership_df,
population_sizes,
baseline_patterns,
use_se=False,
model=LMO_model(1),
groups_to_split_into:list,
observation_group_membership_df:DataFrame,
population_sizes:DataFrame,
baseline_patterns:DataFrame,
use_se:Optional[bool]=False,
model:Optional[SplittingModel]=LMO_model(1),
):
'''
Disaggregates datapoints and pivots observations into estimates for each group per pop id
Expand Down
30 changes: 17 additions & 13 deletions src/pydisagg/models.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
from typing import Optional,List
from numpy.typing import NDArray
from pandas import DataFrame

from pydisagg import transformations
from pydisagg.splittingmodel import SplittingModel

Expand All @@ -12,10 +16,10 @@ class RateMultiplicativeModel(SplittingModel):

def __init__(
self,
rate_pattern=None,
beta_parameter=None,
error_inflation=None,
beta_standard_error=None
rate_pattern:Optional[NDArray]=None,
beta_parameter:Optional[float]=None,
error_inflation:Optional[float]=None,
beta_standard_error:Optional[float]=None
):
super().__init__(
parameter_transformation=transformations.LogTransformation(),
Expand All @@ -33,11 +37,11 @@ class LMO_model(SplittingModel):

def __init__(
self,
m,
rate_pattern=None,
beta_parameter=None,
error_inflation=None,
beta_standard_error=None
m:float,
rate_pattern:Optional[NDArray]=None,
beta_parameter:Optional[float]=None,
error_inflation:Optional[float]=None,
beta_standard_error:Optional[float]=None
):
super().__init__(
parameter_transformation=transformations.LogModifiedOddsTransformation(m),
Expand All @@ -55,10 +59,10 @@ class LogOdds_model(SplittingModel):

def __init__(
self,
rate_pattern=None,
beta_parameter=None,
error_inflation=None,
beta_standard_error=None
rate_pattern:Optional[NDArray]=None,
beta_parameter:Optional[float]=None,
error_inflation:Optional[float]=None,
beta_standard_error:Optional[float]=None
):
super().__init__(
parameter_transformation=transformations.LogOddsTransformation(),
Expand Down
83 changes: 57 additions & 26 deletions src/pydisagg/splittingmodel.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,9 @@
import numpy as np
from typing import Optional,List
from numpy.typing import NDArray
from pandas import DataFrame
from pydisagg.transformations import ParameterTransformation

from scipy.optimize import root_scalar
from scipy.stats import norm

Expand All @@ -14,11 +19,11 @@ class SplittingModel:

def __init__(
self,
parameter_transformation=None,
rate_pattern=None,
beta_parameter=None,
error_inflation=None,
beta_standard_error=None
parameter_transformation:ParameterTransformation,
rate_pattern:Optional[NDArray]=None,
beta_parameter:Optional[float]=None,
error_inflation:Optional[float]=None,
beta_standard_error:Optional[float]=None
):
self.rate_pattern = rate_pattern
self.beta_parameter = beta_parameter
Expand All @@ -30,7 +35,10 @@ def __init__(
self.T_inverse = self.parameter_transformation.inverse
self.T_diff = self.parameter_transformation.diff

def pull_beta(self, beta):
def pull_beta(
self,
beta
):
'''
Checks whether beta parameter is available in input, or if it is null
and returns beta if it is not none. If beta is none, then this will try and return
Expand All @@ -43,7 +51,10 @@ def pull_beta(self, beta):
else:
raise Exception("Not fitted, No Beta Parameter Available")

def pull_set_rate_pattern(self, rate_pattern):
def pull_set_rate_pattern(
self,
rate_pattern:NDArray
):
'''
Checks whether rate_pattern parameter is available in input, or if it is None
if rate_pattern is not none, it will return it and set it as self.rate_pattern
Expand All @@ -59,7 +70,11 @@ def pull_set_rate_pattern(self, rate_pattern):
else:
raise Exception("No Rate Pattern Available")

def predict_rates(self, beta=None, rate_pattern=None):
def predict_rates(
self,
beta:Optional[float]=None,
rate_pattern:Optional[NDArray]=None
):
'''
Generates a predicted rate within each bucket assuming
multiplicativity in the T-transformed space with the additive parameter
Expand Down Expand Up @@ -100,13 +115,13 @@ def _H_diff(self, beta, bucket_populations):

def fit_beta(
self,
bucket_populations,
measured_count,
measured_count_se=None,
rate_pattern=None,
lower_guess=-50,
upper_guess=50,
verbose=0
bucket_populations:NDArray,
measured_count:float,
measured_count_se:Optional[float]=None,
rate_pattern:Optional[NDArray]=None,
lower_guess:Optional[float]=-50,
upper_guess:Optional[float]=50,
verbose:Optional[int]=0
):
'''
Fits a value for beta from the age density of a population and a measured count
Expand Down Expand Up @@ -147,7 +162,10 @@ def predict_rates_SE(self):
raise Exception("No Beta Standard Error is available")
return self._predict_rates_SE(self.beta_parameter, self.beta_standard_error)

def predict_rates_CI(self, alpha=0.05, method='delta-wald'):
def predict_rates_CI(
self,
alpha:Optional[float]=0.05,
method:Optional[str]='delta-wald'):
'''
Computes a 1-alpha confidence interval on the rate function
from the standard error on beta
Expand Down Expand Up @@ -186,10 +204,16 @@ def predict_rates_CI(self, alpha=0.05, method='delta-wald'):

return (lower_rate, upper_rate)

def predict_count(self, bucket_populations):
def predict_count(
self,
bucket_populations:NDArray
):
return self.predict_rates()*bucket_populations

def predict_total_count_SE(self, bucket_populations):
def predict_total_count_SE(
self,
bucket_populations:NDArray
):
'''
Computes the standard error of the total number of events given an age density
using delta method on H
Expand All @@ -206,7 +230,10 @@ def predict_total_count_SE(self, bucket_populations):

return self._H_diff(self.beta_parameter, bucket_populations)*self.beta_standard_error

def predict_count_SE(self, bucket_populations):
def predict_count_SE(
self,
bucket_populations:NDArray
):
'''
Computes the standard error of the number events in each bucket given an age density
using delta method on H
Expand All @@ -219,7 +246,11 @@ def predict_count_SE(self, bucket_populations):

return self.predict_rates_SE()*bucket_populations

def predict_count_CI(self, bucket_populations, alpha=0.05, method='delta-wald'):
def predict_count_CI(
self,
bucket_populations:NDArray,
alpha:Optional[float]=0.05,
method:Optional[str]='delta-wald'):
'''
Computes a (one minus alpha) confidence interval on the events occuring in a population
given an an age density from the standard error on beta
Expand Down Expand Up @@ -248,12 +279,12 @@ def predict_count_CI(self, bucket_populations, alpha=0.05, method='delta-wald'):

def split_groups(
self,
bucket_populations,
measured_count=None,
measured_count_se=None,
rate_pattern=None,
CI_method='delta-wald',
alpha=0.05
bucket_populations:NDArray,
measured_count:Optional[float]=None,
measured_count_se:Optional[float]=None,
rate_pattern:Optional[NDArray]=None,
CI_method:Optional[str]='delta-wald',
alpha:Optional[float]=0.05
):
'''
Splits measured_count into the given bucket populations
Expand Down

0 comments on commit 91e99c8

Please sign in to comment.