:: limine / common / pxe / tftp.s2.c 8.1 KB raw

1
#include <pxe/tftp.h>
2
#include <pxe/pxe.h>
3
#if defined (BIOS)
4
#  include <lib/real.h>
5
#elif defined (UEFI)
6
#  include <efi.h>
7
#endif
8
#include <lib/print.h>
9
#include <lib/libc.h>
10
#include <mm/pmm.h>
11
#include <lib/misc.h>
12
13
// cache the dhcp packet
14
uint8_t cached_dhcp_packet[DHCP_ACK_PACKET_LEN] = { 0 };
15
bool cached_dhcp_ack_valid = false;
16
17
#if defined (BIOS)
18
19
static uint32_t get_boot_server_info(void) {
20
    struct pxenv_get_cached_info cachedinfo = { 0 };
21
    cachedinfo.packet_type = PXENV_PACKET_TYPE_CACHED_REPLY;
22
    int ret = pxe_call(PXENV_GET_CACHED_INFO, ((uint16_t)rm_seg(&cachedinfo)), (uint16_t)rm_off(&cachedinfo));
23
    if (ret || cachedinfo.buffer == 0) {
24
        panic(false, "tftp: Failed to get DHCP cached info");
25
    }
26
    struct bootph *ph = (struct bootph*)(void *) (((((uint32_t)cachedinfo.buffer) >> 16) << 4) + (((uint32_t)cachedinfo.buffer) & 0xFFFF));
27
    if (!cached_dhcp_ack_valid) {
28
        size_t copy_len = cachedinfo.buffer_size < DHCP_ACK_PACKET_LEN
29
                        ? cachedinfo.buffer_size : DHCP_ACK_PACKET_LEN;
30
        memcpy(cached_dhcp_packet, ph, copy_len);
31
        cached_dhcp_ack_valid = true;
32
    }
33
    return ph->sip;
34
}
35
36
static uint32_t parse_ip_addr(const char *server_addr) {
37
    uint32_t out;
38
39
    if (!server_addr || !strlen(server_addr)) {
40
        return get_boot_server_info();
41
    }
42
43
    if (inet_pton(server_addr, &out)) {
44
        panic(true, "tftp: Invalid IPv4 address: \"%s\"", server_addr);
45
    }
46
47
    return out;
48
}
49
50
struct file_handle *tftp_open(struct volume *part, const char *server_addr, const char *name) {
51
    uint32_t server_ip = parse_ip_addr(server_addr);
52
    const uint16_t server_port = 69; // This couldn't be changed previously either
53
    int ret = 0;
54
55
    (void)part;
56
57
    struct PXENV_UNDI_GET_INFORMATION undi_info = { 0 };
58
    ret = pxe_call(UNDI_GET_INFORMATION, ((uint16_t)rm_seg(&undi_info)), (uint16_t)rm_off(&undi_info));
59
    if (ret) {
60
        return NULL;
61
    }
62
63
    //TODO figure out a more proper way to do this.
64
    if (undi_info.MaxTranUnit < 48) {
65
        print("tftp: Invalid MTU (%u), too small for TFTP headers\n", undi_info.MaxTranUnit);
66
        return NULL;
67
    }
68
    uint16_t mtu = undi_info.MaxTranUnit - 48;
69
70
    size_t name_len = strlen(name);
71
    if (name_len >= 128) {
72
        print("tftp: Filename too long (max 127 chars)\n");
73
        return NULL;
74
    }
75
76
    struct pxenv_get_file_size fsize = {
77
        .status = 0,
78
        .sip = server_ip,
79
    };
80
    memcpy(fsize.name, name, name_len + 1);
81
    ret = pxe_call(TFTP_GET_FILE_SIZE, ((uint16_t)rm_seg(&fsize)), (uint16_t)rm_off(&fsize));
82
    if (ret) {
83
        return NULL;
84
    }
85
86
    struct file_handle *handle = ext_mem_alloc(sizeof(struct file_handle));
87
88
    handle->size = fsize.file_size;
89
    handle->is_memfile = true;
90
91
    handle->pxe = true;
92
    handle->pxe_ip = server_ip;
93
    handle->pxe_port = server_port;
94
95
    handle->path = ext_mem_alloc(1 + name_len + 1);
96
    handle->path[0] = '/';
97
    memcpy(&handle->path[1], name, name_len);
98
    handle->path_len = 1 + name_len + 1;
99
100
    struct pxenv_open open = {
101
        .status = 0,
102
        .sip = server_ip,
103
        .port = (server_port >> 8) | (server_port << 8),
104
        .packet_size = mtu
105
    };
106
    memcpy(open.name, name, name_len + 1);
107
108
    ret = pxe_call(TFTP_OPEN, ((uint16_t)rm_seg(&open)), (uint16_t)rm_off(&open));
109
    if (ret) {
110
        print("tftp: Failed to open file %x or bad packet size", open.status);
111
        pmm_free(handle->path, handle->path_len);
112
        pmm_free(handle, sizeof(struct file_handle));
113
        return NULL;
114
    }
115
116
    // Validate server's negotiated packet size doesn't exceed our buffer
117
    if (open.packet_size > mtu) {
118
        print("tftp: Server requested packet size %u exceeds our MTU %u\n", open.packet_size, mtu);
119
        uint16_t close = 0;
120
        pxe_call(TFTP_CLOSE, ((uint16_t)rm_seg(&close)), (uint16_t)rm_off(&close));
121
        pmm_free(handle->path, handle->path_len);
122
        pmm_free(handle, sizeof(struct file_handle));
123
        return NULL;
124
    }
125
126
    uint16_t alloc_mtu = mtu;  // Save original MTU for allocation/free
127
    uint8_t *buf = conv_mem_alloc(alloc_mtu);
128
129
    mtu = open.packet_size;
130
    handle->fd = ext_mem_alloc(handle->size);
131
132
    size_t progress = 0;
133
    bool slow = false;
134
135
    while (progress < handle->size) {
136
        struct pxenv_read read = {
137
            .boff = ((uint16_t)rm_off(buf)),
138
            .bseg = ((uint16_t)rm_seg(buf)),
139
        };
140
141
        ret = pxe_call(TFTP_READ, ((uint16_t)rm_seg(&read)), (uint16_t)rm_off(&read));
142
        if (ret) {
143
            panic(false, "tftp: Read failure");
144
        }
145
146
        // Validate read size doesn't overflow the buffer (use alloc_mtu, not server's mtu)
147
        if (read.bsize > alloc_mtu || progress + read.bsize > handle->size) {
148
            panic(false, "tftp: Server sent more data than expected");
149
        }
150
151
        // Prevent infinite loop from zero-byte reads
152
        if (read.bsize == 0 && progress < handle->size) {
153
            panic(false, "tftp: Server returned zero bytes before transfer complete");
154
        }
155
156
        memcpy(handle->fd + progress, buf, read.bsize);
157
158
        progress += read.bsize;
159
160
        if (read.bsize < mtu && !slow && progress < handle->size) {
161
            slow = true;
162
            print("tftp: Server is sending the file in smaller packets (it sent %d bytes), download might take longer.\n", read.bsize);
163
        }
164
    }
165
166
    uint16_t close = 0;
167
    ret = pxe_call(TFTP_CLOSE, ((uint16_t)rm_seg(&close)), (uint16_t)rm_off(&close));
168
    if (ret) {
169
        panic(false, "tftp: Close failure");
170
    }
171
172
    pmm_free(buf, alloc_mtu);
173
174
    return handle;
175
}
176
177
#elif defined (UEFI)
178
179
static EFI_IP_ADDRESS *parse_ip_addr(struct volume *part, const char *server_addr) {
180
    static EFI_IP_ADDRESS out;
181
182
    if (!server_addr || !strlen(server_addr)) {
183
        EFI_PXE_BASE_CODE_PACKET* packet;
184
        if (part->pxe_base_code->Mode->PxeReplyReceived) packet = &part->pxe_base_code->Mode->PxeReply;
185
        else if (part->pxe_base_code->Mode->ProxyOfferReceived) packet = &part->pxe_base_code->Mode->ProxyOffer;
186
        else packet = &part->pxe_base_code->Mode->DhcpAck;
187
        memcpy(out.Addr, packet->Dhcpv4.BootpSiAddr, 4);
188
        if (!cached_dhcp_ack_valid) {
189
            memcpy(cached_dhcp_packet, packet, DHCP_ACK_PACKET_LEN);
190
            cached_dhcp_ack_valid = true;
191
        }
192
    } else {
193
        if (inet_pton(server_addr, &out.Addr)) {
194
            panic(true, "tftp: Invalid IPv4 address: \"%s\"", server_addr);
195
        }
196
    }
197
198
    return &out;
199
}
200
201
struct file_handle *tftp_open(struct volume *part, const char *server_addr, const char *name) {
202
    if (part == NULL || !part->pxe_base_code) {
203
        return NULL;
204
    }
205
206
    EFI_IP_ADDRESS *ip = parse_ip_addr(part, server_addr);
207
208
    uint64_t file_size;
209
    EFI_STATUS status;
210
211
    status = part->pxe_base_code->Mtftp(
212
            part->pxe_base_code,
213
            EFI_PXE_BASE_CODE_TFTP_GET_FILE_SIZE,
214
            NULL,
215
            false,
216
            &file_size,
217
            NULL,
218
            ip,
219
            (uint8_t *)name,
220
            NULL,
221
            false);
222
223
    if (status) {
224
        return NULL;
225
    }
226
227
    struct file_handle *handle = ext_mem_alloc(sizeof(struct file_handle));
228
229
    uint64_t expected_size = file_size;
230
231
    handle->efi_part_handle = part->efi_handle;
232
    handle->size = expected_size;
233
    handle->is_memfile = true;
234
235
    handle->pxe = true;
236
    handle->pxe_ip = *(uint32_t *)ip->Addr;
237
    handle->pxe_port = 69;
238
239
    size_t name_len = strlen(name);
240
    handle->path = ext_mem_alloc(1 + name_len + 1);
241
    handle->path[0] = '/';
242
    memcpy(&handle->path[1], name, name_len);
243
    handle->path_len = 1 + name_len + 1;
244
245
    handle->fd = ext_mem_alloc(handle->size);
246
247
    status = part->pxe_base_code->Mtftp(
248
            part->pxe_base_code,
249
            EFI_PXE_BASE_CODE_TFTP_READ_FILE,
250
            handle->fd,
251
            false,
252
            &file_size,
253
            NULL,
254
            ip,
255
            (uint8_t *)name,
256
            NULL,
257
            false);
258
259
    if (status || file_size != expected_size) {
260
        pmm_free(handle->fd, handle->size);
261
        pmm_free(handle->path, handle->path_len);
262
        pmm_free(handle, sizeof(struct file_handle));
263
        return NULL;
264
    }
265
266
    return handle;
267
}
268
269
#endif
tab: 248 wrap: offon