-
Notifications
You must be signed in to change notification settings - Fork 1
/
modal_learner.py
112 lines (93 loc) · 3.57 KB
/
modal_learner.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
from typing import Optional, Callable
import warnings
import scipy
import numpy as np
from sklearn.base import BaseEstimator
from modAL import ActiveLearner
from modAL.uncertainty import uncertainty_sampling
from modAL.utils.data import modALinput
from modAL.utils.data import data_vstack
class IndexLearner(ActiveLearner):
"""
Active learner which utilizes index sets instead of array modifications.
"""
def __init__(
self,
estimator: BaseEstimator,
X_training: Optional[modALinput],
y_training: Optional[modALinput],
X_unlabelled: Optional[modALinput],
y_unlabelled: Optional[modALinput],
query_strategy: Callable = uncertainty_sampling,
bootstrap_init: bool = False,
on_transformed: bool = False,
**fit_kwargs,
) -> None:
self._X_unlabelled = X_unlabelled
self._y_unlabelled = y_unlabelled
# See https://github.com/modAL-python/modAL/issues/103
self.bootstrap_init = bootstrap_init
self.taught_idx = np.array([], dtype=int)
super().__init__(
estimator,
query_strategy,
X_training,
y_training,
bootstrap_init,
on_transformed,
**fit_kwargs,
)
@property
def X_training(self):
return data_vstack((self._X_training, self._X_unlabelled[self.taught_idx]))
@property
def y_training(self):
return np.concatenate((self._y_training, self._y_unlabelled[self.taught_idx]))
@property
def X_unlabelled(self):
mask = np.ones(self._X_unlabelled.shape[0], dtype=bool)
mask[self.taught_idx] = False
return self._X_unlabelled[mask]
@property
def y_unlabelled(self):
mask = np.ones(self._X_unlabelled.shape[0], dtype=bool)
mask[self.taught_idx] = False
return self._y_unlabelled[mask]
@X_training.setter
def X_training(self, X):
self._X_training = X
@y_training.setter
def y_training(self, y):
self._y_training = y
def teach(
self, query_idx: modALinput, bootstrap: bool = False, **fit_kwargs
) -> None:
# assert one dimensional array
assert len(query_idx.shape) == 1
# Assert non-overlapping index sets
overlap = np.in1d(self.taught_idx, query_idx)
if np.count_nonzero(overlap) != 0:
raise Exception(
"Attempt to add an example to the training pool which has already been learnt."
f"\nThe examples at indexes {self.taught_idx[overlap]} exist at indexes {np.where(overlap)[0]}."
f"\nThere are currently {len(self.taught_idx)} learnt examples"
)
assert np.count_nonzero(overlap) == 0, str(
np.count_nonzero(np.in1d(self.taught_idx, query_idx))
)
self.taught_idx = np.concatenate((self.taught_idx, query_idx))
self.estimator.fit(self.X_training, self.y_training, **fit_kwargs)
def __getstate__(self):
state = self.__dict__.copy()
clear_attrs = ["_X_training", "_y_training", "_X_unlabelled", "_y_unlabelled"]
for attr in clear_attrs:
if attr in state:
del state[attr]
for k, v in state.items():
if isinstance(v, scipy.sparse.csr_matrix):
raise Exception(f"Serialized learner has a sparse matrix field {k}")
return state
def __setstate__(self, state):
# Pools are restored by compressedstore on load.
# Or by MyActiveLearner in the case of the active classifier.
self.__dict__.update(state)