Skip to content

Commit 04c11cc

Browse files
authored
[BUG] Remove squaring distances in KNN regression (#1697)
* remove unnecessary ExponentTransform import * remove squaring
1 parent caf86f2 commit 04c11cc

File tree

2 files changed

+2
-3
lines changed

2 files changed

+2
-3
lines changed

aeon/regression/distance_based/_time_series_neighbors.py

-1
Original file line numberDiff line numberDiff line change
@@ -172,7 +172,6 @@ def _kneighbors(self, X):
172172

173173
if self.weights == "distance":
174174
ws = distances[closest_idx]
175-
ws = ws**2
176175

177176
# Using epsilon ~= 0 to avoid division by zero
178177
ws = 1 / (ws + np.finfo(float).eps)

aeon/regression/distance_based/tests/test_time_series_neighbors.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,6 @@ def test_knn_neighbors():
2020
model.fit(X_train, y_train)
2121

2222
y_pred = model.predict(X_test)
23-
y_pred_expected = np.array([-216.06541863, -4.54133078, -324.7624233])
23+
y_pred_expected = np.array([-144.410377008, -25.55876587, -229.9764678])
2424

25-
assert np.abs(y_pred - y_pred_expected).max() < 1e-6
25+
assert np.abs(y_pred - y_pred_expected).max() < 1e-4

0 commit comments

Comments
 (0)