Skip to content

Commit

Permalink
Replace a few functions from Types with ox-array machinery
Browse files Browse the repository at this point in the history
  • Loading branch information
Mikolaj committed Dec 28, 2024
1 parent 4ef2740 commit 08aa944
Show file tree
Hide file tree
Showing 5 changed files with 28 additions and 34 deletions.
12 changes: 6 additions & 6 deletions src/HordeAd/Core/AstSimplify.hs
Original file line number Diff line number Diff line change
Expand Up @@ -1028,8 +1028,8 @@ astGatherKnobsR knobs sh0 v0 (vars0, ix0) =
| otherwise ->
if knobExpand knobs
then Ast.AstGather sh0 v0 (vars0, ix0)
else case knownShR sh' of
Dict -> withSNat k $ \snat ->
else case shrRank sh' of
SNat -> withSNat k $ \snat ->
astReplicate snat (astGatherKnobsR knobs sh' v0 (vars, ix0))
where
restN = ixrInit ix0
Expand Down Expand Up @@ -2240,7 +2240,7 @@ astReplicateN :: forall n p s r.
astReplicateN sh v =
let go :: IShR n' -> AstTensor AstMethodLet s (TKR2 (n' + p) r)
go ZSR = v
go (k :$: sh2) | Dict <- knownShR sh2 = withSNat k $ \snat ->
go (k :$: sh2) | SNat <- shrRank sh2 = withSNat k $ \snat ->
astReplicate snat $ go sh2
in go (takeShape sh)

Expand All @@ -2251,7 +2251,7 @@ astReplicateNS :: forall shn shp s r.
astReplicateNS v =
let go :: ShS shn' -> AstTensor AstMethodLet s (TKS2 (shn' ++ shp) r)
go ZSS = v
go ((:$$) @k @shn2 SNat shn2) | Dict <- sshapeKnown shn2 =
go ((:$$) @k @shn2 SNat shn2) | Dict <- shsKnownShS shn2 =
withKnownShS (knownShS @shn2 `shsAppend` knownShS @shp) $
astReplicate (SNat @k) $ go shn2
in go (knownShS @shn)
Expand All @@ -2262,7 +2262,7 @@ astReplicate0N sh =
let go :: IShR n' -> AstTensor AstMethodLet s (TKR 0 r)
-> AstTensor AstMethodLet s (TKR n' r)
go ZSR v = v
go (k :$: sh') v | Dict <- knownShR sh' = withSNat k $ \snat ->
go (k :$: sh') v | SNat <- shrRank sh' = withSNat k $ \snat ->
astReplicate snat $ go sh' v
in go sh . fromPrimal . AstConcrete (FTKR ZSR FTKScalar) . rscalar

Expand All @@ -2272,7 +2272,7 @@ astReplicate0NS =
let go :: ShS sh' -> AstTensor AstMethodLet s (TKS '[] r)
-> AstTensor AstMethodLet s (TKS sh' r)
go ZSS v = v
go ((:$$) SNat sh') v | Dict <- sshapeKnown sh' =
go ((:$$) SNat sh') v | Dict <- shsKnownShS sh' =
astReplicate SNat $ go sh' v
in go (knownShS @shn) . fromPrimal . AstConcrete (FTKS ZSS FTKScalar) . sscalar

Expand Down
18 changes: 9 additions & 9 deletions src/HordeAd/Core/HVectorOps.hs
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ import Data.Array.Nested
, type (++)
)
import Data.Array.Nested.Internal.Shape
(shCvtRX, shCvtSX, shrAppend, shrRank, shsAppend, shsRank)
(shCvtRX, shCvtSX, shrAppend, shrRank, shsAppend, shsKnownShS, shsRank)

import HordeAd.Core.TensorClass
import HordeAd.Core.TensorKind
Expand Down Expand Up @@ -626,7 +626,7 @@ unravelDynamic (DynamicRanked @rp @p t) =
Nothing -> error "unravelDynamic: rank 0"
unravelDynamic (DynamicShaped @_ @sh t) = case knownShS @sh of
ZSS -> error "unravelDynamic: rank 0"
(:$$) SNat tl | Dict <- sshapeKnown tl -> map DynamicShaped $ sunravelToList t
(:$$) SNat tl | Dict <- shsKnownShS tl -> map DynamicShaped $ sunravelToList t
unravelDynamic (DynamicRankedDummy @rp @sh _ _) =
withListSh (Proxy @sh) $ \(sh :: IShR p) ->
case someNatVal $ valueOf @p - 1 of
Expand All @@ -636,7 +636,7 @@ unravelDynamic (DynamicRankedDummy @rp @sh _ _) =
Nothing -> error "unravelDynamic: rank 0"
unravelDynamic (DynamicShapedDummy @rp @sh _ _) = case knownShS @sh of
ZSS -> error "unravelDynamic: rank 0"
(:$$) SNat tl | Dict <- sshapeKnown tl ->
(:$$) SNat tl | Dict <- shsKnownShS tl ->
map DynamicShaped $ sunravelToList (srepl 0 :: target (TKS sh rp))

unravelHVector
Expand Down Expand Up @@ -798,20 +798,20 @@ mapRanked10 f (DynamicRanked t) = case rshape t of
_ :$: _ -> DynamicRanked $ f t
mapRanked10 f (DynamicShaped @_ @sh t) = case knownShS @sh of
ZSS -> error "mapRanked10: rank 0"
(:$$) @_ @sh0 _ tl | Dict <- sshapeKnown tl ->
(:$$) @_ @sh0 _ tl | Dict <- shsKnownShS tl ->
withListSh (Proxy @sh0) $ \(_ :: IShR n) ->
let res = f $ rfromS @_ @_ @sh t
in withShapeP (toList $ rshape res) $ \(Proxy @shr) ->
gcastWith (unsafeCoerceRefl :: Rank shr :~: n) $
DynamicShaped $ sfromR @_ @_ @shr res
mapRanked10 f (DynamicRankedDummy @r @sh _ _) = case knownShS @sh of
ZSS -> error "mapRanked10: rank 0"
(:$$) @_ @sh0 k tl | Dict <- sshapeKnown tl ->
(:$$) @_ @sh0 k tl | Dict <- shsKnownShS tl ->
withListSh (Proxy @sh0) $ \sh1 ->
DynamicRanked @r $ f (rzero $ sNatValue k :$: sh1)
mapRanked10 f (DynamicShapedDummy @r @sh _ _) = case knownShS @sh of
ZSS -> error "mapRanked10: rank 0"
(:$$) @_ @sh0 k tl | Dict <- sshapeKnown tl ->
(:$$) @_ @sh0 k tl | Dict <- shsKnownShS tl ->
withListSh (Proxy @sh0) $ \(sh1 :: IShR n) ->
let res = f @r (rzero $ sNatValue k :$: sh1)
in withShapeP (toList $ rshape res) $ \(Proxy @shr) ->
Expand All @@ -835,7 +835,7 @@ mapRanked11 f (DynamicRanked t) = case rshape t of
_ :$: _ -> DynamicRanked $ f t
mapRanked11 f (DynamicShaped @_ @sh t) = case knownShS @sh of
ZSS -> error "mapRanked11: rank 0"
(:$$) @_ @sh0 _ tl | Dict <- sshapeKnown tl ->
(:$$) @_ @sh0 _ tl | Dict <- shsKnownShS tl ->
withListSh (Proxy @sh0) $ \(_ :: IShR n) ->
let res = f $ rfromS @_ @_ @sh t
in withShapeP (toList $ rshape res) $ \(Proxy @shr) ->
Expand All @@ -847,12 +847,12 @@ mapRanked11 f (DynamicShaped @_ @sh t) = case knownShS @sh of
_ -> error "mapRanked01: impossible someNatVal"
mapRanked11 f (DynamicRankedDummy @r @sh _ _) = case knownShS @sh of
ZSS -> error "mapRanked11: rank 0"
(:$$) @_ @sh0 k tl | Dict <- sshapeKnown tl ->
(:$$) @_ @sh0 k tl | Dict <- shsKnownShS tl ->
withListSh (Proxy @sh0) $ \sh1 ->
DynamicRanked @r $ f (rzero $ sNatValue k :$: sh1)
mapRanked11 f (DynamicShapedDummy @r @sh _ _) = case knownShS @sh of
ZSS -> error "mapRanked11: rank 0"
(:$$) @_ @sh0 k tl | Dict <- sshapeKnown tl ->
(:$$) @_ @sh0 k tl | Dict <- shsKnownShS tl ->
withListSh (Proxy @sh0) $ \(sh1 :: IShR n) ->
let res = f @r (rzero $ sNatValue k :$: sh1)
in withShapeP (toList $ rshape res) $ \(Proxy @shr) ->
Expand Down
7 changes: 4 additions & 3 deletions src/HordeAd/Core/TensorClass.hs
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,8 @@ import Data.Array.Nested
, type (++)
)
import Data.Array.Nested qualified as Nested
import Data.Array.Nested.Internal.Shape (shCvtSX, shrSize, shsAppend)
import Data.Array.Nested.Internal.Shape
(shCvtSX, shrRank, shrSize, shsAppend, shsKnownShS)

import HordeAd.Core.CarriersConcrete
import HordeAd.Core.TensorKind
Expand Down Expand Up @@ -299,7 +300,7 @@ class ( Num (IntOf target)
let buildSh :: IShR m1 -> (IxROf target m1 -> target (TKR2 n r))
-> target (TKR2 (m1 + n) r)
buildSh ZSR f = f ZIR
buildSh (k :$: sh) f | Dict <- knownShR sh =
buildSh (k :$: sh) f | SNat <- shrRank sh =
let g i = buildSh sh (\ix -> f (i :.: ix))
in rbuild1 k g
in buildSh (takeShape @m @n sh0) f0
Expand Down Expand Up @@ -631,7 +632,7 @@ class ( Num (IntOf target)
-> target (TKS2 (sh1 ++ Drop m sh) r)
buildSh sh1 sh1m f = case (sh1, sh1m) of
(ZSS, _) -> f ZIS
((:$$) SNat sh2, (:$$) _ sh2m) | Dict <- sshapeKnown sh2m ->
((:$$) SNat sh2, (:$$) _ sh2m) | Dict <- shsKnownShS sh2m ->
let g i = buildSh sh2 sh2m (f . (i :.$))
in sbuild1 g
in gcastWith (unsafeCoerceRefl
Expand Down
8 changes: 4 additions & 4 deletions src/HordeAd/Core/TensorKind.hs
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ import Data.Array.Nested
)
import Data.Array.Nested qualified as Nested
import Data.Array.Nested.Internal.Mixed as Mixed
import Data.Array.Nested.Internal.Shape (shrRank, shsRank)
import Data.Array.Nested.Internal.Shape (shrRank, shsKnownShS, shsRank)

import HordeAd.Core.Types

Expand Down Expand Up @@ -515,15 +515,15 @@ index1DynamicF rshape sshape rindex sindex u i = case u of
_ :$: _ -> DynamicRanked $ rindex t (i :.: ZIR)
DynamicShaped t -> case sshape t of
ZSS -> error "index1Dynamic: rank 0"
(:$$) SNat tl | Dict <- sshapeKnown tl ->
(:$$) SNat tl | Dict <- shsKnownShS tl ->
DynamicShaped $ sindex t (i :.$ ZIS)
DynamicRankedDummy @r @sh p1 _ -> case knownShS @sh of
ZSS -> error "index1Dynamic: rank 0"
(:$$) @_ @sh2 _ tl | Dict <- sshapeKnown tl ->
(:$$) @_ @sh2 _ tl | Dict <- shsKnownShS tl ->
DynamicRankedDummy @r @sh2 p1 Proxy
DynamicShapedDummy @r @sh p1 _ -> case knownShS @sh of
ZSS -> error "index1Dynamic: rank 0"
(:$$) @_ @sh2 _ tl | Dict <- sshapeKnown tl ->
(:$$) @_ @sh2 _ tl | Dict <- shsKnownShS tl ->
DynamicShapedDummy @r @sh2 p1 Proxy

replicate1HVectorF :: (forall r n. (GoodScalar r, KnownNat n)
Expand Down
17 changes: 5 additions & 12 deletions src/HordeAd/Core/Types.hs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ module HordeAd.Core.Types
SNat, pattern SNat, withSNat, sNatValue, proxyFromSNat, valueOf
-- * Definitions for type-level list shapes
, withKnownShS, withKnownShX
, sshapeKnown, slistKnown, sixKnown, knownShR
, slistKnown, sixKnown
, shapeP, sizeT, sizeP
, withShapeP, sameShape, matchingRank
, Dict(..), PermC, trustMeThisIsAPermutation
Expand Down Expand Up @@ -121,25 +121,18 @@ valueOf = fromInteger $ fromSNat (SNat @n)

-- * Definitions for type-level list shapes

sshapeKnown :: ShS sh -> Dict KnownShS sh
sshapeKnown ZSS = Dict
sshapeKnown (SNat :$$ sh) | Dict <- sshapeKnown sh = Dict

-- TODO: this can probably be retired when we have conversions
-- from ShS to ShR, etc.
slistKnown :: ListS sh i -> Dict KnownShS sh
slistKnown ZS = Dict
slistKnown (_ ::$ sh) | Dict <- slistKnown sh = Dict

-- TODO: this can probably be retired when we have conversions
-- from ShS to ShR, etc.
sixKnown :: IxS sh i -> Dict KnownShS sh
sixKnown ZIS = Dict
sixKnown (_ :.$ sh) | Dict <- sixKnown sh = Dict

knownNatSucc :: KnownNat n => Dict KnownNat (n + 1)
knownNatSucc = Dict

knownShR :: ShR n i -> Dict KnownNat n
knownShR ZSR = Dict
knownShR (_ :$: (l :: ShR m i)) | Dict <- knownShR l = knownNatSucc @m

shapeP :: forall sh. KnownShS sh => Proxy sh -> [Int]
shapeP _ = shsToList (knownShS @sh)

Expand Down

0 comments on commit 08aa944

Please sign in to comment.