diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..ee88093 --- /dev/null +++ b/.gitignore @@ -0,0 +1,2 @@ +dist/ +*egg-info* diff --git a/LICENSE.rst b/LICENSE.rst new file mode 100644 index 0000000..eceb6af --- /dev/null +++ b/LICENSE.rst @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2023 UVa ILP + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/README.md b/README.md index 3badb39..59bcf62 100644 --- a/README.md +++ b/README.md @@ -1,2 +1,3 @@ # valda + A Python Data Valuation Package diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..0a40859 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,22 @@ +[build-system] +requires = ["setuptools>=61.0"] +build-backend = "setuptools.build_meta" + +[project] +name = "valda" +version = "0.1.5" +authors = [ + { name="Yangfeng Ji", email="yangfeng@virginia.edu" }, +] +description = "A Data Valuation Package for Machine Learning" +readme = "README.md" +requires-python = ">=3.6" +classifiers = [ + "Programming Language :: Python :: 3", + "License :: OSI Approved :: MIT License", + "Operating System :: OS Independent", +] + +[project.urls] +"Homepage" = "https://uvanlp.org/valda" +"Bug Tracker" = "https://github.com/uvanlp/valda/issues" diff --git a/src/valda/__init__.py b/src/valda/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/valda/beta_shapley.py b/src/valda/beta_shapley.py new file mode 100644 index 0000000..a268ccd --- /dev/null +++ b/src/valda/beta_shapley.py @@ -0,0 +1,69 @@ +## beta_shapley.py +## Implementation of Beta Shapley + +import numpy as np +from random import shuffle, seed, randint, sample, choice +from tqdm import tqdm +from sklearn.metrics import accuracy_score + + +## Local module +from .util import * + + +def beta_shapley(trnX, trnY, devX, devY, clf, alpha=1.0, + beta=1.0, rho=1.0005, K=10, T=10): + """ + alpha, beta - parameters for Beta distribution + rho - GR statistic threshold + K - number of Markov chains + T - upper bound of iterations + """ + N = trnX.shape[0] + Idx = list(range(N)) # Indices + val, t = np.zeros((N, K, T+1)), 0 + rho_hat = 2*rho + # val_N_list = [] + + # Data information + N = len(trnY) + # Computation + # while np.any(rho_hat >= rho): + for t in tqdm(range(1, T+1)): + # print("Iteration: {}".format(t)) + for j in range(N): + for k in range(K): + Idx = list(range(N)) + Idx.remove(j) # remove j + s = randint(1, N-1) + sub_Idx = sample(Idx, s) + acc_ex, acc_in = None, None + # ========================= + trnX_ex, trnY_ex = trnX[sub_Idx, :], trnY[sub_Idx] + try: + clf.fit(trnX_ex, trnY_ex) + acc_ex = accuracy_score(devY, clf.predict(devX)) + except ValueError: + acc_ex = accuracy_score(devY, [trnY_ex[0]]*len(devY)) + # ========================= + sub_Idx.append(j) # Add example j back for training + trnX_in, trnY_in = trnX[sub_Idx, :], trnY[sub_Idx] + try: + clf.fit(trnX_in, trnY_in) + acc_in = accuracy_score(devY, clf.predict(devX)) + except ValueError: + acc_in = accuracy_score(devY, [trnY_in[0]]*len(devY)) + # Update the value + val[j,k,t] = ((t-1)*val[j,k,t-1])/t + (weight(j+1, N, alpha, beta)/t)*(acc_in - acc_ex) + # Update the Gelman-Rubin statistic rho_hat + if t > 3: + rho_hat = gr_statistic(val, t) # A temp solution for stopping + # print("rho_hat = {}".format(rho_hat[:5])) + if np.all(rho_hat < rho): + # terminate the outer loop earlier + break + # average all the sample values + # val_mean = val[:,:,1:t+1].mean(axis=2).mean(axis=1) # N + val_last = val[:,:,t].mean(axis=1) + # print(val_last) + return val_last diff --git a/src/valda/cs_shapley.py b/src/valda/cs_shapley.py new file mode 100644 index 0000000..64da738 --- /dev/null +++ b/src/valda/cs_shapley.py @@ -0,0 +1,124 @@ +import numpy as np +from random import shuffle, seed, randint, sample, choice +from tqdm import tqdm +from sklearn.metrics import accuracy_score + + +def class_conditional_sampling(Y, label_set): + Idx_nonlabel = [] + for label in label_set: + label_indices = list(np.where(Y == label)[0]) + s = randint(1, len(label_indices)) + Idx_nonlabel += sample(label_indices, s) + shuffle(Idx_nonlabel) # shuffle the sampled indices + # print('len(Idx_nonlabel) = {}'.format(len(Idx_nonlabel))) + return Idx_nonlabel + + +def cs_shapley(trnX, trnY, devX, devY, label, clf, T=200, + epsilon=1e-4, normalized_score=True, resample=1): + ''' + normalized_score - whether normalizing the Shaple values within the class + + resample - the number of resampling when estimating the values with one + specific permutation. Technically, larger values lead to better + results, but in practice, the difference may not be significant + ''' + # Select data based on the class label + orig_indices = np.array(list(range(trnX.shape[0])))[trnY == label] + print("The number of training data with label {} is {}".format(label, len(orig_indices))) + trnX_label = trnX[trnY == label] + trnY_label = trnY[trnY == label] + trnX_nonlabel = trnX[trnY != label] + trnY_nonlabel = trnY[trnY != label] + devX_label = devX[devY == label] + devY_label = devY[devY == label] + devX_nonlabel = devX[devY != label] + devY_nonlabel = devY[devY != label] + N_nonlabel = trnX_nonlabel.shape[0] + nonlabel_set = list(set(trnY_nonlabel)) + print("Labels on the other side: {}".format(nonlabel_set)) + + # Create indices and shuffle them + N = trnX_label.shape[0] + Idx = list(range(N)) + # Shapley values, number of permutations, total number of iterations + val, k = np.zeros((N)), 0 + for t in tqdm(range(1, T+1)): + # print("t = {}".format(t)) + # Shuffle the data + shuffle(Idx) + # For each permutation, resample I times from the other classes + for i in range(resample): + k += 1 + # value container for iteration i + val_i = np.zeros((N+1)) + val_i_non = np.zeros((N+1)) + + # -------------------- + # Sample a subset of training data from other labels for each i + if len(nonlabel_set) == 1: + s = randint(1, N_nonlabel) + # print('s = {}'.format(s)) + Idx_nonlabel = sample(list(range(N_nonlabel)), s) + else: + Idx_nonlabel = class_conditional_sampling(trnY_nonlabel, nonlabel_set) + trnX_nonlabel_i = trnX_nonlabel[Idx_nonlabel, :] + trnY_nonlabel_i = trnY_nonlabel[Idx_nonlabel] + + # -------------------- + # With no data from the target class and the sampled data from other classes + val_i[0] = 0.0 + try: + clf.fit(trnX_nonlabel_i, trnY_nonlabel_i) + val_i_non[0] = accuracy_score(devY_nonlabel, clf.predict(devX_nonlabel), normalize=False)/len(devY) + except ValueError: + # In the sampled trnY_nonlabel_i, there is only one class + # print("One class in the training set") + val_i_non[0] = accuracy_score(devY_nonlabel, [trnY_nonlabel_i[0]]*len(devY_nonlabel), + normalize=False)/len(devY) + + # --------------------- + # With all data from the target class and the sampled data from other classes + tempX = np.concatenate((trnX_nonlabel_i, trnX_label)) + tempY = np.concatenate((trnY_nonlabel_i, trnY_label)) + clf.fit(tempX, tempY) + val_i[N] = accuracy_score(devY_label, clf.predict(devX_label), normalize=False)/len(devY) + val_i_non[N] = accuracy_score(devY_nonlabel, clf.predict(devX_nonlabel), normalize=False)/len(devY) + + # -------------------- + # + for j in range(1,N+1): + if abs(val_i[N] - val_i[j-1]) < epsilon: + val_i[j] = val_i[j-1] + else: + # Extract the first $j$ data points + trnX_j = trnX_label[Idx[:j],:] + trnY_j = trnY_label[Idx[:j]] + try: + # --------------------------------- + tempX = np.concatenate((trnX_nonlabel_i, trnX_j)) + tempY = np.concatenate((trnY_nonlabel_i, trnY_j)) + clf.fit(tempX, tempY) + val_i[j] = accuracy_score(devY_label, clf.predict(devX_label), normalize=False)/len(devY) + val_i_non[j] = accuracy_score(devY_nonlabel, clf.predict(devX_nonlabel), normalize=False)/len(devY) + except ValueError: # This should never happen in this algorithm + print("Only one class in the dataset") + # print(tempY) + return (None, None, None) + # ========================================== + # New implementation + wvalues = np.exp(val_i_non) * val_i + # print("wvalues = {}".format(wvalues)) + diff = wvalues[1:] - wvalues[:N] + val[Idx] = ((1.0*(k-1)/k))*val[Idx] + (1.0/k)*(diff) + + + # Whether normalize the scores within the class + if normalized_score: + val = val/val.sum() + clf.fit(trnX, trnY) + score = accuracy_score(devY_label, clf.predict(devX_label), normalize=False)/len(devY) + print("score = {}".format(score)) + val = val * score + return val, orig_indices diff --git a/src/valda/eval.py b/src/valda/eval.py new file mode 100644 index 0000000..92f5cca --- /dev/null +++ b/src/valda/eval.py @@ -0,0 +1,50 @@ +## data_removal.py +## Evaluate the performance of data valuation by removing one data point +## at a time from the training set + +from sklearn.metrics import accuracy_score, auc +from sklearn.linear_model import LogisticRegression as LR + +import operator + +def data_removal(vals, trnX, trnY, tstX, tstY, clf=None, + remove_high_value=True): + ''' + trnX, trnY - training examples + tstX, tstY - test examples + vals - a Python dict that contains data indices and values + clf - the classifier that will be used for evaluation + ''' + # Create data indices for data removal + N = trnX.shape[0] + Idx_keep = [True]*N + + if clf is None: + clf = LR(solver="liblinear", max_iter=500, random_state=0) + # Sorted the data indices with a descreasing order + sorted_dct = sorted(vals.items(), key=operator.itemgetter(1), reverse=True) + # Accuracy list + accs = [] + if remove_high_value: + lst = range(N) + else: + lst = range(N-1, -1, -1) + # Compute + clf.fit(trnX, trnY) + acc = accuracy_score(clf.predict(tstX), tstY) + accs.append(acc) + for k in lst: + # print(k) + Idx_keep[sorted_dct[k][0]] = False + trnX_k = trnX[Idx_keep, :] + trnY_k = trnY[Idx_keep] + try: + clf.fit(trnX_k, trnY_k) + # print('trnX_k.shape = {}'.format(trnX_k.shape)) + acc = accuracy_score(clf.predict(tstX), tstY) + # print('acc = {}'.format(acc)) + accs.append(acc) + except ValueError: + # print("Training with data from a single class") + accs.append(0.0) + return accs diff --git a/src/valda/inf_func.py b/src/valda/inf_func.py new file mode 100644 index 0000000..8e5c409 --- /dev/null +++ b/src/valda/inf_func.py @@ -0,0 +1,124 @@ +import numpy as np +from tqdm import tqdm + + +## Local module +from .pyclassifier import * + +""" +Important Params: +ys: scalar to be differentiated +params: list of vectors (torch.tensors) w.r.t. each of which the hessian is computed +vs: the list of vectors each of which is to be multiplied to the hessian w.r.t. each parameter +params2: another list of params for second `grad` call in case the second derivation is w.r.t. a different set of parameters +""" + + +def hessian_vector_product(ys, params, vs, params2=None): + grads1 = grad(ys, params, create_graph=True) + if params2 is not None: + params = params2 + + grads2 = grad(grads1, params, grad_outputs=vs) + return grads2 + + +# Each output in the list is obtained by differentiating `ys` w.r.t. only a single parameter. +# Returns: a list of hessians of `ys` w.r.t. each parameter in `params`, i.e. differentiate `ys` twice w.r.t. each parameter. One hessian per param. +""" +Important Params: +ys: scalar that is to be differentiated +params: list of torch.tensors, hessian is computed for each +""" + + +def hessians(ys, params): + jacobians = grad(ys, params, create_graph=True) + + outputs = [] # container for hessians + for j, param in zip(jacobians, params): + hess = [] + j_flat = j.flatten() + for i in range(len(j_flat)): + grad_outputs = torch.zeros_like(j_flat) + grad_outputs[i] = 1 + grad2 = grad(j_flat, param, grad_outputs=grad_outputs, retain_graph=True)[0] + hess.append(grad2) + outputs.append(torch.stack(hess).reshape(j.shape + param.shape)) + return outputs + + +# Compute product of inverse hessian of empirical risk and given vector 'v', computed numerically using LiSSA. +# Return a list of inverse-hvps, computed for each param. + +''' +Important params: +vs: list of vectors in the inverse-hvp, one per parameter +batch_size: size of minibatch sample at each iteration +scale: the factor to scale down loss (to keep hessian <= I) +damping: lambda added to guarantee hessian be p.d. +num_repeats: hyperparameter 'r' in in the paper (to reduce variance) +recursion_depth: number of iterations for LiSSA algorithm +''' + + +def get_inverse_hvp_lissa(model, criterion, dataset, vs, + batch_size, + scale=1, + damping=0.1, + num_repeats=1, + verbose=False): + assert criterion is not None, "ERROR: Criterion cannot be None." + assert batch_size <= len(dataset), "ERROR: Minibatch size for LiSSA should be less than dataset size" + # assert len(dataset) % batch_size == 0, "ERROR: Dataset size for LiSSA should be a multiple of minibatch size" + assert isinstance(dataset, Dataset), "ERROR: `dataset` must be PyTorch Dataset" + + params = [param for param in model.parameters() if param.requires_grad] + data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True) + inverse_hvp = None + + for rep in range(num_repeats): + cur_estimate = vs + for batch in iter(data_loader): + batch_inputs, batch_targets = batch + batch_out = model(batch_inputs) + + loss = criterion(batch_out, batch_targets) / batch_size + + hvp = hessian_vector_product(loss, params, vs=cur_estimate) + cur_estimate = [v + (1 - damping) * ce - hv / scale \ + for (v, ce, hv) in zip(vs, cur_estimate, hvp)] + + inverse_hvp = [hv1 + hv2 / scale for (hv1, hv2) in zip(inverse_hvp, cur_estimate)] \ + if inverse_hvp is not None \ + else [hv2 / scale for hv2 in cur_estimate] + + # avg. over repetitions + inverse_hvp = [item / num_repeats for item in inverse_hvp] + return inverse_hvp + + + +def inf_func(trnX, trnY, devX, devY, clf, + epochs=5, trn_batch_size=1, dev_batch_size=16): + + if epochs > 0: + print("Training models for IF with {} iterations ...".format(epochs)) + clf.fit(trnX, trnY, epochs=epochs, batch_size=trn_batch_size) + + train_grads = clf.grad(trnX, trnY, batch_size=1) + test_grads = clf.grad(devX, devY, batch_size=dev_batch_size) + + infs = [] + for test_grad in test_grads: + inf_up_loss = [] + for train_grad in train_grads: + inf = 0 + for train_grad_p, test_grad_p in zip(train_grad, test_grad): + inf += torch.sum(train_grad_p * test_grad_p) + inf_up_loss.append(inf.cpu().detach().numpy()) + infs.append(inf_up_loss) + + vals = list(np.sum(infs, axis=0)) + return vals + diff --git a/src/valda/loo.py b/src/valda/loo.py new file mode 100644 index 0000000..d5961e8 --- /dev/null +++ b/src/valda/loo.py @@ -0,0 +1,40 @@ +## loo.py +## The implementation of Leave-one-out for data valuation + +import numpy as np +from tqdm import tqdm +from sklearn.metrics import accuracy_score + + +def loo(trnX, trnY, devX, devY, clf): + ''' + trnX, trnY - inputs/outputs of training examples + + ''' + N = trnX.shape[0] + val, t = np.zeros((N)), 0 + for i in tqdm(range(N)): + # Shuffle the data + # print(trnX.shape) + acc_in, acc_ex = None, None + Idx = list(range(N)) + # Include the data point i + try: + clf.fit(trnX, trnY) + acc_in = accuracy_score(devY, clf.predict(devX)) + except ValueError: + # Training set only has a single calss + acc_in = accuracy_score(devY, [trnY[0]]*len(devY)) + # Exclude the data point i + Idx.remove(i) + tempX, tempY = trnX[Idx, :], trnY[Idx] + # print(tempX.shape) + try: + clf.fit(tempX, tempY) + acc_ex = accuracy_score(devY, clf.predict(devX)) + except ValueError: + acc_ex = accuracy_score(devY, [trnY[0]]*len(devY)) + # print("acc_in = {}, acc_ex = {}".format(acc_in, acc_ex)) + val[i] = acc_in - acc_ex + # print('val = {}'.format(val)) + return val diff --git a/src/valda/metrics.py b/src/valda/metrics.py new file mode 100644 index 0000000..3904997 --- /dev/null +++ b/src/valda/metrics.py @@ -0,0 +1,33 @@ +## metrics.py +## Date: 01/26/2023 +## Evaluation metrics used for data valuation + +import numpy as np +import copy + +def weighted_acc_drop(accs): + ''' Weighted accuracy drop, please refer to (Schoch et al., 2022) + for definition + ''' + # accs = copy.copy(accs) + accs.append(0.) + accs = np.array(accs) + diff = accs[:-1] - accs[1:] + c_sum = np.cumsum(diff) + weights = np.array(list(range(1, diff.shape[0]+1))) + weights = 1.0/weights + score = weights * c_sum + return score.sum() + + +def pr_curve(target_list, ranked_list): + ''' Compute P/R for two given lists + ''' + p, r = [], [] + for idx in range(3, len(ranked_list)+1): + partial_list = ranked_list[:idx] + union = list(set(target_list) & set(partial_list)) + r.append(1.0*len(union)/len(target_list)) + p.append(1.0*len(union)/len(partial_list)) + score = auc(r, p) + return (p, r, score) diff --git a/src/valda/params.py b/src/valda/params.py new file mode 100644 index 0000000..c9573c1 --- /dev/null +++ b/src/valda/params.py @@ -0,0 +1,43 @@ +## params.py +## Date: 01/23/2022 +## Initialize default parameters + + +class Parameters(object): + def __init__(self): + self.params = { + # For TMC Shapley + 'tmc_iter':500, + 'tmc_thresh':0.001, + # For CS Shapley + 'cs_iter':500, + 'cs_thresh':0.001, + # For Beta Shapley + 'beta_iter':50, + 'alpha':1.0, + 'beta':16.0, + 'rho':1.0005, + 'beta_chain':10, + # For Influence Function + 'if_iter':30, + 'trn_batch_size':16, + 'dev_batch_size':16 + } + + def update(self, new_params): + for (key, val) in new_params.items(): + try: + self.params[key] = val + except KeyError: + print("Undefined key {} with value {}".format(key)) + # return self.params + + + def get_values(self): + return self.params + + + def print_values(self): + print("The current hyper-parameter setting:") + for (key, val) in self.params.items(): + print("\t{} : {}".format(key, val)) diff --git a/src/valda/pyclassifier.py b/src/valda/pyclassifier.py new file mode 100644 index 0000000..05d6cf7 --- /dev/null +++ b/src/valda/pyclassifier.py @@ -0,0 +1,109 @@ +## classifier.py +## Date: 01/18/2023 +## A general framework of defining a classifier class + + +import torch +import torch.nn as nn +from torch.autograd import Variable, grad +from torch.utils.data import Dataset, DataLoader +import numpy as np +from tqdm import tqdm + + +class Data(Dataset): + def __init__(self, X, y=None): + self.X = torch.from_numpy(X.astype(np.float32)) + if y is not None: + self.y = torch.from_numpy(y).type(torch.LongTensor) + else: + self.y = None + self.len = self.X.shape[0] + self.dim = self.X.shape[1] + + def __getitem__(self, index): + return self.X[index], self.y[index] + + def __len__(self): + return self.len + + def __dim__(self): + return self.dim + + +class PytorchClassifier(object): + def __init__(self, model, optim=None, loss=None): + ''' + model - + optim - + loss - + ''' + self.model = model + self.params = list(self.model.parameters()) + # optimizer + if optim is None: + self.optim = torch.optim.Adam(model.parameters()) + else: + self.optim = optim + # loss function + if loss is None: + self.loss = nn.CrossEntropyLoss() + else: + self.loss = loss + + + def fit(self, X, y, epochs=10, batch_size=1): + ''' + X, y - training examples + ''' + loader = DataLoader(Data(X,y), batch_size=batch_size, shuffle=True, + num_workers=0) + for epoch in tqdm(range(epochs)): + for (inputs, labels) in loader: + self.optim.zero_grad() + outputs = self.model(inputs) + batch_loss = self.loss(outputs, labels) + batch_loss.backward() + self.optim.step() + print("Done training") + + + def predict(self, X): + data = Data(X) + loader = DataLoader(data, batch_size=batch_size, shuffle=False, + num_workers=0) + pred_labels = [] + with torch.no_grad(): + for data in loader: + inputs, _ = data + outputs = self.forward(inputs) + _, predicted = torch.max(outputs.data, 1) + pred_labels += predicted + return pred_labels + + + def online_train(self, x, y): + ''' On line training + ''' + raise NotImplementedError("Not implemented") + + + def grad(self, X, y, batch_size=1): + ''' Compute the gradient of the parameter wrt (X, y) + ''' + grads = [] + loader = DataLoader(Data(X,y), batch_size=batch_size, shuffle=False, + num_workers=0) + for (inputs, labels) in loader: + outputs = self.model(inputs) + batch_loss = self.loss(outputs, labels) + batch_grads = grad(batch_loss, self.params) + grads += batch_grads # are the data structures consistent? + return grads + + + + def get_parameters(self): + ''' Get the model parameters + ''' + return self.params diff --git a/src/valda/tmc_shapley.py b/src/valda/tmc_shapley.py new file mode 100644 index 0000000..f3bf864 --- /dev/null +++ b/src/valda/tmc_shapley.py @@ -0,0 +1,43 @@ +import numpy as np +from random import shuffle, seed, randint, sample, choice +from tqdm import tqdm +from sklearn.metrics import accuracy_score + + + +def tmc_shapley(trnX, trnY, devX, devY, clf, T=20, epsilon=0.001): + N = trnX.shape[0] + Idx = list(range(N)) # Indices + val, t = np.zeros((N)), 0 + # Start calculation + val_T = np.zeros((T, N)) + for t in tqdm(range(1, T+1)): + # Shuffle the data + shuffle(Idx) + val_t = np.zeros((N+1)) + # pre-computed values (with all training data/without training data) + try: + clf.fit(trnX, trnY) + val_t[N] = accuracy_score(devY, clf.predict(devX)) + except ValueError: + # Training set only has a single calss + val_t[N] = accuracy_score(devY, [trnY[0]]*len(devY)) + # + for j in range(1,N+1): + if abs(val_t[N] - val_t[j-1]) < epsilon: + val_t[j] = val_t[j-1] + else: + # Extract the first $j$ data points + trnX_j = trnX[Idx[:j],:] + # print("trnX_j.shape = {}".format(trnX_j.shape)) + trnY_j = trnY[Idx[:j]] + try: + clf.fit(trnX_j, trnY_j) + val_t[j] = accuracy_score(devY, clf.predict(devX)) + except ValueError: + # the majority vote + val_t[j] = accuracy_score(devY, [trnY_j[0]]*len(devY)) + # Update the data shapley values + val[Idx] = ((1.0*(t-1)/t))*val[Idx] + (1.0/t)*(val_t[1:] - val_t[:N]) + val_T[t-1,:] = (np.array(val_t)[1:])[Idx] + return val diff --git a/src/valda/util.py b/src/valda/util.py new file mode 100644 index 0000000..36fb8cd --- /dev/null +++ b/src/valda/util.py @@ -0,0 +1,128 @@ +import numpy as np +from math import factorial +from sklearn.metrics import accuracy_score, auc +import copy + + +def weight(j, n, alpha=1.0, beta=1.0): + log_1, log_2, log_3 = 0.0, 0.0, 0.0 + for k in range(1, j): + log_1 += np.log(beta + k - 1) + for k in range(1, n-j+1): + log_2 += np.log(alpha + k - 1) + for k in range(1, n): + log_3 += np.log(alpha + beta + k - 1) + log_total = np.log(n) + log_1 + log_2 - log_3 + # print("n = {}, j = {}".format(n, j)) + log_comb = None + if n <= 20: + log_comb = np.log(factorial(n-1)) + else: + log_comb = (n-1)*(np.log(n-1) - 1) + if j <= 20: + log_comb -= np.log(factorial(j-1)) + else: + log_comb -= (j-1)*(np.log(j-1) - 1) + if (n-j) <= 20: + log_comb -= np.log(factorial(n-j)) + else: + log_comb -= (n-j)*(np.log(n-j) - 1) + # print("log_total = {}, log_comb = {}".format(log_total, log_comb)) + v = np.exp(log_comb + log_total) + # print("v = {}".format(v)) + return v + + +def gr_statistic(val,t): + v = val[:,:,1:t+1] + sample_var = np.var(v, axis=2, ddof=1) # N x K, along dimension T + mean_sample_var = np.mean(sample_var, axis=1) # N, along dimension K, s^2 in the paper + sample_mean = np.mean(v, axis=2) # N x K, along dimension T + sample_mean_var = np.var(sample_mean, axis=1, ddof=1) # N, along dimension K, B/n in the paper + sigma_hat_2 = ((t-1)*mean_sample_var)/t + sample_mean_var + rho_hat = np.sqrt(sigma_hat_2/(mean_sample_var + 1e-4)) + return rho_hat + + + + + + +def data_removal_figure(neg_lab, pos_lab, trnX, trnY, devX, devY, sorted_dct, clf_label, remove_high_value=True): + # Create data indices for data removal + N = trnX.shape[0] + Idx_keep = [True]*N + # Accuracy list + accs = [] + if remove_high_value: + lst = range(N) + else: + lst = range(N-1, -1, -1) + # Compute + clf = Classifier(clf_label) + clf.fit(trnX, trnY) + dev = zip(devX, devY) + dev = list(dev) + dev_X0 = [] + dev_X1 = [] + dev_Y0 = [] + dev_Y1 = [] + + for i in dev: + if i[1] == pos_lab: + dev_X1.append(i[0]) + dev_Y1.append(i[1]) + elif i[1] == neg_lab: + dev_X0.append(i[0]) + dev_Y0.append(i[1]) + else: + print(i) + + + acc_0 = accuracy_score(dev_Y0, clf.predict(dev_X0), normalize=False)/len(devY) + acc_1 = accuracy_score(dev_Y1, clf.predict(dev_X1), normalize=False)/len(devY) + print(acc_0, acc_1) + + accs_0 = [] + accs_1 = [] + acc = accuracy_score(clf.predict(devX), devY)#/len(devY) + accs.append(acc) + accs_0.append(acc_0) + accs_1.append(acc_1) + vals = [] + labels = [] + points = [] + ks = [] + for k in lst: + # print(k) + Idx_keep = [True] * N + Idx_keep[sorted_dct[k][0]] = False + trnX_k = trnX[Idx_keep, :] + trnY_k = trnY[Idx_keep] + clf = Classifier(clf_label) + try: + clf.fit(trnX_k, trnY_k) + # print('trnX_k.shape = {}'.format(trnX_k.shape)) + labels.append(trnY[k]) + points.append(trnX[k]) + acc = accuracy_score(clf.predict(devX), devY) + acc_0 = accuracy_score(dev_Y0, clf.predict(dev_X0), normalize=False)/len(devY) + acc_1 = accuracy_score(dev_Y1, clf.predict(dev_X1), normalize=False)/len(devY) + # print('acc = {}'.format(acc)) + ks.append(k) + accs.append(acc) + accs_0.append(acc_0) + accs_1.append(acc_1) + vals.append(sorted_dct[k][1]) + except ValueError: + # print("Training with data from a single class") + accs.append(0.0) + return accs, accs_0, accs_1, vals, labels, points, ks + + + + + + + + diff --git a/src/valda/valuation.py b/src/valda/valuation.py new file mode 100644 index 0000000..94e9242 --- /dev/null +++ b/src/valda/valuation.py @@ -0,0 +1,114 @@ +## valuation.py +## Date: 01/18/2023 +## A general framework of defining data valuation class + + +from sklearn.linear_model import LogisticRegression as LR +from sklearn.metrics import accuracy_score + + +## Local module: load data valuation methods +from .loo import loo +from .tmc_shapley import tmc_shapley +from .cs_shapley import cs_shapley +from .beta_shapley import beta_shapley +from .inf_func import inf_func +from .params import Parameters # Import default model parameters + + +class DataValuation(object): + def __init__(self, trnX, trnY, devX=None, devY=None): + ''' + trn_X, trn_Y - Input/output for training, also the examples for + being valued + val_X, val_Y - Input/output for validation, also the examples used + for estimating the values of (trn_X, trn_Y) + ''' + self.trnX, self.trnY = trnX, trnY + if devX is None: + self.valX, self.devY = trnX, trnY + else: + self.devX, self.devY = devX, devY + self.values = {} # A rank list of + self.clf = None # instance of classifier + params = Parameters() + self.params = params.get_values() + + + def estimate(self, clf=None, method='loo', params=None): + ''' + clf - a classifier instance (Logistic regression, by default) + method - the data valuation method (LOO, by default) + params - hyper-parameters for data valuation methods + ''' + self.values = {} + if clf is None: + self.clf = LR(solver="liblinear", max_iter=500, random_state=0) + else: + self.clf = clf + + if params is not None: + print("Overload the model parameters with the user specified ones: {}".format(params)) + self.params = params + + # Call data valuation functions + if method == 'loo': + # Leave-one-out + vals = loo(self.trnX, self.trnY, self.devX, self.devY, self.clf) + for idx in range(len(vals)): + self.values[idx] = vals[idx] + elif method == 'tmc-shapley': + # TMC Data Shapley (TODO: Citation) + # Get the default parameter values + n_iter = self.params['tmc_iter'] + tmc_thresh = self.params['tmc_thresh'] + # + vals = tmc_shapley(self.trnX, self.trnY, self.devX, self.devY, + self.clf, n_iter, tmc_thresh) + for idx in range(len(vals)): + self.values[idx] = vals[idx] + elif method == 'cs-shapley': + # CS Shapley (Schoch et al., 2022) + n_iter = self.params['cs_iter'] + cs_thresh = self.params['cs_thresh'] + labels = list(set(self.trnY)) + for label in labels: + vals, orig_indices = cs_shapley(self.trnX, self.trnY, self.devX, + self.devY, label, + self.clf, n_iter, cs_thresh) + for (idx, val) in zip(list(orig_indices), list(vals)): + self.values[idx] = val + elif method == 'beta-shapley': + # Beta Shapley + n_iter = self.params['beta_iter'] + alpha, beta = self.params['alpha'], self.params['beta'] + rho = self.params['rho'] + n_chain = self.params['beta_chain'] + vals = beta_shapley(self.trnX, self.trnY, self.devX, self.devY, + self.clf, alpha, beta, rho, n_chain, n_iter) + for idx in range(len(vals)): + self.values[idx] = vals[idx] + elif method == 'inf-func': + n_iter = self.params['if_iter'] + trn_batch_size = self.params['trn_batch_size'] + dev_batch_size = self.params['dev_batch_size'] + vals = inf_func(self.trnX, self.trnY, self.devX, self.devY, + clf=self.clf, + epochs=n_iter, + trn_batch_size=trn_batch_size, + dev_batch_size=dev_batch_size) + for idx in range(len(vals)): + self.values[idx] = vals[idx] + else: + raise ValueError("Unrecognized data valuation method: {}".format(method)) + return self.values + + + def get_values(self): + ''' + return the data values + ''' + if self.values is not None: + return self.values + else: + raise ValueError("No values are computed") diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/data/diabetes.pkl b/tests/data/diabetes.pkl new file mode 100644 index 0000000..86dccd3 Binary files /dev/null and b/tests/data/diabetes.pkl differ diff --git a/tests/example/__init__.py b/tests/example/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/example/simple.py b/tests/example/simple.py new file mode 100644 index 0000000..7e1d204 --- /dev/null +++ b/tests/example/simple.py @@ -0,0 +1,29 @@ +## simple.py +## Date: 01/26/2023 +## A simple example + +from pickle import load +from eval import data_removal +from metrics import weighted_acc_drop + + +from valda.valuation import DataValuation + + +data = load(open('diabetes.pkl', 'rb')) + +trnX, trnY = data['trnX'], data['trnY'] +devX, devY = data['devX'], data['devY'] + + +dv = DataValuation(trnX, trnY, devX, devY) + + +vals = dv.estimate(method='beta-shapley') + + +tstX, tstY = data['tstX'], data['tstY'] + +accs = data_removal(vals, trnX, trnY, tstX, tstY) +res = weighted_acc_drop(accs) +print("The weighted accuracy drop is {}".format(res)) diff --git a/tests/example/simple_pytorch.py b/tests/example/simple_pytorch.py new file mode 100644 index 0000000..59e5400 --- /dev/null +++ b/tests/example/simple_pytorch.py @@ -0,0 +1,42 @@ +## simple.py +## A simple example + +import torch +from pickle import load +from sklearn import preprocessing + +from valda.valuation import DataValuation +from valda.pyclassifier import PytorchClassifier + + +class LogisticRegression(torch.nn.Module): + def __init__(self, input_dim, output_dim): + super(LogisticRegression, self).__init__() + self.linear = torch.nn.Linear(input_dim, output_dim) + self.softmax = torch.nn.Softmax(dim=1) + + def forward(self, x): + outputs = self.softmax(self.linear(x)) + return outputs + + +data = load(open('../data/diabetes.pkl', 'rb')) +trnX, trnY = data['trnX'], data['trnY'] +devX, devY = data['devX'], data['devY'] +print('trnX.shape = {}'.format(trnX.shape)) + +labels = list(set(trnY)) +le = preprocessing.LabelEncoder() +le.fit(labels) +trnY = le.transform(trnY) +devY = le.transform(devY) + + +model = LogisticRegression(input_dim=trnX.shape[1], output_dim=len(labels)) +clf = PytorchClassifier(model) + +dv = DataValuation(trnX, trnY, devX, devY) + +vals = dv.estimate(clf=clf, method='inf-func') + +print(vals)