From afbef6d1f24f5b8bdae29031e4f935d3a4d85f4c Mon Sep 17 00:00:00 2001 From: AHsu98 <34590951+AHsu98@users.noreply.github.com> Date: Thu, 2 Feb 2023 23:38:40 -0800 Subject: [PATCH] Added type hints where appropriate --- src/pydisagg/disaggregate.py | 31 +++++++------ src/pydisagg/models.py | 30 ++++++------ src/pydisagg/splittingmodel.py | 83 +++++++++++++++++++++++----------- 3 files changed, 92 insertions(+), 52 deletions(-) diff --git a/src/pydisagg/disaggregate.py b/src/pydisagg/disaggregate.py index b0bce71..19fa5fd 100644 --- a/src/pydisagg/disaggregate.py +++ b/src/pydisagg/disaggregate.py @@ -1,15 +1,20 @@ -import pandas as pd +from typing import Optional,List +from numpy.typing import NDArray +from pandas import DataFrame + +import pandas as pd +from pydisagg.models import SplittingModel from pydisagg.models import LMO_model def split_datapoint( - measured_count, - bucket_populations, - rate_pattern, - measured_count_se=None, - model=LMO_model(1), - CI_method='delta-wald' + measured_count:float, + bucket_populations:NDArray, + rate_pattern:NDArray, + measured_count_se:Optional[float]=None, + model:Optional[SplittingModel]=LMO_model(1), + CI_method:Optional[str]='delta-wald' ): ''' Disaggregates a datapoint using the model given as input. @@ -29,12 +34,12 @@ def split_datapoint( def split_dataframe( - groups_to_split_into, - observation_group_membership_df, - population_sizes, - baseline_patterns, - use_se=False, - model=LMO_model(1), + groups_to_split_into:list, + observation_group_membership_df:DataFrame, + population_sizes:DataFrame, + baseline_patterns:DataFrame, + use_se:Optional[bool]=False, + model:Optional[SplittingModel]=LMO_model(1), ): ''' Disaggregates datapoints and pivots observations into estimates for each group per pop id diff --git a/src/pydisagg/models.py b/src/pydisagg/models.py index e6ba816..4193f9e 100644 --- a/src/pydisagg/models.py +++ b/src/pydisagg/models.py @@ -1,3 +1,7 @@ +from typing import Optional,List +from numpy.typing import NDArray +from pandas import DataFrame + from pydisagg import transformations from pydisagg.splittingmodel import SplittingModel @@ -12,10 +16,10 @@ class RateMultiplicativeModel(SplittingModel): def __init__( self, - rate_pattern=None, - beta_parameter=None, - error_inflation=None, - beta_standard_error=None + rate_pattern:Optional[NDArray]=None, + beta_parameter:Optional[float]=None, + error_inflation:Optional[float]=None, + beta_standard_error:Optional[float]=None ): super().__init__( parameter_transformation=transformations.LogTransformation(), @@ -33,11 +37,11 @@ class LMO_model(SplittingModel): def __init__( self, - m, - rate_pattern=None, - beta_parameter=None, - error_inflation=None, - beta_standard_error=None + m:float, + rate_pattern:Optional[NDArray]=None, + beta_parameter:Optional[float]=None, + error_inflation:Optional[float]=None, + beta_standard_error:Optional[float]=None ): super().__init__( parameter_transformation=transformations.LogModifiedOddsTransformation(m), @@ -55,10 +59,10 @@ class LogOdds_model(SplittingModel): def __init__( self, - rate_pattern=None, - beta_parameter=None, - error_inflation=None, - beta_standard_error=None + rate_pattern:Optional[NDArray]=None, + beta_parameter:Optional[float]=None, + error_inflation:Optional[float]=None, + beta_standard_error:Optional[float]=None ): super().__init__( parameter_transformation=transformations.LogOddsTransformation(), diff --git a/src/pydisagg/splittingmodel.py b/src/pydisagg/splittingmodel.py index c350b54..ec0a027 100644 --- a/src/pydisagg/splittingmodel.py +++ b/src/pydisagg/splittingmodel.py @@ -1,4 +1,9 @@ import numpy as np +from typing import Optional,List +from numpy.typing import NDArray +from pandas import DataFrame +from pydisagg.transformations import ParameterTransformation + from scipy.optimize import root_scalar from scipy.stats import norm @@ -14,11 +19,11 @@ class SplittingModel: def __init__( self, - parameter_transformation=None, - rate_pattern=None, - beta_parameter=None, - error_inflation=None, - beta_standard_error=None + parameter_transformation:ParameterTransformation, + rate_pattern:Optional[NDArray]=None, + beta_parameter:Optional[float]=None, + error_inflation:Optional[float]=None, + beta_standard_error:Optional[float]=None ): self.rate_pattern = rate_pattern self.beta_parameter = beta_parameter @@ -30,7 +35,10 @@ def __init__( self.T_inverse = self.parameter_transformation.inverse self.T_diff = self.parameter_transformation.diff - def pull_beta(self, beta): + def pull_beta( + self, + beta + ): ''' Checks whether beta parameter is available in input, or if it is null and returns beta if it is not none. If beta is none, then this will try and return @@ -43,7 +51,10 @@ def pull_beta(self, beta): else: raise Exception("Not fitted, No Beta Parameter Available") - def pull_set_rate_pattern(self, rate_pattern): + def pull_set_rate_pattern( + self, + rate_pattern:NDArray + ): ''' Checks whether rate_pattern parameter is available in input, or if it is None if rate_pattern is not none, it will return it and set it as self.rate_pattern @@ -59,7 +70,11 @@ def pull_set_rate_pattern(self, rate_pattern): else: raise Exception("No Rate Pattern Available") - def predict_rates(self, beta=None, rate_pattern=None): + def predict_rates( + self, + beta:Optional[float]=None, + rate_pattern:Optional[NDArray]=None + ): ''' Generates a predicted rate within each bucket assuming multiplicativity in the T-transformed space with the additive parameter @@ -100,13 +115,13 @@ def _H_diff(self, beta, bucket_populations): def fit_beta( self, - bucket_populations, - measured_count, - measured_count_se=None, - rate_pattern=None, - lower_guess=-50, - upper_guess=50, - verbose=0 + bucket_populations:NDArray, + measured_count:float, + measured_count_se:Optional[float]=None, + rate_pattern:Optional[NDArray]=None, + lower_guess:Optional[float]=-50, + upper_guess:Optional[float]=50, + verbose:Optional[int]=0 ): ''' Fits a value for beta from the age density of a population and a measured count @@ -147,7 +162,10 @@ def predict_rates_SE(self): raise Exception("No Beta Standard Error is available") return self._predict_rates_SE(self.beta_parameter, self.beta_standard_error) - def predict_rates_CI(self, alpha=0.05, method='delta-wald'): + def predict_rates_CI( + self, + alpha:Optional[float]=0.05, + method:Optional[str]='delta-wald'): ''' Computes a 1-alpha confidence interval on the rate function from the standard error on beta @@ -186,10 +204,16 @@ def predict_rates_CI(self, alpha=0.05, method='delta-wald'): return (lower_rate, upper_rate) - def predict_count(self, bucket_populations): + def predict_count( + self, + bucket_populations:NDArray + ): return self.predict_rates()*bucket_populations - def predict_total_count_SE(self, bucket_populations): + def predict_total_count_SE( + self, + bucket_populations:NDArray + ): ''' Computes the standard error of the total number of events given an age density using delta method on H @@ -206,7 +230,10 @@ def predict_total_count_SE(self, bucket_populations): return self._H_diff(self.beta_parameter, bucket_populations)*self.beta_standard_error - def predict_count_SE(self, bucket_populations): + def predict_count_SE( + self, + bucket_populations:NDArray + ): ''' Computes the standard error of the number events in each bucket given an age density using delta method on H @@ -219,7 +246,11 @@ def predict_count_SE(self, bucket_populations): return self.predict_rates_SE()*bucket_populations - def predict_count_CI(self, bucket_populations, alpha=0.05, method='delta-wald'): + def predict_count_CI( + self, + bucket_populations:NDArray, + alpha:Optional[float]=0.05, + method:Optional[str]='delta-wald'): ''' Computes a (one minus alpha) confidence interval on the events occuring in a population given an an age density from the standard error on beta @@ -248,12 +279,12 @@ def predict_count_CI(self, bucket_populations, alpha=0.05, method='delta-wald'): def split_groups( self, - bucket_populations, - measured_count=None, - measured_count_se=None, - rate_pattern=None, - CI_method='delta-wald', - alpha=0.05 + bucket_populations:NDArray, + measured_count:Optional[float]=None, + measured_count_se:Optional[float]=None, + rate_pattern:Optional[NDArray]=None, + CI_method:Optional[str]='delta-wald', + alpha:Optional[float]=0.05 ): ''' Splits measured_count into the given bucket populations