Skip to content

Commit 8fef7be

Browse files
committed
isort.
1 parent 7f16f1d commit 8fef7be

File tree

7 files changed

+15
-23
lines changed

7 files changed

+15
-23
lines changed

notebooks/mnle_utils.py

+1-3
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,9 @@
22

33
import numpy as np
44
import torch
5-
6-
from torch import Tensor, nn
75
from sbi.utils.sbiutils import standardizing_net
6+
from torch import Tensor, nn
87
from torch.distributions import Bernoulli
9-
from torch import Tensor
108

119

1210
def build_choice_net(

notebooks/revision-experiments/lan_generate_data.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,10 @@
44

55
# Load necessary packages
66
from copy import deepcopy
7-
import torch
87

98
import sbibm
109
import ssms
10+
import torch
1111

1212
# Get benchmark task to load Julia simulator.
1313
seed = torch.randint(100000, (1,)).item()

notebooks/revision-experiments/lan_run_inference.py

+2-3
Original file line numberDiff line numberDiff line change
@@ -3,14 +3,13 @@
33
# function wrapper in utils.
44

55
import pickle
6-
from pathlib import Path
7-
from joblib import Parallel, delayed
86
import sys
7+
from pathlib import Path
98

109
import lanfactory
1110
import sbibm
1211
import torch
13-
12+
from joblib import Parallel, delayed
1413
from sbi.inference import MCMCPosterior
1514
from sbi.utils import mcmc_transform
1615

notebooks/revision-experiments/lan_run_training.py

+4-5
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,14 @@
11
# Adapted from: https://github.com/AlexanderFengler/LANfactory
22
# Script to run LAN training given a pre-simulated data set.
33

4-
# Load necessary packages
5-
import lanfactory
64
import os
75
from copy import deepcopy
8-
import torch
9-
10-
from lanfactory.trainers import ModelTrainerTorchMLP
116
from pathlib import Path
127

8+
# Load necessary packages
9+
import lanfactory
10+
import torch
11+
from lanfactory.trainers import ModelTrainerTorchMLP
1312

1413
BASE_DIR = Path.cwd()
1514

notebooks/revision-experiments/mnle_run_inference.py

-1
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55

66
import sbibm
77
import torch
8-
98
from sbi.inference import MNLE
109

1110
# Get benchmark task to load observations

notebooks/revision-experiments/mnle_training_script.py

+2-3
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,13 @@
11
# Script for training MNLE with pre-simulated data.
22

33
import pickle
4-
import torch
4+
from pathlib import Path
55

6+
import torch
67
from joblib import Parallel, delayed
7-
from pathlib import Path
88
from sbi.inference import MNLE
99
from sbi.utils import likelihood_nn
1010

11-
1211
BASE_DIR = Path(__file__).resolve().parent.parent.parent.as_posix()
1312
data_folder = BASE_DIR + "/data/"
1413
save_folder = BASE_DIR + "/notebooks/mnle-lan-comparison/models/"

notebooks/utils.py

+5-7
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,19 @@
11
import math
22
import os
3-
from pathlib import Path
43
import pickle
4+
from pathlib import Path
55
from typing import Any, Tuple
66

7+
import lanfactory
8+
import matplotlib.pyplot as plt
79
import numpy as np
810
import pandas as pd
9-
import matplotlib.pyplot as plt
1011
import torch
1112
from omegaconf import OmegaConf
13+
from sbi.inference.potentials.base_potential import BasePotential
1214
from sbibm.utils.io import get_float_from_csv
13-
from tqdm.auto import tqdm
14-
1515
from torch.distributions.transforms import AffineTransform
16-
17-
import lanfactory
18-
from sbi.inference.potentials.base_potential import BasePotential
16+
from tqdm.auto import tqdm
1917

2018

2119
def compile_df(

0 commit comments

Comments
 (0)