Skip to content

Commit 54fb6fd

Browse files
committed
Fix deprecated array size checks
1 parent 666d389 commit 54fb6fd

File tree

1 file changed

+73
-32
lines changed
  • unison-runtime/src/Unison/Runtime

1 file changed

+73
-32
lines changed

unison-runtime/src/Unison/Runtime/Array.hs

Lines changed: 73 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -141,44 +141,67 @@ checkIBArray name a f arr i
141141
checkIMBArray
142142
:: CheckCtx
143143
=> Prim a
144+
=> PrimMonad m
144145
=> String
145146
-> a
146-
-> (MutableByteArray s -> Int -> r)
147-
-> MutableByteArray s -> Int -> r
148-
checkIMBArray name a f arr i
149-
| i < 0 || sizeofMutableByteArray arr `quot` sizeOf a <= i
150-
= error $ name ++ " unsafe check out of bounds: " ++ show i
151-
| otherwise = f arr i
147+
-> (MutableByteArray (PrimState m) -> Int -> m r)
148+
-> MutableByteArray (PrimState m) -> Int -> m r
149+
checkIMBArray name a f arr i = do
150+
sz <- getSizeofMutableByteArray arr
151+
if (i < 0 || sz `quot` sizeOf a <= i)
152+
then error $ name ++ " unsafe check out of bounds: " ++ show i
153+
else f arr i
152154
{-# inline checkIMBArray #-}
153155

156+
-- check write mutable byte array
157+
checkWMBArray
158+
:: CheckCtx
159+
=> Prim a
160+
=> PrimMonad m
161+
=> String
162+
-> (MutableByteArray (PrimState m) -> Int -> a -> m r)
163+
-> MutableByteArray (PrimState m) -> Int -> a -> m r
164+
checkWMBArray name f arr i a = do
165+
sz <- getSizeofMutableByteArray arr
166+
if (i < 0 || sz `quot` sizeOf a <= i)
167+
then error $ name ++ " unsafe check out of bounds: " ++ show i
168+
else f arr i a
169+
{-# inline checkWMBArray #-}
170+
171+
154172
-- check copy byte array
155173
checkCBArray
156174
:: CheckCtx
175+
=> PrimMonad m
157176
=> String
158-
-> (MBA s -> Int -> BA -> Int -> Int -> r)
159-
-> MBA s -> Int -> BA -> Int -> Int -> r
160-
checkCBArray name f dst d src s l
161-
| d < 0
162-
|| s < 0
163-
|| sizeofMutableByteArray dst < d + l
164-
|| sizeofByteArray src < s + l
165-
= error $ name ++ " unsafe check out of bounds: " ++ show (d, s, l)
166-
| otherwise = f dst d src s l
177+
-> (MBA (PrimState m) -> Int -> BA -> Int -> Int -> m r)
178+
-> MBA (PrimState m) -> Int -> BA -> Int -> Int -> m r
179+
checkCBArray name f dst d src s l = do
180+
szd <- getSizeofMutableByteArray dst
181+
if (d < 0
182+
|| s < 0
183+
|| szd < d + l
184+
|| sizeofByteArray src < s + l
185+
) then error $ name ++ " unsafe check out of bounds: " ++ show (d, s, l)
186+
else f dst d src s l
167187
{-# inline checkCBArray #-}
168188

169189
-- check copy mutable byte array
170190
checkCMBArray
171191
:: CheckCtx
192+
=> PrimMonad m
172193
=> String
173-
-> (MBA s -> Int -> MBA s -> Int -> Int -> r)
174-
-> MBA s -> Int -> MBA s -> Int -> Int -> r
175-
checkCMBArray name f dst d src s l
176-
| d < 0
177-
|| s < 0
178-
|| sizeofMutableByteArray dst < d + l
179-
|| sizeofMutableByteArray src < s + l
180-
= error $ name ++ " unsafe check out of bounds: " ++ show (d, s, l)
181-
| otherwise = f dst d src s l
194+
-> (MBA (PrimState m) -> Int -> MBA (PrimState m) -> Int -> Int -> m r)
195+
-> MBA (PrimState m) -> Int -> MBA (PrimState m) -> Int -> Int -> m r
196+
checkCMBArray name f dst d src s l = do
197+
szd <- getSizeofMutableByteArray dst
198+
szs <- getSizeofMutableByteArray src
199+
if ( d < 0
200+
|| s < 0
201+
|| szd < d + l
202+
|| szs < s + l
203+
) then error $ name ++ " unsafe check out of bounds: " ++ show (d, s, l)
204+
else f dst d src s l
182205
{-# inline checkCMBArray #-}
183206

184207
-- check index prim array
@@ -197,16 +220,34 @@ checkIPArray name f arr i
197220
-- check index mutable prim array
198221
checkIMPArray
199222
:: CheckCtx
223+
=> PrimMonad m
200224
=> Prim a
201225
=> String
202-
-> (MutablePrimArray s a -> Int -> r)
203-
-> MutablePrimArray s a -> Int -> r
204-
checkIMPArray name f arr i
205-
| i < 0 || sizeofMutablePrimArray arr <= i
206-
= error $ name ++ " unsafe check out of bounds: " ++ show i
207-
| otherwise = f arr i
226+
-> (MutablePrimArray (PrimState m) a -> Int -> m r)
227+
-> MutablePrimArray (PrimState m) a -> Int -> m r
228+
checkIMPArray name f arr i = do
229+
asz <- getSizeofMutablePrimArray arr
230+
if (i < 0 || asz <= i)
231+
then error $ name ++ " unsafe check out of bounds: " ++ show i
232+
else f arr i
208233
{-# inline checkIMPArray #-}
209234

235+
-- check write mutable prim array
236+
checkWMPArray
237+
:: CheckCtx
238+
=> PrimMonad m
239+
=> Prim a
240+
=> String
241+
-> (MutablePrimArray (PrimState m) a -> Int -> a -> m r)
242+
-> MutablePrimArray (PrimState m) a -> Int -> a -> m r
243+
checkWMPArray name f arr i a = do
244+
asz <- getSizeofMutablePrimArray arr
245+
if (i < 0 || asz <= i)
246+
then error $ name ++ " unsafe check out of bounds: " ++ show i
247+
else f arr i a
248+
{-# inline checkWMPArray #-}
249+
250+
210251
#else
211252
type CheckCtx :: Constraint
212253
type CheckCtx = ()
@@ -301,7 +342,7 @@ writeByteArray ::
301342
Int ->
302343
a ->
303344
m ()
304-
writeByteArray = checkIMBArray @a "writeByteArray" undefined PA.writeByteArray
345+
writeByteArray = checkWMBArray @a "writeByteArray" PA.writeByteArray
305346
{-# INLINE writeByteArray #-}
306347

307348
indexByteArray ::
@@ -368,7 +409,7 @@ writePrimArray ::
368409
Int ->
369410
a ->
370411
m ()
371-
writePrimArray = checkIMPArray "writePrimArray" PA.writePrimArray
412+
writePrimArray = checkWMPArray "writePrimArray" PA.writePrimArray
372413
{-# INLINE writePrimArray #-}
373414

374415
indexPrimArray ::

0 commit comments

Comments
 (0)