Skip to content

Commit e28a16f

Browse files
committed
Fix deprecated array size checks
1 parent 666d389 commit e28a16f

File tree

1 file changed

+79
-34
lines changed
  • unison-runtime/src/Unison/Runtime

1 file changed

+79
-34
lines changed

unison-runtime/src/Unison/Runtime/Array.hs

Lines changed: 79 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -141,44 +141,67 @@ checkIBArray name a f arr i
141141
checkIMBArray
142142
:: CheckCtx
143143
=> Prim a
144+
=> PrimMonad m
144145
=> String
145146
-> a
146-
-> (MutableByteArray s -> Int -> r)
147-
-> MutableByteArray s -> Int -> r
148-
checkIMBArray name a f arr i
149-
| i < 0 || sizeofMutableByteArray arr `quot` sizeOf a <= i
150-
= error $ name ++ " unsafe check out of bounds: " ++ show i
151-
| otherwise = f arr i
147+
-> (MutableByteArray (PrimState m) -> Int -> m r)
148+
-> MutableByteArray (PrimState m) -> Int -> m r
149+
checkIMBArray name a f arr i = do
150+
sz <- getSizeofMutableByteArray arr
151+
if (i < 0 || sz `quot` sizeOf a <= i)
152+
then error $ name ++ " unsafe check out of bounds: " ++ show i
153+
else f arr i
152154
{-# inline checkIMBArray #-}
153155

156+
-- check write mutable byte array
157+
checkWMBArray
158+
:: CheckCtx
159+
=> Prim a
160+
=> PrimMonad m
161+
=> String
162+
-> (MutableByteArray (PrimState m) -> Int -> a -> m r)
163+
-> MutableByteArray (PrimState m) -> Int -> a -> m r
164+
checkWMBArray name f arr i a = do
165+
sz <- getSizeofMutableByteArray arr
166+
if (i < 0 || sz `quot` sizeOf a <= i)
167+
then error $ name ++ " unsafe check out of bounds: " ++ show i
168+
else f arr i a
169+
{-# inline checkWMBArray #-}
170+
171+
154172
-- check copy byte array
155173
checkCBArray
156174
:: CheckCtx
175+
=> PrimMonad m
157176
=> String
158-
-> (MBA s -> Int -> BA -> Int -> Int -> r)
159-
-> MBA s -> Int -> BA -> Int -> Int -> r
160-
checkCBArray name f dst d src s l
161-
| d < 0
162-
|| s < 0
163-
|| sizeofMutableByteArray dst < d + l
164-
|| sizeofByteArray src < s + l
165-
= error $ name ++ " unsafe check out of bounds: " ++ show (d, s, l)
166-
| otherwise = f dst d src s l
177+
-> (MBA (PrimState m) -> Int -> BA -> Int -> Int -> m r)
178+
-> MBA (PrimState m) -> Int -> BA -> Int -> Int -> m r
179+
checkCBArray name f dst d src s l = do
180+
szd <- getSizeofMutableByteArray dst
181+
if (d < 0
182+
|| s < 0
183+
|| szd < d + l
184+
|| sizeofByteArray src < s + l
185+
) then error $ name ++ " unsafe check out of bounds: " ++ show (d, s, l)
186+
else f dst d src s l
167187
{-# inline checkCBArray #-}
168188

169189
-- check copy mutable byte array
170190
checkCMBArray
171191
:: CheckCtx
192+
=> PrimMonad m
172193
=> String
173-
-> (MBA s -> Int -> MBA s -> Int -> Int -> r)
174-
-> MBA s -> Int -> MBA s -> Int -> Int -> r
175-
checkCMBArray name f dst d src s l
176-
| d < 0
177-
|| s < 0
178-
|| sizeofMutableByteArray dst < d + l
179-
|| sizeofMutableByteArray src < s + l
180-
= error $ name ++ " unsafe check out of bounds: " ++ show (d, s, l)
181-
| otherwise = f dst d src s l
194+
-> (MBA (PrimState m) -> Int -> MBA (PrimState m) -> Int -> Int -> m r)
195+
-> MBA (PrimState m) -> Int -> MBA (PrimState m) -> Int -> Int -> m r
196+
checkCMBArray name f dst d src s l = do
197+
szd <- getSizeofMutableByteArray dst
198+
szs <- getSizeofMutableByteArray src
199+
if ( d < 0
200+
|| s < 0
201+
|| szd < d + l
202+
|| szs < s + l
203+
) then error $ name ++ " unsafe check out of bounds: " ++ show (d, s, l)
204+
else f dst d src s l
182205
{-# inline checkCMBArray #-}
183206

184207
-- check index prim array
@@ -197,35 +220,57 @@ checkIPArray name f arr i
197220
-- check index mutable prim array
198221
checkIMPArray
199222
:: CheckCtx
223+
=> PrimMonad m
200224
=> Prim a
201225
=> String
202-
-> (MutablePrimArray s a -> Int -> r)
203-
-> MutablePrimArray s a -> Int -> r
204-
checkIMPArray name f arr i
205-
| i < 0 || sizeofMutablePrimArray arr <= i
206-
= error $ name ++ " unsafe check out of bounds: " ++ show i
207-
| otherwise = f arr i
226+
-> (MutablePrimArray (PrimState m) a -> Int -> m r)
227+
-> MutablePrimArray (PrimState m) a -> Int -> m r
228+
checkIMPArray name f arr i = do
229+
asz <- getSizeofMutablePrimArray arr
230+
if (i < 0 || asz <= i)
231+
then error $ name ++ " unsafe check out of bounds: " ++ show i
232+
else f arr i
208233
{-# inline checkIMPArray #-}
209234

235+
-- check write mutable prim array
236+
checkWMPArray
237+
:: CheckCtx
238+
=> PrimMonad m
239+
=> Prim a
240+
=> String
241+
-> (MutablePrimArray (PrimState m) a -> Int -> a -> m r)
242+
-> MutablePrimArray (PrimState m) a -> Int -> a -> m r
243+
checkWMPArray name f arr i a = do
244+
asz <- getSizeofMutablePrimArray arr
245+
if (i < 0 || asz <= i)
246+
then error $ name ++ " unsafe check out of bounds: " ++ show i
247+
else f arr i a
248+
{-# inline checkWMPArray #-}
249+
250+
210251
#else
211252
type CheckCtx :: Constraint
212253
type CheckCtx = ()
213254

214-
checkIMArray, checkIMPArray, checkIPArray :: String -> r -> r
255+
checkIMArray, checkIMPArray, checkWMPArray, checkIPArray :: String -> r -> r
215256
checkCArray, checkCMArray, checkRMArray :: String -> r -> r
216257
checkIMArray _ = id
217258
checkIMPArray _ = id
259+
checkWMPArray _ = id
218260
checkCArray _ = id
219261
checkCMArray _ = id
220262
checkRMArray _ = id
221263
checkIPArray _ = id
222264

223-
checkIBArray, checkIMBArray :: String -> a -> r -> r
265+
checkIBArray, checkIMBArray:: String -> a -> r -> r
224266
checkCBArray, checkCMBArray :: String -> r -> r
225267
checkIBArray _ _ = id
226268
checkIMBArray _ _ = id
227269
checkCBArray _ = id
228270
checkCMBArray _ = id
271+
272+
checkWMBArray :: String -> r -> r
273+
checkWMBArray _ = id
229274
#endif
230275

231276
readArray ::
@@ -301,7 +346,7 @@ writeByteArray ::
301346
Int ->
302347
a ->
303348
m ()
304-
writeByteArray = checkIMBArray @a "writeByteArray" undefined PA.writeByteArray
349+
writeByteArray = checkWMBArray "writeByteArray" PA.writeByteArray
305350
{-# INLINE writeByteArray #-}
306351

307352
indexByteArray ::
@@ -368,7 +413,7 @@ writePrimArray ::
368413
Int ->
369414
a ->
370415
m ()
371-
writePrimArray = checkIMPArray "writePrimArray" PA.writePrimArray
416+
writePrimArray = checkWMPArray "writePrimArray" PA.writePrimArray
372417
{-# INLINE writePrimArray #-}
373418

374419
indexPrimArray ::

0 commit comments

Comments
 (0)