66import neural_tangents
77from neural_tangents import stax
88
9- from pkg_resources import parse_version
10- if parse_version (neural_tangents .__version__ ) >= parse_version ('0.5.0' ):
11- from neural_tangents ._src .utils import utils , dataclasses
12- from neural_tangents ._src .stax .linear import _pool_kernel , Padding
13- from neural_tangents ._src .stax .linear import _Pooling as Pooling
14- else :
15- from neural_tangents .utils import utils , dataclasses
16- from neural_tangents .stax import _pool_kernel , Padding , Pooling
17-
18- from sketching import TensorSRHT2 , PolyTensorSRHT
9+ # from pkg_resources import parse_version
10+ # if parse_version(neural_tangents.__version__) >= parse_version('0.5.0'):
11+ # from neural_tangents._src.utils import utils, dataclasses
12+ # from neural_tangents._src.stax.linear import _pool_kernel, Padding
13+ # from neural_tangents._src.stax.linear import _Pooling as Pooling
14+ # else:
15+ # from neural_tangents.utils import utils, dataclasses
16+ # from neural_tangents.stax import _pool_kernel, Padding, Pooling
17+ from neural_tangents ._src .utils import dataclasses
18+ # from neural_tangents._src.utils.typing import Optional
19+ from typing import Optional
20+ from neural_tangents ._src .stax .linear import _pool_kernel , Padding
21+ from neural_tangents ._src .stax .linear import _Pooling as Pooling
22+
23+ from experimental .sketching import TensorSRHT2 , PolyTensorSRHT
1924""" Implementation for NTK Sketching and Random Features """
2025
2126
@@ -50,13 +55,13 @@ def kappa1(x):
5055
5156@dataclasses .dataclass
5257class Features :
53- nngp_feat : np .ndarray
54- ntk_feat : np .ndarray
58+ nngp_feat : Optional [ np .ndarray ] = None
59+ ntk_feat : Optional [ np .ndarray ] = None
5560
5661 batch_axis : int = dataclasses .field (pytree_node = False )
5762 channel_axis : int = dataclasses .field (pytree_node = False )
5863
59- replace = ... # type: Callable[..., 'Features']
64+ replace = ...
6065
6166
6267def _inputs_to_features (x : np .ndarray ,
@@ -69,7 +74,7 @@ def _inputs_to_features(x: np.ndarray,
6974 nngp_feat = x / x .shape [channel_axis ]** 0.5
7075 ntk_feat = np .empty ((), dtype = nngp_feat .dtype )
7176
72- return Features (nngp_feat = nngp_feat ,
77+ return Features . replace (nngp_feat = nngp_feat ,
7378 ntk_feat = ntk_feat ,
7479 batch_axis = batch_axis ,
7580 channel_axis = channel_axis )
0 commit comments