From bf2dc22c63ce758691ab8ea041d89e25631f0f00 Mon Sep 17 00:00:00 2001 From: Yangfeng Ji Date: Thu, 26 Jan 2023 22:26:18 -0500 Subject: [PATCH] update --- .gitignore | 2 + LICENSE.rst | 21 ++++++ README.md | 1 + pyproject.toml | 22 ++++++ src/valda/__init__.py | 0 src/valda/beta_shapley.py | 69 +++++++++++++++++ src/valda/cs_shapley.py | 124 +++++++++++++++++++++++++++++++ src/valda/eval.py | 50 +++++++++++++ src/valda/inf_func.py | 124 +++++++++++++++++++++++++++++++ src/valda/loo.py | 40 ++++++++++ src/valda/metrics.py | 33 ++++++++ src/valda/params.py | 43 +++++++++++ src/valda/pyclassifier.py | 109 +++++++++++++++++++++++++++ src/valda/tmc_shapley.py | 43 +++++++++++ src/valda/util.py | 128 ++++++++++++++++++++++++++++++++ src/valda/valuation.py | 114 ++++++++++++++++++++++++++++ tests/__init__.py | 0 tests/data/diabetes.pkl | Bin 0 -> 51199 bytes tests/example/__init__.py | 0 tests/example/simple.py | 29 ++++++++ tests/example/simple_pytorch.py | 42 +++++++++++ 21 files changed, 994 insertions(+) create mode 100644 .gitignore create mode 100644 LICENSE.rst create mode 100644 pyproject.toml create mode 100644 src/valda/__init__.py create mode 100644 src/valda/beta_shapley.py create mode 100644 src/valda/cs_shapley.py create mode 100644 src/valda/eval.py create mode 100644 src/valda/inf_func.py create mode 100644 src/valda/loo.py create mode 100644 src/valda/metrics.py create mode 100644 src/valda/params.py create mode 100644 src/valda/pyclassifier.py create mode 100644 src/valda/tmc_shapley.py create mode 100644 src/valda/util.py create mode 100644 src/valda/valuation.py create mode 100644 tests/__init__.py create mode 100644 tests/data/diabetes.pkl create mode 100644 tests/example/__init__.py create mode 100644 tests/example/simple.py create mode 100644 tests/example/simple_pytorch.py diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..ee88093 --- /dev/null +++ b/.gitignore @@ -0,0 +1,2 @@ +dist/ +*egg-info* diff --git a/LICENSE.rst b/LICENSE.rst new file mode 100644 index 0000000..eceb6af --- /dev/null +++ b/LICENSE.rst @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2023 UVa ILP + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/README.md b/README.md index 3badb39..59bcf62 100644 --- a/README.md +++ b/README.md @@ -1,2 +1,3 @@ # valda + A Python Data Valuation Package diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..0a40859 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,22 @@ +[build-system] +requires = ["setuptools>=61.0"] +build-backend = "setuptools.build_meta" + +[project] +name = "valda" +version = "0.1.5" +authors = [ + { name="Yangfeng Ji", email="yangfeng@virginia.edu" }, +] +description = "A Data Valuation Package for Machine Learning" +readme = "README.md" +requires-python = ">=3.6" +classifiers = [ + "Programming Language :: Python :: 3", + "License :: OSI Approved :: MIT License", + "Operating System :: OS Independent", +] + +[project.urls] +"Homepage" = "https://uvanlp.org/valda" +"Bug Tracker" = "https://github.com/uvanlp/valda/issues" diff --git a/src/valda/__init__.py b/src/valda/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/valda/beta_shapley.py b/src/valda/beta_shapley.py new file mode 100644 index 0000000..a268ccd --- /dev/null +++ b/src/valda/beta_shapley.py @@ -0,0 +1,69 @@ +## beta_shapley.py +## Implementation of Beta Shapley + +import numpy as np +from random import shuffle, seed, randint, sample, choice +from tqdm import tqdm +from sklearn.metrics import accuracy_score + + +## Local module +from .util import * + + +def beta_shapley(trnX, trnY, devX, devY, clf, alpha=1.0, + beta=1.0, rho=1.0005, K=10, T=10): + """ + alpha, beta - parameters for Beta distribution + rho - GR statistic threshold + K - number of Markov chains + T - upper bound of iterations + """ + N = trnX.shape[0] + Idx = list(range(N)) # Indices + val, t = np.zeros((N, K, T+1)), 0 + rho_hat = 2*rho + # val_N_list = [] + + # Data information + N = len(trnY) + # Computation + # while np.any(rho_hat >= rho): + for t in tqdm(range(1, T+1)): + # print("Iteration: {}".format(t)) + for j in range(N): + for k in range(K): + Idx = list(range(N)) + Idx.remove(j) # remove j + s = randint(1, N-1) + sub_Idx = sample(Idx, s) + acc_ex, acc_in = None, None + # ========================= + trnX_ex, trnY_ex = trnX[sub_Idx, :], trnY[sub_Idx] + try: + clf.fit(trnX_ex, trnY_ex) + acc_ex = accuracy_score(devY, clf.predict(devX)) + except ValueError: + acc_ex = accuracy_score(devY, [trnY_ex[0]]*len(devY)) + # ========================= + sub_Idx.append(j) # Add example j back for training + trnX_in, trnY_in = trnX[sub_Idx, :], trnY[sub_Idx] + try: + clf.fit(trnX_in, trnY_in) + acc_in = accuracy_score(devY, clf.predict(devX)) + except ValueError: + acc_in = accuracy_score(devY, [trnY_in[0]]*len(devY)) + # Update the value + val[j,k,t] = ((t-1)*val[j,k,t-1])/t + (weight(j+1, N, alpha, beta)/t)*(acc_in - acc_ex) + # Update the Gelman-Rubin statistic rho_hat + if t > 3: + rho_hat = gr_statistic(val, t) # A temp solution for stopping + # print("rho_hat = {}".format(rho_hat[:5])) + if np.all(rho_hat < rho): + # terminate the outer loop earlier + break + # average all the sample values + # val_mean = val[:,:,1:t+1].mean(axis=2).mean(axis=1) # N + val_last = val[:,:,t].mean(axis=1) + # print(val_last) + return val_last diff --git a/src/valda/cs_shapley.py b/src/valda/cs_shapley.py new file mode 100644 index 0000000..64da738 --- /dev/null +++ b/src/valda/cs_shapley.py @@ -0,0 +1,124 @@ +import numpy as np +from random import shuffle, seed, randint, sample, choice +from tqdm import tqdm +from sklearn.metrics import accuracy_score + + +def class_conditional_sampling(Y, label_set): + Idx_nonlabel = [] + for label in label_set: + label_indices = list(np.where(Y == label)[0]) + s = randint(1, len(label_indices)) + Idx_nonlabel += sample(label_indices, s) + shuffle(Idx_nonlabel) # shuffle the sampled indices + # print('len(Idx_nonlabel) = {}'.format(len(Idx_nonlabel))) + return Idx_nonlabel + + +def cs_shapley(trnX, trnY, devX, devY, label, clf, T=200, + epsilon=1e-4, normalized_score=True, resample=1): + ''' + normalized_score - whether normalizing the Shaple values within the class + + resample - the number of resampling when estimating the values with one + specific permutation. Technically, larger values lead to better + results, but in practice, the difference may not be significant + ''' + # Select data based on the class label + orig_indices = np.array(list(range(trnX.shape[0])))[trnY == label] + print("The number of training data with label {} is {}".format(label, len(orig_indices))) + trnX_label = trnX[trnY == label] + trnY_label = trnY[trnY == label] + trnX_nonlabel = trnX[trnY != label] + trnY_nonlabel = trnY[trnY != label] + devX_label = devX[devY == label] + devY_label = devY[devY == label] + devX_nonlabel = devX[devY != label] + devY_nonlabel = devY[devY != label] + N_nonlabel = trnX_nonlabel.shape[0] + nonlabel_set = list(set(trnY_nonlabel)) + print("Labels on the other side: {}".format(nonlabel_set)) + + # Create indices and shuffle them + N = trnX_label.shape[0] + Idx = list(range(N)) + # Shapley values, number of permutations, total number of iterations + val, k = np.zeros((N)), 0 + for t in tqdm(range(1, T+1)): + # print("t = {}".format(t)) + # Shuffle the data + shuffle(Idx) + # For each permutation, resample I times from the other classes + for i in range(resample): + k += 1 + # value container for iteration i + val_i = np.zeros((N+1)) + val_i_non = np.zeros((N+1)) + + # -------------------- + # Sample a subset of training data from other labels for each i + if len(nonlabel_set) == 1: + s = randint(1, N_nonlabel) + # print('s = {}'.format(s)) + Idx_nonlabel = sample(list(range(N_nonlabel)), s) + else: + Idx_nonlabel = class_conditional_sampling(trnY_nonlabel, nonlabel_set) + trnX_nonlabel_i = trnX_nonlabel[Idx_nonlabel, :] + trnY_nonlabel_i = trnY_nonlabel[Idx_nonlabel] + + # -------------------- + # With no data from the target class and the sampled data from other classes + val_i[0] = 0.0 + try: + clf.fit(trnX_nonlabel_i, trnY_nonlabel_i) + val_i_non[0] = accuracy_score(devY_nonlabel, clf.predict(devX_nonlabel), normalize=False)/len(devY) + except ValueError: + # In the sampled trnY_nonlabel_i, there is only one class + # print("One class in the training set") + val_i_non[0] = accuracy_score(devY_nonlabel, [trnY_nonlabel_i[0]]*len(devY_nonlabel), + normalize=False)/len(devY) + + # --------------------- + # With all data from the target class and the sampled data from other classes + tempX = np.concatenate((trnX_nonlabel_i, trnX_label)) + tempY = np.concatenate((trnY_nonlabel_i, trnY_label)) + clf.fit(tempX, tempY) + val_i[N] = accuracy_score(devY_label, clf.predict(devX_label), normalize=False)/len(devY) + val_i_non[N] = accuracy_score(devY_nonlabel, clf.predict(devX_nonlabel), normalize=False)/len(devY) + + # -------------------- + # + for j in range(1,N+1): + if abs(val_i[N] - val_i[j-1]) < epsilon: + val_i[j] = val_i[j-1] + else: + # Extract the first $j$ data points + trnX_j = trnX_label[Idx[:j],:] + trnY_j = trnY_label[Idx[:j]] + try: + # --------------------------------- + tempX = np.concatenate((trnX_nonlabel_i, trnX_j)) + tempY = np.concatenate((trnY_nonlabel_i, trnY_j)) + clf.fit(tempX, tempY) + val_i[j] = accuracy_score(devY_label, clf.predict(devX_label), normalize=False)/len(devY) + val_i_non[j] = accuracy_score(devY_nonlabel, clf.predict(devX_nonlabel), normalize=False)/len(devY) + except ValueError: # This should never happen in this algorithm + print("Only one class in the dataset") + # print(tempY) + return (None, None, None) + # ========================================== + # New implementation + wvalues = np.exp(val_i_non) * val_i + # print("wvalues = {}".format(wvalues)) + diff = wvalues[1:] - wvalues[:N] + val[Idx] = ((1.0*(k-1)/k))*val[Idx] + (1.0/k)*(diff) + + + # Whether normalize the scores within the class + if normalized_score: + val = val/val.sum() + clf.fit(trnX, trnY) + score = accuracy_score(devY_label, clf.predict(devX_label), normalize=False)/len(devY) + print("score = {}".format(score)) + val = val * score + return val, orig_indices diff --git a/src/valda/eval.py b/src/valda/eval.py new file mode 100644 index 0000000..92f5cca --- /dev/null +++ b/src/valda/eval.py @@ -0,0 +1,50 @@ +## data_removal.py +## Evaluate the performance of data valuation by removing one data point +## at a time from the training set + +from sklearn.metrics import accuracy_score, auc +from sklearn.linear_model import LogisticRegression as LR + +import operator + +def data_removal(vals, trnX, trnY, tstX, tstY, clf=None, + remove_high_value=True): + ''' + trnX, trnY - training examples + tstX, tstY - test examples + vals - a Python dict that contains data indices and values + clf - the classifier that will be used for evaluation + ''' + # Create data indices for data removal + N = trnX.shape[0] + Idx_keep = [True]*N + + if clf is None: + clf = LR(solver="liblinear", max_iter=500, random_state=0) + # Sorted the data indices with a descreasing order + sorted_dct = sorted(vals.items(), key=operator.itemgetter(1), reverse=True) + # Accuracy list + accs = [] + if remove_high_value: + lst = range(N) + else: + lst = range(N-1, -1, -1) + # Compute + clf.fit(trnX, trnY) + acc = accuracy_score(clf.predict(tstX), tstY) + accs.append(acc) + for k in lst: + # print(k) + Idx_keep[sorted_dct[k][0]] = False + trnX_k = trnX[Idx_keep, :] + trnY_k = trnY[Idx_keep] + try: + clf.fit(trnX_k, trnY_k) + # print('trnX_k.shape = {}'.format(trnX_k.shape)) + acc = accuracy_score(clf.predict(tstX), tstY) + # print('acc = {}'.format(acc)) + accs.append(acc) + except ValueError: + # print("Training with data from a single class") + accs.append(0.0) + return accs diff --git a/src/valda/inf_func.py b/src/valda/inf_func.py new file mode 100644 index 0000000..8e5c409 --- /dev/null +++ b/src/valda/inf_func.py @@ -0,0 +1,124 @@ +import numpy as np +from tqdm import tqdm + + +## Local module +from .pyclassifier import * + +""" +Important Params: +ys: scalar to be differentiated +params: list of vectors (torch.tensors) w.r.t. each of which the hessian is computed +vs: the list of vectors each of which is to be multiplied to the hessian w.r.t. each parameter +params2: another list of params for second `grad` call in case the second derivation is w.r.t. a different set of parameters +""" + + +def hessian_vector_product(ys, params, vs, params2=None): + grads1 = grad(ys, params, create_graph=True) + if params2 is not None: + params = params2 + + grads2 = grad(grads1, params, grad_outputs=vs) + return grads2 + + +# Each output in the list is obtained by differentiating `ys` w.r.t. only a single parameter. +# Returns: a list of hessians of `ys` w.r.t. each parameter in `params`, i.e. differentiate `ys` twice w.r.t. each parameter. One hessian per param. +""" +Important Params: +ys: scalar that is to be differentiated +params: list of torch.tensors, hessian is computed for each +""" + + +def hessians(ys, params): + jacobians = grad(ys, params, create_graph=True) + + outputs = [] # container for hessians + for j, param in zip(jacobians, params): + hess = [] + j_flat = j.flatten() + for i in range(len(j_flat)): + grad_outputs = torch.zeros_like(j_flat) + grad_outputs[i] = 1 + grad2 = grad(j_flat, param, grad_outputs=grad_outputs, retain_graph=True)[0] + hess.append(grad2) + outputs.append(torch.stack(hess).reshape(j.shape + param.shape)) + return outputs + + +# Compute product of inverse hessian of empirical risk and given vector 'v', computed numerically using LiSSA. +# Return a list of inverse-hvps, computed for each param. + +''' +Important params: +vs: list of vectors in the inverse-hvp, one per parameter +batch_size: size of minibatch sample at each iteration +scale: the factor to scale down loss (to keep hessian <= I) +damping: lambda added to guarantee hessian be p.d. +num_repeats: hyperparameter 'r' in in the paper (to reduce variance) +recursion_depth: number of iterations for LiSSA algorithm +''' + + +def get_inverse_hvp_lissa(model, criterion, dataset, vs, + batch_size, + scale=1, + damping=0.1, + num_repeats=1, + verbose=False): + assert criterion is not None, "ERROR: Criterion cannot be None." + assert batch_size <= len(dataset), "ERROR: Minibatch size for LiSSA should be less than dataset size" + # assert len(dataset) % batch_size == 0, "ERROR: Dataset size for LiSSA should be a multiple of minibatch size" + assert isinstance(dataset, Dataset), "ERROR: `dataset` must be PyTorch Dataset" + + params = [param for param in model.parameters() if param.requires_grad] + data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True) + inverse_hvp = None + + for rep in range(num_repeats): + cur_estimate = vs + for batch in iter(data_loader): + batch_inputs, batch_targets = batch + batch_out = model(batch_inputs) + + loss = criterion(batch_out, batch_targets) / batch_size + + hvp = hessian_vector_product(loss, params, vs=cur_estimate) + cur_estimate = [v + (1 - damping) * ce - hv / scale \ + for (v, ce, hv) in zip(vs, cur_estimate, hvp)] + + inverse_hvp = [hv1 + hv2 / scale for (hv1, hv2) in zip(inverse_hvp, cur_estimate)] \ + if inverse_hvp is not None \ + else [hv2 / scale for hv2 in cur_estimate] + + # avg. over repetitions + inverse_hvp = [item / num_repeats for item in inverse_hvp] + return inverse_hvp + + + +def inf_func(trnX, trnY, devX, devY, clf, + epochs=5, trn_batch_size=1, dev_batch_size=16): + + if epochs > 0: + print("Training models for IF with {} iterations ...".format(epochs)) + clf.fit(trnX, trnY, epochs=epochs, batch_size=trn_batch_size) + + train_grads = clf.grad(trnX, trnY, batch_size=1) + test_grads = clf.grad(devX, devY, batch_size=dev_batch_size) + + infs = [] + for test_grad in test_grads: + inf_up_loss = [] + for train_grad in train_grads: + inf = 0 + for train_grad_p, test_grad_p in zip(train_grad, test_grad): + inf += torch.sum(train_grad_p * test_grad_p) + inf_up_loss.append(inf.cpu().detach().numpy()) + infs.append(inf_up_loss) + + vals = list(np.sum(infs, axis=0)) + return vals + diff --git a/src/valda/loo.py b/src/valda/loo.py new file mode 100644 index 0000000..d5961e8 --- /dev/null +++ b/src/valda/loo.py @@ -0,0 +1,40 @@ +## loo.py +## The implementation of Leave-one-out for data valuation + +import numpy as np +from tqdm import tqdm +from sklearn.metrics import accuracy_score + + +def loo(trnX, trnY, devX, devY, clf): + ''' + trnX, trnY - inputs/outputs of training examples + + ''' + N = trnX.shape[0] + val, t = np.zeros((N)), 0 + for i in tqdm(range(N)): + # Shuffle the data + # print(trnX.shape) + acc_in, acc_ex = None, None + Idx = list(range(N)) + # Include the data point i + try: + clf.fit(trnX, trnY) + acc_in = accuracy_score(devY, clf.predict(devX)) + except ValueError: + # Training set only has a single calss + acc_in = accuracy_score(devY, [trnY[0]]*len(devY)) + # Exclude the data point i + Idx.remove(i) + tempX, tempY = trnX[Idx, :], trnY[Idx] + # print(tempX.shape) + try: + clf.fit(tempX, tempY) + acc_ex = accuracy_score(devY, clf.predict(devX)) + except ValueError: + acc_ex = accuracy_score(devY, [trnY[0]]*len(devY)) + # print("acc_in = {}, acc_ex = {}".format(acc_in, acc_ex)) + val[i] = acc_in - acc_ex + # print('val = {}'.format(val)) + return val diff --git a/src/valda/metrics.py b/src/valda/metrics.py new file mode 100644 index 0000000..3904997 --- /dev/null +++ b/src/valda/metrics.py @@ -0,0 +1,33 @@ +## metrics.py +## Date: 01/26/2023 +## Evaluation metrics used for data valuation + +import numpy as np +import copy + +def weighted_acc_drop(accs): + ''' Weighted accuracy drop, please refer to (Schoch et al., 2022) + for definition + ''' + # accs = copy.copy(accs) + accs.append(0.) + accs = np.array(accs) + diff = accs[:-1] - accs[1:] + c_sum = np.cumsum(diff) + weights = np.array(list(range(1, diff.shape[0]+1))) + weights = 1.0/weights + score = weights * c_sum + return score.sum() + + +def pr_curve(target_list, ranked_list): + ''' Compute P/R for two given lists + ''' + p, r = [], [] + for idx in range(3, len(ranked_list)+1): + partial_list = ranked_list[:idx] + union = list(set(target_list) & set(partial_list)) + r.append(1.0*len(union)/len(target_list)) + p.append(1.0*len(union)/len(partial_list)) + score = auc(r, p) + return (p, r, score) diff --git a/src/valda/params.py b/src/valda/params.py new file mode 100644 index 0000000..c9573c1 --- /dev/null +++ b/src/valda/params.py @@ -0,0 +1,43 @@ +## params.py +## Date: 01/23/2022 +## Initialize default parameters + + +class Parameters(object): + def __init__(self): + self.params = { + # For TMC Shapley + 'tmc_iter':500, + 'tmc_thresh':0.001, + # For CS Shapley + 'cs_iter':500, + 'cs_thresh':0.001, + # For Beta Shapley + 'beta_iter':50, + 'alpha':1.0, + 'beta':16.0, + 'rho':1.0005, + 'beta_chain':10, + # For Influence Function + 'if_iter':30, + 'trn_batch_size':16, + 'dev_batch_size':16 + } + + def update(self, new_params): + for (key, val) in new_params.items(): + try: + self.params[key] = val + except KeyError: + print("Undefined key {} with value {}".format(key)) + # return self.params + + + def get_values(self): + return self.params + + + def print_values(self): + print("The current hyper-parameter setting:") + for (key, val) in self.params.items(): + print("\t{} : {}".format(key, val)) diff --git a/src/valda/pyclassifier.py b/src/valda/pyclassifier.py new file mode 100644 index 0000000..05d6cf7 --- /dev/null +++ b/src/valda/pyclassifier.py @@ -0,0 +1,109 @@ +## classifier.py +## Date: 01/18/2023 +## A general framework of defining a classifier class + + +import torch +import torch.nn as nn +from torch.autograd import Variable, grad +from torch.utils.data import Dataset, DataLoader +import numpy as np +from tqdm import tqdm + + +class Data(Dataset): + def __init__(self, X, y=None): + self.X = torch.from_numpy(X.astype(np.float32)) + if y is not None: + self.y = torch.from_numpy(y).type(torch.LongTensor) + else: + self.y = None + self.len = self.X.shape[0] + self.dim = self.X.shape[1] + + def __getitem__(self, index): + return self.X[index], self.y[index] + + def __len__(self): + return self.len + + def __dim__(self): + return self.dim + + +class PytorchClassifier(object): + def __init__(self, model, optim=None, loss=None): + ''' + model - + optim - + loss - + ''' + self.model = model + self.params = list(self.model.parameters()) + # optimizer + if optim is None: + self.optim = torch.optim.Adam(model.parameters()) + else: + self.optim = optim + # loss function + if loss is None: + self.loss = nn.CrossEntropyLoss() + else: + self.loss = loss + + + def fit(self, X, y, epochs=10, batch_size=1): + ''' + X, y - training examples + ''' + loader = DataLoader(Data(X,y), batch_size=batch_size, shuffle=True, + num_workers=0) + for epoch in tqdm(range(epochs)): + for (inputs, labels) in loader: + self.optim.zero_grad() + outputs = self.model(inputs) + batch_loss = self.loss(outputs, labels) + batch_loss.backward() + self.optim.step() + print("Done training") + + + def predict(self, X): + data = Data(X) + loader = DataLoader(data, batch_size=batch_size, shuffle=False, + num_workers=0) + pred_labels = [] + with torch.no_grad(): + for data in loader: + inputs, _ = data + outputs = self.forward(inputs) + _, predicted = torch.max(outputs.data, 1) + pred_labels += predicted + return pred_labels + + + def online_train(self, x, y): + ''' On line training + ''' + raise NotImplementedError("Not implemented") + + + def grad(self, X, y, batch_size=1): + ''' Compute the gradient of the parameter wrt (X, y) + ''' + grads = [] + loader = DataLoader(Data(X,y), batch_size=batch_size, shuffle=False, + num_workers=0) + for (inputs, labels) in loader: + outputs = self.model(inputs) + batch_loss = self.loss(outputs, labels) + batch_grads = grad(batch_loss, self.params) + grads += batch_grads # are the data structures consistent? + return grads + + + + def get_parameters(self): + ''' Get the model parameters + ''' + return self.params diff --git a/src/valda/tmc_shapley.py b/src/valda/tmc_shapley.py new file mode 100644 index 0000000..f3bf864 --- /dev/null +++ b/src/valda/tmc_shapley.py @@ -0,0 +1,43 @@ +import numpy as np +from random import shuffle, seed, randint, sample, choice +from tqdm import tqdm +from sklearn.metrics import accuracy_score + + + +def tmc_shapley(trnX, trnY, devX, devY, clf, T=20, epsilon=0.001): + N = trnX.shape[0] + Idx = list(range(N)) # Indices + val, t = np.zeros((N)), 0 + # Start calculation + val_T = np.zeros((T, N)) + for t in tqdm(range(1, T+1)): + # Shuffle the data + shuffle(Idx) + val_t = np.zeros((N+1)) + # pre-computed values (with all training data/without training data) + try: + clf.fit(trnX, trnY) + val_t[N] = accuracy_score(devY, clf.predict(devX)) + except ValueError: + # Training set only has a single calss + val_t[N] = accuracy_score(devY, [trnY[0]]*len(devY)) + # + for j in range(1,N+1): + if abs(val_t[N] - val_t[j-1]) < epsilon: + val_t[j] = val_t[j-1] + else: + # Extract the first $j$ data points + trnX_j = trnX[Idx[:j],:] + # print("trnX_j.shape = {}".format(trnX_j.shape)) + trnY_j = trnY[Idx[:j]] + try: + clf.fit(trnX_j, trnY_j) + val_t[j] = accuracy_score(devY, clf.predict(devX)) + except ValueError: + # the majority vote + val_t[j] = accuracy_score(devY, [trnY_j[0]]*len(devY)) + # Update the data shapley values + val[Idx] = ((1.0*(t-1)/t))*val[Idx] + (1.0/t)*(val_t[1:] - val_t[:N]) + val_T[t-1,:] = (np.array(val_t)[1:])[Idx] + return val diff --git a/src/valda/util.py b/src/valda/util.py new file mode 100644 index 0000000..36fb8cd --- /dev/null +++ b/src/valda/util.py @@ -0,0 +1,128 @@ +import numpy as np +from math import factorial +from sklearn.metrics import accuracy_score, auc +import copy + + +def weight(j, n, alpha=1.0, beta=1.0): + log_1, log_2, log_3 = 0.0, 0.0, 0.0 + for k in range(1, j): + log_1 += np.log(beta + k - 1) + for k in range(1, n-j+1): + log_2 += np.log(alpha + k - 1) + for k in range(1, n): + log_3 += np.log(alpha + beta + k - 1) + log_total = np.log(n) + log_1 + log_2 - log_3 + # print("n = {}, j = {}".format(n, j)) + log_comb = None + if n <= 20: + log_comb = np.log(factorial(n-1)) + else: + log_comb = (n-1)*(np.log(n-1) - 1) + if j <= 20: + log_comb -= np.log(factorial(j-1)) + else: + log_comb -= (j-1)*(np.log(j-1) - 1) + if (n-j) <= 20: + log_comb -= np.log(factorial(n-j)) + else: + log_comb -= (n-j)*(np.log(n-j) - 1) + # print("log_total = {}, log_comb = {}".format(log_total, log_comb)) + v = np.exp(log_comb + log_total) + # print("v = {}".format(v)) + return v + + +def gr_statistic(val,t): + v = val[:,:,1:t+1] + sample_var = np.var(v, axis=2, ddof=1) # N x K, along dimension T + mean_sample_var = np.mean(sample_var, axis=1) # N, along dimension K, s^2 in the paper + sample_mean = np.mean(v, axis=2) # N x K, along dimension T + sample_mean_var = np.var(sample_mean, axis=1, ddof=1) # N, along dimension K, B/n in the paper + sigma_hat_2 = ((t-1)*mean_sample_var)/t + sample_mean_var + rho_hat = np.sqrt(sigma_hat_2/(mean_sample_var + 1e-4)) + return rho_hat + + + + + + +def data_removal_figure(neg_lab, pos_lab, trnX, trnY, devX, devY, sorted_dct, clf_label, remove_high_value=True): + # Create data indices for data removal + N = trnX.shape[0] + Idx_keep = [True]*N + # Accuracy list + accs = [] + if remove_high_value: + lst = range(N) + else: + lst = range(N-1, -1, -1) + # Compute + clf = Classifier(clf_label) + clf.fit(trnX, trnY) + dev = zip(devX, devY) + dev = list(dev) + dev_X0 = [] + dev_X1 = [] + dev_Y0 = [] + dev_Y1 = [] + + for i in dev: + if i[1] == pos_lab: + dev_X1.append(i[0]) + dev_Y1.append(i[1]) + elif i[1] == neg_lab: + dev_X0.append(i[0]) + dev_Y0.append(i[1]) + else: + print(i) + + + acc_0 = accuracy_score(dev_Y0, clf.predict(dev_X0), normalize=False)/len(devY) + acc_1 = accuracy_score(dev_Y1, clf.predict(dev_X1), normalize=False)/len(devY) + print(acc_0, acc_1) + + accs_0 = [] + accs_1 = [] + acc = accuracy_score(clf.predict(devX), devY)#/len(devY) + accs.append(acc) + accs_0.append(acc_0) + accs_1.append(acc_1) + vals = [] + labels = [] + points = [] + ks = [] + for k in lst: + # print(k) + Idx_keep = [True] * N + Idx_keep[sorted_dct[k][0]] = False + trnX_k = trnX[Idx_keep, :] + trnY_k = trnY[Idx_keep] + clf = Classifier(clf_label) + try: + clf.fit(trnX_k, trnY_k) + # print('trnX_k.shape = {}'.format(trnX_k.shape)) + labels.append(trnY[k]) + points.append(trnX[k]) + acc = accuracy_score(clf.predict(devX), devY) + acc_0 = accuracy_score(dev_Y0, clf.predict(dev_X0), normalize=False)/len(devY) + acc_1 = accuracy_score(dev_Y1, clf.predict(dev_X1), normalize=False)/len(devY) + # print('acc = {}'.format(acc)) + ks.append(k) + accs.append(acc) + accs_0.append(acc_0) + accs_1.append(acc_1) + vals.append(sorted_dct[k][1]) + except ValueError: + # print("Training with data from a single class") + accs.append(0.0) + return accs, accs_0, accs_1, vals, labels, points, ks + + + + + + + + diff --git a/src/valda/valuation.py b/src/valda/valuation.py new file mode 100644 index 0000000..94e9242 --- /dev/null +++ b/src/valda/valuation.py @@ -0,0 +1,114 @@ +## valuation.py +## Date: 01/18/2023 +## A general framework of defining data valuation class + + +from sklearn.linear_model import LogisticRegression as LR +from sklearn.metrics import accuracy_score + + +## Local module: load data valuation methods +from .loo import loo +from .tmc_shapley import tmc_shapley +from .cs_shapley import cs_shapley +from .beta_shapley import beta_shapley +from .inf_func import inf_func +from .params import Parameters # Import default model parameters + + +class DataValuation(object): + def __init__(self, trnX, trnY, devX=None, devY=None): + ''' + trn_X, trn_Y - Input/output for training, also the examples for + being valued + val_X, val_Y - Input/output for validation, also the examples used + for estimating the values of (trn_X, trn_Y) + ''' + self.trnX, self.trnY = trnX, trnY + if devX is None: + self.valX, self.devY = trnX, trnY + else: + self.devX, self.devY = devX, devY + self.values = {} # A rank list of + self.clf = None # instance of classifier + params = Parameters() + self.params = params.get_values() + + + def estimate(self, clf=None, method='loo', params=None): + ''' + clf - a classifier instance (Logistic regression, by default) + method - the data valuation method (LOO, by default) + params - hyper-parameters for data valuation methods + ''' + self.values = {} + if clf is None: + self.clf = LR(solver="liblinear", max_iter=500, random_state=0) + else: + self.clf = clf + + if params is not None: + print("Overload the model parameters with the user specified ones: {}".format(params)) + self.params = params + + # Call data valuation functions + if method == 'loo': + # Leave-one-out + vals = loo(self.trnX, self.trnY, self.devX, self.devY, self.clf) + for idx in range(len(vals)): + self.values[idx] = vals[idx] + elif method == 'tmc-shapley': + # TMC Data Shapley (TODO: Citation) + # Get the default parameter values + n_iter = self.params['tmc_iter'] + tmc_thresh = self.params['tmc_thresh'] + # + vals = tmc_shapley(self.trnX, self.trnY, self.devX, self.devY, + self.clf, n_iter, tmc_thresh) + for idx in range(len(vals)): + self.values[idx] = vals[idx] + elif method == 'cs-shapley': + # CS Shapley (Schoch et al., 2022) + n_iter = self.params['cs_iter'] + cs_thresh = self.params['cs_thresh'] + labels = list(set(self.trnY)) + for label in labels: + vals, orig_indices = cs_shapley(self.trnX, self.trnY, self.devX, + self.devY, label, + self.clf, n_iter, cs_thresh) + for (idx, val) in zip(list(orig_indices), list(vals)): + self.values[idx] = val + elif method == 'beta-shapley': + # Beta Shapley + n_iter = self.params['beta_iter'] + alpha, beta = self.params['alpha'], self.params['beta'] + rho = self.params['rho'] + n_chain = self.params['beta_chain'] + vals = beta_shapley(self.trnX, self.trnY, self.devX, self.devY, + self.clf, alpha, beta, rho, n_chain, n_iter) + for idx in range(len(vals)): + self.values[idx] = vals[idx] + elif method == 'inf-func': + n_iter = self.params['if_iter'] + trn_batch_size = self.params['trn_batch_size'] + dev_batch_size = self.params['dev_batch_size'] + vals = inf_func(self.trnX, self.trnY, self.devX, self.devY, + clf=self.clf, + epochs=n_iter, + trn_batch_size=trn_batch_size, + dev_batch_size=dev_batch_size) + for idx in range(len(vals)): + self.values[idx] = vals[idx] + else: + raise ValueError("Unrecognized data valuation method: {}".format(method)) + return self.values + + + def get_values(self): + ''' + return the data values + ''' + if self.values is not None: + return self.values + else: + raise ValueError("No values are computed") diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/data/diabetes.pkl b/tests/data/diabetes.pkl new file mode 100644 index 0000000000000000000000000000000000000000..86dccd3c31dfa4c39d71333dbaeaa31c98daa0b4 GIT binary patch literal 51199 zcmai-1zeWP*0(7MyK9Txik%qq*sa^b4pbBaF|fN)u@E~@vBku$+d{zx5djtH?v&C` z4_<@k=J378-`OHO!_2H%@n11(R_pCad7<^6|E+MSZ>MK7Z}vzByK=J^Sj}0|WU}=< zizZeJX4yW>VAeB9V`pk6`EqRWRAt}K{`|0Ic&Dg{kg84c1MSS0|)l~`RjlG z`ARooWfy(sQ zY`yYj6NOXFHt7$WzZd)WY#vam>yU58kLwZaYT-ov-#M`yEns> zxFvJ#5)(VxQ= zrSX`HZRdnZ{^_W1JvQM4c0#-zXVZDrqv3U)nrxF}U;el)k(8?AT* zwWzk>hQHiD^n3UEU;P_nE-6LJh7Nyf_e_C*;h)Z(J`MFc6)kSRSz~w4K1A`@*ui%A zp*V&82e`Go+^KG?Kz??2)a_m;CqQgD_OkXmXJ4_f+1z$#i}{J(3mmly_Y4*3RYDe| zbj=owZ`Nes%dQ>Y3;(pPWnOo^sB}MdVP=n1?bv3t3FQIw8>cQ;Koub91be%76EyLvbJ*^fp+wP~9hVgF(JgWBM zowum(-#24h+E)}1D>gjmpbkQfz=k$zH);6|pvffWigZ}N5_@kcR;snY6 zliTchw5)cpvV2n~vC8?iqHW$Wt#_MPWuyD{ac^G-Oa4QC)%uV9mw7W{#b;-KsXt0a z*BWCP7%#`e|D%kqZ@TC8Sy^9XSM8yu_zKrQY;YzUgJWl!4Q4Jg8Rjla${HtzWsP ze(+UB2k2+dUHeH1(oHrPT{l#@+iQk{!R9RCbn)KuSxusqrenkI>boWjGg&K|pT zVe+%DQHS~^CI1`v&$NuT{vvJ{X~O z+j}vg?WTu|1_p^CTSh2tduEH$i~28Lwb)B^o6w_f#)D+R`60h*{!997{m4{5PC0RE zOO0nclH~l*Z(-4l!?PE-iJ8UAzjv;Ztb8whu%`L*BjV+*K_#o*)zA|An=e)ufHVqJ>>rmE!$Q3avWluXbsyrD6Zb5B6tN z$nE8A-@TIjH~yGoBi+m1iciTKspE=9OZ^T0;QZkq_5N@7IjI}+b)YCD8poMD4Odcf z_Kmq*|EbhpHOp=NdUJWE59%B-O2{_)e? zB>(^WrFg%``!Xf}q5qruJdR8<^A<*+ zbzG7p@xz!)T_!wy@lzmn@|Hd6trxdb@Lu2uKG{_iWF_T{gCI6{)*ZAtz+q+Vj5T1Oa1$%rHQJKPi?4OH$qYC zPvi&wLw_2-ZgqOO^FyUXm|^1)Q)84x$wgb1wvCqiNB>OUmN8(4=SKzn4}JpwjEG<1 zePzjcVRids@G`$pf&YN~js5v#f?oMFxqkKj!Jj7k>Rkz#5uikPd~g_JlPa*kgU>j8 z-ccb_;y3)Ct1JJVkZ2k!2FGcJ`j34h*3~Lr@<*xf0{MY{b-kP7bW#4@(cb&f5U%_UmIN1!EyN&G4)l&u2GK3`RrH2PZu@Ym!&I% zy6yjR(V^%g_5NS42>lY%spKQQ)rmDXd`pz`vwq|U{d+gTVcnM1w*==`q4eq0A+i1< zwEAs>9o=6`{Ez-V;e0_vMh7YrTkP36zOmLLgM0OYGdCuR;;DC=#o8Pe@XyW#8($hm zM=RfU^(vy1;~?=vhq_Dl70F0Y*uQ%HeXGPcMf)Wx_d=eeH+`6-ob~dow$tFA0)Eu; zc-Y}kjX>!?Haz>`bJ_Th;_dRJr~}@CVtJS94IM582t(r=B}>@67tsMu+XLRGN&OH0 zM*h_Ef4QnRv~k-2<;cXBquSIDlKv^@h+!y5(luv$;fA|Og?ZBN8g*Q3hl=^4=ZI9-kDkKZ=JMs^H$9|~S?;m_LMawBc z>8Est|B8~8F7c~+|M>4()34krc|BO3ANFTe@pATY3BUZgKk*Cn7yHNmT=Qo4 z-rgk>MULtE=|8f5NdAF;@n2tAtiIK|ji)kYQnylDLfF?H`fZ2{x|k-&zY4I zTAOQY&K^2eO<#Gdl(sf`RcSyTe&zhYkKiZpPbIrEe(x@LC~I!pX9X-wQj$#;#+{E2 zmiU+XmoGM&J2)|3@=qy^&)Mbr(W3n6wudjeeNw3^Vqe=K!&K7I6+ z0{&L>&)dfBjH?<1D!RkshB)2}RkmAtmfPQ7U*cE%7tRmSD6{~uH>5uPZig@UCvgCNKo)!us_fb z{H4}EjRJ1(%`6_QRP0q}`wW{nf&ORy$qrhbV!9W9#Q2`ptp>VnPLcYf)$PJnpD&0J zC(h_!NfUjIX_yhXcr_`;r zeqTyH{}=fm(C?qN^@c}^FMD=h?4^?+^#}M3{mc2KmEE`1_+Gr2{kmi7kJgc*u;_7Mi`59Br+qlZX455jxIyY{im&6aq zFYycX2mYJ(w2RJuAs-atPvBqVfBm2Mt%cppq30cvMCWFj?)`V^Xw>VsEZyXTRa}a| z|8gm|Y1)Q{8!b*Zu0IZU&;T- zKmIfLZAOz?+uvSzqJSTE8?QTA^x;=YKh}@`iv4B($PfP8)2^CZ7Slfo;^(Y?ig{3x zW*5Bj4_BcmHT4n39nJJ_=Ak|Q+( zm7w}P9j0nu5mPQ@?08l_McFm-aOBMShot^Mf3g482io~Hd6pvM_w$y$PumvySyd3#;s? znf+ZsehyA3-O6#?Go{wbl;G+i$CT*B!y4IcNmquC{LtX!^DpxJ8J}b7Qu#-XAc6it z|GQMmtg!!e5zX@llGJNCHiyX z#9^i8>`4*m@2;75u7$qR){y@iW8=_&fo&np^xK}(ot5Z(_yzp~{lJgtPxe2+Z*-+~ zAz^v=^-un1%hYb=C&x!g{7wFX_={$pTk~$?(?w%Zqhz^SK9c_M|A3yl4QubulK30^ zc4t_N1q=865S)KlL!Xr94YtepHT=K--THc!itEbw6a4eB#8f&Bx227D}fxbLTEVH9C%+r6rXAby7aLjPcYkw5U8 zbLf_bKW8sjk`{GXv8%~f$-l&Zp+EjR=TH1#V!66?w;6jV@L#2|QL`>=3KRHGj9)XQ zUX9@S&m{h1|4A*iY}=TfQr>o~QFY(L2xa%gXBPXCVx|8DeyG>wM76|a-(~(+9sehO zbt9--V&knr3j4=?V*lAc^n?D&pZI_0CKFpMxey}tAN)i94*5lYGX6ySqwV@Sddm1I z_!a$STz*uISp&2+ZqxND=_DlO!(ZU<$4*+E&fX7__?`K||D3-Xe~`bvld*p6@tX;f z|FFNr&#^z?pUX{*whgbEC3;5<)~R3}s0_Cm;o#h*h=%y1dj2-TdgZHzuN3wC{xi!g zc(5`>X|mtz+-|Rj`RLbvr`5vQ8HqA~g#Ckm@L$T7yHqjCAVAu0hhLQd|0zirHk!4l=@SX@h7L= zCv@)O+ZA&}di!?ud(2D~_MaYl25N;$p9W07$^3YWV?T?y8D5f>eitP3t(x##liK!A(Bc^>;>p;mc z^CFt&`{p$=H})6k1af-V1bh-=!sMiOIaYXYZhA@Tv|;_KUQM$OmY9zHGoF|ZI_-Ux zJO3t_`6$N%X1{Km7Ao-R@yS_#n@scJ9vMkerl5b~-N@DN8wD!FbX~S|8f@OpPsVhh zAD9@QLM_w8q^4T`2)3HH zr^SEy{DWd&xto5LIepHLm?%DpdOR_;z5R5w>t^QV)bVL&tuw6aQYJy>B#24jQ-$>U z(DwY&M6s&>ure>wv^3as=5IO0tW~SbWEs;$r$c{i%3@=Oh0k^eO8$d>y(`CrNAAy& zHA&EK)kNd*gWYZ@__S8N%02%yJy^!1u}R;HANF}#Gelwv_5QKxdlqc&?o%kQCJg=s zlao^h)8SKBHaK4O{zE@ye2pXK!3P3`TK{>iB!f9mr?f6p5GsbQZD2qje!09^SvqD>lw3dK@kyY+ z`=NBN$DRC?t}{}HR!a^N)bwoG(&NT?x4iRX{rH6NA2mT~n(B4>?0Br`6`54Q{CS4V zNr8!(-?99$s{Lbv6l%ioN#GxHf)8drTitu~hkW*XxMAIGS-{rRpyW6%e953+?n09l(JGZXi6XpEupPVf7tIrRgM*m7o5zm%UGN#OUVnU2p z%P%ph@e2+Ov{;uU@QJWL)P&%ZfS;g$#scR6U;iWppFZ@7?ji3ADY7O=?NedDiHTtU z`#XQWKmBZ8O%?iwoEknA_=TDdd?Ip6DOk#>^E!DB9-dS!PQ+%q9py% zAB|45^cb6!CTl|Qi3|^#dS~o-CGCIl@J3y=uYQm@VSG~f7x{z#pdT^4n!OsWH98+8 zYa-Y`H6`E&^bg~yshJrz__=G=SD6z={_v?ee_~3rwq~xnyCYQ`P8jj#%Eee=pH<{V z?Cda!Kk&)z?hijTV%{S$ZdD%>eUX;WCq@3gWu+Fr-y~e2CKUMrf8!IX?LRT)s57yP z()#D=599~@rq&;GHakT;-nm=!Y5eiz6#ERt_TlJ%8{1!$_Luz=lYxIYKj?@3*gx^b zx2eS<^5GBUpO}D}f5vSyKQiHAs;J@EO^$!-o^4XhT*a~OP9*~0 zc#EY$L)JvarsTsv|u=nwD%{EPl3r{do0y~(XH`(;fG{)dUx9rH%Z(xm^OK0nr9 zYhR7~U0TOT{|Eg`O$quR{&h`j>vC`On2R=+up?;eu^4OE4_F#_g}Tp zJzX#MOcYjc_VteS43_f8{=r|+AN!5|i2q_cyZNN=BE;}gPP=0-^W6zJ)%DiIjnM%L zKRHSCzn1o)qtgeOauWTak+k@QFYsA*yS*uNo04JP`AJX7$W`o&jSnert>>TlMM|I7S= zHI80xIH-`u;%rRh@ZLd6Gb^V?23|!pVH0k&bhWJ@^#}GF{vm#)rr+SYKCic4{47j{ zg~bhT^Ipn7`j4DG_)k5anwUXlzb$={X(k4nq8FiV>~f&^bhoBe)#{yq&^ed zk4cjDANz;@Q+#64)<)&E^5yjF{JX^FX-RF(D*pk!=bR6hF>T}z`~v@}={LUY1M51S zYUzE&7*R7X0=E#~*{15OG`j4C}_W#fV zyN7402juhrsENUTgCEQ4&uufJcc55b6Wk2IbY%Lx1!S<5@rFw|KBo zqheDdgx`g1v+K5bHCga4@yCMhWRa61CrtbW{b}~$g7M4p-(~y<{YgzOH67H{k<&x| zO4fbaDOB@C-f7Y|3!ZyxOR6|!7;v)HsdPb22mXtUzy67TbEgjc&3HW>t6>cv|Kg`V zH8Io#k&{OLsA(mB0RF-L+t1v(@5G8YQFrZx5$`+3N&aR14iD!Z{1h7|@8p3W%$iN_ zXy<%FaHs6i?I*`8-3XBOi}C1R&JX_u`+@)Q{ZgBgSC%}G_?z>`|0AbFO~i%fGh3Ir z8z^&%<;I35eV!E~$8&y%EOz$3SuF2PfZBdg(^!9B?KY+U&AXF@{v{`0ko~2m^|aO6 z{w@bxCI2J;#IL~5YW>x*iZHhx@Kv1m?-5eBj;Fkn$^NbHl-zmJ<3oY?1^qv!TUntK z=_BLEte=`j;)n1r`YRx5cCU3mzA8uBbSuBUN{rMWoFDj?^Mik``VVfmeExNT|F(L- z7Q6VF#WmCvAb;RL^tW1nk(0eUwMW_53EDFL%lu`=zB=Rm<7a{UAJM>M^v$po#p^=G zmk_H+K}|pWkN*MwK!33RfllW(R*2auuI%k@aBK>naV`499j`#*j3|+-B%Y1E9ED1`p_T!Lrp*W7ymQ6x9!m3V+}N&ZDLB_ zK7CF4Pt1@1O8gP}vw!XN`aKTy3=ls~GmsL6O+d&>ID3w1T{5A>vMu^%MW|=~1-$mD&k{ z_@9r(oP`Uv-H><6;D7KN@}piqHCbJ@jE}xhEALJj^rI#U{jDAkentOMlgav+15A+xF!~g2}iC;XOs4O$K^;F=0&aeNn6$ge_%ogBZ z=%;tzcZR{aA{y6$9hc>pd=y>BkNM~M&ukfgqNWb~!2D|a3IDxbvMSagEI`hW{YCz< z-;AfGYeUmQZ%l-j(&yi=x^0_1lKaoC$@o<%P*6>V@$Y&IS&u;d0e}BpfADwp!vFf3 zDhFy8a;qG2>m2wzKl|^`@?AcaJ3F6YEqQWiY<^GmG z^%OSqZT?3+#Q-fGSx~WBOACeokW!>gk$$6gEqmNE4iY3}x#YtoHxx3Ld;vr#v_U?z z8V>(dR6|0QCJAbWL?IUx5aE*9_6)m6Q618RTA&CBp->3Oa!DA6$f?HqMnRqP0Si>@4d6u=xG2XA#0Wy9KRMQ0Rakuggg$BS|AZ9-Gh$5n?WzN??=@4R&uasW*!t3U7K|-EO3L*EF`^~!htEs||_9eEJ4-|F&Y1BCR zSVF#lFbQD-LR@kI&_pMEJOAoZvSQG1#^(DkBl86$0JJD%90JzQCAEyhPkf$CIwJuA z3V|jY6!|hg3bASFq3zx+P0m*&&UgZ<%uhm}@g$@Oh@ep6e*keo3VluED?(R4OBMtH z2;mKaWzDn(T5g+xu23RElg-i14o(P?=h6cx@@^qsHUCFKKj@IMYU0SPXdasHD|jH?h?@16v7@E=Wr z0Av&iQ)Em+>RQLAPGf&vuB4!a_0wbn{v)BAc3w$$T9YmdD!60~AmaQ8C>0ct1CUIf zySH4l2$g`2`O&`s0vsAHiKBl_>$bZ&ps~FK3ckYDsS z332$JgzN%yP3tRe9|Z{!>?n4QCXBz}N?=wAY&YKNMFJ_>S3=wQG8S(*5Fl<<)OP4)Thuc+TMv}cII zB^k~i`K8Gi^Q-rdLpUb9+{Slvo+(B{Dm%XKqNm~fiez5MI^V@#24vxXnoN+8zqZ2G z$0a5}HpyWBB!r#CIa_P4fIr5T%n*Jl-`Q4v--$j(H zIkAFUw|9B`|EEa^_>ujq$8*UJhaUX}egeN0#Q!v@;*tUU1O3$JM?$QB%MzPgcS#f! zB=;`mFs4l7*Ajm+KLwfCKh7WeQxHp&JT?A9e^OAwcoG`NE_pSa{3}2|@$d_u)d!V) z^MfDpe{kq&azH|9Tza{x&9~=Wa^n1BZ(0o-R69Z@)YS!1;8zNw95&7~I`QlBGWLW0 zgFnF^KeoNuX6WuGz@Nz#uWxRj94h@6=%)_IsLwxVcYKesKf=Z9qPHSOoc5EKB$%Iq zK<3AO5zr(cF7$lko;$`#`_Fh1>fm4X`~)-(9_-iU!auL_*)Q-5_znGy|4Km~31Kz= zG9LLw|F~T#J$rb`I01e=JFU#dgke{t{{sIKP+~mz&Hny2ucT`U@{;I}74xdBbq>sz zP^2IT{Y^lZf-nMdB=o?)6f|Ie2`H103mel-^VIZyjqs_x-ZC@iik&Or!MhGq>}ajIn?4?}r{6M*bU6C|^Po|1Hb% zVy7+#A|-x=|Ck^Ar#}Db_lKrzSQ#iUM9`#?`6-Bi{@_=dBtDJ!_<8ElZ%WFFS@nY# z#!C8Qe<^53ezBiiatA*yzW$-M$JHbm(8B(y$E(**{DApM2*Upa)QCT@ehM0C(v$O~ z=?qOqy2M}T5AXx>gZ^-GeKutA{7npA@95ao-yJ z<7(dd!9Nsq6#$cw6(u9{D+Q4uK!O$zH6M5UVz|T~oIn0I_zV5D+F{(6 zU7@!{#gHXC9Lpz(olCZ_7}6(M`mgXm1<~xE@slh(_8s^&&a3aTwIl1MiPBS2{2hxw zmI*=hH~0bmRi7UPUB&dm>a;jsLPPwP{S$v;{rK-O^=dsZ(0;EFKct|RgeLRX=v!z& zM0TcZQlOrn_|5A<-_Hg4MM?SL{GlK8$N!-qmtMjz->0)r}vLLEf-V$Fdt&$Pn{S8JoV0)Y72;zz>~^bh`GleU?~0 z<^9VIWg=vgU-nNy4)TlrqR9pFUqLIzWpRnTCSi;x{#lUzqe(?ZT+>HpI#EL13MvSo$tD-Nz#r&e;)lpT7eZ+g<~4fCI*+e8;zChZjmA7g#($tc_?_|iZ->Qm z>#8FnWs~CJj#Z8>9+e^}=;i#0U%@{VRB@q-_}R!omAaeU^Og1&{e}G@{(}9a$q`L{ zX%f$cR_F)*gn!ih$Awy&jDY`GKTUoLihn#dUutx`YMLxaph*Dohy3Gzs^=&F$AvcZ z7vqWlz<CK1;xI_UJQ8!Dck@3!En%l&-x8{MkYyViTtWRny1{EN;68>N2C5^Fa0 zt!cBuH=qAgHPHX6=^HH#_7nU>lTG}OqF3TK93L2xZ+_07@thy{gMz+Yj>U@KOAD0v zm;I~#7xHIKUp0SuV#i0tZb0uAmg6&&Vqv!WKL$Dq=+F7lWEA}e{)rno=*u>rvKpf; zmHTWss;6=A?(?$psK@yV(!h`4Px60bo8AncVfsnoLL>19<_G@~|0sxmxDc>#$&MTK z5~BtD2mj)~;(t&OOp}Q3In~Yg-c3_DKl0zmFZhoO_23ui&-{m%*wxmG(AE&Y#eRcd z3-W)tkTK`@z9et=L_w1|=Er_Ao`PfwI$1wWk|}8CLJ#@>g6PNj874RD)p(tk)c@!| z;&+_Cn*S+?wd_^8zu}-rSrCr?!Tw?Y(Vu5`daNoRvR|PfYl)#}siB90W&DZp%n$#o z=O_P3lPVsmfPcBrRM7s3KX^7+zPnHFG|B()A5DVLpWyd~J+kUFY#b_^^naV4{ju~i zEe-Jt_|LNDBAZ?gi3L7VFrZc?v)9|TG(0jwK|A|59h0ozX-T4NvQPa01$}Q$wclE^ z+8br0)qcZJUoYuDuz&KO;6Lye7rOEP_JsABw0vQNjK9IZiil0 z5BgutKlrcb(tn!&9FSL#4*jry;CJa1|HJ>Rp9_6+J?uZ%s_G}}H^4vO z&w}PB|4Ty=F4VJr^f&$&^&`c;uF*Ff@m+XX{G9Wlu9k*^c=R9qhx~!R$)Do?Z#6A^ z@|l&7!2ZyrjrD`y3X(tM7yP8gU&OEAA9ehOM=qT2Kb%$P>`%#m>|dQfM}DC{O*W^F z^}4d8SGc0vYUzdNw%YPS7x)$bM~$C*E{trUc$U&^8vA_YF1xR?ARPTs(KaZu=jFKTTe%cu$MkQ#DNbulR4+FU}wPiTwgUk^kcS7*G7`|Dt~l?-^4p+6PGgf%TLB zA%02yOPw&Afa6uq2X*0>NkazW-<%%}6^P$rzlmSa zWRUnX@<;rf^JD*N{Dl9#zrpq+-madq{*3r3@lW`V_&?{zBPlm;-@mwYRgAP>;4kXO zuwSe{-QeB%OU^L@|L0KA*R!twYAiwgiTTMtuzt1vL4J5-iu33EXj0Aj5r05_Po-%` z6srD6HrZr6{Db~Me_}tVAEC+U`02z%S?@@Hb7$(f{~A*dP2i^p|@7 zG(_l9>F65+moV9+5Bifo$9~bIl!9m)%BbT%%uoE8`H6q@H*g!SDOFm-g|G|uy`Bz# z86vR%@E`Rf1>q^Rl8N_9i|NCN&N_!gkSC>}AwJ7sl0e>+6Aj3V=25(OjntA~` zb_wM)@E`UI{ek{b%m2TttA{&WOA=P|`sjB0b}OI%1%5$)tNEAsV-Kf)f`?A~D*YGm zyW0M;e&RP==%IcN{~h^9f1-cY^K&6MXG^_3O?6*M{e%3#zsSFOJo%e456&#E`smjq zMyf~7XmVT-{}aFOZW!`UK*a<>lX3JP@mKV(dVcIr(~_<`~85BQP#8TeiyWU@Vo1wo1?dl(ykh^cKj?q_2mDv!AJC8dMZ@5lBZAI9mHfy4i9a#F+I|rK z=l%l?f#4r5^dSG7ACEjxe}nx7zhOV&AMCgK{K&r#&-R)#cc7cHb(Bt(G6{Mb^gm5j z$^WRw)8wD}Lmr7@{pb()2mDL^3;Nj<{nYo3e}a@B&X4#j{G+D-*W^v9OGbZJUMyKY z-}P0Pv>)hC^auKf_+eGABDc4h>uaDt^_%$boImy-|DXIV7s6?XvibV!?_<4wJx25A zkw>}zKlwBCC-|F3ay-X-K1`bTtC7~IeUTSCxJdlL{?T8|kN-&imJ4MWflGUxyYN^2 z?4SHM^()NJ{S*8r>OZ(YKtnCoua3W<|F$hT-)*>?r;OiD=~cqBU8}q%^^Avp@Gtr= z%<{v8GCE&nL!J9hIZu3!N67dQ_DB858s|s-75S%PT49Q>Z>(%6!2HNR_Jf9Q$S(~| zwnZD>xZlQA@(<@v`~~?1zjObD`gQ96v45N&7t+D+YWX?yZFto#2eT9!>JdLC{sI3Z zKiDrG3E`19F4V(+tRMQT<%j$q@z2{6?`&E*#6MsBK&}6wKlQKFUpBTnHmk{msC?%~ z{H-AT&xQWDGYN^V(;{X4F#3o39qhL{ezm=QV(ObxKg69gI@v=)Q}aEtf&PO3z)$$^ z#9y_2QhKi`njqtk=x;7WLqFmNulBBWUX&az$e$jovF2*&+Hq3;p})F*5&8Gu;Oc(q zZib+trhCQvv1=E9mh&^7_yzon{L>JI_~pfsPt8l+e=PYI{)hhDf1)8E`ETk+Xflrd zga2r<4t`YgFZny4RT&%KPDz*bKg>`34EYH>dTDR@hLOs_{@efEW)UOnN0-0ZS!GFT znjn6{{P+)?AN~jW7yOR@#`>9`{2S*-{*(AW`PcHrjyJA5HLuA%^l#nt%QKx{_ru74 zV}H1Rlap=W;yC_`GN{kKaXP1q$%ZyuXr=y_{0;Jh{-OSuhOXd${15!c_|$$S7kc>R zyO0h4jNCY6;{%;|ssErK{txy~eSXAGh+nDe-_idC$uIe*x9dtOEtM0p{)hc@AsqXI z{L+vT{6#|w@<-@T;)n1b@mCt6(2!)*p$Zm1%7x4PANY~+=ub8NBLB)Gzvxf+M~&Zz zKV9@_erUR>n?geY{9iTxBK}km{i#3MGpO)&>yha){|^6=|7ZQ=|ES-@|E*BxLYw7< z-^lnU_)mTR3;R()_d;C#_gSLyQQND{#)k>=_xOL%AN@=Hx|)A@bhQ< z)2(sdBUS95`lo{UpZWunxo4{_evl~0A4vZB(~y$*C;VT3$W#3TrV%oKH{-a&VuQ=+ zGXKK)qd)O~)byvJqhjB!)K!BdS^ot8BfscB>W`c{RP16hYJ)(3>NK%_(x!Oc{McXO zZ_rOYKm2DpG9+uv;Q-lC1pA5qToC=Zf3vTw`@=og{AB$N>*xLs`d>Xi`3vGl)c-+$ z=4b!dANa@Wz~)meXU8ee!`F7MQc+98{TIfAAK_p0597hl;3w{%5Wge-3ID=>jHjW< zmuGhdj(z<}K61$Z(ck!=#6Q-SJ8XD#Wj;pJRR+vVi}E5Q+O|$UpNF|0&4-CVxglFB+o3KiCiISJd?9{!w^Xo9dMZ2g~?3^rwFNezmLp z>{@3D@_+a*#1B81wpTPcVX}V!`bTZQi617|XwK@nzfr)S^Ne?8r!4-}Sbbw|{Wa&u z{YlnO{T}?^a^a?Bi-)9$@qMb)`QEUU#7~Szf5ZRKzgNz^3K7O$g8a+NMIl`#pS+M~ z|NqoaPk9{r97HdA{S9DM2#- zMEsff6YHmb75@kS9sig36OUYxe}w+zpRqs0PvBqjm!k`hO`APCS^6*RpZF*8ueKlD zza4w4iq6(Tiu7N>Z`5ytf7SG-q1)p79}?DRyv43XCrTKfGms4}u^;G9>L>7jsh@GK zwDeBwtv53MivAeC=k4m_PeXI|5B`LI$-mOj3jRm_ zh+m<w)M5C2pDf&az&so%%`gWs5+M;fWWSL=Vr z=QT2GY}1mDwBbLJ{~>-od*whw%d2kE{^I{pKU+{kbMn6}x>a%N_Y23U=jV|;{8zR8 z$N%AxQt}_*2Oeo9{{jDcENSdCb9}bIevtnL|3E+R7vsS{?L(Th2r+)45Wn-N-}z0` zj=EBRasI@Q!JnKzkNnZl@zuswQ@eJ0DfKV&6Tj#D7{AV{Q`ThdCxZLeWV@-kxzc8-ynaj=D!tg%d$r`$yBJHBL0E@gZ%@4 zsKcSdz{to|{{9)U( zHH-&53=uRm=Kcrq+tkH7Z`G|8Djzwdet}0K!OtxZoW2=+BtbS*X8pvkuwTdz{Wo}| zaKg6(Q|Haqlllk!5B&?W-_)=4_8Z~o;+|K32mf*ZsG#xG&s<8^`SR8>OveA%KlTUy zMSqaLME-cB6aN+aL;kj)`DYwFF|^^!?_y^I+jDJq`b+jyu`{NLB(Rq8+V6pPxJb=f{tTZ8`R{NX?BH}MnnH~5+O1Mx5HHxFtQWPgc&lRqSX z!g%7p$RGKC>>m$me9lVR8B-=Uum8lK{EeD_)aOV20rIQPe=0R zLjNB05AsL-DEf!^^XaS#&s_JWh>js=!XIb96&{1KYF$eWk@tt-AMhXiPyd6=olV<3 zpOPrJe|Y}U^J>=mpJe=s^J6^g2R{&howNDa$-6hhlq*4+MdRi~xly+p(lu6MapN>H6tEvVO=1 zZPX2&p&#c@{+$Ocs2^kf*njXF`k(R05B-nOUz|VwJMyQdKleZ1eq3?-#rRN(pW#0m zLUaEH{7*x1?qAc;5&MJv$9|yy@ISdfcgM|a-}mELqS%3}4%wX&r2o(Opz-3pU!54) ze+K$>uh8Gu)znW=|4sc7`Um}u{1HDze?dRur{E{}5Bd{-;r<)`8{_Fe)Ouv#8oM=l z{yX{;`$7K-@*m`nf-ZZebzb#Kx!zec+Zhof)cPO%hW>_snVhIM2L;h>@k-kTL7UT%qNkc+^9?;To{{;R~-yc%bkND-qc5Wr>ZpiE3#ra{skRRed z^bc^lcyIZvCeg|wr}7!T*?}_tiv2FAet`NR>{q47CyER&5-0PQ?4S9;pXfjOAL4(e z_G}q8q^XbOfAkOf7yHBbRiFC*o4qDmHl(Hg3Hl*F)ZcS{jHmt({Cp&;d+N2v{<8lQ z^hbXd12iY74`%BhyQ{9sg{50U(0(>ne$vXM&eKQkN<@I#QvcFs6Xo8 z%J0>+23eASIX}j;f8>AOE0+d$>iiTuNI?7m|BwCSf8)QX*H8U^Qd;{K{Wd=nYWXGp z2LGwYWB-%SB+uNlHdyw5fPc8Zg#41fM*bEyo7?VeF+V~4mG~v+PyK#D_>qR*)bEi0 zV*lVL_z(Y$hWgw;;r^RE|38nU<3D46(O>v)^iQn3WxD&br#{mDvwrdq=npmh@P8Kj z4Y|>Nf`^QMVn4~>F~52|_aBj8;>YkG{$oM+e?{HL$CjEF(onykrXTTt_!s&azixGU zx${E>{iSs7QM+l?K!N|v{^{RE{7b$5F@<8QI0YpNW4~$t#)TTl{wu7X2R(QYM9n`9 zT6T`O)gxBWzmtaA@Sj@$5kG_fh@T=q`ix zI(`3@UAub$QvMmQ*57LUg#YB~k~7*=d6%z#0{lt+2k{H;U*do4Ek5A2(FYw3@f-5* z#BU3tKleXBg>~w`y-9*X{tWu9(O-SmS)==?ApJ-EAn`}eAN>J-FG&BefA0UFf8l@p zAL4)LZ}^Y;dmdy#{_r2jpL2f1Z;?OzcOGd+e}f;uui#Jg7x820PyA3Vzr?T5-)jCr z|D8Cae<}S;g!JFpKlumhXNdn(KT*cJS|k5UAqo#d;QwL&$bS?RzasuJtJ2=-4^2E| z{87f=N0n_-{K8=c{y~3a2P`z4v?f?SC{tvkW!Lzwaf12{&L8|k{tx^{{Db)G@#UU# z>|0!y^`GpY`|kzGFZ~bUA9ejc{tM%gzXPMpnrF|Cm-wChJ^2UDpZl-Sj|aup4mH02 zblxL{`;*7&Y@RZ9ZE8OMf&2ydhxo_loCWQSvwn&?iFbR>_LKgP(2w!xU+Aaif8xi)Kk=WCU-Ac>AMqRbC%os~%`R2LmFwGk=v_FTDD&U! zpZa&?2m0T%Egrk(aAghk3+B6vB^&K{E$?5VzsVnfKZxJ4fAS~P9~1w<|KdSI&Y${0 z#)JRM#y7EDGdfYmpTW=g@9bYqKOUqqE*aC$ud%zV|7ZW?uUS9(lm3z9Z_r=pPxue~ zQ4oIG)@efX_|u8X+bB__{fJ<>e&iSVV}AID{4@Q>u3j)`{igm0=|4h$^dIJz@uB1=4oGH5r|%p8jQO{_S|Q+$X=r@q+v<5Au+Ih5pcw`SIU^ z>i2Y*s(nSq|J3~}nIHcN`-%J$|0VuZ$2sia+TI~j|ABwNkLW)hG_RBvRP65K?}Gem z_NTn=Fhw9~~_H zXVwq@A^+e%{3q@okbfk8f&NAQnV#w-KJ+Rc~yGa$Y1p7ySfj_Wc z^#2CGAV2tT1+Aa>OVL9nHJX`)N%;f+sO>-V6aS(A!X`(9%2gi6%JXM_?q3une?`M< z_zifHCWxPre}sRCUm}0tU+^RSbM1GQsO`M0oy6aar+*LfqdvdZaiz2FEld#LKkx_m zf%;wO5B|8VH>hjB5uXM5Bi=ut{)PQhf6DnIzx29Hg9Q#ZEBlI`tNBveo`WN|IZ0xY`+3vu6_xG5e@yIXy zNB+ofuW3@(ouN{Hf5B@9tf2hBLe%wFO zif>_B*EmGR58+?NgTJv~+#f`LbAK582LH@jXIR&zOoDtr0Q(!R`^s%iouV4vZ=-&j z_!sm$+^oFs`fb`8@-N_b^55ux`k$!p-*W$p{u^rjgZ~G9WBt${{8-R<{8y8S(aEa~ zJ@Vy`CZ$biQEX4LAbtgYC~4yE<1{}?zMr#x#mAL01KN7z}f2!#}V5d<;u^E|y z{1N>(u>a@}_+LF9|FcJ*12=~m{FL{fpda`h`9*(Hf94;2G)2oPLE-%<=!g7pf21J% zfd6c=uilk_838i>1^vihGC%f<^Q-vEx$}(xZ4LeJuz!qa{(|U-{_;1rittLjo87%!NJIVBvz)cl&s0g2@k8`C_h+FW z@h{Gg{wq9)3;)9Z#Q(v6k6%>CeELRLgZy%ToBMCbFZ!GQMfjiCPx3e5kM)+X#@!$9 zFY^bC$Nyn|HUF-;ZJ!meFiF-Qf*-M;$Ups4)Z?ii;QlS~3(g<>kNzY6%l&!u_rXEM znuz&%@25dO{QrXLFX&%D{7X%L>c_!P;0NySxDBZFM&C12%t|irJA8FgzW6!$d+ays z=ltn^)BCzl@qSf)NdE)=;r;;f1OI@3$$wKng8u~lcn}=^=lv$|6ZM0o%+D#&z4H36 z!2bbmEiZSf8!NEC;1_3K&FlkhbTr6sRM*zN{oTU_`3LkT@nhDnK7ab>oN3diaIGN$ zl78%;`LQ4B@zh^(e}?<7jOYBxKdQ$=Kll&-1^%J`H~uT}XQy&!H#oUuiuF}qblx&B zS^CcsfB(_H;4k<`JwNr=E7MDqF`Z$cA%D*K!T)>OANzj&Zid476ThJTk_XwbpX%}K zpYh!P0{>F~iTpx8>}Nse$Nf7Vw8#HtJpFSR5B^7f@t>%_;{FBvV`6p3ywS2Wr6B)} z`bF|*>iNmvaDPKhe;#yS=)Pxj*YgR|e`kK~FTj7)Z-5_p|IKsRkK}!o5~cqI{qf(? zU(`Pm|0e!my?x*@_qXv9Kcaup|KJzwH~NG7r}d_v@=Y(}CGijAkw4@g{6PMa`p2TW zSDcL+>u9ha>>vEh`GLRbKS}=%n{itz9@4%p{jWk=)9N~FN@&nu#Q*VMy-M7;5`X zgOenFzJ7T`rYaMK=L2_kN*Py!awv6pnpN{1&&&U zdxqw_|4;n~_7nQU|GfXn`zdPu4gHv({=eK`#{cI29Q*_Qksrom|CpcmhdzItueE;a z4~d`1--3S%(qFuPK>r2$AHsj+5AYwjzr+08A4YzuUtxaY=g7ZW|CL?cqkr13zO}p` zMgL^vpZGQWhyLq1dRn8kvoZzeNB?E`kMpN~2>HSOz(3TV)P8*SXR;7r))vUJ&hhW^LztK9iF!OTbI|L9+j{lkACf6M&H-yok^4Tpa! zD(fFU`Yvw&eaAne$Jk9f5U(O{@=eVcb2^0!dH3Y^H!$*bnb6s#=k$utbhNOywCC7yzj6p zR+RVe+|TCD^xwZV_j~;0_sT^vM?U9R}K! zm}fC(mic6hzrIy+in)!&1na32CeO1lx0!A|+rh3(i)Jlam^5o;(xRD3^Y$jKnwzw0 J<*=a1{{u!!Ppbd` literal 0 HcmV?d00001 diff --git a/tests/example/__init__.py b/tests/example/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/example/simple.py b/tests/example/simple.py new file mode 100644 index 0000000..7e1d204 --- /dev/null +++ b/tests/example/simple.py @@ -0,0 +1,29 @@ +## simple.py +## Date: 01/26/2023 +## A simple example + +from pickle import load +from eval import data_removal +from metrics import weighted_acc_drop + + +from valda.valuation import DataValuation + + +data = load(open('diabetes.pkl', 'rb')) + +trnX, trnY = data['trnX'], data['trnY'] +devX, devY = data['devX'], data['devY'] + + +dv = DataValuation(trnX, trnY, devX, devY) + + +vals = dv.estimate(method='beta-shapley') + + +tstX, tstY = data['tstX'], data['tstY'] + +accs = data_removal(vals, trnX, trnY, tstX, tstY) +res = weighted_acc_drop(accs) +print("The weighted accuracy drop is {}".format(res)) diff --git a/tests/example/simple_pytorch.py b/tests/example/simple_pytorch.py new file mode 100644 index 0000000..59e5400 --- /dev/null +++ b/tests/example/simple_pytorch.py @@ -0,0 +1,42 @@ +## simple.py +## A simple example + +import torch +from pickle import load +from sklearn import preprocessing + +from valda.valuation import DataValuation +from valda.pyclassifier import PytorchClassifier + + +class LogisticRegression(torch.nn.Module): + def __init__(self, input_dim, output_dim): + super(LogisticRegression, self).__init__() + self.linear = torch.nn.Linear(input_dim, output_dim) + self.softmax = torch.nn.Softmax(dim=1) + + def forward(self, x): + outputs = self.softmax(self.linear(x)) + return outputs + + +data = load(open('../data/diabetes.pkl', 'rb')) +trnX, trnY = data['trnX'], data['trnY'] +devX, devY = data['devX'], data['devY'] +print('trnX.shape = {}'.format(trnX.shape)) + +labels = list(set(trnY)) +le = preprocessing.LabelEncoder() +le.fit(labels) +trnY = le.transform(trnY) +devY = le.transform(devY) + + +model = LogisticRegression(input_dim=trnX.shape[1], output_dim=len(labels)) +clf = PytorchClassifier(model) + +dv = DataValuation(trnX, trnY, devX, devY) + +vals = dv.estimate(clf=clf, method='inf-func') + +print(vals)