Skip to content

Commit

Permalink
ruff check and format
Browse files Browse the repository at this point in the history
  • Loading branch information
janfb committed Jan 23, 2025
1 parent bfd09cf commit d57b06d
Show file tree
Hide file tree
Showing 67 changed files with 447 additions and 612 deletions.
30 changes: 0 additions & 30 deletions .github/workflows/lint.yml
Original file line number Diff line number Diff line change
Expand Up @@ -26,33 +26,3 @@ jobs:
- uses: pre-commit/[email protected]
with:
extra_args: --all-files --show-diff-on-failure

pyright:
name: Check types
runs-on: ubuntu-latest
steps:
- name: Checkout
uses: actions/checkout@v4
with:
fetch-depth: 0
lfs: false

- name: Set up Python
uses: actions/setup-python@v5
with:
python-version: '3.9'

- name: Cache dependency
id: cache-dependencies
uses: actions/cache@v4
with:
path: ~/.cache/pip
key: ubuntu-latest-pip-3.9
restore-keys: |
ubuntu-latest-pip-
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install -e .[dev]
18 changes: 18 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
repos:
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.9.0
hooks:
- id: ruff
- id: ruff-format
args: [--diff]
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v5.0.0
hooks:
- id: check-added-large-files
- id: check-merge-conflict
- id: check-toml
- id: check-yaml
- id: end-of-file-fixer
- id: mixed-line-ending
args: [--fix=lf]
- id: trailing-whitespace
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# v1.1.0
# v1.1.0

- Fixed Gaussian Mixture task for `simulation_batch_size` > 1 (thanks to @h3jia, #54, #63). Since experiments in the paper were run with a simulation batch size of 1000, this has an effect on the results. We will issue an update of the results.
- Additional changes for compatibility with `sbi` v0.21.0 (thanks to @bkmi, #60, @janfb #55, #57, #59)
Expand Down
10 changes: 5 additions & 5 deletions CONTRIBUTING.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# Contributing

We invite contributions of new algorithms, tasks, and metrics. Please do not hestitate to get in touch [via email](mailto:[email protected],[email protected]) or by opening an issue on the repository.
We invite contributions of new algorithms, tasks, and metrics. Please do not hestitate to get in touch [via email](mailto:[email protected],[email protected]) or by opening an issue on the repository.

## Setup

Expand All @@ -19,19 +19,19 @@ source my-branch-venv/bin/activate
Upgrade pip:

```bash
python -m pip install --upgrade pip
python -m pip install --upgrade pip
```

Install dependencies including extras for development:

```bash
python -m pip install -e .[dev]
python -m pip install -e .[dev]
```

Run the tests (this currently takes quite a while):

```bash
python -m pytest -x .
python -m pytest -x .
```

To complete, run house keeping apps and try to make all errors disappear:
Expand All @@ -51,7 +51,7 @@ Now commit and push your changes to github and open a PR. Thank you!

Adding new tasks is straightforward. It is easiest to model them after existing tasks. First, take a close look at the base class for tasks in `sbibm/tasks/task.py`: you will find a `_setup` method: This method samples from the prior, generates observations, and finally calls `_sample_reference_posterior`, to generate samples from the reference posterior. All of these results are stored in csv files, and the generation of reference posterior samples happens in parallel.

For some tasks, e.g., the `gaussian_linear`, a closed form solution for the posterior is available, which is used in `_sample_reference_posterior`, while other tasks utilize MCMC.
For some tasks, e.g., the `gaussian_linear`, a closed form solution for the posterior is available, which is used in `_sample_reference_posterior`, while other tasks utilize MCMC.

Note also that each individual tasks ends with a `if __name__ == "__main__"` block at the end which calls `_setup`. This means that `_setup` is executed by calling `python sbibm/tasks/task_name/task.py`. This step overrides the existing reference posterior data, which is in the subfolder `sbibm/tasks/task_name/files/`. It should only be executed whenever a task is changed (and never by a user).

Expand Down
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ c2st_accuracy = c2st(reference_samples, posterior_samples)

# Visualise both posteriors:
from sbibm.visualisation import fig_posterior
fig = fig_posterior(task_name="two_moons", observation=1, samples=[posterior_samples])
fig = fig_posterior(task_name="two_moons", observation=1, samples=[posterior_samples])
# Note: Use fig.show() or fig.save() to show or save the figure

# Get results from other algorithms for comparison:
Expand Down Expand Up @@ -174,7 +174,7 @@ The manuscript is [available through PMLR](http://proceedings.mlr.press/v130/lue
series = {Proceedings of Machine Learning Research},
month = {13--15 Apr},
publisher = {PMLR}
}
}
```


Expand Down
7 changes: 5 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ tracker = "https://github.com/sbi-benchmark/sbibm/issues"

[tool.setuptools.packages.find]
include = ["sbibm*"]
exclude = ["tests", "tests.*"]
exclude = ["tests", "tests.*", "sbibm/third_party"]

[tool.setuptools.dynamic]
version = {attr = "sbibm.__version__.__version__"}
Expand All @@ -78,16 +78,19 @@ line-length = 88

[tool.ruff.lint]
# pycodestyle, Pyflakes, pyupgrade, flake8-bugbear, flake8-simplify, isort
select = ["E", "F", "W", "B", "SIM", "I"]
select = ["E", "F", "W", "B", "I"]
ignore = [
"E731", # allow naming lambda functions.
"B008", # allow function calls in default args.
"E722", # allow bare except.
"E721", # allow == for type comparison.
]

[tool.ruff.lint.extend-per-file-ignores]
"__init__.py" = ["E402", "F401", "F403"] # allow unused imports and undefined names
"test_*.py" = ["F403", "F405"]
"tutorials/*.ipynb" = ["E501"] # allow long lines in notebooks
"sbibm/third_party/*" = ["E402", "E501", "F811", "E741", "B904"]

[tool.ruff.lint.isort]
case-sensitive = true
Expand Down
8 changes: 5 additions & 3 deletions sbibm/algorithms/pyro/mcmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,9 @@ def run(

if "num_simulations" in kwargs:
warnings.warn(
"`num_simulations` was passed as a keyword but will be ignored, see docstring for more info."
"`num_simulations` was passed as a keyword but will be ignored, "
"see docstring for more info.",
stacklevel=2,
)

# Prepare model and transforms
Expand Down Expand Up @@ -146,7 +148,7 @@ def run(
mcmc.run()

toc = time.time()
log.info(f"Finished MCMC after {toc-tic:.3f} seconds")
log.info(f"Finished MCMC after {toc - tic:.3f} seconds")
log.info(f"Automatic transforms {mcmc.transforms}")

log.info(f"Apply thinning of {thinning}")
Expand All @@ -156,7 +158,7 @@ def run(
mcmc._samples["parameters"].shape[0] * mcmc._samples["parameters"].shape[1]
)
if num_samples_available < num_samples:
warnings.warn("Some samples will be included multiple times")
warnings.warn("Some samples will be included multiple times", stacklevel=2)
samples = mcmc.get_samples(num_samples=num_samples, group_by_chain=False)[
"parameters"
].squeeze()
Expand Down
12 changes: 7 additions & 5 deletions sbibm/algorithms/pyro/utils/tensorboard.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def tb_acf(writer, mcmc, site_name="parameters", num_samples=1000, maxlags=50):
fig = plt.figure()
plt.gca().acorr(samples[c, :].squeeze()[:, p].numpy(), maxlags=maxlags)
writer.add_figure(
f"acf/chain {c+1}/parameter {p+1}",
f"acf/chain {c + 1}/parameter {p + 1}",
fig,
close=True,
)
Expand All @@ -32,7 +32,7 @@ def tb_posteriors(writer, mcmc, site_name="parameters", num_samples=1000):
samples = mcmc.get_samples(num_samples=num_samples, group_by_chain=True)[site_name]
for c in range(samples.shape[0]):
tb_plot_posterior(
writer=writer, samples=samples[c, :], tag=f"posterior/chain {c+1}"
writer=writer, samples=samples[c, :], tag=f"posterior/chain {c + 1}"
)


Expand All @@ -41,7 +41,7 @@ def tb_marginals(writer, mcmc, site_name="parameters", num_samples=1000):
for c in range(samples.shape[0]):
for p in range(samples.shape[-1]):
writer.add_histogram(
f"marginal/{site_name}/{p+1}",
f"marginal/{site_name}/{p + 1}",
samples[c, :].squeeze()[:, p],
c,
)
Expand Down Expand Up @@ -72,7 +72,7 @@ def hook_fn(kernel, samples, stage, i):
for p in range(len(samples[site_name].squeeze())):
# Trace
writer.add_scalar(
f"{stage_prefix}/chain/{num_chain+1}/trace/{site_name}/{p+1}",
f"{stage_prefix}/chain/{num_chain + 1}/trace/{site_name}/{p + 1}",
samples_inv.squeeze()[p],
i,
)
Expand All @@ -82,7 +82,9 @@ def hook_fn(kernel, samples, stage, i):
if stage_prefix == "warmup":
kernel_width = kernel._width[p]
writer.add_scalar(
f"{stage_prefix}/chain/{num_chain+1}/bracket width/{site_name}/{p+1}",
f"{stage_prefix}/chain/{num_chain + 1}/bracket width/{
site_name
}/{p + 1}",
kernel_width,
i,
)
Expand Down
8 changes: 4 additions & 4 deletions sbibm/algorithms/pytorch/baseline_grid.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,9 +79,9 @@ def run(

# Construct grid
grid = torch.stack(
torch.meshgrid(
[torch.linspace(low[d], high[d], resolution) for d in range(dim_parameters)]
)
torch.meshgrid([
torch.linspace(low[d], high[d], resolution) for d in range(dim_parameters)
])
) # dim_parameters x resolution x ... x resolution
grid_flat = grid.view(
dim_parameters, -1
Expand Down Expand Up @@ -123,6 +123,6 @@ def run(
log.info(f"Unique samples: {num_unique_samples}")

toc = time.time()
log.info(f"Finished after {toc-tic:.3f} seconds")
log.info(f"Finished after {toc - tic:.3f} seconds")

return samples
5 changes: 3 additions & 2 deletions sbibm/algorithms/pytorch/baseline_posterior.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,9 @@ def run(
log = sbibm.get_logger(__name__)

if "num_simulations" in kwargs:
log.warn(
"`num_simulations` was passed as a keyword but will be ignored, since this is a baseline method."
log.warning(
"`num_simulations` was passed as a keyword but will be ignored, "
"since this is a baseline method."
)

if rerun:
Expand Down
5 changes: 3 additions & 2 deletions sbibm/algorithms/pytorch/baseline_prior.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,9 @@ def run(
log = sbibm.get_logger(__name__)

if "num_simulations" in kwargs:
log.warn(
"`num_simulations` was passed as a keyword but will be ignored, since this is a baseline method."
log.warning(
"`num_simulations` was passed as a keyword but will be ignored, "
"since this is a baseline method."
)

prior = task.get_prior()
Expand Down
11 changes: 6 additions & 5 deletions sbibm/algorithms/pytorch/baseline_rejection.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,9 @@ def run(
log.info("Rejection sampling")

if "num_simulations" in kwargs:
log.warn(
"`num_simulations` was passed as a keyword but will be ignored, since this is a baseline method."
log.warning(
"`num_simulations` was passed as a keyword but will be ignored, "
"since this is a baseline method."
)

prior = task.get_prior()
Expand Down Expand Up @@ -114,13 +115,13 @@ def run(
samples.append(proposal[accept_idxs].detach())
pbar.update(len(accept_idxs))
pbar.set_postfix_str(
s=f"Acceptance rate: {num_accepted/num_sims:.9f}", refresh=True
s=f"Acceptance rate: {num_accepted / num_sims:.9f}", refresh=True
)

pbar.close()

log.info(f"Acceptance rate: {num_accepted/num_sims:.9f}")
log.info(f"Finished after {time.time()-tic:.3f} seconds")
log.info(f"Acceptance rate: {num_accepted / num_sims:.9f}")
log.info(f"Finished after {time.time() - tic:.3f} seconds")

samples = torch.cat(samples)[:num_samples, :]
assert samples.shape[0] == num_samples
Expand Down
2 changes: 1 addition & 1 deletion sbibm/algorithms/pytorch/baseline_runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def run(
batch_size = min(batch_size, num_simulations)
num_batches = int(num_simulations / batch_size)

for i in tqdm(range(num_batches)):
for _i in tqdm(range(num_batches)):
_ = simulator(prior(num_samples=batch_size))

assert simulator.num_simulations == num_simulations
Expand Down
4 changes: 2 additions & 2 deletions sbibm/algorithms/pytorch/baseline_sir.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def run(

particles = []
log_weights = []
for i in tqdm(range(num_batches)):
for _i in tqdm(range(num_batches)):
batch_draws = proposal_dist.sample((batch_size,))
log_weights.append(
log_prob_fn(batch_draws) - proposal_dist.log_prob(batch_draws)
Expand All @@ -89,6 +89,6 @@ def run(
log.info(f"Unique particles: {num_unique} out of {len(samples)}")

toc = time.time()
log.info(f"Finished after {toc-tic:.3f} seconds")
log.info(f"Finished after {toc - tic:.3f} seconds")

return samples
13 changes: 5 additions & 8 deletions sbibm/algorithms/pytorch/utils/proposal.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ def get_proposal(
prior_weight=prior_weight,
)

log.info(f"Proposal distribution is set up, took {time.time()-tic:.3f}sec")
log.info(f"Proposal distribution is set up, took {time.time() - tic:.3f}sec")

return proposal_dist

Expand Down Expand Up @@ -112,13 +112,10 @@ def sample_proposal(self, num_samples):

def log_prob(self, parameters):
return torch.logsumexp(
torch.stack(
[
math.log(self.prior_weight) + self.log_prob_prior(parameters),
math.log(1.0 - self.prior_weight)
+ self.log_prob_proposal(parameters),
]
),
torch.stack([
math.log(self.prior_weight) + self.log_prob_prior(parameters),
math.log(1.0 - self.prior_weight) + self.log_prob_proposal(parameters),
]),
dim=0,
)

Expand Down
Loading

0 comments on commit d57b06d

Please sign in to comment.