:: commit 800beb74ab89b2d20d368822b296aabb9e822664

YukariWaffle <115485512+YukariWaffle@users.noreply.github.com> — 2022-10-11 19:05

parents: cac4d8599e

Improved IO error checking (#55)

* Error reporting on write errors. Very minor cleanup. Better corrupt file detection.

* Forgot to check one fread

* Exit on read error or unexpected eof
diff --git a/src/main.c b/src/main.c
index 7f801f9..5182319 100644
--- a/src/main.c
+++ b/src/main.c
@@ -73,6 +73,78 @@ static void help() {
             "Report bugs to: https://github.com/kspalaiologos/bzip3\n");
 }
 
+static void xwrite(const void * data, size_t size, size_t len, FILE * des) {
+    if (fwrite(data, size, len, des) != len) {
+        fprintf(stderr, "Write error: %s\n", strerror(errno));
+        exit(1);
+    }
+}
+
+/* Read any amount of items (from 0 to len) as long as there is no error */
+static size_t xread(void * data, size_t size, size_t len, FILE * des) {
+    size_t written = fread(data, size, len, des);
+    if (ferror(des)) {
+        fprintf(stderr, "Read error: %s\n", strerror(errno));
+        exit(1);
+    }
+    return written;
+}
+
+/* Either read 0 (due to eof) items or exactly len items */
+static size_t xread_eofcheck(void * data, size_t size, size_t len, FILE * des) {
+    size_t written = xread(data, size, len, des);
+    /* feof will be true */
+    if (!written)
+        return 0;
+    if (feof (des)) {
+        fprintf(stderr, "Error: Corrupt file\n");
+        exit(1);
+    }
+    return written;
+}
+
+/* Always read len items */
+static void xread_noeof(void * data, size_t size, size_t len, FILE * des) {
+    if (!xread_eofcheck (data, size, len, des)) {
+        fprintf(stderr, "Error: Corrupt file\n");
+        exit(1);
+    }
+}
+
+static void close_out_file(FILE * des) {
+    if (des) {
+        int outfd = fileno(des);
+
+        if (fflush(des)) {
+            fprintf(stderr, "Error: Failed on fflush: %s\n", strerror(errno));
+            exit(1);
+        }
+
+        /* would have to use outfd != -1 && !FlushFileBuffers(_get_osfhandle(outfd)) and then use GetLastError + FormatMessage(A?) */
+#ifndef __MSVCRT__
+        while (1) {
+            int status;
+            status = fsync(outfd);
+            if (status == -1) {
+                if (errno == EINVAL)
+                    break;
+                if (errno == EINTR)
+                    continue;
+                fprintf(stderr, "Error: Failed on fsync: %s\n", strerror(errno));
+                exit(1);
+            }
+            break;
+        }
+#endif
+
+        if (des != stdout
+            && fclose(des)) {
+            fprintf(stderr, "Error: Failed on fclose: %s\n", strerror(errno));
+            exit(1);
+        }
+    }
+}
+
 static int process(FILE * input_des, FILE * output_des, int mode, int block_size, int workers) {
     if ((mode == MODE_ENCODE && isatty(fileno(output_des))) ||
         ((mode == MODE_DECODE || mode == MODE_TEST) && isatty(fileno(input_des)))) {
@@ -84,25 +156,22 @@ static int process(FILE * input_des, FILE * output_des, int mode, int block_size
 
     switch (mode) {
         case MODE_ENCODE:
-            fwrite("BZ3v1", 5, 1, output_des);
+            xwrite("BZ3v1", 5, 1, output_des);
 
             write_neutral_s32(byteswap_buf, block_size);
-            fwrite(byteswap_buf, 4, 1, output_des);
+            xwrite(byteswap_buf, 4, 1, output_des);
             break;
         case MODE_DECODE:
         case MODE_TEST: {
             char signature[5];
 
-            fread(signature, 5, 1, input_des);
-            if (strncmp(signature, "BZ3v1", 5) != 0) {
+            if (xread(signature, 5, 1, input_des) != 1
+                || strncmp(signature, "BZ3v1", 5) != 0) {
                 fprintf(stderr, "Invalid signature.\n");
                 return 1;
             }
 
-            if (fread(byteswap_buf, 4, 1, input_des) != 1) {
-                fprintf(stderr, "I/O error.\n");
-                return 1;
-            }
+            xread_noeof(byteswap_buf, 4, 1, input_des);
 
             block_size = read_neutral_s32(byteswap_buf);
 
@@ -142,7 +211,7 @@ static int process(FILE * input_des, FILE * output_des, int mode, int block_size
         if (mode == MODE_ENCODE) {
             s32 read_count;
             while (!feof(input_des)) {
-                read_count = fread(buffer, 1, block_size, input_des);
+                read_count = xread(buffer, 1, block_size, input_des);
 
                 s32 new_size = bz3_encode_block(state, buffer, read_count);
                 if (new_size == -1) {
@@ -151,53 +220,38 @@ static int process(FILE * input_des, FILE * output_des, int mode, int block_size
                 }
 
                 write_neutral_s32(byteswap_buf, new_size);
-                fwrite(byteswap_buf, 4, 1, output_des);
+                xwrite(byteswap_buf, 4, 1, output_des);
                 write_neutral_s32(byteswap_buf, read_count);
-                fwrite(byteswap_buf, 4, 1, output_des);
-                fwrite(buffer, new_size, 1, output_des);
+                xwrite(byteswap_buf, 4, 1, output_des);
+                xwrite(buffer, new_size, 1, output_des);
             }
             fflush(output_des);
         } else if (mode == MODE_DECODE) {
             s32 new_size, old_size;
             while (!feof(input_des)) {
-                if (fread(&byteswap_buf, 1, 4, input_des) != 4) {
-                    // Assume that the file has no more data.
-                    break;
-                }
+                if (!xread_eofcheck(&byteswap_buf, 1, 4, input_des))
+                    continue;
+
                 new_size = read_neutral_s32(byteswap_buf);
-                if (fread(&byteswap_buf, 1, 4, input_des) != 4) {
-                    fprintf(stderr, "I/O error.\n");
-                    return 1;
-                }
+                xread_noeof(&byteswap_buf, 1, 4, input_des);
                 old_size = read_neutral_s32(byteswap_buf);
-                if (fread(buffer, 1, new_size, input_des) != new_size) {
-                    fprintf(stderr, "I/O error.\n");
-                    return 1;
-                }
+                xread_noeof(buffer, 1, new_size, input_des);
                 if (bz3_decode_block(state, buffer, new_size, old_size) == -1) {
                     fprintf(stderr, "Failed to decode a block: %s\n", bz3_strerror(state));
                     return 1;
                 }
-                fwrite(buffer, old_size, 1, output_des);
+                xwrite(buffer, old_size, 1, output_des);
             }
             fflush(output_des);
         } else if (mode == MODE_TEST) {
             s32 new_size, old_size;
             while (!feof(input_des)) {
-                if (fread(&byteswap_buf, 1, 4, input_des) != 4) {
-                    // Assume that the file has no more data.
-                    break;
-                }
+                if (!xread_eofcheck(&byteswap_buf, 1, 4, input_des))
+                    continue;
                 new_size = read_neutral_s32(byteswap_buf);
-                if (fread(&byteswap_buf, 1, 4, input_des) != 4) {
-                    fprintf(stderr, "I/O error.\n");
-                    return 1;
-                }
+                xread_noeof(&byteswap_buf, 1, 4, input_des);
                 old_size = read_neutral_s32(byteswap_buf);
-                if (fread(buffer, 1, new_size, input_des) != new_size) {
-                    fprintf(stderr, "I/O error.\n");
-                    return 1;
-                }
+                xread_noeof(buffer, 1, new_size, input_des);
                 if (bz3_decode_block(state, buffer, new_size, old_size) == -1) {
                     fprintf(stderr, "Failed to decode a block: %s\n", bz3_strerror(state));
                     return 1;
@@ -236,7 +290,7 @@ static int process(FILE * input_des, FILE * output_des, int mode, int block_size
             while (!feof(input_des)) {
                 s32 i = 0;
                 for (; i < workers; i++) {
-                    size_t read_count = fread(buffers[i], 1, block_size, input_des);
+                    size_t read_count = xread(buffers[i], 1, block_size, input_des);
                     sizes[i] = old_sizes[i] = read_count;
                     if (read_count < block_size) {
                         i++;
@@ -252,10 +306,10 @@ static int process(FILE * input_des, FILE * output_des, int mode, int block_size
                 }
                 for (s32 j = 0; j < i; j++) {
                     write_neutral_s32(byteswap_buf, sizes[j]);
-                    fwrite(byteswap_buf, 4, 1, output_des);
+                    xwrite(byteswap_buf, 4, 1, output_des);
                     write_neutral_s32(byteswap_buf, old_sizes[j]);
-                    fwrite(byteswap_buf, 4, 1, output_des);
-                    fwrite(buffers[j], sizes[j], 1, output_des);
+                    xwrite(byteswap_buf, 4, 1, output_des);
+                    xwrite(buffers[j], sizes[j], 1, output_des);
                 }
             }
             fflush(output_des);
@@ -263,17 +317,12 @@ static int process(FILE * input_des, FILE * output_des, int mode, int block_size
             while (!feof(input_des)) {
                 s32 i = 0;
                 for (; i < workers; i++) {
-                    if (fread(&byteswap_buf, 1, 4, input_des) != 4) break;
+                    if (!xread_eofcheck(&byteswap_buf, 1, 4, input_des))
+                        break;
                     sizes[i] = read_neutral_s32(byteswap_buf);
-                    if (fread(&byteswap_buf, 1, 4, input_des) != 4) {
-                        fprintf(stderr, "I/O error.\n");
-                        return 1;
-                    }
+                    xread_noeof(&byteswap_buf, 1, 4, input_des);
                     old_sizes[i] = read_neutral_s32(byteswap_buf);
-                    if (fread(buffers[i], 1, sizes[i], input_des) != sizes[i]) {
-                        fprintf(stderr, "I/O error.\n");
-                        return 1;
-                    }
+                    xread_noeof(buffers[i], 1, sizes[i], input_des);
                 }
                 bz3_decode_blocks(states, buffers, sizes, old_sizes, i);
                 for (s32 j = 0; j < i; j++) {
@@ -283,7 +332,7 @@ static int process(FILE * input_des, FILE * output_des, int mode, int block_size
                     }
                 }
                 for (s32 j = 0; j < i; j++) {
-                    fwrite(buffers[j], old_sizes[j], 1, output_des);
+                    xwrite(buffers[j], old_sizes[j], 1, output_des);
                 }
             }
             fflush(output_des);
@@ -291,17 +340,12 @@ static int process(FILE * input_des, FILE * output_des, int mode, int block_size
             while (!feof(input_des)) {
                 s32 i = 0;
                 for (; i < workers; i++) {
-                    if (fread(&byteswap_buf, 1, 4, input_des) != 4) break;
+                    if (!xread_eofcheck(&byteswap_buf, 1, 4, input_des))
+                        break;
                     sizes[i] = read_neutral_s32(byteswap_buf);
-                    if (fread(&byteswap_buf, 1, 4, input_des) != 4) {
-                        fprintf(stderr, "I/O error.\n");
-                        return 1;
-                    }
+                    xread_noeof(&byteswap_buf, 1, 4, input_des);
                     old_sizes[i] = read_neutral_s32(byteswap_buf);
-                    if (fread(buffers[i], 1, sizes[i], input_des) != sizes[i]) {
-                        fprintf(stderr, "I/O error.\n");
-                        return 1;
-                    }
+                    xread_noeof(buffers[i], 1, sizes[i], input_des);
                 }
                 bz3_decode_blocks(states, buffers, sizes, old_sizes, i);
                 for (s32 j = 0; j < i; j++) {
@@ -335,7 +379,7 @@ static int is_numeric(const char * str) {
     return 1;
 }
 
-FILE * open_output(char * output, int force) {
+static FILE * open_output(char * output, int force) {
     FILE * output_des = NULL;
 
     if (output != NULL) {
@@ -363,7 +407,7 @@ FILE * open_output(char * output, int force) {
     return output_des;
 }
 
-FILE * open_input(char * input) {
+static FILE * open_input(char * input) {
     FILE * input_des = NULL;
 
     if (input != NULL) {
@@ -384,10 +428,6 @@ FILE * open_input(char * input) {
     return input_des;
 }
 
-void close_data_file(FILE * des) {
-    if (des != NULL && des != stdin && des != stdout) fclose(des);
-}
-
 int main(int argc, char * argv[]) {
     int mode = MODE_ENCODE;
 
@@ -499,7 +539,7 @@ int main(int argc, char * argv[]) {
                     if (force_stdstreams)
                         output_name = NULL;
                     else {
-                        output_name = (char *)malloc(strlen(arg) + 5);
+                        output_name = malloc(strlen(arg) + 5);
                         strcpy(output_name, arg);
                         strcat(output_name, ".bz3");
                     }
@@ -507,8 +547,8 @@ int main(int argc, char * argv[]) {
                     FILE * output_des = open_output(output_name, force);
                     process(input_des, output_des, mode, block_size, workers);
 
-                    close_data_file(input_des);
-                    close_data_file(output_des);
+                    fclose(input_des);
+                    close_out_file(output_des);
                     if (!force_stdstreams) free(output_name);
                 }
                 break;
@@ -522,7 +562,7 @@ int main(int argc, char * argv[]) {
                     if (force_stdstreams)
                         output_name = NULL;
                     else {
-                        output_name = (char *)malloc(strlen(arg) + 1);
+                        output_name = malloc(strlen(arg) + 1);
                         strcpy(output_name, arg);
                         if (strlen(output_name) > 4 && !strcmp(output_name + strlen(output_name) - 4, ".bz3"))
                             output_name[strlen(output_name) - 4] = 0;
@@ -535,8 +575,8 @@ int main(int argc, char * argv[]) {
                     FILE * output_des = open_output(output_name, force);
                     process(input_des, output_des, mode, block_size, workers);
 
-                    close_data_file(input_des);
-                    close_data_file(output_des);
+                    fclose(input_des);
+                    close_out_file(output_des);
                     if (!force_stdstreams) free(output_name);
                 }
                 break;
@@ -547,10 +587,16 @@ int main(int argc, char * argv[]) {
 
                     FILE * input_des = open_input(arg);
                     process(input_des, NULL, mode, block_size, workers);
-                    close_data_file(input_des);
+                    fclose(input_des);
                 }
                 break;
         }
+
+        if (fclose(stdout)) {
+            fprintf(stderr, "Error: Failed on fclose(stdout): %s\n", strerror(errno));
+            return 1;
+        }
+
         return 0;
     }
 
@@ -581,7 +627,7 @@ int main(int argc, char * argv[]) {
                 if (force_stdstreams)
                     output = NULL;
                 else {
-                    output = (char *)malloc(strlen(f1) + 5);
+                    output = malloc(strlen(f1) + 5);
                     strcpy(output, f1);
                     strcat(output, ".bz3");
                 }
@@ -597,7 +643,7 @@ int main(int argc, char * argv[]) {
                 if (force_stdstreams)
                     output = NULL;
                 else {
-                    output = (char *)malloc(strlen(f1) + 1);
+                    output = malloc(strlen(f1) + 1);
                     strcpy(output, f1);
                     if (strlen(output) > 4 && !strcmp(output + strlen(output) - 4, ".bz3"))
                         output[strlen(output) - 4] = 0;
@@ -621,8 +667,12 @@ int main(int argc, char * argv[]) {
 
     int r = process(input_des, output_des, mode, block_size, workers);
 
-    close_data_file(input_des);
-    close_data_file(output_des);
+    fclose(input_des);
+    close_out_file(output_des);
+    if (fclose(stdout)) {
+        fprintf(stderr, "Error: Failed on fclose(stdout): %s\n", strerror(errno));
+        return 1;
+    }
 
     return r;
 }
tab: 248 wrap: offon