consolidate symbol_tbl-setting logic
authorGeoffrey Allott <geoffrey@allott.email>
Sat, 3 Sep 2022 20:17:55 +0000 (21:17 +0100)
committerGeoffrey Allott <geoffrey@allott.email>
Sat, 3 Sep 2022 20:17:55 +0000 (21:17 +0100)
src/tANS.c

index ea6a03bfa2a84e430d306e29ccc864625641f180..b55c47e69f6294a9f7e6f94d8c67f09c965af496 100644 (file)
@@ -45,22 +45,83 @@ static size_t tANS_max_compressed_size(size_t len)
     return len * 2;
 }
 
+static int tANS_init_symbol_tbls(struct tANS_symbol_tbl symbol_tbls[static 3], const uint8_t *buf, uint32_t len)
+{
+    uint32_t i, j, count, sum, zeroes;
+    struct tANS_freq_tbl tbls[3];
+
+    tbls[0].n_symbols = N_SYMBOLS;
+    tbls[1].n_symbols = N_SYMBOLS;
+    tbls[2].n_symbols = N_SYMBOLS;
+
+    tbls[0].log2_tblsz = LOG2_TBLSZ;
+    tbls[1].log2_tblsz = LOG2_TBLSZ;
+    tbls[2].log2_tblsz = LOG2_TBLSZ;
+
+    for (i = 0; i < N_SYMBOLS; ++i) {
+        tbls[0].freq[i] = 0;
+        tbls[1].freq[i] = 0;
+        tbls[2].freq[i] = 0;
+    }
+
+    for (i = 0; i < len; ++i) {
+        ++tbls[0].freq[buf[i]];
+        if (buf[i] == 0) {
+            do {
+                count = 0;
+                while (i < len - 1 && buf[++i] == 0 && count < 255) ++count;
+                ++tbls[1].freq[count];
+            } while (i < len - 1 && buf[i] == 0);
+            if (i < len)
+                ++tbls[2].freq[buf[i]];
+        }
+    }
+
+    for (j = 0; j < 3; ++j) {
+        for (i = 0, sum = 0, zeroes = 0; i < N_SYMBOLS; ++i) {
+            sum += tbls[j].freq[i];
+            if (!(j == 2 && i == 0))
+                zeroes += tbls[j].freq[i] == 0;
+        }
+        if (sum == 0) {
+            for (i = 0; i < N_SYMBOLS; ++i)
+                tbls[j].freq[i] = (1 << LOG2_TBLSZ) / N_SYMBOLS;
+        } else {
+            for (i = 0; i < N_SYMBOLS; ++i)
+                tbls[j].freq[i] = (uint16_t) ((tbls[j].freq[i] * ((1 << LOG2_TBLSZ) - zeroes) + sum / 2) / sum);
+            for (i = 0; i < N_SYMBOLS; ++i)
+                if (tbls[j].freq[i] == 0 && !(j == 2 && i == 0))
+                    tbls[j].freq[i] = 1;
+            for (i = 0, sum = 0, zeroes = 0; i < N_SYMBOLS; ++i)
+                sum += tbls[j].freq[i];
+            for (i = 0; i < N_SYMBOLS && sum < 1 << LOG2_TBLSZ; ++i)
+                if (!(j == 2 && i == 0))
+                    ++tbls[j].freq[i], ++sum;
+            for (i = 0; i < N_SYMBOLS && sum > 1 << LOG2_TBLSZ; ++i)
+                if (tbls[j].freq[i] > 1)
+                    --tbls[j].freq[i], --sum;
+        }
+    }
+
+    if (tANS_symbol_tbl_init(&symbol_tbls[0], &tbls[0]) != 0) return -1;
+    if (tANS_symbol_tbl_init(&symbol_tbls[1], &tbls[1]) != 0) return -1;
+    if (tANS_symbol_tbl_init(&symbol_tbls[2], &tbls[2]) != 0) return -1;
+
+    return 0;
+}
+
 static int stree_tANS_compress_file(FILE* input, FILE *output)
 {
-    uint32_t i, len, bits, count;
+    uint32_t len, bits;
     uint8_t *read_buf;
     uint8_t *enc_buf;
     uint8_t *aux_buf;
     uint8_t *write_buf;
-    double p[N_AUX][N_SYMBOLS] = {0};
-    struct tANS_freq_tbl *freq_tbls;
     struct tANS_symbol_tbl *symbol_tbls;
     struct tANS_rl_encode_st *st;
-    const uint16_t log2_tblsz = LOG2_TBLSZ;
     uint32_t read_sz = 1024;
     uint32_t magic = TANS_MAGIC;
 
-    freq_tbls = malloc(sizeof(struct tANS_freq_tbl) * N_AUX);
     symbol_tbls = malloc(sizeof(struct tANS_symbol_tbl) * N_AUX);
     st = malloc(sizeof(struct tANS_rl_encode_st));
     read_buf = malloc(MAX_BUFSZ);
@@ -68,34 +129,18 @@ static int stree_tANS_compress_file(FILE* input, FILE *output)
     aux_buf = malloc(MAX_BUFSZ);
     write_buf = calloc(tANS_max_compressed_size(MAX_BUFSZ), 1);
 
-    if (!freq_tbls || !symbol_tbls || !st || !read_buf || !enc_buf || !aux_buf || !write_buf) goto fail;
+    if (!symbol_tbls || !st || !read_buf || !enc_buf || !aux_buf || !write_buf) goto fail;
 
     if (fwrite(&magic, sizeof magic, 1, output) != 1) goto fail;
 
+    if (tANS_init_symbol_tbls(symbol_tbls, (const uint8_t *) "", 0) != 0) goto fail;
+
     while (!feof(input)) {
-        for (i = 0; i < N_AUX; ++i) {
-            if (tANS_freq_tbl_init(freq_tbls + i, N_SYMBOLS, p[i], log2_tblsz) != 0) goto fail;
-            if (tANS_symbol_tbl_init(symbol_tbls + i, freq_tbls + i) != 0) goto fail;
-        }
         tANS_rl_encode_st_init(st, symbol_tbls);
 
         len = (uint32_t) fread(read_buf, 1, read_sz, input);
         if (len == 0) break;
         if (stree_encode(len, read_buf, enc_buf, aux_buf) != 0) goto fail;
-
-        for (i = 0; i < len; ++i) {
-            ++p[0][enc_buf[i]];
-            if (enc_buf[i] == 0) {
-                do {
-                    count = 0;
-                    while (i < len - 1 && enc_buf[++i] == 0 && count < 255) ++count;
-                    ++p[1][count];
-                } while (i < len - 1 && enc_buf[i] == 0);
-                if (i < len)
-                    ++p[2][enc_buf[i]];
-            }
-        }
-
         if (fwrite(&len, sizeof len, 1, output) != 1) goto fail;
         bits = tANS_rl_encode(st, enc_buf, len, write_buf);
         if (fwrite(&bits, sizeof bits, 1, output) != 1) goto fail;
@@ -106,11 +151,11 @@ static int stree_tANS_compress_file(FILE* input, FILE *output)
 
         read_sz *= 2;
         if (read_sz > MAX_BUFSZ) read_sz = MAX_BUFSZ;
+        if (tANS_init_symbol_tbls(symbol_tbls, enc_buf, len) != 0) goto fail;
     }
 
     if (ferror(input)) goto fail;
 
-    free(freq_tbls);
     free(symbol_tbls);
     free(st);
     free(read_buf);
@@ -120,7 +165,6 @@ static int stree_tANS_compress_file(FILE* input, FILE *output)
     return 0;
 
 fail:
-    free(freq_tbls);
     free(symbol_tbls);
     free(st);
     free(read_buf);
@@ -132,19 +176,15 @@ fail:
 
 static int stree_tANS_decompress_file(FILE* input, FILE *output)
 {
-    uint32_t i, len, bits, count;
+    uint32_t len, bits;
     uint8_t *read_buf;
     uint8_t *enc_buf;
     uint8_t *aux_buf;
     uint8_t *write_buf;
-    double p[N_AUX][N_SYMBOLS] = {0};
-    struct tANS_freq_tbl *freq_tbls;
     struct tANS_symbol_tbl *symbol_tbls;
     struct tANS_rl_decode_st *st;
-    const uint16_t log2_tblsz = LOG2_TBLSZ;
     uint32_t magic;
 
-    freq_tbls = malloc(sizeof(struct tANS_freq_tbl) * N_AUX);
     symbol_tbls = malloc(sizeof(struct tANS_symbol_tbl) * N_AUX);
     st = malloc(sizeof(struct tANS_rl_decode_st));
     read_buf = malloc(tANS_max_compressed_size(MAX_BUFSZ));
@@ -152,7 +192,7 @@ static int stree_tANS_decompress_file(FILE* input, FILE *output)
     aux_buf = malloc(MAX_BUFSZ);
     write_buf = malloc(MAX_BUFSZ);
 
-    if (!freq_tbls || !symbol_tbls || !st || !read_buf || !enc_buf || !aux_buf || !write_buf) goto fail;
+    if (!symbol_tbls || !st || !read_buf || !enc_buf || !aux_buf || !write_buf) goto fail;
 
     if (fread(&magic, sizeof magic, 1, input) != 1) goto fail;
     if (magic != TANS_MAGIC) {
@@ -160,11 +200,9 @@ static int stree_tANS_decompress_file(FILE* input, FILE *output)
         goto fail;
     }
 
+    if (tANS_init_symbol_tbls(symbol_tbls, (const uint8_t *) "", 0) != 0) goto fail;
+
     while (!feof(input)) {
-        for (i = 0; i < N_AUX; ++i) {
-            if (tANS_freq_tbl_init(freq_tbls + i, N_SYMBOLS, p[i], log2_tblsz) != 0) goto fail;
-            if (tANS_symbol_tbl_init(symbol_tbls + i, freq_tbls + i) != 0) goto fail;
-        }
         tANS_rl_decode_st_init(st, symbol_tbls);
 
         if (fread(&len, sizeof len, 1, input) != 1) break;
@@ -179,21 +217,9 @@ static int stree_tANS_decompress_file(FILE* input, FILE *output)
         }
         if (stree_decode(len, enc_buf, write_buf, aux_buf) != 0) goto fail;
         if (fwrite(write_buf, len, 1, output) != 1) goto fail;
-        for (i = 0; i < len; ++i) {
-            ++p[0][enc_buf[i]];
-            if (enc_buf[i] == 0) {
-                do {
-                    count = 0;
-                    while (i < len - 1 && enc_buf[++i] == 0 && count < 255) ++count;
-                    ++p[1][count];
-                } while (i < len - 1 && enc_buf[i] == 0);
-                if (i < len)
-                    ++p[2][enc_buf[i]];
-            }
-        }
+        if (tANS_init_symbol_tbls(symbol_tbls, enc_buf, len) != 0) goto fail;
     }
 
-    free(freq_tbls);
     free(symbol_tbls);
     free(st);
     free(read_buf);
@@ -203,7 +229,6 @@ static int stree_tANS_decompress_file(FILE* input, FILE *output)
     return 0;
 
 fail:
-    free(freq_tbls);
     free(symbol_tbls);
     free(st);
     free(read_buf);
@@ -215,53 +240,32 @@ fail:
 
 static int tANS_compress_file(FILE* input, FILE *output)
 {
-    uint32_t i, len, bits, count, bits_and_x;
+    uint32_t len, bits, bits_and_x;
     uint16_t u16_len;
     uint8_t *read_buf;
     uint8_t *write_buf;
-    double p[N_AUX][N_SYMBOLS] = {0};
-    struct tANS_freq_tbl *freq_tbls;
     struct tANS_symbol_tbl *symbol_tbls;
     struct tANS_rl_encode_st *st;
-    const uint16_t log2_tblsz = LOG2_TBLSZ;
     uint32_t read_sz = INIT_READSZ;
     uint32_t magic = TANS_ONLY_MAGIC;
 
-    freq_tbls = malloc(sizeof(struct tANS_freq_tbl) * N_AUX);
     symbol_tbls = malloc(sizeof(struct tANS_symbol_tbl) * N_AUX);
     st = malloc(sizeof(struct tANS_rl_encode_st));
     read_buf = malloc(MAX_BUFSZ);
     write_buf = calloc(tANS_max_compressed_size(MAX_BUFSZ), 1);
 
-    if (!freq_tbls || !symbol_tbls || !st || !read_buf || !write_buf) goto fail;
+    if (!symbol_tbls || !st || !read_buf || !write_buf) goto fail;
 
     if (fwrite(&magic, sizeof magic, 1, output) != 1) goto fail;
 
+    if (tANS_init_symbol_tbls(symbol_tbls, (const uint8_t *) "", 0) != 0) goto fail;
+
     while (!feof(input)) {
-        for (i = 0; i < N_AUX; ++i) {
-            if (tANS_freq_tbl_init(freq_tbls + i, N_SYMBOLS, p[i], log2_tblsz) != 0) goto fail;
-            if (tANS_symbol_tbl_init(symbol_tbls + i, freq_tbls + i) != 0) goto fail;
-        }
         tANS_rl_encode_st_init(st, symbol_tbls);
 
         len = (uint32_t) fread(read_buf, 1, read_sz, input);
         if (len == 0) break;
 
-        memset(p, 0, sizeof p);
-
-        for (i = 0; i < len; ++i) {
-            ++p[0][read_buf[i]];
-            if (read_buf[i] == 0) {
-                do {
-                    count = 0;
-                    while (i < len - 1 && read_buf[++i] == 0 && count < 255) ++count;
-                    ++p[1][count];
-                } while (i < len - 1 && read_buf[i] == 0);
-                if (i < len)
-                    ++p[2][read_buf[i]];
-            }
-        }
-
         st->x += read_buf[len-1];
         bits = tANS_rl_encode(st, read_buf, len - 1, write_buf);
         u16_len = (uint16_t) (len - 1);
@@ -274,11 +278,11 @@ static int tANS_compress_file(FILE* input, FILE *output)
 
         read_sz *= 2;
         if (read_sz > MAX_BUFSZ) read_sz = MAX_BUFSZ;
+        if (tANS_init_symbol_tbls(symbol_tbls, read_buf, len - 1) != 0) goto fail;
     }
 
     if (ferror(input)) goto fail;
 
-    free(freq_tbls);
     free(symbol_tbls);
     free(st);
     free(read_buf);
@@ -286,7 +290,6 @@ static int tANS_compress_file(FILE* input, FILE *output)
     return 0;
 
 fail:
-    free(freq_tbls);
     free(symbol_tbls);
     free(st);
     free(read_buf);
@@ -296,24 +299,20 @@ fail:
 
 static int tANS_decompress_file(FILE* input, FILE *output)
 {
-    uint32_t i, len, bits, count, bits_and_x;
+    uint32_t len, bits, bits_and_x;
     uint16_t u16_len;
     uint8_t *read_buf;
     uint8_t *write_buf;
-    double p[N_AUX][N_SYMBOLS] = {0};
-    struct tANS_freq_tbl *freq_tbls;
     struct tANS_symbol_tbl *symbol_tbls;
     struct tANS_rl_decode_st *st;
-    const uint16_t log2_tblsz = LOG2_TBLSZ;
     uint32_t magic;
 
-    freq_tbls = malloc(sizeof(struct tANS_freq_tbl) * N_AUX);
     symbol_tbls = malloc(sizeof(struct tANS_symbol_tbl) * N_AUX);
     st = malloc(sizeof(struct tANS_rl_decode_st));
     read_buf = malloc(tANS_max_compressed_size(MAX_BUFSZ));
     write_buf = malloc(MAX_BUFSZ);
 
-    if (!freq_tbls || !symbol_tbls || !st || !read_buf || !write_buf) goto fail;
+    if (!symbol_tbls || !st || !read_buf || !write_buf) goto fail;
 
     if (fread(&magic, sizeof magic, 1, input) != 1) goto fail;
     if (magic != TANS_ONLY_MAGIC) {
@@ -321,13 +320,10 @@ static int tANS_decompress_file(FILE* input, FILE *output)
         goto fail;
     }
 
+    if (tANS_init_symbol_tbls(symbol_tbls, (const uint8_t *) "", 0) != 0) goto fail;
+
     while (!feof(input)) {
-        for (i = 0; i < N_AUX; ++i) {
-            if (tANS_freq_tbl_init(freq_tbls + i, N_SYMBOLS, p[i], log2_tblsz) != 0) goto fail;
-            if (tANS_symbol_tbl_init(symbol_tbls + i, freq_tbls + i) != 0) goto fail;
-        }
         tANS_rl_decode_st_init(st, symbol_tbls);
-
         if (fread(&u16_len, sizeof u16_len, 1, input) != 1) break;
         if (fread(&bits_and_x, sizeof bits_and_x, 1, input) != 1) goto fail;
         len = (uint32_t) u16_len + 1;
@@ -341,23 +337,9 @@ static int tANS_decompress_file(FILE* input, FILE *output)
         }
         write_buf[len-1] = (uint8_t) st->x;
         if (fwrite(write_buf, len, 1, output) != 1) goto fail;
-        memset(p, 0, sizeof p);
-
-        for (i = 0; i < len; ++i) {
-            ++p[0][write_buf[i]];
-            if (write_buf[i] == 0) {
-                do {
-                    count = 0;
-                    while (i < len - 1 && write_buf[++i] == 0 && count < 255) ++count;
-                    ++p[1][count];
-                } while (i < len - 1 && write_buf[i] == 0);
-                if (i < len)
-                    ++p[2][write_buf[i]];
-            }
-        }
+        if (tANS_init_symbol_tbls(symbol_tbls, write_buf, len - 1) != 0) goto fail;
     }
 
-    free(freq_tbls);
     free(symbol_tbls);
     free(st);
     free(read_buf);
@@ -365,7 +347,6 @@ static int tANS_decompress_file(FILE* input, FILE *output)
     return 0;
 
 fail:
-    free(freq_tbls);
     free(symbol_tbls);
     free(st);
     free(read_buf);