Skip to content

Commit f9b5379

Browse files
committed
Network scale autodiff test
1 parent ffbbbf2 commit f9b5379

16 files changed

+292
-58
lines changed

examples/main/gan-mnist.hs

+1-1
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
-- of numbers similar to those in MNIST.
1212
--
1313
-- It demonstrates a different usage of the library. Within about 15
14-
-- minutes hour it was producing examples like this:
14+
-- minutes it was producing examples like this:
1515
--
1616
-- --.
1717
-- .=-.--..#=###

grenade.cabal

+1
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,7 @@ test-suite test
137137
, random
138138
, ad
139139
, reflection
140+
, vector
140141

141142

142143
benchmark bench

test/Test/Grenade/Layers/Convolution.hs

+8-6
Original file line numberDiff line numberDiff line change
@@ -19,11 +19,12 @@ import Grenade.Core
1919
import Grenade.Layers.Convolution
2020

2121
import Hedgehog
22+
import Hedgehog.Gen ( Gen )
2223
import qualified Hedgehog.Gen as Gen
2324

24-
import Test.Jack.Hmatrix
25-
import Test.Jack.TypeLits
26-
import Test.Jack.Compat
25+
import Test.Hedgehog.Hmatrix
26+
import Test.Hedgehog.TypeLits
27+
import Test.Hedgehog.Compat
2728

2829
data OpaqueConvolution :: * where
2930
OpaqueConvolution :: Convolution channels filters kernelRows kernelColumns strideRows strideColumns -> OpaqueConvolution
@@ -39,10 +40,11 @@ genConvolution :: ( KnownNat channels
3940
, KnownNat strideColumns
4041
, KnownNat kernelFlattened
4142
, kernelFlattened ~ (kernelRows * kernelColumns * channels)
42-
) => Jack (Convolution channels filters kernelRows kernelColumns strideRows strideColumns)
43+
, Monad m
44+
) => Gen m (Convolution channels filters kernelRows kernelColumns strideRows strideColumns)
4345
genConvolution = Convolution <$> uniformSample <*> uniformSample
4446

45-
genOpaqueOpaqueConvolution :: Jack OpaqueConvolution
47+
genOpaqueOpaqueConvolution :: Monad m => Gen m OpaqueConvolution
4648
genOpaqueOpaqueConvolution = do
4749
channels <- genNat
4850
filters <- genNat
@@ -58,7 +60,7 @@ genOpaqueOpaqueConvolution = do
5860
p2 = natDict pkc
5961
p3 = natDict pch
6062
in case p1 %* p2 %* p3 of
61-
Dict -> OpaqueConvolution <$> (genConvolution :: Jack (Convolution ch fl kr kc sr sc))
63+
Dict -> OpaqueConvolution <$> (genConvolution :: Monad n => Gen n (Convolution ch fl kr kc sr sc))
6264

6365
prop_conv_net_witness = property $
6466
blindForAll genOpaqueOpaqueConvolution >>= \onet ->

test/Test/Grenade/Layers/FullyConnected.hs

+3-3
Original file line numberDiff line numberDiff line change
@@ -17,16 +17,16 @@ import Grenade.Layers.FullyConnected
1717

1818
import Hedgehog
1919

20-
import Test.Jack.Compat
21-
import Test.Jack.Hmatrix
20+
import Test.Hedgehog.Compat
21+
import Test.Hedgehog.Hmatrix
2222

2323
data OpaqueFullyConnected :: * where
2424
OpaqueFullyConnected :: (KnownNat i, KnownNat o) => FullyConnected i o -> OpaqueFullyConnected
2525

2626
instance Show OpaqueFullyConnected where
2727
show (OpaqueFullyConnected n) = show n
2828

29-
genOpaqueFullyConnected :: Jack OpaqueFullyConnected
29+
genOpaqueFullyConnected :: Monad m => Gen m OpaqueFullyConnected
3030
genOpaqueFullyConnected = do
3131
input :: Integer <- choose 2 100
3232
output :: Integer <- choose 1 100

test/Test/Grenade/Layers/Internal/Convolution.hs

+1-1
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ import qualified Hedgehog.Gen as Gen
1717
import qualified Hedgehog.Range as Range
1818

1919
import qualified Test.Grenade.Layers.Internal.Reference as Reference
20-
import Test.Jack.Compat
20+
import Test.Hedgehog.Compat
2121

2222
prop_im2col_col2im_symmetrical_with_kernel_stride =
2323
let factors n = [x | x <- [1..n], n `mod` x == 0]

test/Test/Grenade/Layers/Internal/Pooling.hs

+1-1
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ import qualified Hedgehog.Gen as Gen
1717
import qualified Hedgehog.Range as Range
1818

1919
import qualified Test.Grenade.Layers.Internal.Reference as Reference
20-
import Test.Jack.Compat
20+
import Test.Hedgehog.Compat
2121

2222
prop_poolForwards_poolBackwards_behaves_as_reference =
2323
let ok extent kernel = [stride | stride <- [1..extent], (extent - kernel) `mod` stride == 0]

test/Test/Grenade/Layers/Nonlinear.hs

+3-5
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,9 @@ import GHC.TypeLits
2020

2121
import Hedgehog
2222

23-
import Test.Jack.Compat
24-
import Test.Jack.Hmatrix
25-
import Test.Jack.TypeLits
23+
import Test.Hedgehog.Compat
24+
import Test.Hedgehog.Hmatrix
25+
import Test.Hedgehog.TypeLits
2626

2727
import Numeric.LinearAlgebra.Static ( norm_Inf )
2828

@@ -68,8 +68,6 @@ prop_softmax_grad = property $
6868
in assert ((case numericalGradient - ret of
6969
(S1D x) -> norm_Inf x < 0.0001) :: Bool)
7070

71-
72-
7371
tests :: IO Bool
7472
tests = $$(checkConcurrent)
7573

test/Test/Grenade/Layers/PadCrop.hs

+1-1
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ import Hedgehog
1919

2020
import Numeric.LinearAlgebra.Static ( norm_Inf )
2121

22-
import Test.Jack.Hmatrix
22+
import Test.Hedgehog.Hmatrix
2323

2424
prop_pad_crop :: Property
2525
prop_pad_crop =

test/Test/Grenade/Layers/Pooling.hs

+4-2
Original file line numberDiff line numberDiff line change
@@ -13,15 +13,17 @@ import GHC.TypeLits
1313
import Grenade.Layers.Pooling
1414

1515
import Hedgehog
16-
import Test.Jack.Compat
16+
import Hedgehog.Gen ( Gen )
17+
18+
import Test.Hedgehog.Compat
1719

1820
data OpaquePooling :: * where
1921
OpaquePooling :: (KnownNat kh, KnownNat kw, KnownNat sh, KnownNat sw) => Pooling kh kw sh sw -> OpaquePooling
2022

2123
instance Show OpaquePooling where
2224
show (OpaquePooling n) = show n
2325

24-
genOpaquePooling :: Jack OpaquePooling
26+
genOpaquePooling :: Monad m => Gen m OpaquePooling
2527
genOpaquePooling = do
2628
Just kernelHeight <- someNatVal <$> choose 2 15
2729
Just kernelWidth <- someNatVal <$> choose 2 15

test/Test/Grenade/Recurrent/Layers/LSTM.hs

+3-4
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
module Test.Grenade.Recurrent.Layers.LSTM where
1212

1313
import Hedgehog
14+
import Hedgehog.Gen ( Gen )
1415
import Hedgehog.Internal.Source
1516
import Hedgehog.Internal.Show
1617
import Hedgehog.Internal.Property ( failWith, Diff (..) )
@@ -26,10 +27,9 @@ import qualified Numeric.LinearAlgebra.Static as S
2627

2728

2829
import qualified Test.Grenade.Recurrent.Layers.LSTM.Reference as Reference
29-
import Test.Jack.Hmatrix
30-
import Test.Jack.Compat
30+
import Test.Hedgehog.Hmatrix
3131

32-
genLSTM :: forall i o. (KnownNat i, KnownNat o) => Jack (LSTM i o)
32+
genLSTM :: forall i o m. (KnownNat i, KnownNat o, Monad m) => Gen m (LSTM i o)
3333
genLSTM = do
3434
let w = uniformSample
3535
u = uniformSample
@@ -121,6 +121,5 @@ prop_lstm_reference_backwards_cell =
121121
failWith (Just $ Diff "Failed (" "- lhs" "~/~" "+ rhs" ")" diff) ""
122122
infix 4 ~~~
123123

124-
125124
tests :: IO Bool
126125
tests = $$(checkConcurrent)

test/Test/Hedgehog/Compat.hs

+39
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
{-# LANGUAGE RankNTypes #-}
2+
module Test.Hedgehog.Compat where
3+
4+
import Control.Monad.Trans.Class (MonadTrans(..))
5+
6+
import Data.Typeable ( typeOf )
7+
8+
import Hedgehog.Gen ( Gen )
9+
import qualified Hedgehog.Gen as Gen
10+
import qualified Hedgehog.Range as Range
11+
import Hedgehog.Internal.Property
12+
import Hedgehog.Internal.Source
13+
import Hedgehog.Internal.Show
14+
15+
(...) :: (c -> d) -> (a -> b -> c) -> a -> b -> d
16+
(...) = (.) . (.)
17+
{-# INLINE (...) #-}
18+
19+
choose :: ( Monad m, Integral a ) => a -> a -> Gen m a
20+
choose = Gen.integral ... Range.constant
21+
22+
-- | Generates a random input for the test by running the provided generator.
23+
blindForAll :: Monad m => Gen m a -> Test m a
24+
blindForAll = Test . lift . lift
25+
26+
-- | Generates a random input for the test by running the provided generator.
27+
semiBlindForAll :: (Monad m, Show a, HasCallStack) => Gen m a -> Test m a
28+
semiBlindForAll gen = do
29+
x <- Test . lift $ lift gen
30+
writeLog $ Input (getCaller callStack) (typeOf ()) (showPretty x)
31+
return x
32+
33+
34+
-- | Generates a random input for the test by running the provided generator.
35+
forAllRender :: (Monad m, HasCallStack) => ( a -> String ) -> Gen m a -> Test m a
36+
forAllRender render gen = do
37+
x <- Test . lift $ lift gen
38+
writeLog $ Input (getCaller callStack) (typeOf ()) (render x)
39+
return x

test/Test/Jack/Hmatrix.hs test/Test/Hedgehog/Hmatrix.hs

+11-5
Original file line numberDiff line numberDiff line change
@@ -3,29 +3,35 @@
33
{-# LANGUAGE ScopedTypeVariables #-}
44
{-# LANGUAGE GADTs #-}
55

6-
module Test.Jack.Hmatrix where
6+
module Test.Hedgehog.Hmatrix where
77

88
import Grenade
99
import Data.Singletons
1010

11+
import Hedgehog.Gen ( Gen )
1112
import qualified Hedgehog.Gen as Gen
1213
import qualified Hedgehog.Range as Range
1314

1415
import GHC.TypeLits
1516

1617
import qualified Numeric.LinearAlgebra.Static as HStatic
17-
import Test.Jack.Compat
1818

19-
randomVector :: forall n. KnownNat n => Jack (HStatic.R n)
19+
randomVector :: forall m n. ( Monad m, KnownNat n ) => Gen m (HStatic.R n)
2020
randomVector = (\s -> HStatic.randomVector s HStatic.Uniform * 2 - 1) <$> Gen.int Range.linearBounded
2121

22-
uniformSample :: forall m n. (KnownNat m, KnownNat n) => Jack (HStatic.L m n)
22+
uniformSample :: forall mm m n. ( Monad mm, KnownNat m, KnownNat n ) => Gen mm (HStatic.L m n)
2323
uniformSample = (\s -> HStatic.uniformSample s (-1) 1 ) <$> Gen.int Range.linearBounded
2424

2525
-- | Generate random data of the desired shape
26-
genOfShape :: forall x. ( SingI x ) => Jack (S x)
26+
genOfShape :: forall m x. ( Monad m, SingI x ) => Gen m (S x)
2727
genOfShape =
2828
case (sing :: Sing x) of
2929
D1Sing -> S1D <$> randomVector
3030
D2Sing -> S2D <$> uniformSample
3131
D3Sing -> S3D <$> uniformSample
32+
33+
34+
nice :: S shape -> String
35+
nice (S1D x) = show . HStatic.extract $ x
36+
nice (S2D x) = show . HStatic.extract $ x
37+
nice (S3D x) = show . HStatic.extract $ x

0 commit comments

Comments
 (0)