:: bzip3 / src / libbz3.c 35.4 KB raw

1
2
/*
3
 * BZip3 - A spiritual successor to BZip2.
4
 * Copyright (C) 2022-2024 Kamila Szewczyk
5
 *
6
 * This program is free software: you can redistribute it and/or modify it
7
 * under the terms of the GNU Lesser General Public License as published by the Free
8
 * Software Foundation, either version 3 of the License, or (at your option)
9
 * any later version.
10
 *
11
 * This program is distributed in the hope that it will be useful, but WITHOUT
12
 * ANY WARRANTY; without even the implied warranty of  MERCHANTABILITY or
13
 * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License for
14
 * more details.
15
 *
16
 * You should have received a copy of the GNU Lesser General Public License along with
17
 * this program.  If not, see <http://www.gnu.org/licenses/>.
18
 */
19
20
#include "libbz3.h"
21
#include <stdlib.h>
22
#include <string.h>
23
#include "libsais.h"
24
25
#if defined(__GNUC__) || defined(__clang__)
26
    #define LIKELY(x)   __builtin_expect(!!(x), 1)
27
    #define UNLIKELY(x) __builtin_expect(!!(x), 0)
28
#else
29
    #define LIKELY(x)   (x)
30
    #define UNLIKELY(x) (x)
31
#endif
32
33
/* CRC32 implementation. Since CRC32 generally takes less than 1% of the runtime on real-world data (e.g. the
34
   Silesia corpus), I decided against using hardware CRC32. This implementation is simple, fast, fool-proof and
35
   good enough to be used with bzip3. */
36
37
static const u32 crc32Table[256] = {
38
    0x00000000L, 0xF26B8303L, 0xE13B70F7L, 0x1350F3F4L, 0xC79A971FL, 0x35F1141CL, 0x26A1E7E8L, 0xD4CA64EBL, 0x8AD958CFL,
39
    0x78B2DBCCL, 0x6BE22838L, 0x9989AB3BL, 0x4D43CFD0L, 0xBF284CD3L, 0xAC78BF27L, 0x5E133C24L, 0x105EC76FL, 0xE235446CL,
40
    0xF165B798L, 0x030E349BL, 0xD7C45070L, 0x25AFD373L, 0x36FF2087L, 0xC494A384L, 0x9A879FA0L, 0x68EC1CA3L, 0x7BBCEF57L,
41
    0x89D76C54L, 0x5D1D08BFL, 0xAF768BBCL, 0xBC267848L, 0x4E4DFB4BL, 0x20BD8EDEL, 0xD2D60DDDL, 0xC186FE29L, 0x33ED7D2AL,
42
    0xE72719C1L, 0x154C9AC2L, 0x061C6936L, 0xF477EA35L, 0xAA64D611L, 0x580F5512L, 0x4B5FA6E6L, 0xB93425E5L, 0x6DFE410EL,
43
    0x9F95C20DL, 0x8CC531F9L, 0x7EAEB2FAL, 0x30E349B1L, 0xC288CAB2L, 0xD1D83946L, 0x23B3BA45L, 0xF779DEAEL, 0x05125DADL,
44
    0x1642AE59L, 0xE4292D5AL, 0xBA3A117EL, 0x4851927DL, 0x5B016189L, 0xA96AE28AL, 0x7DA08661L, 0x8FCB0562L, 0x9C9BF696L,
45
    0x6EF07595L, 0x417B1DBCL, 0xB3109EBFL, 0xA0406D4BL, 0x522BEE48L, 0x86E18AA3L, 0x748A09A0L, 0x67DAFA54L, 0x95B17957L,
46
    0xCBA24573L, 0x39C9C670L, 0x2A993584L, 0xD8F2B687L, 0x0C38D26CL, 0xFE53516FL, 0xED03A29BL, 0x1F682198L, 0x5125DAD3L,
47
    0xA34E59D0L, 0xB01EAA24L, 0x42752927L, 0x96BF4DCCL, 0x64D4CECFL, 0x77843D3BL, 0x85EFBE38L, 0xDBFC821CL, 0x2997011FL,
48
    0x3AC7F2EBL, 0xC8AC71E8L, 0x1C661503L, 0xEE0D9600L, 0xFD5D65F4L, 0x0F36E6F7L, 0x61C69362L, 0x93AD1061L, 0x80FDE395L,
49
    0x72966096L, 0xA65C047DL, 0x5437877EL, 0x4767748AL, 0xB50CF789L, 0xEB1FCBADL, 0x197448AEL, 0x0A24BB5AL, 0xF84F3859L,
50
    0x2C855CB2L, 0xDEEEDFB1L, 0xCDBE2C45L, 0x3FD5AF46L, 0x7198540DL, 0x83F3D70EL, 0x90A324FAL, 0x62C8A7F9L, 0xB602C312L,
51
    0x44694011L, 0x5739B3E5L, 0xA55230E6L, 0xFB410CC2L, 0x092A8FC1L, 0x1A7A7C35L, 0xE811FF36L, 0x3CDB9BDDL, 0xCEB018DEL,
52
    0xDDE0EB2AL, 0x2F8B6829L, 0x82F63B78L, 0x709DB87BL, 0x63CD4B8FL, 0x91A6C88CL, 0x456CAC67L, 0xB7072F64L, 0xA457DC90L,
53
    0x563C5F93L, 0x082F63B7L, 0xFA44E0B4L, 0xE9141340L, 0x1B7F9043L, 0xCFB5F4A8L, 0x3DDE77ABL, 0x2E8E845FL, 0xDCE5075CL,
54
    0x92A8FC17L, 0x60C37F14L, 0x73938CE0L, 0x81F80FE3L, 0x55326B08L, 0xA759E80BL, 0xB4091BFFL, 0x466298FCL, 0x1871A4D8L,
55
    0xEA1A27DBL, 0xF94AD42FL, 0x0B21572CL, 0xDFEB33C7L, 0x2D80B0C4L, 0x3ED04330L, 0xCCBBC033L, 0xA24BB5A6L, 0x502036A5L,
56
    0x4370C551L, 0xB11B4652L, 0x65D122B9L, 0x97BAA1BAL, 0x84EA524EL, 0x7681D14DL, 0x2892ED69L, 0xDAF96E6AL, 0xC9A99D9EL,
57
    0x3BC21E9DL, 0xEF087A76L, 0x1D63F975L, 0x0E330A81L, 0xFC588982L, 0xB21572C9L, 0x407EF1CAL, 0x532E023EL, 0xA145813DL,
58
    0x758FE5D6L, 0x87E466D5L, 0x94B49521L, 0x66DF1622L, 0x38CC2A06L, 0xCAA7A905L, 0xD9F75AF1L, 0x2B9CD9F2L, 0xFF56BD19L,
59
    0x0D3D3E1AL, 0x1E6DCDEEL, 0xEC064EEDL, 0xC38D26C4L, 0x31E6A5C7L, 0x22B65633L, 0xD0DDD530L, 0x0417B1DBL, 0xF67C32D8L,
60
    0xE52CC12CL, 0x1747422FL, 0x49547E0BL, 0xBB3FFD08L, 0xA86F0EFCL, 0x5A048DFFL, 0x8ECEE914L, 0x7CA56A17L, 0x6FF599E3L,
61
    0x9D9E1AE0L, 0xD3D3E1ABL, 0x21B862A8L, 0x32E8915CL, 0xC083125FL, 0x144976B4L, 0xE622F5B7L, 0xF5720643L, 0x07198540L,
62
    0x590AB964L, 0xAB613A67L, 0xB831C993L, 0x4A5A4A90L, 0x9E902E7BL, 0x6CFBAD78L, 0x7FAB5E8CL, 0x8DC0DD8FL, 0xE330A81AL,
63
    0x115B2B19L, 0x020BD8EDL, 0xF0605BEEL, 0x24AA3F05L, 0xD6C1BC06L, 0xC5914FF2L, 0x37FACCF1L, 0x69E9F0D5L, 0x9B8273D6L,
64
    0x88D28022L, 0x7AB90321L, 0xAE7367CAL, 0x5C18E4C9L, 0x4F48173DL, 0xBD23943EL, 0xF36E6F75L, 0x0105EC76L, 0x12551F82L,
65
    0xE03E9C81L, 0x34F4F86AL, 0xC69F7B69L, 0xD5CF889DL, 0x27A40B9EL, 0x79B737BAL, 0x8BDCB4B9L, 0x988C474DL, 0x6AE7C44EL,
66
    0xBE2DA0A5L, 0x4C4623A6L, 0x5F16D052L, 0xAD7D5351L
67
};
68
69
static u32 crc32sum(u32 crc, u8 * RESTRICT buf, size_t size) {
70
    while (size--) crc = crc32Table[((u8)crc ^ *(buf++)) & 0xff] ^ (crc >> 8);
71
    return crc;
72
}
73
74
/* LZP code. These constants were manually tuned to give the best compression ratio while using relatively
75
   little resources. The LZP dictionary is only around 1MiB in size and the minimum match length was chosen
76
   so that LZP would not interfere too much with the Burrows-Wheeler transform and the arithmetic coder, and
77
   just collapse long redundant data instead (for a major speed-up at a low compression ratio cost - in fact,
78
   LZP preprocessing often improves compression in some cases). */
79
80
/* A heavily modified version of libbsc's LZP predictor w/ unaligned accesses follows. This one has single thread
81
   performance and provides better compression ratio. It is also mostly UB-free and less brittle during
82
   AFL fuzzing. */
83
84
#define LZP_DICTIONARY 18
85
#define LZP_MIN_MATCH 40
86
87
#define MATCH 0xf2
88
89
static u32 lzp_upcast(const u8 * ptr) {
90
    // val = *(u32 *)ptr; - written this way to avoid UB
91
    u32 val;
92
    memcpy(&val, ptr, sizeof(val));
93
    return val;
94
}
95
96
/**
97
 * @brief Check if the buffer size is sufficient for decoding a bz3 block
98
 * 
99
 * Data passed to the last step can be one of the following:
100
 * - original data
101
 * - original data + LZP
102
 * - original data + RLE
103
 * - original data + RLE + LZP
104
 *
105
 * We must ensure `buffer_size` is large enough to store the data at every step 
106
 * when walking backwards. The required size may be stored in  either `lzp_size`,
107
 * `rle_size` OR `orig_size`.
108
 *
109
 * @param buffer_size Size of the output buffer
110
 * @param lzp_size Size after LZP decompression (-1 if LZP not used)
111
 * @param rle_size Size after RLE decompression (-1 if RLE not used) 
112
 * @return 1 if buffer size is sufficient, 0 otherwise
113
 */
114
static int bz3_check_buffer_size(size_t buffer_size, s32 lzp_size, s32 rle_size, s32 orig_size) {
115
    // Handle -1 cases to avoid implicit conversion issues
116
    size_t effective_lzp_size = lzp_size < 0 ? 0 : (size_t)lzp_size;
117
    size_t effective_rle_size = rle_size < 0 ? 0 : (size_t)rle_size;
118
    size_t effective_orig_size = orig_size < 0 ? 0 : (size_t)orig_size;
119
120
    // Check if buffer can hold intermediate results
121
    return (effective_lzp_size <= buffer_size) && (effective_rle_size <= buffer_size) && (effective_orig_size <= buffer_size);
122
}
123
124
static s32 lzp_encode_block(const u8 * RESTRICT in, const u8 * in_end, u8 * RESTRICT out, u8 * out_end,
125
                            s32 * RESTRICT lut) {
126
    const u8 * ins = in;
127
    const u8 * outs = out;
128
    const u8 * out_eob = out_end - 8;
129
    const u8 * heur = in;
130
131
    u32 ctx;
132
133
    for (s32 i = 0; i < 4; ++i) *out++ = *in++;
134
135
    ctx = ((u32)in[-1]) | (((u32)in[-2]) << 8) | (((u32)in[-3]) << 16) | (((u32)in[-4]) << 24);
136
137
    while (in < in_end - LZP_MIN_MATCH - 32 && out < out_eob) {
138
        u32 idx = (ctx >> 15 ^ ctx ^ ctx >> 3) & ((s32)(1 << LZP_DICTIONARY) - 1);
139
        s32 val = lut[idx];
140
        lut[idx] = in - ins;
141
        if (val > 0) {
142
            const u8 * RESTRICT ref = ins + val;
143
            if (memcmp(in + LZP_MIN_MATCH - 4, ref + LZP_MIN_MATCH - 4, sizeof(u32)) == 0 &&
144
                memcmp(in, ref, sizeof(u32)) == 0) {
145
                if (heur > in && lzp_upcast(heur) != lzp_upcast(ref + (heur - in))) goto not_found;
146
147
                s32 len = 4;
148
                for (; in + len < in_end - LZP_MIN_MATCH - 32; len += sizeof(u32)) {
149
                    if (lzp_upcast(in + len) != lzp_upcast(ref + len)) break;
150
                }
151
152
                if (len < LZP_MIN_MATCH) {
153
                    if (heur < in + len) heur = in + len;
154
                    goto not_found;
155
                }
156
157
                len += in[len] == ref[len];
158
                len += in[len] == ref[len];
159
                len += in[len] == ref[len];
160
161
                in += len;
162
                ctx = ((u32)in[-1]) | (((u32)in[-2]) << 8) | (((u32)in[-3]) << 16) | (((u32)in[-4]) << 24);
163
164
                *out++ = MATCH;
165
166
                len -= LZP_MIN_MATCH;
167
                while (len >= 254) {
168
                    len -= 254;
169
                    *out++ = 254;
170
                    if (out >= out_eob) break;
171
                }
172
173
                *out++ = len;
174
            } else {
175
            not_found:;
176
                u8 next = *out++ = *in++;
177
                ctx = ctx << 8 | next;
178
                if (next == MATCH) *out++ = 255;
179
            }
180
        } else {
181
            ctx = (ctx << 8) | (*out++ = *in++);
182
        }
183
    }
184
185
    ctx = ((u32)in[-1]) | (((u32)in[-2]) << 8) | (((u32)in[-3]) << 16) | (((u32)in[-4]) << 24);
186
187
    while (in < in_end && out < out_eob) {
188
        u32 idx = (ctx >> 15 ^ ctx ^ ctx >> 3) & ((s32)(1 << LZP_DICTIONARY) - 1);
189
        s32 val = lut[idx];
190
        lut[idx] = (s32)(in - ins);
191
192
        u8 next = *out++ = *in++;
193
        ctx = ctx << 8 | next;
194
        if (next == MATCH && val > 0) *out++ = 255;
195
    }
196
197
    return out >= out_eob ? -1 : (s32)(out - outs);
198
}
199
200
static s32 lzp_decode_block(const u8 * RESTRICT in, const u8 * in_end, s32 * RESTRICT lut, u8 * RESTRICT out,
201
                            const u8 * out_end) {
202
    const u8 * outs = out;
203
204
    for (s32 i = 0; i < 4; ++i) *out++ = *in++;
205
206
    u32 ctx = ((u32)out[-1]) | (((u32)out[-2]) << 8) | (((u32)out[-3]) << 16) | (((u32)out[-4]) << 24);
207
208
    while (in < in_end && out < out_end) {
209
        u32 idx = (ctx >> 15 ^ ctx ^ ctx >> 3) & ((s32)(1 << LZP_DICTIONARY) - 1);
210
        s32 val = lut[idx]; // SAFETY: guaranteed to be in-bounds by & mask. 
211
        lut[idx] = (s32)(out - outs);
212
        if (*in == MATCH && val > 0) {
213
            in++;
214
            // SAFETY: 'in' is advanced here, but it may have been at last index in the case of untrusted bad data.
215
            if (UNLIKELY(in == in_end)) return -1;
216
            if (*in != 255) {
217
                s32 len = LZP_MIN_MATCH;
218
                while (1) {
219
                    if (UNLIKELY(in == in_end)) return -1;
220
                    len += *in;
221
                    if (*in++ != 254) break;
222
                }
223
224
                const u8 * ref = outs + val;
225
                const u8 * oe = out + len;
226
                if (UNLIKELY(oe > out_end)) oe = out_end;
227
228
                while (out < oe) *out++ = *ref++;
229
230
                ctx = ((u32)out[-1]) | (((u32)out[-2]) << 8) | (((u32)out[-3]) << 16) | (((u32)out[-4]) << 24);
231
            } else {
232
                in++;
233
                ctx = (ctx << 8) | (*out++ = MATCH);
234
            }
235
        } else {
236
            ctx = (ctx << 8) | (*out++ = *in++);
237
        }
238
    }
239
240
    return out - outs;
241
}
242
243
static s32 lzp_compress(const u8 * RESTRICT in, u8 * RESTRICT out, s32 n, s32 * RESTRICT lut) {
244
    if (n < LZP_MIN_MATCH + 32) return -1;
245
246
    memset(lut, 0, sizeof(s32) * (1 << LZP_DICTIONARY));
247
248
    return lzp_encode_block(in, in + n, out, out + n, lut);
249
}
250
251
static s32 lzp_decompress(const u8 * RESTRICT in, u8 * RESTRICT out, s32 n, s32 max, s32 * RESTRICT lut) {
252
    if (n < 4) return -1;
253
254
    memset(lut, 0, sizeof(s32) * (1 << LZP_DICTIONARY));
255
256
    return lzp_decode_block(in, in + n, lut, out, out + max);
257
}
258
259
/* RLE code. Unlike RLE in other compressors, we collapse all runs if they yield a net gain
260
   for a given character and encode this as a set bit in the RLE metadata. This improves the
261
   performance and reduces the amount of collapsing done in normal blocks (so that BWT+AC can
262
   be more efficient) while we still filter out all the pathological data. */
263
264
static s32 mrlec(u8 * in, s32 inlen, u8 * out) {
265
    u8 * ip = in;
266
    u8 * in_end = in + inlen;
267
    s32 op = 0;
268
    s32 c, pc = -1;
269
    s32 t[256] = { 0 };
270
    s32 run = 0;
271
    while ((c = (ip < in_end ? *ip++ : -1)) != -1) {
272
        if (c == pc)
273
            t[c] += (++run % 255) != 0;
274
        else
275
            --t[c], run = 0;
276
        pc = c;
277
    }
278
    for (s32 i = 0; i < 32; ++i) {
279
        c = 0;
280
        for (s32 j = 0; j < 8; ++j) c += (t[i * 8 + j] > 0) << j;
281
        out[op++] = c;
282
    }
283
    ip = in;
284
    c = pc = -1;
285
    run = 0;
286
    do {
287
        c = ip < in_end ? *ip++ : -1;
288
        if (c == pc)
289
            ++run;
290
        else if (run > 0 && t[pc] > 0) {
291
            out[op++] = pc;
292
            for (; run > 255; run -= 255) out[op++] = 255;
293
            out[op++] = run - 1;
294
            run = 1;
295
        } else
296
            for (++run; run > 1; --run) out[op++] = pc;
297
        pc = c;
298
    } while (c != -1);
299
300
    return op;
301
}
302
303
static int mrled(u8 * RESTRICT in, u8 * RESTRICT out, s32 outlen, s32 maxin) {
304
    s32 op = 0, ip = 0;
305
306
    s32 c, pc = -1;
307
    s32 t[256] = { 0 };
308
    s32 run = 0;
309
310
    if (maxin < 32) return 1;
311
312
    for (s32 i = 0; i < 32; ++i) {
313
        c = in[ip++];
314
        for (s32 j = 0; j < 8; ++j) t[i * 8 + j] = (c >> j) & 1;
315
    }
316
317
    while (op < outlen && ip < maxin) {
318
        c = in[ip++];
319
        if (t[c]) {
320
            for (run = 0; ip < maxin && (pc = in[ip++]) == 255; run += 255)
321
                ;
322
            run += pc + 1;
323
            for (; run > 0 && op < outlen; --run) out[op++] = c;
324
        } else
325
            out[op++] = c;
326
    }
327
328
    return op != outlen;
329
}
330
331
/* The entropy coder. Uses an arithmetic coder implementation outlined in Matt Mahoney's DCE. */
332
333
typedef struct {
334
    /* Input/output. */
335
    u8 *in_queue, *out_queue;
336
    s32 input_ptr, output_ptr, input_max;
337
338
    /* C0, C1 - used for making the initial prediction, C2 used for an APM with a slightly low
339
       learning rate (6) and 512 contexts. kanzi merges C0 and C1, uses slightly different
340
       counter initialisation code and prediction code which from my tests tends to be suboptimal. */
341
    u16 C0[256], C1[256][256], C2[512][17];
342
} state;
343
344
#define write_out(s, c) (s)->out_queue[(s)->output_ptr++] = (c)
345
#define read_in(s) ((s)->input_ptr < (s)->input_max ? (s)->in_queue[(s)->input_ptr++] : -1)
346
347
#define update0(p, x) (p) = ((p) - ((p) >> x))
348
#define update1(p, x) (p) = ((p) + (((p) ^ 65535) >> x))
349
350
static void begin(state * s) {
351
    prefetch(s);
352
    for (int i = 0; i < 256; i++) s->C0[i] = 1 << 15;
353
    for (int i = 0; i < 256; i++)
354
        for (int j = 0; j < 256; j++) s->C1[i][j] = 1 << 15;
355
    for (int i = 0; i < 2; i++)
356
        for (int j = 0; j < 256; j++)
357
            for (int k = 0; k < 17; k++) s->C2[2 * j + i][k] = (k << 12) - (k == 16);  // Firm difference from stdpack.
358
}
359
360
static void encode_bytes(state * s, u8 * buf, s32 size) {
361
    /* Arithmetic coding, detecting runs of characters in the file */
362
    u32 high = 0xFFFFFFFF, low = 0, c1 = 0, c2 = 0, run = 0;
363
364
    for (s32 i = 0; i < size; i++) {
365
        u8 c = buf[i];
366
367
        if (c1 == c2)
368
            ++run;
369
        else
370
            run = 0;
371
372
        const int f = run > 2;
373
374
        int ctx = 1;
375
376
        while (ctx < 256) {
377
            const int p0 = s->C0[ctx];
378
            const int p1 = s->C1[c1][ctx];
379
            const int p2 = s->C1[c2][ctx];
380
            const int p = ((p0 + p1) * 7 + p2 + p2) >> 4;
381
382
            const int j = p >> 12;
383
            const int x1 = s->C2[2 * ctx + f][j];
384
            const int x2 = s->C2[2 * ctx + f][j + 1];
385
            const int ssep = x1 + (((x2 - x1) * (p & 4095)) >> 12);
386
387
            if (c & 128) {
388
                high = low + (((u64)(high - low) * (ssep * 3 + p)) >> 18);
389
390
                while ((low ^ high) < (1 << 24)) {
391
                    write_out(s, low >> 24);
392
                    low <<= 8;
393
                    high = (high << 8) + 0xFF;
394
                }
395
396
                update1(s->C0[ctx], 2);
397
                update1(s->C1[c1][ctx], 4);
398
                update1(s->C2[2 * ctx + f][j], 6);
399
                update1(s->C2[2 * ctx + f][j + 1], 6);
400
                ctx += ctx + 1;
401
            } else {
402
                low += (((u64)(high - low) * (ssep * 3 + p)) >> 18) + 1;
403
404
                // Write identical bits.
405
                while ((low ^ high) < (1 << 24)) {
406
                    write_out(s, low >> 24);  // Same as high >> 24
407
                    low <<= 8;
408
                    high = (high << 8) + 0xFF;
409
                }
410
411
                update0(s->C0[ctx], 2);
412
                update0(s->C1[c1][ctx], 4);
413
                update0(s->C2[2 * ctx + f][j], 6);
414
                update0(s->C2[2 * ctx + f][j + 1], 6);
415
                ctx += ctx;
416
            }
417
418
            c <<= 1;
419
        }
420
421
        c2 = c1;
422
        c1 = ctx & 255;
423
    }
424
425
    write_out(s, low >> 24);
426
    low <<= 8;
427
    write_out(s, low >> 24);
428
    low <<= 8;
429
    write_out(s, low >> 24);
430
    low <<= 8;
431
    write_out(s, low >> 24);
432
    low <<= 8;
433
}
434
435
static void decode_bytes(state * s, u8 * c, s32 size) {
436
    u32 high = 0xFFFFFFFF, low = 0, c1 = 0, c2 = 0, run = 0, code = 0;
437
438
    code = (code << 8) + read_in(s);
439
    code = (code << 8) + read_in(s);
440
    code = (code << 8) + read_in(s);
441
    code = (code << 8) + read_in(s);
442
443
    for (s32 i = 0; i < size; i++) {
444
        if (c1 == c2)
445
            ++run;
446
        else
447
            run = 0;
448
449
        const int f = run > 2;
450
451
        int ctx = 1;
452
453
        while (ctx < 256) {
454
            const int p0 = s->C0[ctx];
455
            const int p1 = s->C1[c1][ctx];
456
            const int p2 = s->C1[c2][ctx];
457
            const int p = ((p0 + p1) * 7 + p2 + p2) >> 4;
458
459
            const int j = p >> 12;
460
            const int x1 = s->C2[2 * ctx + f][j];
461
            const int x2 = s->C2[2 * ctx + f][j + 1];
462
            const int ssep = x1 + (((x2 - x1) * (p & 4095)) >> 12);
463
464
            const u32 mid = low + (((u64)(high - low) * (ssep * 3 + p)) >> 18);
465
            const u8 bit = code <= mid;
466
            if (bit)
467
                high = mid;
468
            else
469
                low = mid + 1;
470
            while ((low ^ high) < (1 << 24)) {
471
                low <<= 8;
472
                high = (high << 8) + 255;
473
                code = (code << 8) + read_in(s);
474
            }
475
476
            if (bit) {
477
                update1(s->C0[ctx], 2);
478
                update1(s->C1[c1][ctx], 4);
479
                update1(s->C2[2 * ctx + f][j], 6);
480
                update1(s->C2[2 * ctx + f][j + 1], 6);
481
                ctx += ctx + 1;
482
            } else {
483
                update0(s->C0[ctx], 2);
484
                update0(s->C1[c1][ctx], 4);
485
                update0(s->C2[2 * ctx + f][j], 6);
486
                update0(s->C2[2 * ctx + f][j + 1], 6);
487
                ctx += ctx;
488
            }
489
        }
490
491
        c2 = c1;
492
        c[i] = c1 = ctx & 255;
493
    }
494
}
495
496
/* Public API. */
497
498
struct bz3_state {
499
    u8 * swap_buffer;
500
    s32 block_size;
501
    s32 *sais_array, *lzp_lut;
502
    state * cm_state;
503
    s8 last_error;
504
};
505
506
BZIP3_API s8 bz3_last_error(struct bz3_state * state) { return state->last_error; }
507
508
BZIP3_API const char * bz3_version(void) { return VERSION; }
509
510
BZIP3_API size_t bz3_bound(size_t input_size) { return input_size + input_size / 50 + 32; }
511
512
BZIP3_API const char * bz3_strerror(struct bz3_state * state) {
513
    switch (state->last_error) {
514
        case BZ3_OK:
515
            return "No error";
516
        case BZ3_ERR_OUT_OF_BOUNDS:
517
            return "Data index out of bounds";
518
        case BZ3_ERR_BWT:
519
            return "Burrows-Wheeler transform failed";
520
        case BZ3_ERR_CRC:
521
            return "CRC32 check failed";
522
        case BZ3_ERR_MALFORMED_HEADER:
523
            return "Malformed header";
524
        case BZ3_ERR_TRUNCATED_DATA:
525
            return "Truncated data";
526
        case BZ3_ERR_DATA_TOO_BIG:
527
            return "Too much data";
528
        case BZ3_ERR_DATA_SIZE_TOO_SMALL:
529
            return "Size of buffer `buffer_size` passed to the block decoder (bz3_decode_block) is too small. See function docs for details.";
530
        default:
531
            return "Unknown error";
532
    }
533
}
534
535
BZIP3_API struct bz3_state * bz3_new(s32 block_size) {
536
    if (block_size < KiB(65) || block_size > MiB(511)) {
537
        return NULL;
538
    }
539
540
    struct bz3_state * bz3_state = malloc(sizeof(struct bz3_state));
541
542
    if (!bz3_state) {
543
        return NULL;
544
    }
545
546
    bz3_state->cm_state = malloc(sizeof(state));
547
548
    bz3_state->swap_buffer = malloc(bz3_bound(block_size));
549
    bz3_state->sais_array = malloc(BWT_BOUND(block_size) * sizeof(s32));
550
    memset(bz3_state->sais_array, 0, sizeof(s32) * BWT_BOUND(block_size));
551
552
    bz3_state->lzp_lut = calloc(1 << LZP_DICTIONARY, sizeof(s32));
553
554
    if (!bz3_state->cm_state || !bz3_state->swap_buffer || !bz3_state->sais_array || !bz3_state->lzp_lut) {
555
        if (bz3_state->cm_state) free(bz3_state->cm_state);
556
        if (bz3_state->swap_buffer) free(bz3_state->swap_buffer);
557
        if (bz3_state->sais_array) free(bz3_state->sais_array);
558
        if (bz3_state->lzp_lut) free(bz3_state->lzp_lut);
559
        free(bz3_state);
560
        return NULL;
561
    }
562
563
    bz3_state->block_size = block_size;
564
565
    bz3_state->last_error = BZ3_OK;
566
567
    return bz3_state;
568
}
569
570
BZIP3_API void bz3_free(struct bz3_state * state) {
571
    free(state->swap_buffer);
572
    free(state->sais_array);
573
    free(state->cm_state);
574
    free(state->lzp_lut);
575
    free(state);
576
}
577
578
#define swap(x, y)    \
579
    {                 \
580
        u8 * tmp = x; \
581
        x = y;        \
582
        y = tmp;      \
583
    }
584
585
BZIP3_API s32 bz3_encode_block(struct bz3_state * state, u8 * buffer, s32 data_size) {
586
    u8 *b1 = buffer, *b2 = state->swap_buffer;
587
588
    if (data_size > state->block_size) {
589
        state->last_error = BZ3_ERR_DATA_TOO_BIG;
590
        return -1;
591
    }
592
593
    u32 crc32 = crc32sum(1, b1, data_size);
594
595
    // Ignore small blocks. They won't benefit from the entropy coding step.
596
    if (data_size < 64) {
597
        memmove(b1 + 8, b1, data_size);
598
        write_neutral_s32(b1, crc32);
599
        write_neutral_s32(b1 + 4, -1);
600
        return data_size + 8;
601
    }
602
603
    // Back to front:
604
    // bit 1: lzp | no lzp
605
    // bit 2: srt | no srt
606
    s8 model = 0;
607
    s32 lzp_size, rle_size;
608
609
    rle_size = mrlec(b1, data_size, b2);
610
    if (rle_size < data_size) {
611
        swap(b1, b2);
612
        data_size = rle_size;
613
        model |= 4;
614
    }
615
616
    lzp_size = lzp_compress(b1, b2, data_size, state->lzp_lut);
617
    if (lzp_size > 0 && lzp_size < data_size) {
618
        swap(b1, b2);
619
        data_size = lzp_size;
620
        model |= 2;
621
    }
622
623
    s32 bwt_idx = libsais_bwt(b1, b2, state->sais_array, data_size, 0, NULL);
624
    if (bwt_idx < 0) {
625
        state->last_error = BZ3_ERR_BWT;
626
        return -1;
627
    }
628
629
    // Compute the amount of overhead dwords.
630
    s32 overhead = 2;           // CRC32 + BWT index
631
    if (model & 2) overhead++;  // LZP
632
    if (model & 4) overhead++;  // RLE
633
634
    begin(state->cm_state);
635
    state->cm_state->out_queue = b1 + overhead * 4 + 1;
636
    state->cm_state->output_ptr = 0;
637
    encode_bytes(state->cm_state, b2, data_size);
638
    data_size = state->cm_state->output_ptr;
639
640
    // Write the header. Starting with common entries.
641
    write_neutral_s32(b1, crc32);
642
    write_neutral_s32(b1 + 4, bwt_idx);
643
    b1[8] = model;
644
645
    s32 p = 0;
646
    if (model & 2) write_neutral_s32(b1 + 9 + 4 * p++, lzp_size);
647
    if (model & 4) write_neutral_s32(b1 + 9 + 4 * p++, rle_size);
648
649
    state->last_error = BZ3_OK;
650
651
    if (b1 != buffer) memcpy(buffer, b1, data_size + overhead * 4 + 1);
652
653
    return data_size + overhead * 4 + 1;
654
}
655
656
BZIP3_API s32 bz3_decode_block(struct bz3_state * state, u8 * buffer, size_t buffer_size, s32 compressed_size, s32 orig_size) {
657
    // Need minimum bytes for initial header, and compressed_size needs to fit within claimed buffer size.
658
    if (buffer_size < 9 || buffer_size < compressed_size) {
659
        state->last_error = BZ3_ERR_DATA_SIZE_TOO_SMALL;
660
        return -1;
661
    }
662
663
    // Read the header.
664
    u32 crc32 = read_neutral_s32(buffer);
665
    s32 bwt_idx = read_neutral_s32(buffer + 4);
666
667
    if (compressed_size > bz3_bound(state->block_size) || compressed_size < 0) {
668
        state->last_error = BZ3_ERR_MALFORMED_HEADER;
669
        return -1;
670
    }
671
672
    if (bwt_idx == -1) {
673
        if (compressed_size - 8 > 64 || compressed_size < 8) {
674
            state->last_error = BZ3_ERR_MALFORMED_HEADER;
675
            return -1;
676
        }
677
678
        // Ensure there's enough space for the raw copied data.
679
        if (compressed_size - 8 > buffer_size) {
680
            state->last_error = BZ3_ERR_DATA_SIZE_TOO_SMALL;
681
            return -1;
682
        }
683
684
        memmove(buffer, buffer + 8, compressed_size - 8);
685
686
        if (crc32sum(1, buffer, compressed_size - 8) != crc32) {
687
            state->last_error = BZ3_ERR_CRC;
688
            return -1;
689
        }
690
691
        return compressed_size - 8;
692
    }
693
694
    s8 model = buffer[8];
695
696
    // Ensure we have sufficient bytes for the rle/lzp sizes.
697
    size_t needed_header_size = 9 + ((model & 2) * 4) + ((model & 4) * 4);
698
    if (buffer_size < needed_header_size) {
699
        state->last_error = BZ3_ERR_DATA_SIZE_TOO_SMALL;
700
        return -1;
701
    }
702
703
    s32 lzp_size = -1, rle_size = -1, p = 0;
704
    if (model & 2) lzp_size = read_neutral_s32(buffer + 9 + 4 * p++);
705
    if (model & 4) rle_size = read_neutral_s32(buffer + 9 + 4 * p++);
706
    p += 2;
707
708
    compressed_size -= p * 4 + 1;
709
710
    if (((model & 2) && (lzp_size > bz3_bound(state->block_size) || lzp_size < 0)) ||
711
        ((model & 4) && (rle_size > bz3_bound(state->block_size) || rle_size < 0))) {
712
        state->last_error = BZ3_ERR_MALFORMED_HEADER;
713
        return -1;
714
    }
715
716
    if (orig_size > bz3_bound(state->block_size) || orig_size < 0) {
717
        state->last_error = BZ3_ERR_MALFORMED_HEADER;
718
        return -1;
719
    }
720
721
    // Size that undoing BWT+BCM should decompress into.
722
    s32 size_before_bwt;
723
724
    if (model & 2)
725
        size_before_bwt = lzp_size;
726
    else if (model & 4)
727
        size_before_bwt = rle_size;
728
    else
729
        size_before_bwt = orig_size;
730
731
    // Note(sewer): It's technically valid within the spec to create a bzip3 block
732
    // where the size after LZP/RLE is larger than the original input. Some earlier encoders
733
    // even (mistakenly?) were able to do this.
734
    if (!bz3_check_buffer_size(buffer_size, lzp_size, rle_size, orig_size)) {
735
        state->last_error = BZ3_ERR_DATA_SIZE_TOO_SMALL;
736
        return -1;
737
    }
738
739
    // Decode the data.
740
    u8 *b1 = buffer, *b2 = state->swap_buffer;
741
742
    begin(state->cm_state);
743
    state->cm_state->in_queue = b1 + p * 4 + 1;
744
    state->cm_state->input_ptr = 0;
745
    state->cm_state->input_max = compressed_size;
746
747
    decode_bytes(state->cm_state, b2, size_before_bwt);
748
    swap(b1, b2);
749
750
    if (bwt_idx > size_before_bwt) {
751
        state->last_error = BZ3_ERR_MALFORMED_HEADER;
752
        return -1;
753
    }
754
755
    // Undo BWT
756
    memset(state->sais_array, 0, sizeof(s32) * BWT_BOUND(state->block_size));
757
    memset(b2, 0, size_before_bwt); // buffer b2, swap b1
758
    if (libsais_unbwt(b1, b2, state->sais_array, size_before_bwt, NULL, bwt_idx) < 0) {
759
        state->last_error = BZ3_ERR_BWT;
760
        return -1;
761
    }
762
    swap(b1, b2);
763
764
    s32 size_src = size_before_bwt;
765
766
    // Undo LZP
767
    if (model & 2) {
768
        size_src = lzp_decompress(b1, b2, lzp_size, bz3_bound(state->block_size), state->lzp_lut);
769
        if (size_src == -1) {
770
            state->last_error = BZ3_ERR_CRC;
771
            return -1;
772
        }
773
        // SAFETY(sewer): An attacker formed bzip3 data which decompresses as valid lzp.
774
        // The headers above were set to ones that pass validation (size within bounds), but the 
775
        // data itself tries to escape buffer_size. Don't allow it to.
776
        if (size_src > buffer_size) {
777
            state->last_error = BZ3_ERR_DATA_SIZE_TOO_SMALL;    
778
            return -1;
779
        }
780
        swap(b1, b2);
781
    }
782
783
    if (model & 4) { 
784
        // SAFETY: mrled is capped at orig_size, which is in bounds.
785
        int err = mrled(b1, b2, orig_size, size_src);
786
        if (err) {
787
            state->last_error = BZ3_ERR_CRC;
788
            return -1;
789
        }
790
        size_src = orig_size;
791
        swap(b1, b2);
792
    }
793
794
    state->last_error = BZ3_OK;
795
796
    if (size_src > state->block_size || size_src < 0) {
797
        state->last_error = BZ3_ERR_MALFORMED_HEADER;
798
        return -1;
799
    }
800
801
    if (b1 != buffer) memcpy(buffer, b1, size_src);
802
803
    if (crc32 != crc32sum(1, buffer, size_src)) {
804
        state->last_error = BZ3_ERR_CRC;
805
        return -1;
806
    }
807
808
    return size_src;
809
}
810
811
#undef swap
812
813
#ifdef PTHREAD
814
815
    #include <pthread.h>
816
817
typedef struct {
818
    struct bz3_state * state;
819
    u8 * buffer;
820
    s32 size;
821
} encode_thread_msg;
822
823
typedef struct {
824
    struct bz3_state * state;
825
    u8 * buffer;
826
    size_t buffer_size;
827
    s32 size;
828
    s32 orig_size;
829
} decode_thread_msg;
830
831
static void * bz3_init_encode_thread(void * _msg) {
832
    encode_thread_msg * msg = _msg;
833
    msg->size = bz3_encode_block(msg->state, msg->buffer, msg->size);
834
    pthread_exit(NULL);
835
    return NULL;  // unreachable
836
}
837
838
static void * bz3_init_decode_thread(void * _msg) {
839
    decode_thread_msg * msg = _msg;
840
    bz3_decode_block(msg->state, msg->buffer, msg->buffer_size, msg->size, msg->orig_size);
841
    pthread_exit(NULL);
842
    return NULL;  // unreachable
843
}
844
845
BZIP3_API void bz3_encode_blocks(struct bz3_state * states[], u8 * buffers[], s32 sizes[], s32 n) {
846
    encode_thread_msg messages[n];
847
    pthread_t threads[n];
848
    for (s32 i = 0; i < n; i++) {
849
        messages[i].state = states[i];
850
        messages[i].buffer = buffers[i];
851
        messages[i].size = sizes[i];
852
        pthread_create(&threads[i], NULL, bz3_init_encode_thread, &messages[i]);
853
    }
854
    for (s32 i = 0; i < n; i++) pthread_join(threads[i], NULL);
855
    for (s32 i = 0; i < n; i++) sizes[i] = messages[i].size;
856
}
857
858
BZIP3_API void bz3_decode_blocks(struct bz3_state * states[], u8 * buffers[], size_t buffer_sizes[], s32 sizes[], s32 orig_sizes[], s32 n) {
859
    decode_thread_msg messages[n];
860
    pthread_t threads[n];
861
    for (s32 i = 0; i < n; i++) {
862
        messages[i].state = states[i];
863
        messages[i].buffer = buffers[i];
864
        messages[i].buffer_size = buffer_sizes[i];
865
        messages[i].size = sizes[i];
866
        messages[i].orig_size = orig_sizes[i];
867
        pthread_create(&threads[i], NULL, bz3_init_decode_thread, &messages[i]);
868
    }
869
    for (s32 i = 0; i < n; i++) pthread_join(threads[i], NULL);
870
}
871
872
#endif
873
874
/* High level API implementations. */
875
876
BZIP3_API int bz3_compress(u32 block_size, const u8 * const in, u8 * out, size_t in_size, size_t * out_size) {
877
    if (block_size > in_size) block_size = bz3_bound(in_size);
878
    block_size = block_size <= KiB(65) ? KiB(65) : block_size;
879
880
    struct bz3_state * state = bz3_new(block_size);
881
    if (!state) return BZ3_ERR_INIT;
882
883
    u8 * compression_buf = malloc(bz3_bound(block_size));
884
    if (!compression_buf) {
885
        bz3_free(state);
886
        return BZ3_ERR_INIT;
887
    }
888
889
    size_t buf_max = *out_size;
890
    *out_size = 0;
891
892
    u32 n_blocks = in_size / block_size;
893
    if (in_size % block_size) n_blocks++;
894
895
    if (buf_max < 13 || buf_max < bz3_bound(in_size)) {
896
        bz3_free(state);
897
        free(compression_buf);
898
        return BZ3_ERR_DATA_TOO_BIG;
899
    }
900
901
    out[0] = 'B';
902
    out[1] = 'Z';
903
    out[2] = '3';
904
    out[3] = 'v';
905
    out[4] = '1';
906
    write_neutral_s32(out + 5, block_size);
907
    write_neutral_s32(out + 9, n_blocks);
908
    *out_size += 13;
909
910
    // Compress and write the blocks.
911
    size_t in_offset = 0;
912
    for (u32 i = 0; i < n_blocks; i++) {
913
        s32 size = block_size;
914
        if (i == n_blocks - 1) size = in_size % block_size;
915
        memcpy(compression_buf, in + in_offset, size);
916
        s32 out_size_block = bz3_encode_block(state, compression_buf, size);
917
        if (bz3_last_error(state) != BZ3_OK) {
918
            s8 last_error = state->last_error;
919
            bz3_free(state);
920
            free(compression_buf);
921
            return last_error;
922
        }
923
        memcpy(out + *out_size + 8, compression_buf, out_size_block);
924
        write_neutral_s32(out + *out_size, out_size_block);
925
        write_neutral_s32(out + *out_size + 4, size);
926
        *out_size += out_size_block + 8;
927
        in_offset += size;
928
    }
929
930
    bz3_free(state);
931
    free(compression_buf);
932
    return BZ3_OK;
933
}
934
935
BZIP3_API int bz3_decompress(const uint8_t * in, uint8_t * out, size_t in_size, size_t * out_size) {
936
    if (in_size < 13) return BZ3_ERR_MALFORMED_HEADER;
937
    if (in[0] != 'B' || in[1] != 'Z' || in[2] != '3' || in[3] != 'v' || in[4] != '1') {
938
        return BZ3_ERR_MALFORMED_HEADER;
939
    }
940
    u32 block_size = read_neutral_s32(in + 5);
941
    u32 n_blocks = read_neutral_s32(in + 9);
942
    in_size -= 13;
943
    in += 13;
944
945
    struct bz3_state * state = bz3_new(block_size);
946
    if (!state) return BZ3_ERR_INIT;
947
948
    size_t compression_buf_size = bz3_bound(block_size);
949
    u8 * compression_buf = malloc(compression_buf_size);
950
    if (!compression_buf) {
951
        bz3_free(state);
952
        return BZ3_ERR_INIT;
953
    }
954
955
    size_t buf_max = *out_size;
956
    *out_size = 0;
957
958
    for (u32 i = 0; i < n_blocks; i++) {
959
        if (in_size < 8) {
960
        malformed_header:
961
            bz3_free(state);
962
            free(compression_buf);
963
            return BZ3_ERR_MALFORMED_HEADER;
964
        }
965
        s32 size = read_neutral_s32(in);
966
        if (size < 0 || size > block_size) goto malformed_header;
967
        if (in_size < size + 8) {
968
            bz3_free(state);
969
            free(compression_buf);
970
            return BZ3_ERR_TRUNCATED_DATA;
971
        }
972
        s32 orig_size = read_neutral_s32(in + 4);
973
        if (orig_size < 0) goto malformed_header;
974
        if (buf_max < *out_size + orig_size) {
975
            bz3_free(state);
976
            free(compression_buf);
977
            return BZ3_ERR_DATA_TOO_BIG;
978
        }
979
        memcpy(compression_buf, in + 8, size);
980
        bz3_decode_block(state, compression_buf, compression_buf_size, size, orig_size);
981
        if (bz3_last_error(state) != BZ3_OK) {
982
            s8 last_error = state->last_error;
983
            bz3_free(state);
984
            free(compression_buf);
985
            return last_error;
986
        }
987
        memcpy(out + *out_size, compression_buf, orig_size);
988
        *out_size += orig_size;
989
        in += size + 8;
990
        in_size -= size + 8;
991
    }
992
993
    bz3_free(state);
994
    free(compression_buf);
995
996
    return BZ3_OK;
997
}
998
999
BZIP3_API size_t bz3_min_memory_needed(int32_t block_size) {
1000
    if (block_size < KiB(65) || block_size > MiB(511)) {
1001
        return 0;
1002
    }
1003
1004
    size_t total_size = 0;
1005
1006
    // This is based on bz3_new.
1007
    // Core state structure
1008
    total_size += sizeof(struct bz3_state);
1009
1010
    // cm_state
1011
    total_size += sizeof(state);
1012
1013
    // Swap buffer (needs to handle expanded size) (swap_buffer)
1014
    total_size += bz3_bound(block_size);
1015
1016
    // SAIS array
1017
    total_size += BWT_BOUND(block_size) * sizeof(int32_t);
1018
1019
    // LZP lookup table (lzp_lut)
1020
    total_size += (1 << LZP_DICTIONARY) * sizeof(int32_t);
1021
    return total_size;
1022
}
1023
1024
1025
BZIP3_API int bz3_orig_size_sufficient_for_decode(const u8 * block, size_t block_size, s32 orig_size) {
1026
    // Need at least 9 bytes for the initial header (4 bytes BWT index + 4 bytes CRC + 1 byte model)
1027
    if (block_size < 9) {
1028
        return -1;
1029
    }
1030
1031
    s32 bwt_idx = read_neutral_s32(block + 4);
1032
    if (bwt_idx == -1) {
1033
        // Uncompressed literals.
1034
        // Original size always sufficient for uncompressed blocks
1035
        return 1;  
1036
    }
1037
1038
    s8 model = block[8];
1039
    s32 lzp_size = -1, rle_size = -1;
1040
    size_t header_size = 9;  // Start after model byte
1041
1042
    // Ensure we have sufficient bytes for the rle/lzp sizes.
1043
    size_t needed_header_size = 9 + ((model & 2) * 4) + ((model & 4) * 4);
1044
    if (block_size < needed_header_size) {
1045
        return -1;
1046
    }
1047
1048
    // Need additional 4 bytes for each size field that might be present
1049
    if (model & 2) {
1050
        lzp_size = read_neutral_s32(block + header_size);
1051
        header_size += 4;
1052
    }
1053
    if (model & 4) rle_size = read_neutral_s32(block + header_size);
1054
    return bz3_check_buffer_size((size_t)orig_size, lzp_size, rle_size, orig_size);
1055
}
tab: 248 wrap: offon