Skip to content

Commit

Permalink
rework and address review
Browse files Browse the repository at this point in the history
  • Loading branch information
divyegala committed Sep 4, 2024
1 parent dac3588 commit 3e4d088
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 23 deletions.
46 changes: 25 additions & 21 deletions python/cuml/cuml/manifold/umap.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -857,34 +857,38 @@ class UMAP(UniversalBase,
del X_m
return embedding

def __getstate__(self):
state = self.__dict__.copy()
@property
def _n_neighbors(self):
return self.n_neighbors

if "handle" in state:
del state["handle"]
@property
def _a(self):
return self.a

state['_n_neighbors'] = self.n_neighbors
state['_a'] = self.a
state['_b'] = self.b
state['_initial_alpha'] = self.learning_rate
state['_disconnection_distance'] = DISCONNECTION_DISTANCES.get(self.metric, np.inf)
@property
def _b(self):
return self.b

@property
def _initial_alpha(self):
return self.learning_rate

@property
def _disconnection_distance(self):
return DISCONNECTION_DISTANCES.get(self.metric, np.inf)

def gpu_to_cpu(self):
if hasattr(self, 'knn_dists') and hasattr(self, 'knn_indices'):
state['_knn_dists'] = self.knn_dists.to_output('numpy')
state['_knn_indices'] = self.knn_indices.to_output('numpy')
state['_knn_search_index'] = None
self._knn_dists = self.knn_dists
self._knn_indices = self.knn_indices
self._knn_search_index = None
elif hasattr(self, '_raw_data'):
host_raw_data = self._raw_data.to_output('numpy')
state['_knn_dists'], state['_knn_indices'], state['_knn_search_index'] = \
nearest_neighbors(host_raw_data, self.n_neighbors, self.metric,
self._raw_data = self._raw_data.to_output('numpy')
self._knn_dists, self._knn_indices, self._knn_search_index = \
nearest_neighbors(self._raw_data, self.n_neighbors, self.metric,
self.metric_kwds, False, self.random_state)

return state

def __setstate__(self, state):
super(UMAP, self).__init__(handle=None,
verbose=state["verbose"])
self.__dict__.update(state)
super().gpu_to_cpu()

def get_param_names(self):
return super().get_param_names() + [
Expand Down
2 changes: 0 additions & 2 deletions python/cuml/cuml/tests/test_device_selection.py
Original file line number Diff line number Diff line change
Expand Up @@ -596,8 +596,6 @@ def test_train_cpu_infer_cpu(test_data):

def test_train_gpu_infer_cpu(test_data):
cuEstimator = test_data["cuEstimator"]
if cuEstimator is UMAP:
pytest.skip("UMAP GPU training CPU inference not yet implemented")

model = cuEstimator(**test_data["kwargs"])
with using_device_type("gpu"):
Expand Down

0 comments on commit 3e4d088

Please sign in to comment.