Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
uvaml committed Jan 27, 2023
1 parent cd14418 commit bf2dc22
Show file tree
Hide file tree
Showing 21 changed files with 994 additions and 0 deletions.
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
dist/
*egg-info*
21 changes: 21 additions & 0 deletions LICENSE.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
MIT License

Copyright (c) 2023 UVa ILP

Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:

The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.

THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
# valda

A Python Data Valuation Package
22 changes: 22 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
[build-system]
requires = ["setuptools>=61.0"]
build-backend = "setuptools.build_meta"

[project]
name = "valda"
version = "0.1.5"
authors = [
{ name="Yangfeng Ji", email="[email protected]" },
]
description = "A Data Valuation Package for Machine Learning"
readme = "README.md"
requires-python = ">=3.6"
classifiers = [
"Programming Language :: Python :: 3",
"License :: OSI Approved :: MIT License",
"Operating System :: OS Independent",
]

[project.urls]
"Homepage" = "https://uvanlp.org/valda"
"Bug Tracker" = "https://github.com/uvanlp/valda/issues"
Empty file added src/valda/__init__.py
Empty file.
69 changes: 69 additions & 0 deletions src/valda/beta_shapley.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
## beta_shapley.py
## Implementation of Beta Shapley

import numpy as np
from random import shuffle, seed, randint, sample, choice
from tqdm import tqdm
from sklearn.metrics import accuracy_score


## Local module
from .util import *


def beta_shapley(trnX, trnY, devX, devY, clf, alpha=1.0,
beta=1.0, rho=1.0005, K=10, T=10):
"""
alpha, beta - parameters for Beta distribution
rho - GR statistic threshold
K - number of Markov chains
T - upper bound of iterations
"""
N = trnX.shape[0]
Idx = list(range(N)) # Indices
val, t = np.zeros((N, K, T+1)), 0
rho_hat = 2*rho
# val_N_list = []

# Data information
N = len(trnY)
# Computation
# while np.any(rho_hat >= rho):
for t in tqdm(range(1, T+1)):
# print("Iteration: {}".format(t))
for j in range(N):
for k in range(K):
Idx = list(range(N))
Idx.remove(j) # remove j
s = randint(1, N-1)
sub_Idx = sample(Idx, s)
acc_ex, acc_in = None, None
# =========================
trnX_ex, trnY_ex = trnX[sub_Idx, :], trnY[sub_Idx]
try:
clf.fit(trnX_ex, trnY_ex)
acc_ex = accuracy_score(devY, clf.predict(devX))
except ValueError:
acc_ex = accuracy_score(devY, [trnY_ex[0]]*len(devY))
# =========================
sub_Idx.append(j) # Add example j back for training
trnX_in, trnY_in = trnX[sub_Idx, :], trnY[sub_Idx]
try:
clf.fit(trnX_in, trnY_in)
acc_in = accuracy_score(devY, clf.predict(devX))
except ValueError:
acc_in = accuracy_score(devY, [trnY_in[0]]*len(devY))
# Update the value
val[j,k,t] = ((t-1)*val[j,k,t-1])/t + (weight(j+1, N, alpha, beta)/t)*(acc_in - acc_ex)
# Update the Gelman-Rubin statistic rho_hat
if t > 3:
rho_hat = gr_statistic(val, t) # A temp solution for stopping
# print("rho_hat = {}".format(rho_hat[:5]))
if np.all(rho_hat < rho):
# terminate the outer loop earlier
break
# average all the sample values
# val_mean = val[:,:,1:t+1].mean(axis=2).mean(axis=1) # N
val_last = val[:,:,t].mean(axis=1)
# print(val_last)
return val_last
124 changes: 124 additions & 0 deletions src/valda/cs_shapley.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
import numpy as np
from random import shuffle, seed, randint, sample, choice
from tqdm import tqdm
from sklearn.metrics import accuracy_score


def class_conditional_sampling(Y, label_set):
Idx_nonlabel = []
for label in label_set:
label_indices = list(np.where(Y == label)[0])
s = randint(1, len(label_indices))
Idx_nonlabel += sample(label_indices, s)
shuffle(Idx_nonlabel) # shuffle the sampled indices
# print('len(Idx_nonlabel) = {}'.format(len(Idx_nonlabel)))
return Idx_nonlabel


def cs_shapley(trnX, trnY, devX, devY, label, clf, T=200,
epsilon=1e-4, normalized_score=True, resample=1):
'''
normalized_score - whether normalizing the Shaple values within the class
resample - the number of resampling when estimating the values with one
specific permutation. Technically, larger values lead to better
results, but in practice, the difference may not be significant
'''
# Select data based on the class label
orig_indices = np.array(list(range(trnX.shape[0])))[trnY == label]
print("The number of training data with label {} is {}".format(label, len(orig_indices)))
trnX_label = trnX[trnY == label]
trnY_label = trnY[trnY == label]
trnX_nonlabel = trnX[trnY != label]
trnY_nonlabel = trnY[trnY != label]
devX_label = devX[devY == label]
devY_label = devY[devY == label]
devX_nonlabel = devX[devY != label]
devY_nonlabel = devY[devY != label]
N_nonlabel = trnX_nonlabel.shape[0]
nonlabel_set = list(set(trnY_nonlabel))
print("Labels on the other side: {}".format(nonlabel_set))

# Create indices and shuffle them
N = trnX_label.shape[0]
Idx = list(range(N))
# Shapley values, number of permutations, total number of iterations
val, k = np.zeros((N)), 0
for t in tqdm(range(1, T+1)):
# print("t = {}".format(t))
# Shuffle the data
shuffle(Idx)
# For each permutation, resample I times from the other classes
for i in range(resample):
k += 1
# value container for iteration i
val_i = np.zeros((N+1))
val_i_non = np.zeros((N+1))

# --------------------
# Sample a subset of training data from other labels for each i
if len(nonlabel_set) == 1:
s = randint(1, N_nonlabel)
# print('s = {}'.format(s))
Idx_nonlabel = sample(list(range(N_nonlabel)), s)
else:
Idx_nonlabel = class_conditional_sampling(trnY_nonlabel, nonlabel_set)
trnX_nonlabel_i = trnX_nonlabel[Idx_nonlabel, :]
trnY_nonlabel_i = trnY_nonlabel[Idx_nonlabel]

# --------------------
# With no data from the target class and the sampled data from other classes
val_i[0] = 0.0
try:
clf.fit(trnX_nonlabel_i, trnY_nonlabel_i)
val_i_non[0] = accuracy_score(devY_nonlabel, clf.predict(devX_nonlabel), normalize=False)/len(devY)
except ValueError:
# In the sampled trnY_nonlabel_i, there is only one class
# print("One class in the training set")
val_i_non[0] = accuracy_score(devY_nonlabel, [trnY_nonlabel_i[0]]*len(devY_nonlabel),
normalize=False)/len(devY)

# ---------------------
# With all data from the target class and the sampled data from other classes
tempX = np.concatenate((trnX_nonlabel_i, trnX_label))
tempY = np.concatenate((trnY_nonlabel_i, trnY_label))
clf.fit(tempX, tempY)
val_i[N] = accuracy_score(devY_label, clf.predict(devX_label), normalize=False)/len(devY)
val_i_non[N] = accuracy_score(devY_nonlabel, clf.predict(devX_nonlabel), normalize=False)/len(devY)

# --------------------
#
for j in range(1,N+1):
if abs(val_i[N] - val_i[j-1]) < epsilon:
val_i[j] = val_i[j-1]
else:
# Extract the first $j$ data points
trnX_j = trnX_label[Idx[:j],:]
trnY_j = trnY_label[Idx[:j]]
try:
# ---------------------------------
tempX = np.concatenate((trnX_nonlabel_i, trnX_j))
tempY = np.concatenate((trnY_nonlabel_i, trnY_j))
clf.fit(tempX, tempY)
val_i[j] = accuracy_score(devY_label, clf.predict(devX_label), normalize=False)/len(devY)
val_i_non[j] = accuracy_score(devY_nonlabel, clf.predict(devX_nonlabel), normalize=False)/len(devY)
except ValueError: # This should never happen in this algorithm
print("Only one class in the dataset")
# print(tempY)
return (None, None, None)
# ==========================================
# New implementation
wvalues = np.exp(val_i_non) * val_i
# print("wvalues = {}".format(wvalues))
diff = wvalues[1:] - wvalues[:N]
val[Idx] = ((1.0*(k-1)/k))*val[Idx] + (1.0/k)*(diff)


# Whether normalize the scores within the class
if normalized_score:
val = val/val.sum()
clf.fit(trnX, trnY)
score = accuracy_score(devY_label, clf.predict(devX_label), normalize=False)/len(devY)
print("score = {}".format(score))
val = val * score
return val, orig_indices
50 changes: 50 additions & 0 deletions src/valda/eval.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
## data_removal.py
## Evaluate the performance of data valuation by removing one data point
## at a time from the training set

from sklearn.metrics import accuracy_score, auc
from sklearn.linear_model import LogisticRegression as LR

import operator

def data_removal(vals, trnX, trnY, tstX, tstY, clf=None,
remove_high_value=True):
'''
trnX, trnY - training examples
tstX, tstY - test examples
vals - a Python dict that contains data indices and values
clf - the classifier that will be used for evaluation
'''
# Create data indices for data removal
N = trnX.shape[0]
Idx_keep = [True]*N

if clf is None:
clf = LR(solver="liblinear", max_iter=500, random_state=0)
# Sorted the data indices with a descreasing order
sorted_dct = sorted(vals.items(), key=operator.itemgetter(1), reverse=True)
# Accuracy list
accs = []
if remove_high_value:
lst = range(N)
else:
lst = range(N-1, -1, -1)
# Compute
clf.fit(trnX, trnY)
acc = accuracy_score(clf.predict(tstX), tstY)
accs.append(acc)
for k in lst:
# print(k)
Idx_keep[sorted_dct[k][0]] = False
trnX_k = trnX[Idx_keep, :]
trnY_k = trnY[Idx_keep]
try:
clf.fit(trnX_k, trnY_k)
# print('trnX_k.shape = {}'.format(trnX_k.shape))
acc = accuracy_score(clf.predict(tstX), tstY)
# print('acc = {}'.format(acc))
accs.append(acc)
except ValueError:
# print("Training with data from a single class")
accs.append(0.0)
return accs
Loading

0 comments on commit bf2dc22

Please sign in to comment.