Skip to content

Commit

Permalink
disable on GPU since it's not implemented there
Browse files Browse the repository at this point in the history
  • Loading branch information
david-cortes-intel committed Oct 11, 2024
1 parent 394fd67 commit 469ec0f
Showing 1 changed file with 7 additions and 6 deletions.
13 changes: 7 additions & 6 deletions sklearnex/linear_model/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

import logging
from abc import ABC
from functools import partialmethod

import numpy as np
from sklearn.exceptions import NotFittedError
Expand Down Expand Up @@ -140,7 +141,7 @@ def score(self, X, y, sample_weight=None):
sample_weight=sample_weight,
)

def _onedal_fit_supported(self, method_name, *data):
def _onedal_fit_supported(self, is_gpu, method_name, *data):
assert method_name == "fit"
assert len(data) == 3
X, y, sample_weight = data
Expand Down Expand Up @@ -179,7 +180,7 @@ def _onedal_fit_supported(self, method_name, *data):
"Forced positive coefficients are not supported.",
),
(
not (is_underdetermined and not supports_all_variants),
not (is_underdetermined and (not supports_all_variants or is_gpu)),
"The shape of X (fitting) does not satisfy oneDAL requirements:"
"Number of features + 1 >= number of samples.",
),
Expand Down Expand Up @@ -212,15 +213,15 @@ def _onedal_predict_supported(self, method_name, *data):

return patching_status

def _onedal_supported(self, method_name, *data):
def _onedal_supported(self, is_gpu, method_name, *data):
if method_name == "fit":
return self._onedal_fit_supported(method_name, *data)
return self._onedal_fit_supported(is_gpu, method_name, *data)
if method_name in ["predict", "score"]:
return self._onedal_predict_supported(method_name, *data)
raise RuntimeError(f"Unknown method {method_name} in {self.__class__.__name__}")

_onedal_gpu_supported = _onedal_supported
_onedal_cpu_supported = _onedal_supported
_onedal_gpu_supported = partialmethod(_onedal_supported, True)
_onedal_cpu_supported = partialmethod(_onedal_supported, False)

def _initialize_onedal_estimator(self):
onedal_params = {"fit_intercept": self.fit_intercept, "copy_X": self.copy_X}
Expand Down

0 comments on commit 469ec0f

Please sign in to comment.