diff --git a/jax/_src/nn/initializers.py b/jax/_src/nn/initializers.py index 25ce3632b0aa..25ecb4013d1f 100644 --- a/jax/_src/nn/initializers.py +++ b/jax/_src/nn/initializers.py @@ -222,9 +222,12 @@ def _compute_fans(shape: Sequence[int], Axes not in in_axis, out_axis, or batch_axis are assumed to constitute the "receptive field" of a convolution (kernel spatial dimensions). """ - if len(shape) <= 1: - raise ValueError(f"Can't compute input and output sizes of a {len(shape)}" - "-dimensional weights tensor. Must be at least 2D.") + if in_axis == -2 and len(shape) <= 1: + raise ValueError( + f"Can't compute input and output sizes of a {len(shape)}-dimensional" + " weights tensor with default in_axis. Must be at least 2D or specify" + " in_axis explicitly." + ) if isinstance(in_axis, int): in_size = shape[in_axis] diff --git a/tests/nn_test.py b/tests/nn_test.py index 1c4261451cc0..7671721665b4 100644 --- a/tests/nn_test.py +++ b/tests/nn_test.py @@ -781,7 +781,7 @@ def testLog1mExpGrad(self): "InitializerRecord", ["name", "initializer", "shapes", "dtypes"]) -ALL_SHAPES = [(2,), (2, 2), (2, 3), (3, 2), (2, 3, 4), (4, 3, 2), (2, 3, 4, 5)] +ALL_SHAPES = [(), (2,), (2, 2), (2, 3), (3, 2), (2, 3, 4), (4, 3, 2), (2, 3, 4, 5)] def initializer_record(name, initializer, dtypes, min_dims=2, max_dims=4): shapes = [shape for shape in ALL_SHAPES @@ -805,6 +805,24 @@ def initializer_record(name, initializer, dtypes, min_dims=2, max_dims=4): partial(nn.initializers.variance_scaling, 1, "fan_geo_avg", "normal"), jtu.dtypes.floating, ), + initializer_record( + "variance_scaling_fan_in", + partial(nn.initializers.variance_scaling, 1, "fan_in", "normal", in_axis=[0], out_axis=[]), + jtu.dtypes.floating, + min_dims=1, + ), + initializer_record( + "variance_scaling_fan_in", + partial(nn.initializers.variance_scaling, 1, "fan_in", "normal", in_axis=[], out_axis=[0]), + jtu.dtypes.floating, + min_dims=1, + ), + initializer_record( + "variance_scaling_fan_in", + partial(nn.initializers.variance_scaling, 1, "fan_in", "normal", in_axis=[], out_axis=[]), + jtu.dtypes.floating, + min_dims=0, + ), ] @@ -869,8 +887,9 @@ def testVarianceScalingError(self): with self.assertRaisesRegex( ValueError, - "Can't compute input and output sizes of a 1" - "-dimensional weights tensor. Must be at least 2D." + "Can't compute input and output sizes of a 1-dimensional" + " weights tensor with default in_axis. Must be at least 2D or specify" + " in_axis explicitly.", ): initializer(rng, shape)