Skip to content

Commit

Permalink
Unboxing and streamlining Map maps
Browse files Browse the repository at this point in the history
* Use an unboxed-sum version of `Maybe` to implement `mapMaybeWithKey`.
  This potentially (I suspect usually) allows all the `Maybe`s to be
  erased.

* Comprehensive rewrite rules for both strict and lazy versions of
  `map`, `mapWithKey`, `mapMaybeWithKey`, and `filterWithKey` quickly
  get out of hand. Following `unordered-containers`, tame the mess
  by implementing both lazy and strict mapping functions in terms of
  versions that use unboxed results. Rewrite rules on these underlying
  functions will then apply uniformly. One concern: I found it a bit
  tricky to get the unfoldings I wanted; lots of things had to be marked
  `INLINABLE` explicitly.
  • Loading branch information
treeowl committed Nov 20, 2022
1 parent 3db464d commit 75a721b
Show file tree
Hide file tree
Showing 6 changed files with 293 additions and 40 deletions.
2 changes: 2 additions & 0 deletions containers-tests/containers-tests.cabal
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,8 @@ library
Utils.Containers.Internal.BitQueue
Utils.Containers.Internal.BitUtil
Utils.Containers.Internal.StrictPair
Utils.Containers.Internal.UnboxedMaybe
Utils.Containers.Internal.UnboxedSolo
if impl(ghc >= 8.6.0)
exposed-modules:
Utils.NoThunks
Expand Down
2 changes: 2 additions & 0 deletions containers/containers.cabal
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,8 @@ Library
Utils.Containers.Internal.BitUtil
Utils.Containers.Internal.BitQueue
Utils.Containers.Internal.StrictPair
Utils.Containers.Internal.UnboxedMaybe
Utils.Containers.Internal.UnboxedSolo

other-modules:
Prelude
Expand Down
151 changes: 135 additions & 16 deletions containers/src/Data/Map/Internal.hs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
{-# LANGUAGE PatternGuards #-}
#if defined(__GLASGOW_HASKELL__)
{-# LANGUAGE DeriveLift #-}
{-# LANGUAGE UnboxedTuples #-}
{-# LANGUAGE RoleAnnotations #-}
{-# LANGUAGE StandaloneDeriving #-}
{-# LANGUAGE Trustworthy #-}
Expand Down Expand Up @@ -236,7 +237,9 @@ module Data.Map.Internal (
-- * Traversal
-- ** Map
, map
, mapU
, mapWithKey
, mapWithKeyU
, traverseWithKey
, traverseMaybeWithKey
, mapAccum
Expand Down Expand Up @@ -301,6 +304,7 @@ module Data.Map.Internal (

, mapMaybe
, mapMaybeWithKey
, mapMaybeWithKeyU
, mapEither
, mapEitherWithKey

Expand Down Expand Up @@ -407,6 +411,8 @@ import Data.Data
import qualified Control.Category as Category
import Data.Coerce
#endif
import Utils.Containers.Internal.UnboxedMaybe
import Utils.Containers.Internal.UnboxedSolo


{--------------------------------------------------------------------
Expand Down Expand Up @@ -2849,6 +2855,7 @@ isProperSubmapOfBy f t1 t2
filter :: (a -> Bool) -> Map k a -> Map k a
filter p m
= filterWithKey (\_ x -> p x) m
{-# INLINE filter #-}

-- | \(O(n)\). Filter all keys\/values that satisfy the predicate.
--
Expand All @@ -2863,6 +2870,32 @@ filterWithKey p t@(Bin _ kx x l r)
| otherwise = link2 pl pr
where !pl = filterWithKey p l
!pr = filterWithKey p r
{-# NOINLINE [1] filterWithKey #-}

{-# RULES
"filterWK/filterWK" forall p q m. filterWithKey p (filterWithKey q m) =
filterWithKey (\k x -> q k x && p k x) m
"filterWK/mapU" forall p f m. filterWithKey p (mapU f m) =
mapMaybeWithKeyU (\k x -> case f x of
SoloU y
| p k y -> JustU y
| otherwise -> NothingU) m
"filterWK/mapWK#" forall p f m. filterWithKey p (mapWithKeyU f m) =
mapMaybeWithKeyU (\k x -> case f k x of
SoloU y
| p k y -> JustU y
| otherwise -> NothingU) m
"mapU/filterWK" forall f p m. mapU f (filterWithKey p m) =
mapMaybeWithKeyU (\k x ->
if p k x
then case f x of SoloU y -> JustU y
else NothingU) m
"mapWK#/filterWK" forall f p m. mapWithKeyU f (filterWithKey p m) =
mapMaybeWithKeyU (\k x ->
if p k x
then case f k x of SoloU y -> JustU y
else NothingU) m
#-}

-- | \(O(n)\). Filter keys and values using an 'Applicative'
-- predicate.
Expand Down Expand Up @@ -2977,17 +3010,60 @@ partitionWithKey p0 t0 = toPair $ go p0 t0

mapMaybe :: (a -> Maybe b) -> Map k a -> Map k b
mapMaybe f = mapMaybeWithKey (\_ x -> f x)
{-# INLINE mapMaybe #-}

-- | \(O(n)\). Map keys\/values and collect the 'Just' results.
--
-- > let f k _ = if k < 5 then Just ("key : " ++ (show k)) else Nothing
-- > mapMaybeWithKey f (fromList [(5,"a"), (3,"b")]) == singleton 3 "key : 3"

mapMaybeWithKey :: (k -> a -> Maybe b) -> Map k a -> Map k b
{-
mapMaybeWithKey _ Tip = Tip
mapMaybeWithKey f (Bin _ kx x l r) = case f kx x of
Just y -> link kx y (mapMaybeWithKey f l) (mapMaybeWithKey f r)
Nothing -> link2 (mapMaybeWithKey f l) (mapMaybeWithKey f r)
-}
mapMaybeWithKey f = \m ->
mapMaybeWithKeyU (\k x -> toMaybeU (f k x)) m
{-# INLINE mapMaybeWithKey #-}

mapMaybeWithKeyU :: (k -> a -> MaybeU b) -> Map k a -> Map k b
mapMaybeWithKeyU _ Tip = Tip
mapMaybeWithKeyU f (Bin _ kx x l r) = case f kx x of
JustU y -> link kx y (mapMaybeWithKeyU f l) (mapMaybeWithKeyU f r)
NothingU -> link2 (mapMaybeWithKeyU f l) (mapMaybeWithKeyU f r)
{-# NOINLINE [1] mapMaybeWithKeyU #-}

{-# RULES
"mapMaybeWK#/mapU" forall f g m. mapMaybeWithKeyU f (mapU g m) =
mapMaybeWithKeyU (\k x -> case g x of SoloU y -> f k y) m
"mapU/mapMaybeWK#" forall f g m. mapU f (mapMaybeWithKeyU g m) =
mapMaybeWithKeyU
(\k x -> case g k x of
NothingU -> NothingU
JustU y -> case f y of SoloU z -> JustU z) m
"mapMaybeWK#/mapWK#" forall f g m. mapMaybeWithKeyU f (mapWithKeyU g m) =
mapMaybeWithKeyU (\k x -> case g k x of SoloU y -> f k y) m
"mapWK#/mapMaybeWK#" forall f g m. mapWithKeyU f (mapMaybeWithKeyU g m) =
mapMaybeWithKeyU
(\k x -> case g k x of
NothingU -> NothingU
JustU y -> case f k y of SoloU z -> JustU z) m
"mapMaybeWK#/mapMaybeWK#" forall f g m. mapMaybeWithKeyU f (mapMaybeWithKeyU g m) =
mapMaybeWithKeyU
(\k x -> case g k x of
NothingU -> NothingU
JustU y -> f k y) m
"mapMaybeWK#/filterWK" forall f p m. mapMaybeWithKeyU f (filterWithKey p m) =
mapMaybeWithKeyU (\k x -> if p k x then f k x else NothingU) m
"filterWK/mapMaybeWK#" forall p f m. filterWithKey p (mapMaybeWithKeyU f m) =
mapMaybeWithKeyU (\k x -> case f k x of
NothingU -> NothingU
JustU y
| p k y -> JustU y
| otherwise -> NothingU) m
#-}

-- | \(O(n)\). Traverse keys\/values and collect the 'Just' results.
--
Expand Down Expand Up @@ -3045,17 +3121,41 @@ mapEitherWithKey f0 t0 = toPair $ go f0 t0
-- > map (++ "x") (fromList [(5,"a"), (3,"b")]) == fromList [(3, "bx"), (5, "ax")]

map :: (a -> b) -> Map k a -> Map k b
#ifdef __GLASGOW_HASKELL__
-- We define map using mapU solely to reduce the number of rewrite
-- rules we need.
map f = mapU (\x -> SoloU (f x))
-- We delay inlinability of map to support map/coerce. While a
-- mapU/coerce rule seems to work when everything is done just so,
-- it feels too brittle to me for now (GHC 9.4).
{-# INLINABLE [1] map #-}
#else
map f = go where
go Tip = Tip
go (Bin sx kx x l r) = Bin sx kx (f x) (go l) (go r)
-- We use a `go` function to allow `map` to inline. This makes
-- a big difference if someone uses `map (const x) m` instead
-- of `x <$ m`; it doesn't seem to do any harm.
#endif

#ifdef __GLASGOW_HASKELL__
{-# NOINLINE [1] map #-}
mapU :: (a -> SoloU b) -> Map k a -> Map k b
mapU f = go where
go Tip = Tip
go (Bin sx kx x l r)
| SoloU y <- f x
= Bin sx kx y (go l) (go r)
#if defined (__GLASGOW_HASKELL__) && (__GLASGOW_HASKELL__ >= 806) && (__GLASGOW_HASKELL__ < 810)
-- Something goes wrong checking SoloU completeness
-- in these versions
go _ = error "impossible"
#endif
-- We use a `go` function to allow `mapU` to inline. Without this,
-- we'd slow down both strict and lazy map, which wouldn't be great.
-- This also lets us avoid a custom implementation of <$

-- We don't let mapU inline until phase 0 because we need a step
-- after map inlines.
{-# NOINLINE [0] mapU #-}
{-# RULES
"map/map" forall f g xs . map f (map g xs) = map (f . g) xs
"mapU/mapU" forall f g xs . mapU f (mapU g xs) = mapU (\x -> case g x of SoloU y -> f y) xs
"map/coerce" map coerce = coerce
#-}
#endif
Expand All @@ -3066,21 +3166,38 @@ map f = go where
-- > mapWithKey f (fromList [(5,"a"), (3,"b")]) == fromList [(3, "3:b"), (5, "5:a")]

mapWithKey :: (k -> a -> b) -> Map k a -> Map k b
#ifdef __GLASGOW_HASKELL__
mapWithKey f = mapWithKeyU (\k a -> SoloU (f k a))
{-# INLINABLE mapWithKey #-}
#else
mapWithKey _ Tip = Tip
mapWithKey f (Bin sx kx x l r) = Bin sx kx (f kx x) (mapWithKey f l) (mapWithKey f r)
#endif

-- | A version of 'mapWithKey' that takes a function producing a unary
-- unboxed tuple.
mapWithKeyU :: (k -> a -> SoloU b) -> Map k a -> Map k b
mapWithKeyU f = go where
go Tip = Tip
go (Bin sx kx x l r)
| SoloU y <- f kx x
= Bin sx kx y (go l) (go r)
#if defined (__GLASGOW_HASKELL__) && (__GLASGOW_HASKELL__ >= 806) && (__GLASGOW_HASKELL__ < 810)
-- Something goes wrong checking SoloU completeness
-- in these versions
go _ = error "impossible"
#endif

#ifdef __GLASGOW_HASKELL__
{-# NOINLINE [1] mapWithKey #-}
{-# NOINLINE [1] mapWithKeyU #-}
{-# RULES
"mapWithKey/mapWithKey" forall f g xs . mapWithKey f (mapWithKey g xs) =
mapWithKey (\k a -> f k (g k a)) xs
"mapWithKey/map" forall f g xs . mapWithKey f (map g xs) =
mapWithKey (\k a -> f k (g a)) xs
"map/mapWithKey" forall f g xs . map f (mapWithKey g xs) =
mapWithKey (\k a -> f (g k a)) xs
"mapWK#/mapWK#" forall f g xs. mapWithKeyU f (mapWithKeyU g xs) = mapWithKeyU (\k x -> case g k x of SoloU y -> f k y) xs
"mapWK#/mapU" forall f g xs. mapWithKeyU f (mapU g xs) = mapWithKeyU (\k x -> case g x of SoloU y -> f k y) xs
"mapU/mapWK#" forall f g xs. mapU f (mapWithKeyU g xs) = mapWithKeyU (\k x -> case g k x of SoloU y -> f y) xs
#-}
#endif


-- | \(O(n)\).
-- @'traverseWithKey' f m == 'fromList' <$> 'traverse' (\(k, v) -> (,) k <$> f k v) ('toList' m)@
-- That is, behaves exactly like a regular 'traverse' except that the traversing
Expand Down Expand Up @@ -4195,10 +4312,12 @@ instance (Ord k, Read k) => Read1 (Map k) where
--------------------------------------------------------------------}
instance Functor (Map k) where
fmap f m = map f m
#ifdef __GLASGOW_HASKELL__
_ <$ Tip = Tip
a <$ (Bin sx kx _ l r) = Bin sx kx a (a <$ l) (a <$ r)
#endif
{-# INLINABLE fmap #-}
a <$ m = map (const a) m
-- For some reason, we need an explicit INLINE or INLINABLE pragma to
-- get the unfolding to use map rather than expanding into a recursive
-- function that RULES will never match. Hmm....
{-# INLINABLE (<$) #-}

-- | Traverses in order of increasing key.
instance Traversable (Map k) where
Expand Down
46 changes: 22 additions & 24 deletions containers/src/Data/Map/Strict/Internal.hs
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
{-# LANGUAGE CPP #-}
{-# LANGUAGE BangPatterns #-}
#if defined(__GLASGOW_HASKELL__)
{-# LANGUAGE MagicHash #-}
{-# LANGUAGE PatternSynonyms #-}
{-# LANGUAGE Trustworthy #-}
{-# LANGUAGE UnboxedTuples #-}
#endif
{-# OPTIONS_HADDOCK not-home #-}

Expand Down Expand Up @@ -420,6 +423,8 @@ import Data.Semigroup (Arg (..))
import qualified Data.Set.Internal as Set
import qualified Data.Map.Internal as L
import Utils.Containers.Internal.StrictPair
import Utils.Containers.Internal.UnboxedMaybe (pattern NothingU, pattern JustU)
import Utils.Containers.Internal.UnboxedSolo (pattern SoloU)

import Data.Bits (shiftL, shiftR)
#ifdef __GLASGOW_HASKELL__
Expand Down Expand Up @@ -1271,17 +1276,26 @@ mergeWithKey f g1 g2 = go

mapMaybe :: (a -> Maybe b) -> Map k a -> Map k b
mapMaybe f = mapMaybeWithKey (\_ x -> f x)
{-# INLINABLE mapMaybe #-}

-- | \(O(n)\). Map keys\/values and collect the 'Just' results.
--
-- > let f k _ = if k < 5 then Just ("key : " ++ (show k)) else Nothing
-- > mapMaybeWithKey f (fromList [(5,"a"), (3,"b")]) == singleton 3 "key : 3"

mapMaybeWithKey :: (k -> a -> Maybe b) -> Map k a -> Map k b
{-
-
mapMaybeWithKey _ Tip = Tip
mapMaybeWithKey f (Bin _ kx x l r) = case f kx x of
Just y -> y `seq` link kx y (mapMaybeWithKey f l) (mapMaybeWithKey f r)
Nothing -> link2 (mapMaybeWithKey f l) (mapMaybeWithKey f r)
-}
mapMaybeWithKey f = \m ->
L.mapMaybeWithKeyU (\k x -> case f k x of
Nothing -> NothingU
Just !a -> JustU a) m
{-# INLINABLE mapMaybeWithKey #-}

-- | \(O(n)\). Traverse keys\/values and collect the 'Just' results.
--
Expand Down Expand Up @@ -1340,19 +1354,16 @@ mapEitherWithKey f0 t0 = toPair $ go f0 t0
-- > map (++ "x") (fromList [(5,"a"), (3,"b")]) == fromList [(3, "bx"), (5, "ax")]

map :: (a -> b) -> Map k a -> Map k b
#ifdef __GLASGOW_HASKELL__
map f = L.mapU (\x -> let !y = f x in SoloU y)
{-# INLINABLE map #-}
#else
map f = go
where
go Tip = Tip
go (Bin sx kx x l r) = let !x' = f x in Bin sx kx x' (go l) (go r)
-- We use `go` to let `map` inline. This is important if `f` is a constant
-- function.

#ifdef __GLASGOW_HASKELL__
{-# NOINLINE [1] map #-}
{-# RULES
"map/map" forall f g xs . map f (map g xs) = map (\x -> f $! g x) xs
"map/mapL" forall f g xs . map f (L.map g xs) = map (\x -> f (g x)) xs
#-}
#endif

-- | \(O(n)\). Map a function over all values in the map.
Expand All @@ -1361,27 +1372,14 @@ map f = go
-- > mapWithKey f (fromList [(5,"a"), (3,"b")]) == fromList [(3, "3:b"), (5, "5:a")]

mapWithKey :: (k -> a -> b) -> Map k a -> Map k b
#ifdef __GLASGOW_HASKELL__
mapWithKey f = L.mapWithKeyU (\k x -> let !y = f k x in SoloU y)
{-# INLINABLE mapWithKey #-}
#else
mapWithKey _ Tip = Tip
mapWithKey f (Bin sx kx x l r) =
let x' = f kx x
in x' `seq` Bin sx kx x' (mapWithKey f l) (mapWithKey f r)

#ifdef __GLASGOW_HASKELL__
{-# NOINLINE [1] mapWithKey #-}
{-# RULES
"mapWithKey/mapWithKey" forall f g xs . mapWithKey f (mapWithKey g xs) =
mapWithKey (\k a -> f k $! g k a) xs
"mapWithKey/mapWithKeyL" forall f g xs . mapWithKey f (L.mapWithKey g xs) =
mapWithKey (\k a -> f k (g k a)) xs
"mapWithKey/map" forall f g xs . mapWithKey f (map g xs) =
mapWithKey (\k a -> f k $! g a) xs
"mapWithKey/mapL" forall f g xs . mapWithKey f (L.map g xs) =
mapWithKey (\k a -> f k (g a)) xs
"map/mapWithKey" forall f g xs . map f (mapWithKey g xs) =
mapWithKey (\k a -> f $! g k a) xs
"map/mapWithKeyL" forall f g xs . map f (L.mapWithKey g xs) =
mapWithKey (\k a -> f (g k a)) xs
#-}
#endif

-- | \(O(n)\).
Expand Down
Loading

0 comments on commit 75a721b

Please sign in to comment.