mm/pmm: Add overflow checks and memory leak fixes
diff --git a/common/mm/pmm.s2.c b/common/mm/pmm.s2.c
index 185a63ae..81a74212 100644
--- a/common/mm/pmm.s2.c
+++ b/common/mm/pmm.s2.c
@@ -177,7 +177,9 @@ void pmm_sanitise_entries(struct memmap_entry *m, size_t *_count, bool align_ent
continue;
if (!pmm_sanitiser_keep_first_page && m[i].base < 0x1000) {
- if (m[i].base + m[i].length <= 0x1000) {
+ uint64_t entry_top;
+ if (__builtin_add_overflow(m[i].base, m[i].length, &entry_top) ||
+ entry_top <= 0x1000) {
goto del_mm1;
}
@@ -196,7 +198,7 @@ del_mm1:
}
// Sort the entries
- for (size_t p = 0; p < count - 1; p++) {
+ for (size_t p = 0; p + 1 < count; p++) {
uint64_t min = m[p].base;
size_t min_index = p;
for (size_t i = p; i < count; i++) {
@@ -211,7 +213,7 @@ del_mm1:
}
// Merge contiguous bootloader-reclaimable, ACPI tables, usable entries
- for (size_t i = 0; i < count - 1; i++) {
+ for (size_t i = 0; i + 1 < count; i++) {
if (m[i].type != MEMMAP_BOOTLOADER_RECLAIMABLE
&& m[i].type != MEMMAP_ACPI_TABLES
&& m[i].type != MEMMAP_USABLE)
@@ -320,16 +322,22 @@ void init_memmap(void) {
status = gBS->AllocatePool(EfiLoaderData, memmap_max_entries * sizeof(struct memmap_entry), (void **)&memmap);
if (status) {
+ gBS->FreePool(efi_mmap);
goto fail;
}
status = gBS->AllocatePool(EfiLoaderData, memmap_max_entries * sizeof(struct memmap_entry), (void **)&untouched_memmap);
if (status) {
+ gBS->FreePool(efi_mmap);
+ gBS->FreePool(memmap);
goto fail;
}
status = gBS->GetMemoryMap(&efi_mmap_size, efi_mmap, &mmap_key, &efi_desc_size, &efi_desc_ver);
if (status) {
+ gBS->FreePool(efi_mmap);
+ gBS->FreePool(memmap);
+ gBS->FreePool(untouched_memmap);
goto fail;
}
@@ -689,12 +697,18 @@ static bool pmm_new_entry(struct memmap_entry *m, size_t *_count,
uint64_t base, uint64_t length, uint32_t type) {
size_t count = *_count;
- uint64_t top = base + length;
+ uint64_t top;
+ if (__builtin_add_overflow(base, length, &top)) {
+ panic(false, "pmm: Integer overflow in memory range calculation");
+ }
// Handle overlapping new entries.
for (size_t i = 0; i < count; i++) {
uint64_t entry_base = m[i].base;
- uint64_t entry_top = m[i].base + m[i].length;
+ uint64_t entry_top;
+ if (__builtin_add_overflow(m[i].base, m[i].length, &entry_top)) {
+ continue; // Skip malformed entries
+ }
// Full overlap
if (base <= entry_base && top >= entry_top) {
