Skip to content

Commit

Permalink
Fix: pyright error: Type "floating[Any]" is not assignable to return …
Browse files Browse the repository at this point in the history
…type "ndarray[Unknown, Unknown]" (#765)
  • Loading branch information
wanghan-iapcm authored Dec 23, 2024
1 parent 5423efe commit a496b6f
Show file tree
Hide file tree
Showing 5 changed files with 13 additions and 11 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ on:

jobs:
build:
runs-on: ubuntu-latest
runs-on: ubuntu-22.04
strategy:
matrix:
python-version: ["3.7", "3.8", "3.12"]
Expand Down
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ repos:
# Python
- repo: https://github.com/astral-sh/ruff-pre-commit
# Ruff version.
rev: v0.8.1
rev: v0.8.2
hooks:
- id: ruff
args: ["--fix"]
Expand Down
17 changes: 9 additions & 8 deletions dpdata/stat.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,14 @@

from abc import ABCMeta, abstractmethod
from functools import lru_cache
from typing import Any

import numpy as np

from dpdata.system import LabeledSystem, MultiSystems


def mae(errors: np.ndarray) -> np.float64:
def mae(errors: np.ndarray) -> np.floating[Any]:
"""Compute the mean absolute error (MAE).
Parameters
Expand All @@ -18,13 +19,13 @@ def mae(errors: np.ndarray) -> np.float64:
Returns
-------
np.float64
floating[Any]
mean absolute error (MAE)
"""
return np.mean(np.abs(errors))


def rmse(errors: np.ndarray) -> np.float64:
def rmse(errors: np.ndarray) -> np.floating[Any]:
"""Compute the root mean squared error (RMSE).
Parameters
Expand All @@ -34,7 +35,7 @@ def rmse(errors: np.ndarray) -> np.float64:
Returns
-------
np.float64
floating[Any]
root mean squared error (RMSE)
"""
return np.sqrt(np.mean(np.square(errors)))
Expand Down Expand Up @@ -74,22 +75,22 @@ def f_errors(self) -> np.ndarray:
"""Force errors."""

@property
def e_mae(self) -> np.float64:
def e_mae(self) -> np.floating[Any]:
"""Energy MAE."""
return mae(self.e_errors)

@property
def e_rmse(self) -> np.float64:
def e_rmse(self) -> np.floating[Any]:
"""Energy RMSE."""
return rmse(self.e_errors)

@property
def f_mae(self) -> np.float64:
def f_mae(self) -> np.floating[Any]:
"""Force MAE."""
return mae(self.f_errors)

@property
def f_rmse(self) -> np.float64:
def f_rmse(self) -> np.floating[Any]:
"""Force RMSE."""
return rmse(self.f_errors)

Expand Down
1 change: 1 addition & 0 deletions dpdata/system.py
Original file line number Diff line number Diff line change
Expand Up @@ -1049,6 +1049,7 @@ def remove_atom_names(self, atom_names: str | list[str]):
atom_idx = self.data["atom_types"] == idx
removed_atom_idx.append(atom_idx)
picked_atom_idx = ~np.any(removed_atom_idx, axis=0)
assert not isinstance(picked_atom_idx, np.bool_)
new_sys = self.pick_atom_idx(picked_atom_idx)
# let's remove atom_names
# firstly, rearrange atom_names and put these atom_names in the end
Expand Down
2 changes: 1 addition & 1 deletion tests/test_abacus_pw_scf.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,7 @@ def test_noforcestress_job(self):
# check below will not throw error
system_ch4 = dpdata.LabeledSystem("abacus.scf", fmt="abacus/scf")
# check the returned force is empty
self.assertFalse(system_ch4.data["forces"])
self.assertFalse(system_ch4.data["forces"].size)
self.assertTrue("virials" not in system_ch4.data)
# test append self
system_ch4.append(system_ch4)
Expand Down

0 comments on commit a496b6f

Please sign in to comment.