38
38
)
39
39
40
40
41
- def sigmoid_t (x ): # noqa: ANN001, ANN201
41
+ def sigmoid_t (x ):
42
42
"""Sigmoid."""
43
43
if array_api_compat .is_jax_array (x ):
44
44
from deepmd .jax .env import (
@@ -55,7 +55,7 @@ class Identity(NativeOP):
55
55
def __init__ (self ) -> None :
56
56
super ().__init__ ()
57
57
58
- def call (self , x ): # noqa: ANN001, ANN201
58
+ def call (self , x ):
59
59
"""The Identity operation layer."""
60
60
return x
61
61
@@ -260,7 +260,7 @@ def dim_out(self) -> int:
260
260
return self .w .shape [1 ]
261
261
262
262
@support_array_api (version = "2022.12" )
263
- def call (self , x ): # noqa: ANN001, ANN201
263
+ def call (self , x ):
264
264
"""Forward pass.
265
265
266
266
Parameters
@@ -301,22 +301,22 @@ def get_activation_fn(activation_function: str) -> Callable[[np.ndarray], np.nda
301
301
activation_function = activation_function .lower ()
302
302
if activation_function == "tanh" :
303
303
304
- def fn (x ): # noqa: ANN001, ANN202 # noqa: ANN001, ANN202
304
+ def fn (x ):
305
305
xp = array_api_compat .array_namespace (x )
306
306
return xp .tanh (x )
307
307
308
308
return fn
309
309
elif activation_function == "relu" :
310
310
311
- def fn (x ): # noqa: ANN001, ANN202
311
+ def fn (x ):
312
312
xp = array_api_compat .array_namespace (x )
313
313
# https://stackoverflow.com/a/47936476/9567349
314
314
return x * xp .astype (x > 0 , x .dtype )
315
315
316
316
return fn
317
317
elif activation_function in ("gelu" , "gelu_tf" ):
318
318
319
- def fn (x ): # noqa: ANN001, ANN202
319
+ def fn (x ):
320
320
xp = array_api_compat .array_namespace (x )
321
321
# generated by GitHub Copilot
322
322
return (
@@ -328,7 +328,7 @@ def fn(x): # noqa: ANN001, ANN202
328
328
return fn
329
329
elif activation_function == "relu6" :
330
330
331
- def fn (x ): # noqa: ANN001, ANN202
331
+ def fn (x ):
332
332
xp = array_api_compat .array_namespace (x )
333
333
# generated by GitHub Copilot
334
334
return xp .where (
@@ -338,22 +338,22 @@ def fn(x): # noqa: ANN001, ANN202
338
338
return fn
339
339
elif activation_function == "softplus" :
340
340
341
- def fn (x ): # noqa: ANN001, ANN202
341
+ def fn (x ):
342
342
xp = array_api_compat .array_namespace (x )
343
343
# generated by GitHub Copilot
344
344
return xp .log (1 + xp .exp (x ))
345
345
346
346
return fn
347
347
elif activation_function == "sigmoid" :
348
348
349
- def fn (x ): # noqa: ANN001, ANN202
349
+ def fn (x ):
350
350
# generated by GitHub Copilot
351
351
return sigmoid_t (x )
352
352
353
353
return fn
354
354
elif activation_function == "silu" :
355
355
356
- def fn (x ): # noqa: ANN001, ANN202
356
+ def fn (x ):
357
357
# generated by GitHub Copilot
358
358
return x * sigmoid_t (x )
359
359
@@ -362,13 +362,13 @@ def fn(x): # noqa: ANN001, ANN202
362
362
"custom_silu"
363
363
):
364
364
365
- def sigmoid (x ): # noqa: ANN001, ANN202
365
+ def sigmoid (x ):
366
366
return 1 / (1 + np .exp (- x ))
367
367
368
- def silu (x ): # noqa: ANN001, ANN202
368
+ def silu (x ):
369
369
return x * sigmoid (x )
370
370
371
- def silu_grad (x ): # noqa: ANN001, ANN202
371
+ def silu_grad (x ):
372
372
sig = sigmoid (x )
373
373
return sig + x * sig * (1 - sig )
374
374
@@ -380,7 +380,7 @@ def silu_grad(x): # noqa: ANN001, ANN202
380
380
slope = float (silu_grad (threshold ))
381
381
const = float (silu (threshold ))
382
382
383
- def fn (x ): # noqa: ANN001, ANN202
383
+ def fn (x ):
384
384
xp = array_api_compat .array_namespace (x )
385
385
return xp .where (
386
386
x < threshold ,
@@ -391,7 +391,7 @@ def fn(x): # noqa: ANN001, ANN202
391
391
return fn
392
392
elif activation_function .lower () in ("none" , "linear" ):
393
393
394
- def fn (x ): # noqa: ANN001, ANN202
394
+ def fn (x ):
395
395
return x
396
396
397
397
return fn
@@ -535,7 +535,7 @@ def __getitem__(self, key: str) -> Any:
535
535
def dim_out (self ) -> int :
536
536
return self .w .shape [0 ]
537
537
538
- def call (self , x ): # noqa: ANN001, ANN201
538
+ def call (self , x ):
539
539
"""Forward pass.
540
540
541
541
Parameters
@@ -552,11 +552,11 @@ def call(self, x): # noqa: ANN001, ANN201
552
552
return y
553
553
554
554
@staticmethod
555
- def layer_norm_numpy ( # noqa: ANN205
556
- x , # noqa: ANN001
555
+ def layer_norm_numpy (
556
+ x ,
557
557
shape : tuple [int , ...],
558
- weight = None , # noqa: ANN001
559
- bias = None , # noqa: ANN001
558
+ weight = None ,
559
+ bias = None ,
560
560
eps : float = 1e-5 ,
561
561
):
562
562
xp = array_api_compat .array_namespace (x )
@@ -633,7 +633,7 @@ def check_shape_consistency(self) -> None:
633
633
f"output { self .layers [ii ].dim_out } " ,
634
634
)
635
635
636
- def call (self , x ): # noqa: ANN001, ANN202
636
+ def call (self , x ):
637
637
"""Forward pass.
638
638
639
639
Parameters
@@ -650,7 +650,7 @@ def call(self, x): # noqa: ANN001, ANN202
650
650
x = layer (x )
651
651
return x
652
652
653
- def call_until_last (self , x ): # noqa: ANN001, ANN202
653
+ def call_until_last (self , x ):
654
654
"""Return the output before last layer.
655
655
656
656
Parameters
@@ -1025,9 +1025,9 @@ def deserialize(cls, data: dict) -> "NetworkCollection":
1025
1025
return cls (** data )
1026
1026
1027
1027
1028
- def aggregate ( # noqa: ANN201
1029
- data , # noqa: ANN001
1030
- owners , # noqa: ANN001
1028
+ def aggregate (
1029
+ data ,
1030
+ owners ,
1031
1031
average : bool = True ,
1032
1032
num_owner : Optional [int ] = None ,
1033
1033
):
@@ -1065,10 +1065,10 @@ def aggregate( # noqa: ANN201
1065
1065
return output
1066
1066
1067
1067
1068
- def get_graph_index ( # noqa: ANN201
1069
- nlist , # noqa: ANN001
1070
- nlist_mask , # noqa: ANN001
1071
- a_nlist_mask , # noqa: ANN001
1068
+ def get_graph_index (
1069
+ nlist ,
1070
+ nlist_mask ,
1071
+ a_nlist_mask ,
1072
1072
nall : int ,
1073
1073
use_loc_mapping : bool = True ,
1074
1074
):
0 commit comments