Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[POC | not ready for merge] Speed up bootstrap using multiprocessing #185

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 31 additions & 3 deletions src/mozanalysis/frequentist_stats/bootstrap.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,15 @@
# file, You can obtain one at http://mozilla.org/MPL/2.0/.
import numpy as np
import pandas as pd
import multiprocessing as mp
import logging

import mozanalysis.bayesian_stats as mabs
from mozanalysis.utils import filter_outliers

logger = logging.getLogger(__name__)
mp.set_start_method("fork")


def compare_branches(
df,
Expand All @@ -17,6 +22,7 @@ def compare_branches(
threshold_quantile=None,
individual_summary_quantiles=mabs.DEFAULT_QUANTILES,
comparative_summary_quantiles=mabs.DEFAULT_QUANTILES,
processes=1,
):
"""Jointly sample bootstrapped statistics then compare them.

Expand Down Expand Up @@ -53,6 +59,9 @@ def compare_branches(
statistics (i.e. the change relative to the reference
branch, probably the control). Change these when making
Bonferroni corrections.
processes (int, optional): Speed up bootstrapping by specifying
more than one process. Set processes=None to use the maximum
number of processes available.

Returns a dictionary:
If ``stat_fn`` returns a scalar (this is the default), then
Expand Down Expand Up @@ -88,6 +97,7 @@ def compare_branches(
stat_fn,
num_samples,
threshold_quantile=threshold_quantile,
processes=processes,
)
for b in branch_list
}
Expand Down Expand Up @@ -144,6 +154,7 @@ def get_bootstrap_samples(
num_samples=10000,
seed_start=None,
threshold_quantile=None,
processes=1,
):
"""Return ``stat_fn`` evaluated on resampled and original data.

Expand All @@ -167,6 +178,9 @@ def get_bootstrap_samples(
in this calculation. By default, use a random seed.
threshold_quantile (float, optional): An optional threshold
quantile, above which to discard outliers. E.g. ``0.9999``.
processes (int, optional): Speed up bootstrapping by specifying
more than one process. Set processes=None to use the maximum
number of processes available.

Returns:
``stat_fn`` evaluated over ``num_samples`` samples.
Expand All @@ -192,9 +206,23 @@ def get_bootstrap_samples(
# Need to ensure every call has a unique, deterministic seed.
seed_range = range(seed_start, seed_start + num_samples)

summary_stat_samples = [
_resample_and_agg_once(data, stat_fn, unique_seed) for unique_seed in seed_range
]
if processes == 1:
summary_stat_samples = [
_resample_and_agg_once(data, stat_fn, unique_seed)
for unique_seed in seed_range
]
else:
global __resample_and_agg_once

def __resample_and_agg_once(unique_seed):
return _resample_and_agg_once(data, stat_fn, unique_seed)

with mp.Pool(processes=processes) as pool:
logger.info(
f"""Started a multiprocessing.pool with {pool._processes} processes to run
bootstrapping."""
)
summary_stat_samples = pool.map(__resample_and_agg_once, seed_range)

summary_df = pd.DataFrame(summary_stat_samples)
if len(summary_df.columns) == 1:
Expand Down
16 changes: 16 additions & 0 deletions tests/frequentist_stats/test_bootstrap.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import numpy as np
import pandas as pd
import pytest
from time import time

import mozanalysis.frequentist_stats.bootstrap as mafsb

Expand Down Expand Up @@ -72,6 +73,21 @@ def test_get_bootstrap_samples_multistat(stack_depth=0):
test_get_bootstrap_samples_multistat(stack_depth + 1)


def test_get_bootstrap_samples_multiprocess():
t0 = time()
res1 = mafsb.get_bootstrap_samples(np.asarray(range(10**5)), processes=1)
t1 = time() - t0

t0 = time()
res2 = mafsb.get_bootstrap_samples(np.asarray(range(10**5)), processes=None)
t2 = time() - t0
assert t2 * 2 < t1 # The multiprocessing speedup is worth it if it's 2x faster.
# Not sure why we're getting different answers, but I believe this discrepancy is
# resolvable.
assert res1[0] == res2[0]
assert res1[1] == res2[1]


def test_bootstrap_one_branch():
data = np.concatenate([np.zeros(10000), np.ones(10000)])
res = mafsb.bootstrap_one_branch(
Expand Down