Skip to content

Commit 71b8fc4

Browse files
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
1 parent a132dbf commit 71b8fc4

File tree

2 files changed

+37
-37
lines changed

2 files changed

+37
-37
lines changed

deepmd/dpmodel/output_def.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -267,7 +267,7 @@ def __getitem__(
267267
def get_data(self) -> dict[str, OutputVariableDef]:
268268
return self.var_defs
269269

270-
def keys(self): # noqa: ANN201
270+
def keys(self):
271271
return self.var_defs.keys()
272272

273273

@@ -319,25 +319,25 @@ def get_data(
319319
) -> dict[str, OutputVariableDef]:
320320
return self.var_defs
321321

322-
def keys(self): # noqa: ANN201
322+
def keys(self):
323323
return self.var_defs.keys()
324324

325-
def keys_outp(self): # noqa: ANN201
325+
def keys_outp(self):
326326
return self.def_outp.keys()
327327

328-
def keys_redu(self): # noqa: ANN201
328+
def keys_redu(self):
329329
return self.def_redu.keys()
330330

331-
def keys_derv_r(self): # noqa: ANN201
331+
def keys_derv_r(self):
332332
return self.def_derv_r.keys()
333333

334-
def keys_hess_r(self): # noqa: ANN201
334+
def keys_hess_r(self):
335335
return self.def_hess_r.keys()
336336

337-
def keys_derv_c(self): # noqa: ANN201
337+
def keys_derv_c(self):
338338
return self.def_derv_c.keys()
339339

340-
def keys_derv_c_redu(self): # noqa: ANN201
340+
def keys_derv_c_redu(self):
341341
return self.def_derv_c_redu.keys()
342342

343343

deepmd/dpmodel/utils/network.py

Lines changed: 29 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@
3838
)
3939

4040

41-
def sigmoid_t(x): # noqa: ANN001, ANN201
41+
def sigmoid_t(x):
4242
"""Sigmoid."""
4343
if array_api_compat.is_jax_array(x):
4444
from deepmd.jax.env import (
@@ -55,7 +55,7 @@ class Identity(NativeOP):
5555
def __init__(self) -> None:
5656
super().__init__()
5757

58-
def call(self, x): # noqa: ANN001, ANN201
58+
def call(self, x):
5959
"""The Identity operation layer."""
6060
return x
6161

@@ -260,7 +260,7 @@ def dim_out(self) -> int:
260260
return self.w.shape[1]
261261

262262
@support_array_api(version="2022.12")
263-
def call(self, x): # noqa: ANN001, ANN201
263+
def call(self, x):
264264
"""Forward pass.
265265
266266
Parameters
@@ -301,22 +301,22 @@ def get_activation_fn(activation_function: str) -> Callable[[np.ndarray], np.nda
301301
activation_function = activation_function.lower()
302302
if activation_function == "tanh":
303303

304-
def fn(x): # noqa: ANN001, ANN202 # noqa: ANN001, ANN202
304+
def fn(x):
305305
xp = array_api_compat.array_namespace(x)
306306
return xp.tanh(x)
307307

308308
return fn
309309
elif activation_function == "relu":
310310

311-
def fn(x): # noqa: ANN001, ANN202
311+
def fn(x):
312312
xp = array_api_compat.array_namespace(x)
313313
# https://stackoverflow.com/a/47936476/9567349
314314
return x * xp.astype(x > 0, x.dtype)
315315

316316
return fn
317317
elif activation_function in ("gelu", "gelu_tf"):
318318

319-
def fn(x): # noqa: ANN001, ANN202
319+
def fn(x):
320320
xp = array_api_compat.array_namespace(x)
321321
# generated by GitHub Copilot
322322
return (
@@ -328,7 +328,7 @@ def fn(x): # noqa: ANN001, ANN202
328328
return fn
329329
elif activation_function == "relu6":
330330

331-
def fn(x): # noqa: ANN001, ANN202
331+
def fn(x):
332332
xp = array_api_compat.array_namespace(x)
333333
# generated by GitHub Copilot
334334
return xp.where(
@@ -338,22 +338,22 @@ def fn(x): # noqa: ANN001, ANN202
338338
return fn
339339
elif activation_function == "softplus":
340340

341-
def fn(x): # noqa: ANN001, ANN202
341+
def fn(x):
342342
xp = array_api_compat.array_namespace(x)
343343
# generated by GitHub Copilot
344344
return xp.log(1 + xp.exp(x))
345345

346346
return fn
347347
elif activation_function == "sigmoid":
348348

349-
def fn(x): # noqa: ANN001, ANN202
349+
def fn(x):
350350
# generated by GitHub Copilot
351351
return sigmoid_t(x)
352352

353353
return fn
354354
elif activation_function == "silu":
355355

356-
def fn(x): # noqa: ANN001, ANN202
356+
def fn(x):
357357
# generated by GitHub Copilot
358358
return x * sigmoid_t(x)
359359

@@ -362,13 +362,13 @@ def fn(x): # noqa: ANN001, ANN202
362362
"custom_silu"
363363
):
364364

365-
def sigmoid(x): # noqa: ANN001, ANN202
365+
def sigmoid(x):
366366
return 1 / (1 + np.exp(-x))
367367

368-
def silu(x): # noqa: ANN001, ANN202
368+
def silu(x):
369369
return x * sigmoid(x)
370370

371-
def silu_grad(x): # noqa: ANN001, ANN202
371+
def silu_grad(x):
372372
sig = sigmoid(x)
373373
return sig + x * sig * (1 - sig)
374374

@@ -380,7 +380,7 @@ def silu_grad(x): # noqa: ANN001, ANN202
380380
slope = float(silu_grad(threshold))
381381
const = float(silu(threshold))
382382

383-
def fn(x): # noqa: ANN001, ANN202
383+
def fn(x):
384384
xp = array_api_compat.array_namespace(x)
385385
return xp.where(
386386
x < threshold,
@@ -391,7 +391,7 @@ def fn(x): # noqa: ANN001, ANN202
391391
return fn
392392
elif activation_function.lower() in ("none", "linear"):
393393

394-
def fn(x): # noqa: ANN001, ANN202
394+
def fn(x):
395395
return x
396396

397397
return fn
@@ -535,7 +535,7 @@ def __getitem__(self, key: str) -> Any:
535535
def dim_out(self) -> int:
536536
return self.w.shape[0]
537537

538-
def call(self, x): # noqa: ANN001, ANN201
538+
def call(self, x):
539539
"""Forward pass.
540540
541541
Parameters
@@ -552,11 +552,11 @@ def call(self, x): # noqa: ANN001, ANN201
552552
return y
553553

554554
@staticmethod
555-
def layer_norm_numpy( # noqa: ANN205
556-
x, # noqa: ANN001
555+
def layer_norm_numpy(
556+
x,
557557
shape: tuple[int, ...],
558-
weight=None, # noqa: ANN001
559-
bias=None, # noqa: ANN001
558+
weight=None,
559+
bias=None,
560560
eps: float = 1e-5,
561561
):
562562
xp = array_api_compat.array_namespace(x)
@@ -633,7 +633,7 @@ def check_shape_consistency(self) -> None:
633633
f"output {self.layers[ii].dim_out}",
634634
)
635635

636-
def call(self, x): # noqa: ANN001, ANN202
636+
def call(self, x):
637637
"""Forward pass.
638638
639639
Parameters
@@ -650,7 +650,7 @@ def call(self, x): # noqa: ANN001, ANN202
650650
x = layer(x)
651651
return x
652652

653-
def call_until_last(self, x): # noqa: ANN001, ANN202
653+
def call_until_last(self, x):
654654
"""Return the output before last layer.
655655
656656
Parameters
@@ -1025,9 +1025,9 @@ def deserialize(cls, data: dict) -> "NetworkCollection":
10251025
return cls(**data)
10261026

10271027

1028-
def aggregate( # noqa: ANN201
1029-
data, # noqa: ANN001
1030-
owners, # noqa: ANN001
1028+
def aggregate(
1029+
data,
1030+
owners,
10311031
average: bool = True,
10321032
num_owner: Optional[int] = None,
10331033
):
@@ -1065,10 +1065,10 @@ def aggregate( # noqa: ANN201
10651065
return output
10661066

10671067

1068-
def get_graph_index( # noqa: ANN201
1069-
nlist, # noqa: ANN001
1070-
nlist_mask, # noqa: ANN001
1071-
a_nlist_mask, # noqa: ANN001
1068+
def get_graph_index(
1069+
nlist,
1070+
nlist_mask,
1071+
a_nlist_mask,
10721072
nall: int,
10731073
use_loc_mapping: bool = True,
10741074
):

0 commit comments

Comments
 (0)