@@ -19,11 +19,12 @@ import Grenade.Core
19
19
import Grenade.Layers.Convolution
20
20
21
21
import Hedgehog
22
+ import Hedgehog.Gen ( Gen )
22
23
import qualified Hedgehog.Gen as Gen
23
24
24
- import Test.Jack .Hmatrix
25
- import Test.Jack .TypeLits
26
- import Test.Jack .Compat
25
+ import Test.Hedgehog .Hmatrix
26
+ import Test.Hedgehog .TypeLits
27
+ import Test.Hedgehog .Compat
27
28
28
29
data OpaqueConvolution :: * where
29
30
OpaqueConvolution :: Convolution channels filters kernelRows kernelColumns strideRows strideColumns -> OpaqueConvolution
@@ -39,10 +40,11 @@ genConvolution :: ( KnownNat channels
39
40
, KnownNat strideColumns
40
41
, KnownNat kernelFlattened
41
42
, kernelFlattened ~ (kernelRows * kernelColumns * channels )
42
- ) => Jack (Convolution channels filters kernelRows kernelColumns strideRows strideColumns )
43
+ , Monad m
44
+ ) => Gen m (Convolution channels filters kernelRows kernelColumns strideRows strideColumns )
43
45
genConvolution = Convolution <$> uniformSample <*> uniformSample
44
46
45
- genOpaqueOpaqueConvolution :: Jack OpaqueConvolution
47
+ genOpaqueOpaqueConvolution :: Monad m => Gen m OpaqueConvolution
46
48
genOpaqueOpaqueConvolution = do
47
49
channels <- genNat
48
50
filters <- genNat
@@ -58,7 +60,7 @@ genOpaqueOpaqueConvolution = do
58
60
p2 = natDict pkc
59
61
p3 = natDict pch
60
62
in case p1 %* p2 %* p3 of
61
- Dict -> OpaqueConvolution <$> (genConvolution :: Jack (Convolution ch fl kr kc sr sc ))
63
+ Dict -> OpaqueConvolution <$> (genConvolution :: Monad n => Gen n (Convolution ch fl kr kc sr sc ))
62
64
63
65
prop_conv_net_witness = property $
64
66
blindForAll genOpaqueOpaqueConvolution >>= \ onet ->
0 commit comments