From: Geoffrey Allott Date: Sat, 3 Sep 2022 20:17:55 +0000 (+0100) Subject: consolidate symbol_tbl-setting logic X-Git-Url: https://git.pointlesshacks.com/?a=commitdiff_plain;h=007b05f7b6adacd4ac0bba2db7ecad0cffcecac1;p=tANS.git consolidate symbol_tbl-setting logic --- diff --git a/src/tANS.c b/src/tANS.c index ea6a03b..b55c47e 100644 --- a/src/tANS.c +++ b/src/tANS.c @@ -45,22 +45,83 @@ static size_t tANS_max_compressed_size(size_t len) return len * 2; } +static int tANS_init_symbol_tbls(struct tANS_symbol_tbl symbol_tbls[static 3], const uint8_t *buf, uint32_t len) +{ + uint32_t i, j, count, sum, zeroes; + struct tANS_freq_tbl tbls[3]; + + tbls[0].n_symbols = N_SYMBOLS; + tbls[1].n_symbols = N_SYMBOLS; + tbls[2].n_symbols = N_SYMBOLS; + + tbls[0].log2_tblsz = LOG2_TBLSZ; + tbls[1].log2_tblsz = LOG2_TBLSZ; + tbls[2].log2_tblsz = LOG2_TBLSZ; + + for (i = 0; i < N_SYMBOLS; ++i) { + tbls[0].freq[i] = 0; + tbls[1].freq[i] = 0; + tbls[2].freq[i] = 0; + } + + for (i = 0; i < len; ++i) { + ++tbls[0].freq[buf[i]]; + if (buf[i] == 0) { + do { + count = 0; + while (i < len - 1 && buf[++i] == 0 && count < 255) ++count; + ++tbls[1].freq[count]; + } while (i < len - 1 && buf[i] == 0); + if (i < len) + ++tbls[2].freq[buf[i]]; + } + } + + for (j = 0; j < 3; ++j) { + for (i = 0, sum = 0, zeroes = 0; i < N_SYMBOLS; ++i) { + sum += tbls[j].freq[i]; + if (!(j == 2 && i == 0)) + zeroes += tbls[j].freq[i] == 0; + } + if (sum == 0) { + for (i = 0; i < N_SYMBOLS; ++i) + tbls[j].freq[i] = (1 << LOG2_TBLSZ) / N_SYMBOLS; + } else { + for (i = 0; i < N_SYMBOLS; ++i) + tbls[j].freq[i] = (uint16_t) ((tbls[j].freq[i] * ((1 << LOG2_TBLSZ) - zeroes) + sum / 2) / sum); + for (i = 0; i < N_SYMBOLS; ++i) + if (tbls[j].freq[i] == 0 && !(j == 2 && i == 0)) + tbls[j].freq[i] = 1; + for (i = 0, sum = 0, zeroes = 0; i < N_SYMBOLS; ++i) + sum += tbls[j].freq[i]; + for (i = 0; i < N_SYMBOLS && sum < 1 << LOG2_TBLSZ; ++i) + if (!(j == 2 && i == 0)) + ++tbls[j].freq[i], ++sum; + for (i = 0; i < N_SYMBOLS && sum > 1 << LOG2_TBLSZ; ++i) + if (tbls[j].freq[i] > 1) + --tbls[j].freq[i], --sum; + } + } + + if (tANS_symbol_tbl_init(&symbol_tbls[0], &tbls[0]) != 0) return -1; + if (tANS_symbol_tbl_init(&symbol_tbls[1], &tbls[1]) != 0) return -1; + if (tANS_symbol_tbl_init(&symbol_tbls[2], &tbls[2]) != 0) return -1; + + return 0; +} + static int stree_tANS_compress_file(FILE* input, FILE *output) { - uint32_t i, len, bits, count; + uint32_t len, bits; uint8_t *read_buf; uint8_t *enc_buf; uint8_t *aux_buf; uint8_t *write_buf; - double p[N_AUX][N_SYMBOLS] = {0}; - struct tANS_freq_tbl *freq_tbls; struct tANS_symbol_tbl *symbol_tbls; struct tANS_rl_encode_st *st; - const uint16_t log2_tblsz = LOG2_TBLSZ; uint32_t read_sz = 1024; uint32_t magic = TANS_MAGIC; - freq_tbls = malloc(sizeof(struct tANS_freq_tbl) * N_AUX); symbol_tbls = malloc(sizeof(struct tANS_symbol_tbl) * N_AUX); st = malloc(sizeof(struct tANS_rl_encode_st)); read_buf = malloc(MAX_BUFSZ); @@ -68,34 +129,18 @@ static int stree_tANS_compress_file(FILE* input, FILE *output) aux_buf = malloc(MAX_BUFSZ); write_buf = calloc(tANS_max_compressed_size(MAX_BUFSZ), 1); - if (!freq_tbls || !symbol_tbls || !st || !read_buf || !enc_buf || !aux_buf || !write_buf) goto fail; + if (!symbol_tbls || !st || !read_buf || !enc_buf || !aux_buf || !write_buf) goto fail; if (fwrite(&magic, sizeof magic, 1, output) != 1) goto fail; + if (tANS_init_symbol_tbls(symbol_tbls, (const uint8_t *) "", 0) != 0) goto fail; + while (!feof(input)) { - for (i = 0; i < N_AUX; ++i) { - if (tANS_freq_tbl_init(freq_tbls + i, N_SYMBOLS, p[i], log2_tblsz) != 0) goto fail; - if (tANS_symbol_tbl_init(symbol_tbls + i, freq_tbls + i) != 0) goto fail; - } tANS_rl_encode_st_init(st, symbol_tbls); len = (uint32_t) fread(read_buf, 1, read_sz, input); if (len == 0) break; if (stree_encode(len, read_buf, enc_buf, aux_buf) != 0) goto fail; - - for (i = 0; i < len; ++i) { - ++p[0][enc_buf[i]]; - if (enc_buf[i] == 0) { - do { - count = 0; - while (i < len - 1 && enc_buf[++i] == 0 && count < 255) ++count; - ++p[1][count]; - } while (i < len - 1 && enc_buf[i] == 0); - if (i < len) - ++p[2][enc_buf[i]]; - } - } - if (fwrite(&len, sizeof len, 1, output) != 1) goto fail; bits = tANS_rl_encode(st, enc_buf, len, write_buf); if (fwrite(&bits, sizeof bits, 1, output) != 1) goto fail; @@ -106,11 +151,11 @@ static int stree_tANS_compress_file(FILE* input, FILE *output) read_sz *= 2; if (read_sz > MAX_BUFSZ) read_sz = MAX_BUFSZ; + if (tANS_init_symbol_tbls(symbol_tbls, enc_buf, len) != 0) goto fail; } if (ferror(input)) goto fail; - free(freq_tbls); free(symbol_tbls); free(st); free(read_buf); @@ -120,7 +165,6 @@ static int stree_tANS_compress_file(FILE* input, FILE *output) return 0; fail: - free(freq_tbls); free(symbol_tbls); free(st); free(read_buf); @@ -132,19 +176,15 @@ fail: static int stree_tANS_decompress_file(FILE* input, FILE *output) { - uint32_t i, len, bits, count; + uint32_t len, bits; uint8_t *read_buf; uint8_t *enc_buf; uint8_t *aux_buf; uint8_t *write_buf; - double p[N_AUX][N_SYMBOLS] = {0}; - struct tANS_freq_tbl *freq_tbls; struct tANS_symbol_tbl *symbol_tbls; struct tANS_rl_decode_st *st; - const uint16_t log2_tblsz = LOG2_TBLSZ; uint32_t magic; - freq_tbls = malloc(sizeof(struct tANS_freq_tbl) * N_AUX); symbol_tbls = malloc(sizeof(struct tANS_symbol_tbl) * N_AUX); st = malloc(sizeof(struct tANS_rl_decode_st)); read_buf = malloc(tANS_max_compressed_size(MAX_BUFSZ)); @@ -152,7 +192,7 @@ static int stree_tANS_decompress_file(FILE* input, FILE *output) aux_buf = malloc(MAX_BUFSZ); write_buf = malloc(MAX_BUFSZ); - if (!freq_tbls || !symbol_tbls || !st || !read_buf || !enc_buf || !aux_buf || !write_buf) goto fail; + if (!symbol_tbls || !st || !read_buf || !enc_buf || !aux_buf || !write_buf) goto fail; if (fread(&magic, sizeof magic, 1, input) != 1) goto fail; if (magic != TANS_MAGIC) { @@ -160,11 +200,9 @@ static int stree_tANS_decompress_file(FILE* input, FILE *output) goto fail; } + if (tANS_init_symbol_tbls(symbol_tbls, (const uint8_t *) "", 0) != 0) goto fail; + while (!feof(input)) { - for (i = 0; i < N_AUX; ++i) { - if (tANS_freq_tbl_init(freq_tbls + i, N_SYMBOLS, p[i], log2_tblsz) != 0) goto fail; - if (tANS_symbol_tbl_init(symbol_tbls + i, freq_tbls + i) != 0) goto fail; - } tANS_rl_decode_st_init(st, symbol_tbls); if (fread(&len, sizeof len, 1, input) != 1) break; @@ -179,21 +217,9 @@ static int stree_tANS_decompress_file(FILE* input, FILE *output) } if (stree_decode(len, enc_buf, write_buf, aux_buf) != 0) goto fail; if (fwrite(write_buf, len, 1, output) != 1) goto fail; - for (i = 0; i < len; ++i) { - ++p[0][enc_buf[i]]; - if (enc_buf[i] == 0) { - do { - count = 0; - while (i < len - 1 && enc_buf[++i] == 0 && count < 255) ++count; - ++p[1][count]; - } while (i < len - 1 && enc_buf[i] == 0); - if (i < len) - ++p[2][enc_buf[i]]; - } - } + if (tANS_init_symbol_tbls(symbol_tbls, enc_buf, len) != 0) goto fail; } - free(freq_tbls); free(symbol_tbls); free(st); free(read_buf); @@ -203,7 +229,6 @@ static int stree_tANS_decompress_file(FILE* input, FILE *output) return 0; fail: - free(freq_tbls); free(symbol_tbls); free(st); free(read_buf); @@ -215,53 +240,32 @@ fail: static int tANS_compress_file(FILE* input, FILE *output) { - uint32_t i, len, bits, count, bits_and_x; + uint32_t len, bits, bits_and_x; uint16_t u16_len; uint8_t *read_buf; uint8_t *write_buf; - double p[N_AUX][N_SYMBOLS] = {0}; - struct tANS_freq_tbl *freq_tbls; struct tANS_symbol_tbl *symbol_tbls; struct tANS_rl_encode_st *st; - const uint16_t log2_tblsz = LOG2_TBLSZ; uint32_t read_sz = INIT_READSZ; uint32_t magic = TANS_ONLY_MAGIC; - freq_tbls = malloc(sizeof(struct tANS_freq_tbl) * N_AUX); symbol_tbls = malloc(sizeof(struct tANS_symbol_tbl) * N_AUX); st = malloc(sizeof(struct tANS_rl_encode_st)); read_buf = malloc(MAX_BUFSZ); write_buf = calloc(tANS_max_compressed_size(MAX_BUFSZ), 1); - if (!freq_tbls || !symbol_tbls || !st || !read_buf || !write_buf) goto fail; + if (!symbol_tbls || !st || !read_buf || !write_buf) goto fail; if (fwrite(&magic, sizeof magic, 1, output) != 1) goto fail; + if (tANS_init_symbol_tbls(symbol_tbls, (const uint8_t *) "", 0) != 0) goto fail; + while (!feof(input)) { - for (i = 0; i < N_AUX; ++i) { - if (tANS_freq_tbl_init(freq_tbls + i, N_SYMBOLS, p[i], log2_tblsz) != 0) goto fail; - if (tANS_symbol_tbl_init(symbol_tbls + i, freq_tbls + i) != 0) goto fail; - } tANS_rl_encode_st_init(st, symbol_tbls); len = (uint32_t) fread(read_buf, 1, read_sz, input); if (len == 0) break; - memset(p, 0, sizeof p); - - for (i = 0; i < len; ++i) { - ++p[0][read_buf[i]]; - if (read_buf[i] == 0) { - do { - count = 0; - while (i < len - 1 && read_buf[++i] == 0 && count < 255) ++count; - ++p[1][count]; - } while (i < len - 1 && read_buf[i] == 0); - if (i < len) - ++p[2][read_buf[i]]; - } - } - st->x += read_buf[len-1]; bits = tANS_rl_encode(st, read_buf, len - 1, write_buf); u16_len = (uint16_t) (len - 1); @@ -274,11 +278,11 @@ static int tANS_compress_file(FILE* input, FILE *output) read_sz *= 2; if (read_sz > MAX_BUFSZ) read_sz = MAX_BUFSZ; + if (tANS_init_symbol_tbls(symbol_tbls, read_buf, len - 1) != 0) goto fail; } if (ferror(input)) goto fail; - free(freq_tbls); free(symbol_tbls); free(st); free(read_buf); @@ -286,7 +290,6 @@ static int tANS_compress_file(FILE* input, FILE *output) return 0; fail: - free(freq_tbls); free(symbol_tbls); free(st); free(read_buf); @@ -296,24 +299,20 @@ fail: static int tANS_decompress_file(FILE* input, FILE *output) { - uint32_t i, len, bits, count, bits_and_x; + uint32_t len, bits, bits_and_x; uint16_t u16_len; uint8_t *read_buf; uint8_t *write_buf; - double p[N_AUX][N_SYMBOLS] = {0}; - struct tANS_freq_tbl *freq_tbls; struct tANS_symbol_tbl *symbol_tbls; struct tANS_rl_decode_st *st; - const uint16_t log2_tblsz = LOG2_TBLSZ; uint32_t magic; - freq_tbls = malloc(sizeof(struct tANS_freq_tbl) * N_AUX); symbol_tbls = malloc(sizeof(struct tANS_symbol_tbl) * N_AUX); st = malloc(sizeof(struct tANS_rl_decode_st)); read_buf = malloc(tANS_max_compressed_size(MAX_BUFSZ)); write_buf = malloc(MAX_BUFSZ); - if (!freq_tbls || !symbol_tbls || !st || !read_buf || !write_buf) goto fail; + if (!symbol_tbls || !st || !read_buf || !write_buf) goto fail; if (fread(&magic, sizeof magic, 1, input) != 1) goto fail; if (magic != TANS_ONLY_MAGIC) { @@ -321,13 +320,10 @@ static int tANS_decompress_file(FILE* input, FILE *output) goto fail; } + if (tANS_init_symbol_tbls(symbol_tbls, (const uint8_t *) "", 0) != 0) goto fail; + while (!feof(input)) { - for (i = 0; i < N_AUX; ++i) { - if (tANS_freq_tbl_init(freq_tbls + i, N_SYMBOLS, p[i], log2_tblsz) != 0) goto fail; - if (tANS_symbol_tbl_init(symbol_tbls + i, freq_tbls + i) != 0) goto fail; - } tANS_rl_decode_st_init(st, symbol_tbls); - if (fread(&u16_len, sizeof u16_len, 1, input) != 1) break; if (fread(&bits_and_x, sizeof bits_and_x, 1, input) != 1) goto fail; len = (uint32_t) u16_len + 1; @@ -341,23 +337,9 @@ static int tANS_decompress_file(FILE* input, FILE *output) } write_buf[len-1] = (uint8_t) st->x; if (fwrite(write_buf, len, 1, output) != 1) goto fail; - memset(p, 0, sizeof p); - - for (i = 0; i < len; ++i) { - ++p[0][write_buf[i]]; - if (write_buf[i] == 0) { - do { - count = 0; - while (i < len - 1 && write_buf[++i] == 0 && count < 255) ++count; - ++p[1][count]; - } while (i < len - 1 && write_buf[i] == 0); - if (i < len) - ++p[2][write_buf[i]]; - } - } + if (tANS_init_symbol_tbls(symbol_tbls, write_buf, len - 1) != 0) goto fail; } - free(freq_tbls); free(symbol_tbls); free(st); free(read_buf); @@ -365,7 +347,6 @@ static int tANS_decompress_file(FILE* input, FILE *output) return 0; fail: - free(freq_tbls); free(symbol_tbls); free(st); free(read_buf);