mm: Add CHECKED_ADD macro, use it for all base+length computations
diff --git a/common/lib/elsewhere.c b/common/lib/elsewhere.c
index 67d63f8b..960020af 100644
--- a/common/lib/elsewhere.c
+++ b/common/lib/elsewhere.c
@@ -21,7 +21,7 @@ bool elsewhere_append(
uint64_t top = 0;
for (size_t i = 0; i < *ranges_count; i++) {
- uint64_t r_top = ranges[i].target + ranges[i].length;
+ uint64_t r_top = CHECKED_ADD(ranges[i].target, ranges[i].length, continue);
if (top < r_top) {
top = r_top;
@@ -39,7 +39,7 @@ retry:
}
for (size_t i = 0; i < *ranges_count; i++) {
- uint64_t t_top = *target + t_length;
+ uint64_t t_top = CHECKED_ADD(*target, t_length, return false);
// Ensure allocation stays within 32-bit address space.
if (t_top > 0x100000000) {
@@ -50,7 +50,7 @@ retry:
{
uint64_t base = ranges[i].target;
uint64_t length = ranges[i].length;
- uint64_t top = base + length;
+ uint64_t top = CHECKED_ADD(base, length, continue);
if (elsewhere_overlap_check(base, top, *target, t_top)) {
if (!flexible_target) {
@@ -65,7 +65,7 @@ retry:
{
uint64_t base = ranges[i].elsewhere;
uint64_t length = ranges[i].length;
- uint64_t top = base + length;
+ uint64_t top = CHECKED_ADD(base, length, continue);
if (elsewhere_overlap_check(base, top, *target, t_top)) {
if (!flexible_target) {
diff --git a/common/lib/misc.h b/common/lib/misc.h
index e8de42b3..4f074bc5 100644
--- a/common/lib/misc.h
+++ b/common/lib/misc.h
@@ -71,6 +71,16 @@ uint64_t strtoui(const char *s, const char **end, int base);
MAX_a > MAX_b ? MAX_a : MAX_b; \
})
+#define CHECKED_ADD(a, b, onerror) ({ \
+ __auto_type CHECKED_ADD_a = (a); \
+ __auto_type CHECKED_ADD_b = (b); \
+ typeof(CHECKED_ADD_a + CHECKED_ADD_b) CHECKED_ADD_res; \
+ if (__builtin_add_overflow(CHECKED_ADD_a, CHECKED_ADD_b, &CHECKED_ADD_res)) { \
+ onerror; \
+ } \
+ CHECKED_ADD_res; \
+})
+
#define DIV_ROUNDUP(a, b) ({ \
__auto_type DIV_ROUNDUP_a = (a); \
__auto_type DIV_ROUNDUP_b = (b); \
diff --git a/common/mm/pmm.c b/common/mm/pmm.c
index 123c1569..9dde1ec4 100644
--- a/common/mm/pmm.c
+++ b/common/mm/pmm.c
@@ -4,6 +4,7 @@
#include <mm/pmm.h>
#include <lib/rand.h>
#include <lib/print.h>
+#include <lib/misc.h>
static bool full_overlap_check(uint64_t base1, uint64_t top1,
uint64_t base2, uint64_t top2) {
@@ -24,7 +25,7 @@ bool check_usable_memory(uint64_t base, uint64_t top) {
continue;
}
- uint64_t memmap_top = memmap[i].base + memmap[i].length;
+ uint64_t memmap_top = CHECKED_ADD(memmap[i].base, memmap[i].length, continue);
if (full_overlap_check(base, top, memmap[i].base, memmap_top)) {
return true;
diff --git a/common/mm/pmm.s2.c b/common/mm/pmm.s2.c
index b1c54222..2210fbc0 100644
--- a/common/mm/pmm.s2.c
+++ b/common/mm/pmm.s2.c
@@ -137,11 +137,11 @@ void pmm_sanitise_entries(struct memmap_entry *m, size_t *_count, bool align_ent
uint64_t base = m[i].base;
uint64_t length = m[i].length;
- uint64_t top = base + length;
+ uint64_t top = CHECKED_ADD(base, length, goto del_mm0);
uint64_t res_base = m[j].base;
uint64_t res_length = m[j].length;
- uint64_t res_top = res_base + res_length;
+ uint64_t res_top = CHECKED_ADD(res_base, res_length, continue);
// Non-usable entry fully contains usable entry
if (res_base <= base && res_top >= top) {
@@ -170,6 +170,7 @@ void pmm_sanitise_entries(struct memmap_entry *m, size_t *_count, bool align_ent
if (!m[i].length
|| (align_entries && !align_entry(&m[i].base, &m[i].length))) {
+del_mm0:
// Remove i from memmap
if (i < count - 1) {
m[i] = m[count - 1];
@@ -184,9 +185,8 @@ void pmm_sanitise_entries(struct memmap_entry *m, size_t *_count, bool align_ent
continue;
if (!pmm_sanitiser_keep_first_page && m[i].base < 0x1000) {
- uint64_t entry_top;
- if (__builtin_add_overflow(m[i].base, m[i].length, &entry_top) ||
- entry_top <= 0x1000) {
+ uint64_t entry_top = CHECKED_ADD(m[i].base, m[i].length, goto del_mm1);
+ if (entry_top <= 0x1000) {
goto del_mm1;
}
@@ -226,9 +226,10 @@ del_mm1:
&& m[i].type != MEMMAP_USABLE)
continue;
+ uint64_t merge_top = CHECKED_ADD(m[i].base, m[i].length, continue);
if (m[i+1].type == m[i].type
- && m[i+1].base == m[i].base + m[i].length) {
- m[i].length += m[i+1].length;
+ && m[i+1].base == merge_top) {
+ m[i].length = CHECKED_ADD(m[i].length, m[i+1].length, continue);
// Eradicate from memmap
for (size_t j = i + 2; j < count; j++) {
@@ -273,7 +274,7 @@ void init_memmap(void) {
memmap[memmap_entries] = e820_map[i];
- uint64_t top = memmap[memmap_entries].base + memmap[memmap_entries].length;
+ uint64_t top = CHECKED_ADD(memmap[memmap_entries].base, memmap[memmap_entries].length, continue);
if (memmap[memmap_entries].type == MEMMAP_USABLE) {
if (memmap[memmap_entries].base >= EBDA && memmap[memmap_entries].base < 0x100000) {
@@ -414,7 +415,7 @@ void init_memmap(void) {
EFI_PHYSICAL_ADDRESS base = untouched_memmap[i].base;
#if defined (__i386__)
- if (untouched_memmap[i].base + untouched_memmap[i].length > 0x100000000) {
+ if (CHECKED_ADD(untouched_memmap[i].base, untouched_memmap[i].length, continue) > 0x100000000) {
continue;
}
#endif
@@ -473,7 +474,7 @@ static void pmm_reclaim_uefi_mem(struct memmap_entry *m, size_t *_count, bool ra
EFI_MEMORY_DESCRIPTOR *entry = (void *)efi_mmap + i * efi_desc_size;
uint64_t base = r->base;
- uint64_t top = base + r->length;
+ uint64_t top = CHECKED_ADD(base, r->length, continue);
uint64_t efi_base = entry->PhysicalStart;
uint64_t efi_size;
if (__builtin_mul_overflow(entry->NumberOfPages, (uint64_t)4096, &efi_size)) {
@@ -487,7 +488,7 @@ static void pmm_reclaim_uefi_mem(struct memmap_entry *m, size_t *_count, bool ra
efi_base = base;
}
- uint64_t efi_top = efi_base + efi_size;
+ uint64_t efi_top = CHECKED_ADD(efi_base, efi_size, continue);
if (efi_top > top) {
if (efi_size <= efi_top - top)
@@ -653,7 +654,7 @@ again:
continue;
int64_t entry_base = (int64_t)(memmap[i].base);
- int64_t entry_top = (int64_t)(memmap[i].base + memmap[i].length);
+ int64_t entry_top = (int64_t)CHECKED_ADD(memmap[i].base, memmap[i].length, continue);
if ((uint64_t)entry_top > limit) {
entry_top = (int64_t)limit;
@@ -720,7 +721,7 @@ struct meminfo mmap_get_info(size_t mmap_count, struct memmap_entry *mmap) {
if (mmap[i].type != MEMMAP_USABLE)
continue;
uint64_t base = mmap[i].base;
- uint64_t top = base + mmap[i].length;
+ uint64_t top = CHECKED_ADD(base, mmap[i].length, continue);
if (base <= lower_end && top > lower_end) {
lower_end = top;
progress = true;
@@ -745,18 +746,12 @@ static bool pmm_new_entry(struct memmap_entry *m, size_t *_count,
uint64_t base, uint64_t length, uint32_t type) {
size_t count = *_count;
- uint64_t top;
- if (__builtin_add_overflow(base, length, &top)) {
- panic(false, "pmm: Integer overflow in memory range calculation");
- }
+ uint64_t top = CHECKED_ADD(base, length, panic(false, "pmm: Integer overflow in memory range calculation"));
// Handle overlapping new entries.
for (size_t i = 0; i < count; i++) {
uint64_t entry_base = m[i].base;
- uint64_t entry_top;
- if (__builtin_add_overflow(m[i].base, m[i].length, &entry_top)) {
- continue; // Skip malformed entries
- }
+ uint64_t entry_top = CHECKED_ADD(m[i].base, m[i].length, continue);
// Full overlap
if (base <= entry_base && top >= entry_top) {
@@ -819,7 +814,7 @@ static bool pmm_new_entry(struct memmap_entry *m, size_t *_count,
uint64_t pmm_check_type(uint64_t addr) {
for (size_t i = 0; i < memmap_entries; i++) {
uint64_t entry_base = memmap[i].base;
- uint64_t entry_top = memmap[i].base + memmap[i].length;
+ uint64_t entry_top = CHECKED_ADD(memmap[i].base, memmap[i].length, continue);
if (addr >= entry_base && addr < entry_top) {
return memmap[i].type;
@@ -840,21 +835,18 @@ bool memmap_alloc_range_in(struct memmap_entry *m, size_t *_count,
return true;
}
- uint64_t top;
- if (__builtin_add_overflow(base, length, &top)) {
+ uint64_t top = CHECKED_ADD(base, length, ({
if (do_panic)
panic(false, "Memory allocation overflow.");
return false;
- }
+ }));
for (size_t i = 0; i < count; i++) {
if (overlay_type != 0 && m[i].type != overlay_type)
continue;
uint64_t entry_base = m[i].base;
- uint64_t entry_top;
- if (__builtin_add_overflow(m[i].base, m[i].length, &entry_top))
- continue;
+ uint64_t entry_top = CHECKED_ADD(m[i].base, m[i].length, continue);
if (base >= entry_base && base < entry_top && top <= entry_top) {
if (simulation)
