Skip to content

Commit

Permalink
Fix potential overflows in mspace bitmask handling
Browse files Browse the repository at this point in the history
  • Loading branch information
nickg committed Aug 21, 2024
1 parent 0087f5a commit efed9c3
Show file tree
Hide file tree
Showing 6 changed files with 60 additions and 55 deletions.
4 changes: 2 additions & 2 deletions src/jit/jit-optim.c
Original file line number Diff line number Diff line change
Expand Up @@ -1299,7 +1299,7 @@ static void lscan_walk_cfg(jit_func_t *f, jit_cfg_t *cfg, int bi,

jit_block_t *b = &(cfg->blocks[bi]);

for (int bit = -1; mask_iter(&b->livein, &bit);)
for (size_t bit = -1; mask_iter(&b->livein, &bit);)
lscan_grow_range(bit, li, b->first);

for (int i = b->first; i <= b->last; i++) {
Expand All @@ -1315,7 +1315,7 @@ static void lscan_walk_cfg(jit_func_t *f, jit_cfg_t *cfg, int bi,
lscan_grow_range(ir->arg2.reg, li, i);
}

for (int bit = -1; mask_iter(&b->liveout, &bit); )
for (size_t bit = -1; mask_iter(&b->liveout, &bit); )
lscan_grow_range(bit, li, b->last);

for (int i = 0; i < b->out.count; i++) {
Expand Down
56 changes: 28 additions & 28 deletions src/mask.c
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ static inline uint64_t mask_for_range(int low, int high)
return mask;
}

void mask_clear_range(bit_mask_t *m, int start, int count)
void mask_clear_range(bit_mask_t *m, size_t start, size_t count)
{
if (m->size <= 64) {
m->bits &= ~mask_for_range(start, start + count - 1);
Expand All @@ -68,9 +68,9 @@ void mask_clear_range(bit_mask_t *m, int start, int count)

if (count > 0 && start % 64 != 0) {
// Pre-loop: clear range of bits in first 64-bit word
const int low = start % 64;
const int high = MIN(low + count - 1, 63);
const int nbits = high - low + 1;
const size_t low = start % 64;
const size_t high = MIN(low + count - 1, 63);
const size_t nbits = high - low + 1;
m->ptr[start / 64] &= ~mask_for_range(low, high);
start += nbits;
count -= nbits;
Expand All @@ -91,7 +91,7 @@ void mask_clear_range(bit_mask_t *m, int start, int count)
}
}

void mask_set_range(bit_mask_t *m, int start, int count)
void mask_set_range(bit_mask_t *m, size_t start, size_t count)
{
if (m->size <= 64) {
m->bits |= mask_for_range(start, start + count - 1);
Expand All @@ -100,9 +100,9 @@ void mask_set_range(bit_mask_t *m, int start, int count)

if (count > 0 && start % 64 != 0) {
// Pre-loop: set range of bits in first 64-bit word
const int low = start % 64;
const int high = MIN(low + count - 1, 63);
const int nbits = high - low + 1;
const size_t low = start % 64;
const size_t high = MIN(low + count - 1, 63);
const size_t nbits = high - low + 1;
m->ptr[start / 64] |= mask_for_range(low, high);
start += nbits;
count -= nbits;
Expand All @@ -123,16 +123,16 @@ void mask_set_range(bit_mask_t *m, int start, int count)
}
}

bool mask_test_range(bit_mask_t *m, int start, int count)
bool mask_test_range(bit_mask_t *m, size_t start, size_t count)
{
if (m->size <= 64)
return !!(m->bits & mask_for_range(start, start + count - 1));

if (count > 0 && start % 64 != 0) {
// Pre-loop: test range of bits in first 64-bit word
const int low = start % 64;
const int high = MIN(low + count - 1, 63);
const int nbits = high - low + 1;
const size_t low = start % 64;
const size_t high = MIN(low + count - 1, 63);
const size_t nbits = high - low + 1;

if (m->ptr[start / 64] & mask_for_range(low, high))
return true;
Expand Down Expand Up @@ -163,11 +163,11 @@ bool mask_test_range(bit_mask_t *m, int start, int count)
return false;
}

int mask_popcount(bit_mask_t *m)
size_t mask_popcount(bit_mask_t *m)
{
if (m->size > 64) {
int sum = 0;
for (int i = 0; i < (m->size + 63) / 64; i++)
size_t sum = 0;
for (ssize_t i = 0; i < (m->size + 63) / 64; i++)
sum += __builtin_popcountll(m->ptr[i]);
return sum;
}
Expand All @@ -185,7 +185,7 @@ void mask_clearall(bit_mask_t *m)
mask_clear_range(m, 0, m->size);
}

int mask_scan_backwards(bit_mask_t *m, int bit)
ssize_t mask_scan_backwards(bit_mask_t *m, size_t bit)
{
if (m->size <= 64) {
uint64_t word0 = m->bits & mask_for_range(0, bit);
Expand All @@ -198,17 +198,17 @@ int mask_scan_backwards(bit_mask_t *m, int bit)
if (word0 != 0)
return (bit | 63) - __builtin_clzll(word0);

bit -= bit % 64 + 1;
ssize_t i = bit - (bit % 64 + 1);

for (; bit > 0 && m->ptr[bit / 64] == 0; bit -= 64);
for (; i > 0 && m->ptr[i / 64] == 0; i -= 64);

if (bit > 0)
return (bit | 63) - __builtin_clzll(m->ptr[bit / 64]);
if (i > 0)
return (i | 63) - __builtin_clzll(m->ptr[i / 64]);

return -1;
}

int mask_count_clear(bit_mask_t *m, int bit)
size_t mask_count_clear(bit_mask_t *m, size_t bit)
{
assert(bit < m->size);

Expand All @@ -217,7 +217,7 @@ int mask_count_clear(bit_mask_t *m, int bit)
return fs > 0 ? fs - 1 - bit : m->size - bit;
}

int count = 0;
size_t count = 0;

const int modbits = bit % 64, maxbits = MIN(64, m->size - (bit & ~63));
if (modbits > 0) {
Expand Down Expand Up @@ -251,7 +251,7 @@ void mask_subtract(bit_mask_t *m, const bit_mask_t *m2)
assert(m->size == m2->size);

if (m->size > 64) {
for (int i = 0; i < (m->size + 63) / 64; i++)
for (ssize_t i = 0; i < (m->size + 63) / 64; i++)
m->ptr[i] &= ~m2->ptr[i];
}
else
Expand All @@ -263,7 +263,7 @@ void mask_union(bit_mask_t *m, const bit_mask_t *m2)
assert(m->size == m2->size);

if (m->size > 64) {
for (int i = 0; i < (m->size + 63) / 64; i++)
for (ssize_t i = 0; i < (m->size + 63) / 64; i++)
m->ptr[i] |= m2->ptr[i];
}
else
Expand All @@ -275,7 +275,7 @@ void mask_copy(bit_mask_t *m, const bit_mask_t *m2)
assert(m->size == m2->size);

if (m->size > 64) {
for (int i = 0; i < (m->size + 63) / 64; i++)
for (ssize_t i = 0; i < (m->size + 63) / 64; i++)
m->ptr[i] = m2->ptr[i];
}
else
Expand All @@ -287,7 +287,7 @@ bool mask_eq(const bit_mask_t *m1, const bit_mask_t *m2)
assert(m1->size == m2->size);

if (m1->size > 64) {
for (int i = 0; i < (m1->size + 63) / 64; i++) {
for (ssize_t i = 0; i < (m1->size + 63) / 64; i++) {
if (m1->ptr[i] != m2->ptr[i])
return false;
}
Expand All @@ -298,12 +298,12 @@ bool mask_eq(const bit_mask_t *m1, const bit_mask_t *m2)
return m1->bits == m2->bits;
}

bool mask_iter(bit_mask_t *m, int *bit)
bool mask_iter(bit_mask_t *m, size_t *bit)
{
if (*bit + 1 < 0 || *bit + 1 >= m->size)
return false;
else if (m->size > 64) {
int word = (*bit + 1) / 64;
size_t word = (*bit + 1) / 64;

if ((*bit + 1) % 64 > 0) {
const uint64_t remain = m->ptr[word] & ~mask_for_range(0, *bit % 64);
Expand Down
24 changes: 13 additions & 11 deletions src/mask.h
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
//
// Copyright (C) 2022-2023 Nick Gasson
// Copyright (C) 2022-2024 Nick Gasson
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
Expand All @@ -21,6 +21,8 @@
#include "prim.h"

#include <stdbool.h>
#include <stddef.h>
#include <sys/types.h>

typedef struct _bit_mask {
size_t size;
Expand All @@ -31,15 +33,15 @@ typedef struct _bit_mask {
} bit_mask_t;

#define mask_test(m, bit) ({ \
const int _bit = (bit); \
const size_t _bit = (bit); \
assert(_bit < (m)->size); \
(m)->size > 64 \
? !!((m)->ptr[_bit / 64] & (UINT64_C(1) << (_bit % 64))) \
: !!((m)->bits & (UINT64_C(1) << (_bit % 64))); \
})

#define mask_set(m, bit) do { \
const int _bit = (bit); \
const size_t _bit = (bit); \
assert(_bit < (m)->size); \
if ((m)->size > 64) \
(m)->ptr[_bit / 64] |= (UINT64_C(1) << (_bit % 64)); \
Expand All @@ -48,7 +50,7 @@ typedef struct _bit_mask {
} while (0)

#define mask_clear(m, bit) do { \
const int _bit = (bit); \
const size_t _bit = (bit); \
assert(_bit < (m)->size); \
if ((m)->size > 64) \
(m)->ptr[_bit / 64] &= ~(UINT64_C(1) << (_bit % 64)); \
Expand All @@ -60,18 +62,18 @@ typedef struct _bit_mask {

void mask_init(bit_mask_t *m, size_t size);
void mask_free(bit_mask_t *m);
void mask_clear_range(bit_mask_t *m, int start, int count);
void mask_set_range(bit_mask_t *m, int start, int count);
bool mask_test_range(bit_mask_t *m, int start, int count);
int mask_popcount(bit_mask_t *m);
void mask_clear_range(bit_mask_t *m, size_t start, size_t count);
void mask_set_range(bit_mask_t *m, size_t start, size_t count);
bool mask_test_range(bit_mask_t *m, size_t start, size_t count);
size_t mask_popcount(bit_mask_t *m);
void mask_setall(bit_mask_t *m);
void mask_clearall(bit_mask_t *m);
int mask_scan_backwards(bit_mask_t *m, int bit);
int mask_count_clear(bit_mask_t *m, int bit);
ssize_t mask_scan_backwards(bit_mask_t *m, size_t bit);
size_t mask_count_clear(bit_mask_t *m, size_t bit);
void mask_subtract(bit_mask_t *m, const bit_mask_t *m2);
void mask_union(bit_mask_t *m, const bit_mask_t *m2);
void mask_copy(bit_mask_t *m, const bit_mask_t *m2);
bool mask_eq(const bit_mask_t *m1, const bit_mask_t *m2);
bool mask_iter(bit_mask_t *m, int *bit);
bool mask_iter(bit_mask_t *m, size_t *bit);

#endif // _MASK_H
6 changes: 3 additions & 3 deletions src/rt/model.c
Original file line number Diff line number Diff line change
Expand Up @@ -229,9 +229,9 @@ static const char *trace_states(bit_mask_t *mask)
tb_rewind(tb);
tb_append(tb, '{');

int bit = -1;
size_t bit = -1;
while (mask_iter(mask, &bit))
tb_printf(tb, "%s%d", tb_len(tb) > 1 ? "," : "", bit);
tb_printf(tb, "%s%zd", tb_len(tb) > 1 ? "," : "", bit);

tb_append(tb, '}');

Expand Down Expand Up @@ -2453,7 +2453,7 @@ static void update_property(rt_model_t *m, rt_prop_t *prop)

mask_clearall(&prop->newstate);

int bit = -1;
size_t bit = -1;
while (mask_iter(&prop->state, &bit)) {
jit_scalar_t state = { .integer = bit }, result;
if (!jit_fastcall(m->jit, prop->handle, &result, context,
Expand Down
21 changes: 12 additions & 9 deletions src/rt/mspace.c
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
//
// Copyright (C) 2022-2023 Nick Gasson
// Copyright (C) 2022-2024 Nick Gasson
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
Expand Down Expand Up @@ -82,7 +82,7 @@ struct _free_list {
struct _mspace {
nvc_lock_t lock;
size_t maxsize;
unsigned maxlines;
size_t maxlines;
char *space;
bit_mask_t headmask;
mptr_t roots;
Expand Down Expand Up @@ -184,7 +184,7 @@ static void *mspace_try_alloc(mspace_t *m, size_t size)
{
// Add one to size before rounding up to LINE_SIZE to allow a valid
// pointer to point at one element past the end of an array
const int nlines = (size + LINE_SIZE) / LINE_SIZE;
const size_t nlines = (size + LINE_SIZE) / LINE_SIZE;
const size_t asize = nlines * LINE_SIZE;

SCOPED_LOCK(m->lock);
Expand All @@ -194,10 +194,12 @@ static void *mspace_try_alloc(mspace_t *m, size_t size)
if ((*it)->size >= asize) {
char *base = (*it)->ptr;
assert((uintptr_t)base % LINE_SIZE == 0);
assert(base >= m->space);
assert(base < m->space + m->maxsize);

MSPACE_UNPOISON(base, size);

const int line = (base - m->space) / LINE_SIZE;
const ptrdiff_t line = (base - m->space) / LINE_SIZE;
mask_set(&(m->headmask), line);
if (nlines > 1)
mask_clear_range(&(m->headmask), line + 1, nlines - 1);
Expand Down Expand Up @@ -429,9 +431,10 @@ static void mspace_mark_root(mspace_t *m, intptr_t p, gc_state_t *state)
line = mask_scan_backwards(&(m->headmask), line);
assert(line != -1);

int objlen = 1;
size_t objlen = 1;
if (line + 1 < m->maxlines)
objlen += mask_count_clear(&(m->headmask), line + 1);
assert(objlen < UINT32_MAX);

if (!mask_test(&(state->markmask), line)) {
mask_set_range(&(state->markmask), line, objlen);
Expand Down Expand Up @@ -520,7 +523,7 @@ static void mspace_gc(mspace_t *m)
const uint32_t line = enc >> 32;
const uint32_t objlen = enc & 0xffffffff;

for (int i = 0; i < objlen; i++) {
for (size_t i = 0; i < objlen; i++) {
const ptrdiff_t off = (uintptr_t)(line + i) * LINE_SIZE;
intptr_t *words = (intptr_t *)(m->space + off);
for (int j = 0; j < LINE_WORDS; j++)
Expand All @@ -543,7 +546,7 @@ static void mspace_gc(mspace_t *m)

int freefrags = 0, freelines = 0;
free_list_t **tail = &(m->free_list);
for (int line = 0; line < m->maxlines;) {
for (size_t line = 0; line < m->maxlines;) {
const int clear = mask_count_clear(&(state.markmask), line);
if (clear == 0)
line++;
Expand All @@ -569,7 +572,7 @@ static void mspace_gc(mspace_t *m)

if (opt_get_verbose(OPT_GC_VERBOSE, NULL)) {
const int ticks = get_timestamp_us() - start_ticks;
debugf("GC: allocated %d/%zu; fragmentation %.2g%% [%d us]",
debugf("GC: allocated %zd/%zu; fragmentation %.2g%% [%d us]",
mask_popcount(&(state.markmask)) * LINE_SIZE, m->maxsize,
((double)(freefrags - 1) / (double)freelines) * 100.0, ticks);

Expand All @@ -590,7 +593,7 @@ void *mspace_find(mspace_t *m, void *ptr, size_t *size)
return NULL;
}

int line = ((char *)ptr - m->space) / LINE_SIZE;
ptrdiff_t line = ((char *)ptr - m->space) / LINE_SIZE;

// Scan backwards to the start of the object
line = mask_scan_backwards(&(m->headmask), line);
Expand Down
4 changes: 2 additions & 2 deletions test/test_misc.c
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
//
// Copyright (C) 2021-2023 Nick Gasson
// Copyright (C) 2021-2024 Nick Gasson
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
Expand Down Expand Up @@ -545,7 +545,7 @@ START_TEST(test_mask_iter)
mask_set(&m, 6);
mask_set(&m, 17);

int bit = -1;
size_t bit = -1;
fail_unless(mask_iter(&m, &bit));
fail_unless(bit == 1);
fail_unless(mask_iter(&m, &bit));
Expand Down

0 comments on commit efed9c3

Please sign in to comment.