Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions HISTORY.rst
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ History
0.1.10 (2024-??-??)
------------------
* Long EM and RPCA operations wrapped with tqdm progress bars
* Readme code sample updated, and results table made consistant

0.1.9 (2024-08-29)
------------------
Expand Down
3 changes: 1 addition & 2 deletions README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -88,9 +88,8 @@ With just these few lines of code, you can see how easy it is to
generator_holes = missing_patterns.EmpiricalHoleGenerator(n_splits=4, ratio_masked=0.1)
comparison = comparator.Comparator(
dict_imputers,
columns,
generator_holes = generator_holes,
metrics = ["mae", "wmape", "kl_columnwise", "ks_test", "energy"],
metrics = ["mae", "wmape", "kl_columnwise", "frechet"],
)
results = comparison.compare(df_with_nan)
results.style.highlight_min(color="lightsteelblue", axis=1)
Expand Down
Binary file modified docs/images/readme_tabular_comparison.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
1 change: 0 additions & 1 deletion examples/tutorials/plot_tuto_benchmark_TS.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,6 @@

comparison = comparator.Comparator(
dict_imputers,
cols_to_impute,
generator_holes=generator_holes,
metrics=["mae", "wmape", "kl_columnwise", "wasserstein_columnwise"],
max_evals=10,
Expand Down
1 change: 0 additions & 1 deletion examples/tutorials/plot_tuto_categorical.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,6 @@

comparison = comparator.Comparator(
dict_imputers,
cols_to_impute,
generator_holes=generator_holes,
metrics=metrics,
max_evals=2,
Expand Down
2 changes: 0 additions & 2 deletions examples/tutorials/plot_tuto_diffusion_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,6 @@

comparison = comparator.Comparator(
dict_imputers,
selected_columns=df_data.columns,
generator_holes=missing_patterns.UniformHoleGenerator(n_splits=2, random_state=rng),
metrics=["mae", "kl_columnwise"],
)
Expand Down Expand Up @@ -224,7 +223,6 @@

comparison = comparator.Comparator(
dict_imputers,
selected_columns=df_data.columns,
generator_holes=missing_patterns.UniformHoleGenerator(n_splits=2, random_state=rng),
metrics=["mae", "kl_columnwise"],
)
Expand Down
1 change: 0 additions & 1 deletion examples/tutorials/plot_tuto_mean_median.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,6 @@

comparison = comparator.Comparator(
dict_imputers,
cols_to_impute,
generator_holes=generator_holes,
metrics=metrics,
max_evals=5,
Expand Down
5 changes: 0 additions & 5 deletions qolmat/benchmark/comparator.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,6 @@ class Comparator:
----------
dict_models: Dict[str, any]
dictionary of imputation methods
selected_columns: List[str]Œ
list of column's names selected (all with at least one null value will
be imputed)
columnwise_evaluation : Optional[bool], optional
whether the metric should be calculated column-wise or not,
by default False
Expand All @@ -46,7 +43,6 @@ class Comparator:
def __init__(
self,
dict_models: Dict[str, Any],
selected_columns: List[str],
generator_holes: _HoleGenerator,
metrics: List = ["mae", "wmape", "kl_columnwise"],
dict_config_opti: Optional[Dict[str, Any]] = {},
Expand All @@ -55,7 +51,6 @@ def __init__(
verbose: bool = False,
):
self.dict_imputers = dict_models
self.selected_columns = selected_columns
self.generator_holes = generator_holes
self.metrics = metrics
self.dict_config_opti = dict_config_opti
Expand Down
13 changes: 11 additions & 2 deletions qolmat/benchmark/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -835,6 +835,7 @@ def sum_pairwise_distances(
def frechet_distance_base(
df1: pd.DataFrame,
df2: pd.DataFrame,
df_mask: pd.DataFrame,
) -> pd.Series:
"""Compute the Fréchet distance between two dataframes df1 and df2.

Expand All @@ -853,16 +854,24 @@ def frechet_distance_base(
true dataframe
df2 : pd.DataFrame
predicted dataframe
df_mask : pd.DataFrame
Elements of the dataframes to compute on

Returns
-------
pd.Series
Frechet distance in a Series object

"""
if df1.shape != df2.shape:
if df1.shape != df2.shape or df1.shape != df_mask.shape:
raise Exception("inputs have to be of same dimensions.")

df1 = df1.copy()
df2 = df2.copy()
# Set to nan the values not in the mask
df1[~df_mask] = np.nan
df2[~df_mask] = np.nan

std = (np.std(df1) + np.std(df2) + EPS) / 2
mu = (np.nanmean(df1, axis=0) + np.nanmean(df2, axis=0)) / 2
df1 = (df1 - mu) / std
Expand Down Expand Up @@ -911,7 +920,7 @@ def frechet_distance(

"""
if method == "single":
return frechet_distance_base(df1, df2)
return frechet_distance_base(df1, df2, df_mask)
return pattern_based_weighted_mean_metric(
df1,
df2,
Expand Down
33 changes: 16 additions & 17 deletions qolmat/imputations/imputers_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import pandas as pd
from numpy.typing import NDArray
from sklearn.preprocessing import StandardScaler
from tqdm import tqdm

# from typing_extensions import Self
from qolmat.benchmark import metrics
Expand Down Expand Up @@ -106,23 +107,21 @@ def _fit_estimator(
optimizer = optim.Adam(estimator.parameters(), lr=self.learning_rate)
loss_fn = self.loss_fn

for epoch in range(self.epochs):
estimator.train()
optimizer.zero_grad()

input_data = torch.Tensor(X.values)
target_data = torch.Tensor(y.values)
target_data = target_data.unsqueeze(1)
outputs = estimator(input_data)
loss = loss_fn(outputs, target_data)

loss.backward()
optimizer.step()
if (epoch + 1) % 10 == 0:
logging.info(
f"Epoch [{epoch + 1}/{self.epochs}], "
f"Loss: {loss.item():.4f}"
)
with tqdm(total=self.epochs, desc="Training", unit="epoch") as pbar:
for _ in range(self.epochs):
estimator.train()
optimizer.zero_grad()

input_data = torch.Tensor(X.values)
target_data = torch.Tensor(y.values)
target_data = target_data.unsqueeze(1)
outputs = estimator(input_data)
loss = loss_fn(outputs, target_data)

loss.backward()
optimizer.step()
pbar.set_postfix(loss=f"{loss.item():.4f}")
pbar.update(1)
return estimator

def _predict_estimator(
Expand Down
120 changes: 63 additions & 57 deletions qolmat/imputations/rpca/rpca_noisy.py
Original file line number Diff line number Diff line change
Expand Up @@ -317,67 +317,73 @@ def minimise_loss(
Ir = np.eye(rank)
In = identity(n_rows)

for _ in tqdm(
range(max_iterations),
with tqdm(
total=max_iterations,
desc="Noisy RPCA loss minimization",
unit="iteration",
disable=not verbose,
):
M_temp = M.copy()
A_temp = A.copy()
L_temp = L.copy()
Q_temp = Q.copy()
if norm == "L1":
R_temp = R.copy()
sums = np.zeros((n_rows, n_cols))
for i_period, _ in enumerate(list_periods):
sums += mu * R[i_period] - list_H[i_period] @ Y

M = spsolve(
(1 + mu) * In + HtH,
D - A + mu * L @ Q - Y + sums,
)
else:
M = spsolve(
(1 + mu) * In + 2 * HtH,
D - A + mu * L @ Q - Y,
)
M = M.reshape(D.shape)

A_Omega = rpca_utils.soft_thresholding(D - M, lam)
A_Omega_C = D - M
A = np.where(Omega, A_Omega, A_Omega_C)
Q = scp.linalg.solve(
a=tau * Ir + mu * (L.T @ L),
b=L.T @ (mu * M + Y),
)

L = scp.linalg.solve(
a=tau * Ir + mu * (Q @ Q.T),
b=Q @ (mu * M.T + Y.T),
).T

Y += mu * (M - L @ Q)
if norm == "L1":
for i_period, _ in enumerate(list_periods):
eta = list_etas[i_period]
R[i_period] = rpca_utils.soft_thresholding(
R[i_period] / mu, eta / mu
) as pbar:
for _ in range(max_iterations):
M_temp = M.copy()
A_temp = A.copy()
L_temp = L.copy()
Q_temp = Q.copy()
if norm == "L1":
R_temp = R.copy()
sums = np.zeros((n_rows, n_cols))
for i_period, _ in enumerate(list_periods):
sums += mu * R[i_period] - list_H[i_period] @ Y

M = spsolve(
(1 + mu) * In + HtH,
D - A + mu * L @ Q - Y + sums,
)
else:
M = spsolve(
(1 + mu) * In + 2 * HtH,
D - A + mu * L @ Q - Y,
)
M = M.reshape(D.shape)

A_Omega = rpca_utils.soft_thresholding(D - M, lam)
A_Omega_C = D - M
A = np.where(Omega, A_Omega, A_Omega_C)
Q = scp.linalg.solve(
a=tau * Ir + mu * (L.T @ L),
b=L.T @ (mu * M + Y),
)

mu = min(mu * rho, mu_bar)

Mc = np.linalg.norm(M - M_temp, np.inf)
Ac = np.linalg.norm(A - A_temp, np.inf)
Lc = np.linalg.norm(L - L_temp, np.inf)
Qc = np.linalg.norm(Q - Q_temp, np.inf)
error_max = max([Mc, Ac, Lc, Qc]) # type: ignore # noqa
if norm == "L1":
for i_period, _ in enumerate(list_periods):
Rc = np.linalg.norm(R[i_period] - R_temp[i_period], np.inf)
error_max = max(error_max, Rc) # type: ignore # noqa

if error_max < tolerance:
break
L = scp.linalg.solve(
a=tau * Ir + mu * (Q @ Q.T),
b=Q @ (mu * M.T + Y.T),
).T

Y += mu * (M - L @ Q)
if norm == "L1":
for i_period, _ in enumerate(list_periods):
eta = list_etas[i_period]
R[i_period] = rpca_utils.soft_thresholding(
R[i_period] / mu, eta / mu
)

mu = min(mu * rho, mu_bar)

Mc = np.linalg.norm(M - M_temp, np.inf)
Ac = np.linalg.norm(A - A_temp, np.inf)
Lc = np.linalg.norm(L - L_temp, np.inf)
Qc = np.linalg.norm(Q - Q_temp, np.inf)
error_max = max([Mc, Ac, Lc, Qc]) # type: ignore # noqa
if norm == "L1":
for i_period, _ in enumerate(list_periods):
Rc = np.linalg.norm(
R[i_period] - R_temp[i_period], np.inf
)
error_max = max(error_max, Rc) # type: ignore # noqa

if error_max < tolerance:
break
pbar.set_postfix(error=f"{error_max.item():.4f}")
pbar.update(1)

M = L @ Q

Expand Down
2 changes: 0 additions & 2 deletions tests/benchmark/test_comparator.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,6 @@ def generator_holes_mock(mocker: MockerFixture) -> _HoleGenerator:
def comparator(generator_holes_mock: _HoleGenerator) -> Comparator:
return Comparator(
dict_models={},
selected_columns=["A", "B"],
generator_holes=generator_holes_mock,
metrics=["mae", "mse"],
)
Expand Down Expand Up @@ -439,7 +438,6 @@ def test_compare_reproducibility():
)
comparator = Comparator(
dict_models=dict_models,
selected_columns=df_data.columns,
generator_holes=generator_holes,
metrics=["mae", "mse"],
)
Expand Down
14 changes: 9 additions & 5 deletions tests/benchmark/test_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,12 +171,16 @@ def test_kl_divergence_gaussian(

@pytest.mark.parametrize("df1", [df_incomplete])
@pytest.mark.parametrize("df2", [df_imputed])
def test_frechet_distance_base(df1: pd.DataFrame, df2: pd.DataFrame) -> None:
result = metrics.frechet_distance_base(df1, df1)
@pytest.mark.parametrize("df_mask", [df_mask])
def test_frechet_distance_base(
df1: pd.DataFrame, df2: pd.DataFrame, df_mask: pd.DataFrame
) -> None:
result = metrics.frechet_distance_base(df1, df1, df_mask)
np.testing.assert_allclose(result, 0, atol=1e-3)

result = metrics.frechet_distance_base(df1, df2)
np.testing.assert_allclose(result, 0.134, atol=1e-3)
result = metrics.frechet_distance_base(df1, df2, df_mask)
assert np.all(0 < result)
assert np.all(result < 1)


@pytest.mark.parametrize("df1", [df_incomplete])
Expand Down Expand Up @@ -360,7 +364,7 @@ def test_exception_raise_different_shapes(
df1, df2, df_mask
)
with pytest.raises(Exception):
metrics.frechet_distance_base(df1, df2)
metrics.frechet_distance_base(df1, df2, df_mask)


@pytest.mark.parametrize("df1", [df_incomplete_cat])
Expand Down
Loading