:: commit d7f0160ddc7a9cf9763a5e4664006ac0d78261e2

Mintsuki <mintsuki@protonmail.com> — 2026-03-31 12:58

parents: cc586e6cc0

mm/pmm: Add ext_mem_alloc_counted(), use it everywhere for checked array allocations

diff --git a/common/drivers/gop.c b/common/drivers/gop.c
index 30f8517f..3d2ae48b 100644
--- a/common/drivers/gop.c
+++ b/common/drivers/gop.c
@@ -189,7 +189,7 @@ static bool try_mode(struct fb_info *ret, EFI_GRAPHICS_OUTPUT_PROTOCOL *gop,
 static struct fb_info *get_mode_list(size_t *count, EFI_GRAPHICS_OUTPUT_PROTOCOL *gop) {
     UINTN modes_count = gop->Mode->MaxMode;
 
-    struct fb_info *ret = ext_mem_alloc(modes_count * sizeof(struct fb_info));
+    struct fb_info *ret = ext_mem_alloc_counted(modes_count, sizeof(struct fb_info));
 
     size_t actual_count = 0;
     for (size_t i = 0; i < modes_count; i++) {
@@ -198,7 +198,7 @@ static struct fb_info *get_mode_list(size_t *count, EFI_GRAPHICS_OUTPUT_PROTOCOL
         }
     }
 
-    struct fb_info *tmp = ext_mem_alloc(actual_count * sizeof(struct fb_info));
+    struct fb_info *tmp = ext_mem_alloc_counted(actual_count, sizeof(struct fb_info));
     memcpy(tmp, ret, actual_count * sizeof(struct fb_info));
 
     pmm_free(ret, modes_count * sizeof(struct fb_info));
@@ -248,7 +248,7 @@ void init_gop(struct fb_info **ret, size_t *_fbs_count,
 
     size_t handles_count = handles_size / sizeof(EFI_HANDLE);
 
-    *ret = ext_mem_alloc(handles_count * sizeof(struct fb_info));
+    *ret = ext_mem_alloc_counted(handles_count, sizeof(struct fb_info));
 
     const struct resolution fallback_resolutions[] = {
         { 0,    0,   0  },   // Overridden by EDID
diff --git a/common/drivers/vbe.c b/common/drivers/vbe.c
index eff5c000..4c7c4dbc 100644
--- a/common/drivers/vbe.c
+++ b/common/drivers/vbe.c
@@ -174,7 +174,7 @@ struct fb_info *vbe_get_mode_list(size_t *count) {
         modes_count++;
     }
 
-    struct fb_info *ret = ext_mem_alloc(modes_count * sizeof(struct fb_info));
+    struct fb_info *ret = ext_mem_alloc_counted(modes_count, sizeof(struct fb_info));
 
     for (size_t i = 0, j = 0; i < VBE_MAX_MODES && vid_modes[i] != 0xffff; i++) {
         struct vbe_mode_info_struct vbe_mode_info;
diff --git a/common/fs/fat32.s2.c b/common/fs/fat32.s2.c
index dc287868..7625a219 100644
--- a/common/fs/fat32.s2.c
+++ b/common/fs/fat32.s2.c
@@ -337,16 +337,12 @@ static uint32_t *cache_cluster_chain(struct fat32_context *context,
         return NULL;
     }
 
-    size_t alloc_size;
-    if (__builtin_mul_overflow(chain_length, sizeof(uint32_t), &alloc_size)) {
-        return NULL;
-    }
-    uint32_t *cluster_chain = ext_mem_alloc(alloc_size);
+    uint32_t *cluster_chain = ext_mem_alloc_counted(chain_length, sizeof(uint32_t));
     cluster = initial_cluster;
     for (size_t i = 0; i < chain_length; i++) {
         cluster_chain[i] = cluster;
         if (read_cluster_from_map(context, cluster, &cluster) != 0) {
-            pmm_free(cluster_chain, alloc_size);
+            pmm_free(cluster_chain, chain_length * sizeof(uint32_t));
             return NULL;
         }
     }
diff --git a/common/fs/iso9660.s2.c b/common/fs/iso9660.s2.c
index a953c277..80b4c3c4 100644
--- a/common/fs/iso9660.s2.c
+++ b/common/fs/iso9660.s2.c
@@ -438,7 +438,7 @@ struct file_handle *iso9660_open(struct volume *vol, const char *path) {
             }
 
             // Allocate extent array
-            ret->extents = ext_mem_alloc(extent_count * sizeof(struct iso9660_extent));
+            ret->extents = ext_mem_alloc_counted(extent_count, sizeof(struct iso9660_extent));
             ret->extent_count = extent_count;
             ret->total_size = total_size;
 
diff --git a/common/lib/elf.c b/common/lib/elf.c
index cfe5b6f6..1b12f342 100644
--- a/common/lib/elf.c
+++ b/common/lib/elf.c
@@ -465,7 +465,7 @@ end_of_pt_segment:
         }
         relocs_i += dt_pltrelsz / rela_ent;
     }
-    struct elf64_rela **relocs = ext_mem_alloc(relocs_i * sizeof(struct elf64_rela *));
+    struct elf64_rela **relocs = ext_mem_alloc_counted(relocs_i, sizeof(struct elf64_rela *));
 
     if (relr_size != 0) {
         size_t relr_i;
@@ -780,7 +780,7 @@ static void elf64_get_ranges(uint8_t *elf, uint64_t slide, struct mem_range **_r
         panic(true, "elf: No higher half PHDRs exist");
     }
 
-    struct mem_range *ranges = ext_mem_alloc(ranges_count * sizeof(struct mem_range));
+    struct mem_range *ranges = ext_mem_alloc_counted(ranges_count, sizeof(struct mem_range));
 
     size_t r = 0;
     for (uint16_t i = 0; i < hdr->ph_num; i++) {
diff --git a/common/lib/gterm.c b/common/lib/gterm.c
index 975184ff..dd0248ea 100644
--- a/common/lib/gterm.c
+++ b/common/lib/gterm.c
@@ -786,7 +786,7 @@ bool gterm_init(struct fb_info **_fbs, size_t *_fbs_count,
     gterm_parse_config(config, &cfg);
 
     terms_i = 0;
-    terms = ext_mem_alloc(fbs_count * sizeof(void *));
+    terms = ext_mem_alloc_counted(fbs_count, sizeof(void *));
 
     for (size_t i = 0; i < fbs_count; i++) {
         struct fb_info *fb = &fbs[i];
diff --git a/common/lib/pe.c b/common/lib/pe.c
index c5e52a66..8c727fd8 100644
--- a/common/lib/pe.c
+++ b/common/lib/pe.c
@@ -455,7 +455,7 @@ again:
             range_count++;
         }
 
-        struct mem_range *ranges = ext_mem_alloc(range_count * sizeof(struct mem_range));
+        struct mem_range *ranges = ext_mem_alloc_counted(range_count, sizeof(struct mem_range));
 
         *_ranges = ranges;
         *_ranges_count = range_count;
diff --git a/common/lib/rand.c b/common/lib/rand.c
index 6d34ce1f..f31419d5 100644
--- a/common/lib/rand.c
+++ b/common/lib/rand.c
@@ -38,7 +38,7 @@ static void init_rand(void) {
     }
 #endif
 
-    status = ext_mem_alloc(n * sizeof(uint32_t));
+    status = ext_mem_alloc_counted(n, sizeof(uint32_t));
 
     srand(seed);
 
diff --git a/common/menu.c b/common/menu.c
index 5599f848..2b251492 100644
--- a/common/menu.c
+++ b/common/menu.c
@@ -966,7 +966,7 @@ noreturn void _menu(bool first_run) {
         rewound_bss = ext_mem_alloc(bss_size);
 #endif
         /* addition due to allocation potentially adding new memory map entries */
-        rewound_memmap = ext_mem_alloc((memmap_entries + 16) * sizeof(struct memmap_entry));
+        rewound_memmap = ext_mem_alloc_counted(memmap_entries + 16, sizeof(struct memmap_entry));
         memcpy(rewound_memmap, memmap, memmap_entries * sizeof(struct memmap_entry));
         rewound_memmap_entries = memmap_entries;
         memcpy(rewound_data, data_begin, data_size);
diff --git a/common/mm/pmm.h b/common/mm/pmm.h
index e3bc4045..db14dba5 100644
--- a/common/mm/pmm.h
+++ b/common/mm/pmm.h
@@ -57,6 +57,7 @@ void pmm_randomise_memory(void);
 
 void *ext_mem_alloc_size_t(size_t count);
 void *ext_mem_alloc(uint64_t count);
+void *ext_mem_alloc_counted(uint64_t count, uint64_t elem_size);
 void *ext_mem_alloc_type(uint64_t count, uint32_t type);
 void *ext_mem_alloc_type_aligned(uint64_t count, uint32_t type, size_t alignment);
 void *ext_mem_alloc_type_aligned_mode(uint64_t count, uint32_t type, size_t alignment, bool allow_high_allocs);
diff --git a/common/mm/pmm.s2.c b/common/mm/pmm.s2.c
index 2210fbc0..693b9637 100644
--- a/common/mm/pmm.s2.c
+++ b/common/mm/pmm.s2.c
@@ -445,7 +445,7 @@ void init_memmap(void) {
 
     pmm_sanitise_entries(memmap, &memmap_entries, false);
 
-    recl = ext_mem_alloc(1024 * sizeof(struct memmap_entry));
+    recl = ext_mem_alloc_counted(1024, sizeof(struct memmap_entry));
 
     return;
 
@@ -617,6 +617,14 @@ void *ext_mem_alloc(uint64_t count) {
     return ext_mem_alloc_type(count, MEMMAP_BOOTLOADER_RECLAIMABLE);
 }
 
+void *ext_mem_alloc_counted(uint64_t count, uint64_t elem_size) {
+    uint64_t total;
+    if (__builtin_mul_overflow(count, elem_size, &total)) {
+        panic(false, "ext_mem_alloc_counted: allocation size overflow");
+    }
+    return ext_mem_alloc(total);
+}
+
 void *ext_mem_alloc_type(uint64_t count, uint32_t type) {
     return ext_mem_alloc_type_aligned(count, type, 4096);
 }
diff --git a/common/protos/limine.c b/common/protos/limine.c
index 475ce4e5..773d2b64 100644
--- a/common/protos/limine.c
+++ b/common/protos/limine.c
@@ -107,7 +107,7 @@ static pagemap_t build_identity_map(void) {
 
     size_t _memmap_entries = memmap_entries;
     struct memmap_entry *_memmap =
-        ext_mem_alloc(_memmap_entries * sizeof(struct memmap_entry));
+        ext_mem_alloc_counted(_memmap_entries, sizeof(struct memmap_entry));
     for (size_t i = 0; i < _memmap_entries; i++) {
         _memmap[i] = memmap[i];
     }
@@ -197,7 +197,7 @@ static pagemap_t build_pagemap(int base_revision,
 
     size_t _memmap_entries = memmap_entries;
     struct memmap_entry *_memmap =
-        ext_mem_alloc(_memmap_entries * sizeof(struct memmap_entry));
+        ext_mem_alloc_counted(_memmap_entries, sizeof(struct memmap_entry));
     for (size_t i = 0; i < _memmap_entries; i++)
         _memmap[i] = memmap[i];
 
@@ -594,7 +594,7 @@ noreturn void limine_load(char *config, char *cmdline) {
 
     // Load requests
     uint64_t *limine_reqs = NULL;
-    requests = ext_mem_alloc(MAX_REQUESTS * sizeof(void *));
+    requests = ext_mem_alloc_counted(MAX_REQUESTS, sizeof(void *));
     requests_count = 0;
     if (base_revision == 0 && kernel_format == EXECUTABLE_FORMAT_ELF && elf64_load_section(kernel, kernel_file->size, &limine_reqs, ".limine_reqs", 0, slide)) {
         for (size_t i = 0; ; i++) {
@@ -1209,7 +1209,7 @@ FEAT_START
 
     module_response->revision = 2;
 
-    struct limine_file *modules = ext_mem_alloc(module_count * sizeof(struct limine_file));
+    struct limine_file *modules = ext_mem_alloc_counted(module_count, sizeof(struct limine_file));
 
     size_t final_module_count = 0;
     for (size_t i = 0; i < module_count; i++) {
@@ -1300,7 +1300,7 @@ FEAT_START
         fclose(f);
     }
 
-    uint64_t *modules_list = ext_mem_alloc(final_module_count * sizeof(uint64_t));
+    uint64_t *modules_list = ext_mem_alloc_counted(final_module_count, sizeof(uint64_t));
     for (size_t i = 0; i < final_module_count; i++) {
         modules_list[i] = reported_addr(&modules[i]);
     }
@@ -1400,17 +1400,17 @@ FEAT_START
         break;
     }
 
-    fbp = ext_mem_alloc(fbs_count * sizeof(struct limine_framebuffer));
+    fbp = ext_mem_alloc_counted(fbs_count, sizeof(struct limine_framebuffer));
 
     struct limine_framebuffer_response *framebuffer_response =
         ext_mem_alloc(sizeof(struct limine_framebuffer_response));
 
     framebuffer_response->revision = 1;
 
-    uint64_t *fb_list = ext_mem_alloc(fbs_count * sizeof(uint64_t));
+    uint64_t *fb_list = ext_mem_alloc_counted(fbs_count, sizeof(uint64_t));
 
     for (size_t i = 0; i < fbs_count; i++) {
-        uint64_t *modes_list = ext_mem_alloc(fbs[i].mode_count * sizeof(uint64_t));
+        uint64_t *modes_list = ext_mem_alloc_counted(fbs[i].mode_count, sizeof(uint64_t));
         for (size_t j = 0; j < fbs[i].mode_count; j++) {
             fbs[i].mode_list[j].memory_model = LIMINE_FRAMEBUFFER_RGB;
             modes_list[j] = reported_addr(&fbs[i].mode_list[j]);
@@ -1456,12 +1456,12 @@ FEAT_START
         break;
     }
 
-    struct flanterm_params *fip_raw = ext_mem_alloc(fbs_count * sizeof(struct flanterm_params));
+    struct flanterm_params *fip_raw = ext_mem_alloc_counted(fbs_count, sizeof(struct flanterm_params));
     size_t fip_count = gterm_prepare_flanterm_params(fbs, fbs_count, fip_raw, fbs_count);
 
     struct limine_flanterm_fb_init_params *fip_entries =
-        ext_mem_alloc(fbs_count * sizeof(struct limine_flanterm_fb_init_params));
-    uint64_t *fip_list = ext_mem_alloc(fbs_count * sizeof(uint64_t));
+        ext_mem_alloc_counted(fbs_count, sizeof(struct limine_flanterm_fb_init_params));
+    uint64_t *fip_list = ext_mem_alloc_counted(fbs_count, sizeof(uint64_t));
 
     size_t fip_idx = 0;
     for (size_t i = 0; i < fbs_count; i++) {
@@ -1741,7 +1741,7 @@ FEAT_START
 #error Unknown architecture
 #endif
 
-    uint64_t *mp_list = ext_mem_alloc(cpu_count * sizeof(uint64_t));
+    uint64_t *mp_list = ext_mem_alloc_counted(cpu_count, sizeof(uint64_t));
     for (size_t i = 0; i < cpu_count; i++) {
         mp_list[i] = reported_addr(&mp_info[i]);
     }
@@ -1789,7 +1789,7 @@ FEAT_START
     if (memmap_request != NULL) {
         memmap_response = ext_mem_alloc(sizeof(struct limine_memmap_response));
         _memmap = ext_mem_alloc(sizeof(struct limine_memmap_entry) * MEMMAP_MAX);
-        memmap_list = ext_mem_alloc(MEMMAP_MAX * sizeof(uint64_t));
+        memmap_list = ext_mem_alloc_counted(MEMMAP_MAX, sizeof(uint64_t));
     }
 
     size_t mmap_entries;
diff --git a/common/protos/linux_x86.c b/common/protos/linux_x86.c
index e99d46dc..72e0a9d1 100644
--- a/common/protos/linux_x86.c
+++ b/common/protos/linux_x86.c
@@ -429,7 +429,7 @@ noreturn void linux_load(char *config, char *cmdline) {
         goto no_modules;
     }
 
-    struct file_handle **modules = ext_mem_alloc(module_count * sizeof(struct file_handle *));
+    struct file_handle **modules = ext_mem_alloc_counted(module_count, sizeof(struct file_handle *));
 
     for (size_t i = 0; ; i++) {
         char *module_path = config_get_value(config, i, "MODULE_PATH");
diff --git a/common/sys/lapic.c b/common/sys/lapic.c
index bc8c94d2..08d40fd3 100644
--- a/common/sys/lapic.c
+++ b/common/sys/lapic.c
@@ -441,7 +441,7 @@ void init_io_apics(void) {
         }
     }
 
-    io_apics = ext_mem_alloc(max_io_apics * sizeof(struct madt_io_apic *));
+    io_apics = ext_mem_alloc_counted(max_io_apics, sizeof(struct madt_io_apic *));
     max_io_apics = 0;
 
     for (uint8_t *madt_ptr = (uint8_t *)madt->madt_entries_begin;
diff --git a/common/sys/smp.c b/common/sys/smp.c
index b2313132..7db57b45 100644
--- a/common/sys/smp.c
+++ b/common/sys/smp.c
@@ -209,7 +209,7 @@ struct limine_mp_info *init_smp(size_t   *cpu_count,
         return NULL;
     }
 
-    struct limine_mp_info *ret = ext_mem_alloc(max_cpus * sizeof(struct limine_mp_info));
+    struct limine_mp_info *ret = ext_mem_alloc_counted(max_cpus, sizeof(struct limine_mp_info));
     *cpu_count = 0;
 
     // Try to start all APs
@@ -532,7 +532,7 @@ static struct limine_mp_info *try_acpi_smp(size_t   *cpu_count,
         }
     }
 
-    struct limine_mp_info *ret = ext_mem_alloc(max_cpus * sizeof(struct limine_mp_info));
+    struct limine_mp_info *ret = ext_mem_alloc_counted(max_cpus, sizeof(struct limine_mp_info));
     *cpu_count = 0;
 
     // Try to start all APs
@@ -662,7 +662,7 @@ static struct limine_mp_info *try_dtb_smp( void *dtb,
         max_cpus++;
     }
 
-    struct limine_mp_info *ret = ext_mem_alloc(max_cpus * sizeof(struct limine_mp_info));
+    struct limine_mp_info *ret = ext_mem_alloc_counted(max_cpus, sizeof(struct limine_mp_info));
 
     fdt_for_each_subnode(node, dtb, cpus) {
         const void *prop;
@@ -850,7 +850,7 @@ struct limine_mp_info *init_smp(size_t *cpu_count, pagemap_t pagemap, uint64_t h
         }
     }
 
-    struct limine_mp_info *ret = ext_mem_alloc(num_cpus * sizeof(struct limine_mp_info));
+    struct limine_mp_info *ret = ext_mem_alloc_counted(num_cpus, sizeof(struct limine_mp_info));
 
     *cpu_count = 0;
     for (struct riscv_hart *hart = hart_list; hart != NULL; hart = hart->next) {
@@ -1015,7 +1015,7 @@ static struct limine_mp_info *try_acpi_smp(size_t *cpu_count, uint32_t *bsp_phys
     if (max_cpus == 0)
         return NULL;
 
-    struct limine_mp_info *ret = ext_mem_alloc(max_cpus * sizeof(struct limine_mp_info));
+    struct limine_mp_info *ret = ext_mem_alloc_counted(max_cpus, sizeof(struct limine_mp_info));
 
     for (uint8_t *madt_ptr = (uint8_t *)madt->madt_entries_begin;
          (uintptr_t)madt_ptr + 1 < (uintptr_t)madt + madt->header.length;
@@ -1131,7 +1131,7 @@ static struct limine_mp_info *try_dtb_smp(void *dtb, size_t *cpu_count,
     if (max_cpus == 0)
         return NULL;
 
-    struct limine_mp_info *ret = ext_mem_alloc(max_cpus * sizeof(struct limine_mp_info));
+    struct limine_mp_info *ret = ext_mem_alloc_counted(max_cpus, sizeof(struct limine_mp_info));
 
     fdt_for_each_subnode(node, dtb, cpus) {
         const void *prop;
tab: 248 wrap: offon