incorporate stree encoding into tANS
authorGeoffrey Allott <geoffrey@allott.email>
Tue, 23 Aug 2022 19:47:05 +0000 (20:47 +0100)
committerGeoffrey Allott <geoffrey@allott.email>
Tue, 23 Aug 2022 19:47:05 +0000 (20:47 +0100)
src/tANS.c

index 3f0f99dc68688429a0b05ac080b7c567f9a28704..003a8663b316bab3f36daf9962c11db8ac5069b2 100644 (file)
@@ -5,13 +5,18 @@ author: Geoffrey Allott <geoffrey@allott.email>
 ref: https://arxiv.org/abs/1311.2540
 */
 
+#include "stree.h"
 #include "tANS_encode_st.h"
 #include "tANS_decode_st.h"
 
+#include <errno.h>
 #include <stdio.h>
+#include <stdlib.h>
 #include <string.h>
 #include <unistd.h>
 
+#define MAX_BUFSZ 1048576
+
 static void usage(void)
 {
     printf(
@@ -22,76 +27,124 @@ static void usage(void)
     );
 }
 
+static size_t tANS_max_compressed_size(size_t len)
+{
+    return len * 2;
+}
+
 static int tANS_compress_file(FILE* input, FILE *output)
 {
     uint32_t i, len, bits;
-    uint8_t read_buf[1048576];
-    uint8_t write_buf[2097152] = {0};
+    uint8_t *read_buf;
+    uint8_t *enc_buf;
+    size_t *aux_buf;
+    uint8_t *write_buf;
     double p[256] = {0};
     struct tANS_freq_tbl freq_tbl;
     struct tANS_symbol_tbl symbol_tbl;
     struct tANS_encode_st st;
-    const uint16_t log2_tblsz = 12;
+    const uint16_t log2_tblsz = 10;
     uint32_t read_sz = 8;
 
+    read_buf = malloc(MAX_BUFSZ);
+    enc_buf = malloc(MAX_BUFSZ);
+    aux_buf = malloc(MAX_BUFSZ * sizeof(size_t));
+    write_buf = calloc(tANS_max_compressed_size(MAX_BUFSZ), 1);
+
+    if (!read_buf || !enc_buf || !aux_buf || !write_buf) goto fail;
+
     while (!feof(input)) {
-        if (tANS_freq_tbl_init(&freq_tbl, 256, p, log2_tblsz) != 0) return -1;
-        if (tANS_symbol_tbl_init(&symbol_tbl, &freq_tbl) != 0) return -1;
+        if (tANS_freq_tbl_init(&freq_tbl, 256, p, log2_tblsz) != 0) goto fail;
+        if (tANS_symbol_tbl_init(&symbol_tbl, &freq_tbl) != 0) goto fail;
         tANS_encode_st_init(&st, &symbol_tbl);
 
         len = (uint32_t) fread(read_buf, 1, read_sz, input);
-        for (i = 0; i < len; ++i) ++p[read_buf[i]];
-        if (fwrite(&len, sizeof len, 1, output) != 1) return -1;
-        bits = tANS_encode(&st, read_buf, len, write_buf);
-        if (fwrite(&bits, sizeof bits, 1, output) != 1) return -1;
-        if (fwrite(&st.x, sizeof st.x, 1, output) != 1) return -1;
-        if (fwrite(write_buf, (bits + 7) / 8, 1, output) != 1) return -1;
+        if (stree_encode(len, read_buf, enc_buf, aux_buf) != 0) goto fail;
+        for (i = 0; i < len; ++i) ++p[enc_buf[i]];
+        if (fwrite(&len, sizeof len, 1, output) != 1) goto fail;
+        bits = tANS_encode(&st, enc_buf, len, write_buf);
+        if (fwrite(&bits, sizeof bits, 1, output) != 1) goto fail;
+        if (fwrite(&st.x, sizeof st.x, 1, output) != 1) goto fail;
+        if (fwrite(write_buf, (bits + 7) / 8, 1, output) != 1) goto fail;
 
         memset(write_buf, 0, (bits + 7) / 8);
 
         read_sz *= 2;
-        if (read_sz > sizeof read_buf) read_sz = sizeof read_buf;
+        if (read_sz > MAX_BUFSZ) read_sz = MAX_BUFSZ;
     }
 
+    free(read_buf);
+    free(enc_buf);
+    free(aux_buf);
+    free(write_buf);
     return 0;
+
+fail:
+    free(read_buf);
+    free(enc_buf);
+    free(aux_buf);
+    free(write_buf);
+    return -1;
 }
 
 static int tANS_decompress_file(FILE* input, FILE *output)
 {
     uint32_t i, len, bits;
-    uint8_t read_buf[2097152];
-    uint8_t write_buf[1048576];
+    uint8_t *read_buf;
+    uint8_t *enc_buf;
+    size_t *aux_buf;
+    uint8_t *write_buf;
     double p[256] = {0};
     struct tANS_freq_tbl freq_tbl;
     struct tANS_symbol_tbl symbol_tbl;
     struct tANS_decode_st st;
-    const uint16_t log2_tblsz = 12;
+    const uint16_t log2_tblsz = 10;
+
+    read_buf = malloc(MAX_BUFSZ);
+    enc_buf = malloc(MAX_BUFSZ);
+    aux_buf = malloc(MAX_BUFSZ * sizeof(size_t));
+    write_buf = malloc(tANS_max_compressed_size(MAX_BUFSZ));
+
+    if (!read_buf || !enc_buf || !aux_buf || !write_buf) goto fail;
 
     while (!feof(input)) {
-        if (tANS_freq_tbl_init(&freq_tbl, 256, p, log2_tblsz) != 0) return -1;
-        if (tANS_symbol_tbl_init(&symbol_tbl, &freq_tbl) != 0) return -1;
+        if (tANS_freq_tbl_init(&freq_tbl, 256, p, log2_tblsz) != 0) goto fail;
+        if (tANS_symbol_tbl_init(&symbol_tbl, &freq_tbl) != 0) goto fail;
         tANS_decode_st_init(&st, &symbol_tbl);
 
-        if (fread(&len, sizeof len, 1, input) != 1) return -1;
-        if (fread(&bits, sizeof bits, 1, input) != 1) return -1;
-        if (fread(&st.x, sizeof st.x, 1, input) != 1) return -1;
-        if (fread(read_buf + 4, (bits + 7) / 8, 1, input) != 1) return -1;
+        if (fread(&len, sizeof len, 1, input) != 1) goto fail;
+        if (fread(&bits, sizeof bits, 1, input) != 1) goto fail;
+        if (fread(&st.x, sizeof st.x, 1, input) != 1) goto fail;
+        if (fread(read_buf + 4, (bits + 7) / 8, 1, input) != 1) goto fail;
         st.x &= symbol_tbl.tblsz - 1;
-        bits = tANS_decode(&st, write_buf, len, read_buf + 4, bits);
+        bits = tANS_decode(&st, enc_buf, len, read_buf + 4, bits);
         if (bits != 0) {
             fprintf(stderr, "tANS: corrupted file\n");
-            return -1;
+            goto fail;
         }
-        if (fwrite(write_buf, len, 1, output) != 1) return -1;
-        for (i = 0; i < len; ++i) ++p[write_buf[i]];
+        if (stree_decode(len, enc_buf, write_buf, aux_buf) != 0) goto fail;
+        if (fwrite(write_buf, len, 1, output) != 1) goto fail;
+        for (i = 0; i < len; ++i) ++p[enc_buf[i]];
     }
 
+    free(read_buf);
+    free(enc_buf);
+    free(aux_buf);
+    free(write_buf);
     return 0;
+
+fail:
+    free(read_buf);
+    free(enc_buf);
+    free(aux_buf);
+    free(write_buf);
+    return -1;
 }
 
 int main(int argc, char *argv[])
 {
-    int opt, compress = 1;
+    int ret, opt, compress = 1;
+    FILE *input = stdin, *output = stdout;
 
     while ((opt = getopt(argc, argv, "hcd")) != -1) {
         switch (opt) {
@@ -110,11 +163,33 @@ int main(int argc, char *argv[])
     argv += optind;
     argc -= optind;
 
-    if (argc == 0) {
-        if (compress) {
-            return tANS_compress_file(stdin, stdout) == 0;
-        } else {
-            return tANS_decompress_file(stdin, stdout) == 0;
+    if (argc > 2)
+        return 3;
+
+    if (argc >= 1) {
+        input = fopen(argv[0], "rb");
+        if (!input) {
+            fprintf(stderr, "tANS: fopen: %s: %s\n", argv[0], strerror(errno));
+            return 2;
+        }
+    }
+
+    if (argc >= 2) {
+        output = fopen(argv[1], "wb");
+        if (!output) {
+            fprintf(stderr, "tANS: fopen: %s: %s\n", argv[0], strerror(errno));
+            fclose(input);
+            return 2;
         }
     }
+
+    if (compress) {
+        ret = tANS_compress_file(input, output) == 0;
+    } else {
+        ret = tANS_decompress_file(input, output) == 0;
+    }
+
+    fclose(input);
+    fclose(output);
+    return ret;
 }