From: Geoffrey Allott Date: Sun, 28 Aug 2022 20:56:47 +0000 (+0100) Subject: working tANS with stree implementation X-Git-Url: https://git.pointlesshacks.com/?a=commitdiff_plain;h=c58d772aa68665396a95e56a0875263281cbb419;p=tANS.git working tANS with stree implementation --- diff --git a/Makefile b/Makefile index 1a3dab9..2c1469f 100644 --- a/Makefile +++ b/Makefile @@ -1,4 +1,5 @@ -CFLAGS = -Isrc -g -Wall -Wextra -Wconversion -fsanitize=undefined -fsanitize=address -ftrivial-auto-var-init=pattern +#CFLAGS = -Isrc -g -Wall -Wextra -Wconversion -fsanitize=undefined -fsanitize=address -ftrivial-auto-var-init=pattern +CFLAGS = -Isrc -Wall -Wextra -Wconversion -O3 -DNDEBUG LDFLAGS = -lasan -lubsan -lm BIN = src/tANS @@ -21,6 +22,8 @@ test/test_tANS.o: src/tANS_decode_st.h src/tANS_decode_tbl.h src/tANS_encode_st. test/test_tANS: src/tANS_decode_st.o src/tANS_decode_tbl.o src/tANS_encode_st.o src/tANS_encode_tbl.o src/tANS_symbol_tbl.o src/tANS_freq_tbl.o test/test_stree.o: src/stree.h test/test_stree: src/stree.o +test/test_rle.o: src/rle.h +test/test_rle: src/rle.o clean: rm -f $(OBJS) diff --git a/src/stree.c b/src/stree.c index 61edfa1..81122f3 100644 --- a/src/stree.c +++ b/src/stree.c @@ -162,17 +162,17 @@ static size_t stree_max_size(size_t len) static uint8_t stree_aux(size_t rem) { - if (rem == 0) return 0; + if (rem <= 0) return 0; if (rem <= 1) return 1; if (rem <= 2) return 2; - if (rem <= 4) return 3; - if (rem <= 8) return 4; - if (rem <= 16) return 5; - if (rem <= 32) return 6; + if (rem <= 3) return 3; + if (rem <= 4) return 4; + if (rem <= 5) return 5; + if (rem <= 6) return 6; return 7; } -int stree_encode(size_t len, const uint8_t *in, uint8_t *out, size_t *aux) +int stree_encode(size_t len, const uint8_t *in, uint8_t *out, uint8_t *aux) { struct node *nodes; struct node *root; @@ -265,7 +265,7 @@ int stree_encode(size_t len, const uint8_t *in, uint8_t *out, size_t *aux) return 0; } -int stree_decode(size_t len, const uint8_t *in, uint8_t *out, size_t *aux) +int stree_decode(size_t len, const uint8_t *in, uint8_t *out, uint8_t *aux) { struct node *nodes; struct node *root; diff --git a/src/stree.h b/src/stree.h index ddff5e7..3e62403 100644 --- a/src/stree.h +++ b/src/stree.h @@ -3,5 +3,5 @@ #include #include -int stree_encode(size_t len, const uint8_t *in, uint8_t *out, size_t *aux); -int stree_decode(size_t len, const uint8_t *in, uint8_t *out, size_t *aux); +int stree_encode(size_t len, const uint8_t *in, uint8_t *out, uint8_t *aux); +int stree_decode(size_t len, const uint8_t *in, uint8_t *out, uint8_t *aux); diff --git a/src/tANS.c b/src/tANS.c index 89916cf..e4ce4d4 100644 --- a/src/tANS.c +++ b/src/tANS.c @@ -19,15 +19,23 @@ ref: https://arxiv.org/abs/1311.2540 //FIXME #include -#define MAX_BUFSZ 1048576 +#define MAX_BUFSZ 1048576 +#define N_SYMBOLS 256 +#define N_AUX 3 static void usage(void) { printf( - "usage: tANS [-hcd] [file...]\n" + "usage: tANS [-hcdz] [-S .suf] [file...]\n" "\n" - "Compress the given files using tabled Asymmetric Numeral Systems\n" - "If `-d' is given, decompress instead.\n" + "Compress the given files using suffix trees and tabled Asymmetric Numeral Systems\n" + "\n" + " -h - Show this help text\n" + " -c - Send output to stdout\n" + " -d - Decompress\n" + " -z - Compress (default)\n" + "\n" + " -S .suf - Use the given suffix instead of \".ans\"\n" ); } @@ -36,87 +44,110 @@ static size_t tANS_max_compressed_size(size_t len) return len * 2; } +static void tANS_set_default_probabilities(double p[static 3][256]) +{ + int i; + + p[0][0] = 25.0; + p[0][1] = 15.0; + p[0][2] = 5.0; + p[0][3] = 4.0; + p[0][4] = 3.0; + p[0][5] = 2.0; + p[0][6] = 1.0; + + for (i = 7; i < 256; ++i) + p[0][i] = 0.0; + + p[1][0] = 40.0; + p[1][1] = 20.0; + p[1][2] = 15.0; + p[1][3] = 5.0; + p[1][4] = 4.0; + p[1][5] = 3.0; + p[1][6] = 2.0; + p[1][7] = 1.0; + + for (i = 8; i < 256; ++i) + p[1][i] = 0.0; + + p[2][0] = 0.0; + p[2][1] = 20.0; + p[2][2] = 10.0; + p[2][3] = 5.0; + p[2][4] = 4.0; + p[2][5] = 3.0; + p[2][6] = 2.0; + p[2][7] = 1.0; + + for (i = 8; i < 256; ++i) + p[0][i] = 0.0; +} + static int tANS_compress_file(FILE* input, FILE *output) { uint32_t i, len, bits; uint8_t *read_buf; uint8_t *enc_buf; - size_t *aux_buf; + uint8_t *aux_buf; uint8_t *write_buf; - double p[256] = {0}; - double p_aux[256][256] = {0}; - struct tANS_freq_tbl freq_tbl; - struct tANS_symbol_tbl symbol_tbl; - struct tANS_encode_st st; + double p[N_AUX][N_SYMBOLS] = {0}; + struct tANS_freq_tbl *freq_tbls; + struct tANS_symbol_tbl *symbol_tbls; + struct tANS_rl_encode_st *st; const uint16_t log2_tblsz = 10; - uint32_t read_sz = 8; - - double total_len = 0; - double compressed_bits = 0; - double compressed_len = 0; + uint32_t read_sz = 1024; + uint32_t magic = TANS_MAGIC; + freq_tbls = malloc(sizeof(struct tANS_freq_tbl) * N_AUX); + symbol_tbls = malloc(sizeof(struct tANS_symbol_tbl) * N_AUX); + st = malloc(sizeof(struct tANS_rl_encode_st)); read_buf = malloc(MAX_BUFSZ); enc_buf = malloc(MAX_BUFSZ); - aux_buf = malloc(MAX_BUFSZ * sizeof(size_t)); + aux_buf = malloc(MAX_BUFSZ); write_buf = calloc(tANS_max_compressed_size(MAX_BUFSZ), 1); - if (!read_buf || !enc_buf || !aux_buf || !write_buf) goto fail; + if (!freq_tbls || !symbol_tbls || !st || !read_buf || !enc_buf || !aux_buf || !write_buf) goto fail; + + tANS_set_default_probabilities(p); + + if (fwrite(&magic, sizeof magic, 1, output) != 1) goto fail; while (!feof(input)) { - if (tANS_freq_tbl_init(&freq_tbl, 256, p, log2_tblsz) != 0) goto fail; - if (tANS_symbol_tbl_init(&symbol_tbl, &freq_tbl) != 0) goto fail; - tANS_encode_st_init(&st, &symbol_tbl); + for (i = 0; i < N_AUX; ++i) { + if (tANS_freq_tbl_init(freq_tbls + i, N_SYMBOLS, p[i], log2_tblsz) != 0) goto fail; + if (tANS_symbol_tbl_init(symbol_tbls + i, freq_tbls + i) != 0) goto fail; + } + tANS_rl_encode_st_init(st, symbol_tbls); len = (uint32_t) fread(read_buf, 1, read_sz, input); if (stree_encode(len, read_buf, enc_buf, aux_buf) != 0) goto fail; - for (i = 0; i < len; ++i) ++p[enc_buf[i]]; - for (i = 0; i < len; ++i) ++p_aux[aux_buf[i] > 255 ? 255 : aux_buf[i]][enc_buf[i]]; + for (i = 0; i < len; ++i) { + ++p[0][enc_buf[i]]; + if (enc_buf[i] == 0) { + size_t count = 0; + while (enc_buf[++i] == 0 && count < 255) ++count; + ++p[1][count]; + --i; + } else { + ++p[2][enc_buf[i]]; + } + } if (fwrite(&len, sizeof len, 1, output) != 1) goto fail; - bits = tANS_encode(&st, enc_buf, len, write_buf); + bits = tANS_rl_encode(st, enc_buf, len, write_buf); if (fwrite(&bits, sizeof bits, 1, output) != 1) goto fail; - if (fwrite(&st.x, sizeof st.x, 1, output) != 1) goto fail; + if (fwrite(&st->x, sizeof st->x, 1, output) != 1) goto fail; if (fwrite(write_buf, (bits + 7) / 8, 1, output) != 1) goto fail; memset(write_buf, 0, (bits + 7) / 8); read_sz *= 2; if (read_sz > MAX_BUFSZ) read_sz = MAX_BUFSZ; - - total_len += len; - compressed_bits += bits; - compressed_len += sizeof len; - compressed_len += sizeof bits; - compressed_len += sizeof st.x; - compressed_len += (bits + 7) / 8; - } - - double true_bits_req = 0.0; - - for (i = 0; i < 256; ++i) - fprintf(stderr, "p[%u] = %f\n", i, p[i]); - for (size_t j = 0; j < 8; ++j) { - double H = 0.0; - double p_total = 0.0; - for (i = 0; i < 256; ++i) - p_total += p_aux[j][i]; - for (i = 0; i < 256; ++i) - if (p_aux[j][i] > 0) - H -= log2(p_aux[j][i]/p_total) * (p_aux[j][i]/p_total); - fprintf(stderr, "Η = %12.08f\n", H); - true_bits_req += H * p_total; } - double H = 0.0; - double p_total = 0.0; - for (i = 0; i < 256; ++i) - p_total += p[i]; - fprintf(stderr, "\n"); - H = true_bits_req / p_total; - fprintf(stderr, "Η = %12.08f\n", H); - fprintf(stderr, "ΔΗ_tANS = %12.08f\n", compressed_bits / total_len - H); - fprintf(stderr, "ΔΗ_total = %12.08f\n", compressed_len * 8 / total_len - H); - fprintf(stderr, "compression_ratio = %12.08f\n", compressed_len / total_len); - fprintf(stderr, "Η-optimal ratio = %12.08f\n", H / 8); + free(freq_tbls); + free(symbol_tbls); + free(st); free(read_buf); free(enc_buf); free(aux_buf); @@ -124,6 +155,9 @@ static int tANS_compress_file(FILE* input, FILE *output) return 0; fail: + free(freq_tbls); + free(symbol_tbls); + free(st); free(read_buf); free(enc_buf); free(aux_buf); @@ -136,41 +170,68 @@ static int tANS_decompress_file(FILE* input, FILE *output) uint32_t i, len, bits; uint8_t *read_buf; uint8_t *enc_buf; - size_t *aux_buf; + uint8_t *aux_buf; uint8_t *write_buf; - double p[256] = {0}; - struct tANS_freq_tbl freq_tbl; - struct tANS_symbol_tbl symbol_tbl; - struct tANS_decode_st st; + double p[N_AUX][N_SYMBOLS] = {0}; + struct tANS_freq_tbl *freq_tbls; + struct tANS_symbol_tbl *symbol_tbls; + struct tANS_rl_decode_st *st; const uint16_t log2_tblsz = 10; + uint32_t magic; - read_buf = malloc(MAX_BUFSZ); + freq_tbls = malloc(sizeof(struct tANS_freq_tbl) * N_AUX); + symbol_tbls = malloc(sizeof(struct tANS_symbol_tbl) * N_AUX); + st = malloc(sizeof(struct tANS_rl_encode_st)); + read_buf = malloc(tANS_max_compressed_size(MAX_BUFSZ)); enc_buf = malloc(MAX_BUFSZ); - aux_buf = malloc(MAX_BUFSZ * sizeof(size_t)); - write_buf = malloc(tANS_max_compressed_size(MAX_BUFSZ)); + aux_buf = malloc(MAX_BUFSZ); + write_buf = malloc(MAX_BUFSZ); - if (!read_buf || !enc_buf || !aux_buf || !write_buf) goto fail; + if (!freq_tbls || !symbol_tbls || !st || !read_buf || !enc_buf || !aux_buf || !write_buf) goto fail; + + if (fread(&magic, sizeof magic, 1, input) != 1) goto fail; + if (magic != TANS_MAGIC) { + fprintf(stderr, "tANS: not a valid tANS file\n"); + goto fail; + } + + tANS_set_default_probabilities(p); while (!feof(input)) { - if (tANS_freq_tbl_init(&freq_tbl, 256, p, log2_tblsz) != 0) goto fail; - if (tANS_symbol_tbl_init(&symbol_tbl, &freq_tbl) != 0) goto fail; - tANS_decode_st_init(&st, &symbol_tbl); + for (i = 0; i < N_AUX; ++i) { + if (tANS_freq_tbl_init(freq_tbls + i, N_SYMBOLS, p[i], log2_tblsz) != 0) goto fail; + if (tANS_symbol_tbl_init(symbol_tbls + i, freq_tbls + i) != 0) goto fail; + } + tANS_rl_decode_st_init(st, symbol_tbls); if (fread(&len, sizeof len, 1, input) != 1) goto fail; if (fread(&bits, sizeof bits, 1, input) != 1) goto fail; - if (fread(&st.x, sizeof st.x, 1, input) != 1) goto fail; + if (fread(&st->x, sizeof st->x, 1, input) != 1) goto fail; if (fread(read_buf + 4, (bits + 7) / 8, 1, input) != 1) goto fail; - st.x &= symbol_tbl.tblsz - 1; - bits = tANS_decode(&st, enc_buf, len, read_buf + 4, bits); + st->x &= symbol_tbls[0].tblsz - 1; + bits = tANS_rl_decode(st, enc_buf, len, read_buf + 4, bits); if (bits != 0) { fprintf(stderr, "tANS: corrupted file\n"); goto fail; } if (stree_decode(len, enc_buf, write_buf, aux_buf) != 0) goto fail; if (fwrite(write_buf, len, 1, output) != 1) goto fail; - for (i = 0; i < len; ++i) ++p[enc_buf[i]]; + for (i = 0; i < len; ++i) { + ++p[0][enc_buf[i]]; + if (enc_buf[i] == 0) { + size_t count = 0; + while (enc_buf[++i] == 0 && count < 255) ++count; + ++p[1][count]; + --i; + } else { + ++p[2][enc_buf[i]]; + } + } } + free(freq_tbls); + free(symbol_tbls); + free(st); free(read_buf); free(enc_buf); free(aux_buf); @@ -178,6 +239,9 @@ static int tANS_decompress_file(FILE* input, FILE *output) return 0; fail: + free(freq_tbls); + free(symbol_tbls); + free(st); free(read_buf); free(enc_buf); free(aux_buf); @@ -187,53 +251,70 @@ fail: int main(int argc, char *argv[]) { - int ret, opt, compress = 1; + int ret, opt, to_stdout = 0, compress = 1; FILE *input = stdin, *output = stdout; + const char *suffix = ".ans"; + char outpath[1024]; - while ((opt = getopt(argc, argv, "hcd")) != -1) { + while ((opt = getopt(argc, argv, "hcdzS:")) != -1) { switch (opt) { case 'h': usage(); return 0; case 'c': - compress = 1; + to_stdout = 1; break; case 'd': compress = 0; break; + case 'z': + compress = 1; + break; + case 'S': + suffix = optarg; + break; } } argv += optind; argc -= optind; - if (argc > 2) - return 3; - - if (argc >= 1) { - input = fopen(argv[0], "rb"); - if (!input) { - fprintf(stderr, "tANS: fopen: %s: %s\n", argv[0], strerror(errno)); - return 2; + if (argc == 0) { + if (compress) { + return tANS_compress_file(input, output) != 0; + } else { + return tANS_decompress_file(input, output) != 0; } - } + } else { + for (; argc >= 1; --argc, ++argv) { + input = fopen(argv[0], "rb"); + if (!input) { + fprintf(stderr, "tANS: fopen: %s: %s\n", argv[0], strerror(errno)); + return 2; + } + + if (!to_stdout) { + strncpy(outpath, argv[0], sizeof outpath - 1); + strncpy(outpath + strlen(outpath), suffix, sizeof outpath - 1 - strlen(outpath)); + output = fopen(outpath, "wb"); + if (!output) { + fprintf(stderr, "tANS: fopen: %s: %s\n", argv[0], strerror(errno)); + return 2; + } + } + + if (compress) { + ret = tANS_compress_file(input, output); + } else { + ret = tANS_decompress_file(input, output); + } - if (argc >= 2) { - output = fopen(argv[1], "wb"); - if (!output) { - fprintf(stderr, "tANS: fopen: %s: %s\n", argv[0], strerror(errno)); fclose(input); - return 2; - } - } + if (!to_stdout) fclose(output); - if (compress) { - ret = tANS_compress_file(input, output) == 0; - } else { - ret = tANS_decompress_file(input, output) == 0; + if (ret != 0) return 1; + } } - fclose(input); - fclose(output); - return ret; + return 0; } diff --git a/src/tANS_constants.h b/src/tANS_constants.h index 4a69e34..462c18b 100644 --- a/src/tANS_constants.h +++ b/src/tANS_constants.h @@ -1,5 +1,7 @@ #pragma once +#define TANS_MAGIC 0xfac0162a #define TANS_LOG2_MAX_TBLSZ 12 #define TANS_MAX_TBLSZ (1 << TANS_LOG2_MAX_TBLSZ) -#define TANS_MAX_SYMBOLS 256 +#define TANS_MAX_SYMBOLS 1024 +#define TANS_MULTI_MAX_AUX 8 diff --git a/src/tANS_decode_st.c b/src/tANS_decode_st.c index 6f3c4b2..23e3503 100644 --- a/src/tANS_decode_st.c +++ b/src/tANS_decode_st.c @@ -26,9 +26,96 @@ uint32_t tANS_decode(struct tANS_decode_st *self, uint8_t *data, uint32_t len, u value >>= (24 + bit - t.nb_bits) & 31; value &= (uint32_t) ((1 << t.nb_bits) - 1); self->x = (uint16_t) (t.new_x + value); - data[len-i-1] = t.symbol; + data[len-i-1] = (uint8_t) t.symbol; bits -= t.nb_bits; } return bits; } + +int tANS_multi_decode_st_init(struct tANS_multi_decode_st *self, struct tANS_symbol_tbl *symbol_tbls, uint16_t n_aux) +{ + uint16_t i; + + if (n_aux == 0 || n_aux > TANS_MULTI_MAX_AUX) return -1; + + for (i = 0; i < n_aux; ++i) { + tANS_decode_tbl_init(self->decode_tbls + i, symbol_tbls + i); + } + self->n_aux = n_aux; + self->x = 0; + + return 0; +} + +uint32_t tANS_multi_decode(struct tANS_multi_decode_st *self, uint8_t *data, uint32_t len, uint8_t *buf, uint32_t bits, const uint8_t *aux) +{ + uint32_t i; + struct tANS_decode_tbl_entry t; + uint32_t bit, byte, value; + + for (i = 0; i < len; ++i) { + t = self->decode_tbls[aux[i]].entries[self->x]; + bit = bits & 7; + bit += (uint32_t) 8 * (bit == 0); + byte = (uint32_t) ((bits + 7) >> 3); + value = get_uint32(buf - 4 + byte); + value >>= (24 + bit - t.nb_bits) & 31; + value &= (uint32_t) ((1 << t.nb_bits) - 1); + self->x = (uint16_t) (t.new_x + value); + data[len-i-1] = (uint8_t) t.symbol; + bits -= t.nb_bits; + } + + return bits; +} + +void tANS_rl_decode_st_init(struct tANS_rl_decode_st *self, struct tANS_symbol_tbl symbol_tbls[static 3]) +{ + tANS_decode_tbl_init(self->decode_tbls + 0, symbol_tbls + 0); + tANS_decode_tbl_init(self->decode_tbls + 1, symbol_tbls + 1); + tANS_decode_tbl_init(self->decode_tbls + 2, symbol_tbls + 2); + + self->x = 0; +} + +enum rl_aux { + rl_aux_symbol = 0, + rl_aux_len = 1, + rl_aux_nonzero = 2, +}; + +uint32_t tANS_rl_decode(struct tANS_rl_decode_st *self, uint8_t *data, uint32_t len, uint8_t *buf, uint32_t bits) +{ + uint32_t i, j; + struct tANS_decode_tbl_entry t; + uint32_t bit, byte, value; + enum rl_aux aux = rl_aux_symbol; + + for (i = 0; i < len; ++i) { + t = self->decode_tbls[aux].entries[self->x]; + bit = bits & 7; + bit += (uint32_t) 8 * (bit == 0); + byte = (uint32_t) ((bits + 7) >> 3); + value = get_uint32(buf - 4 + byte); + value >>= (24 + bit - t.nb_bits) & 31; + value &= (uint32_t) ((1 << t.nb_bits) - 1); + self->x = (uint16_t) (t.new_x + value); + bits -= t.nb_bits; + if (aux == rl_aux_symbol && t.symbol == 0) { + data[len-i-1] = 0; + aux = rl_aux_len; + } else if (aux == rl_aux_len) { + for (j = 0; j < t.symbol; ++j) + data[len-i-1-j] = 0; + i += t.symbol; + --i; + aux = t.symbol == 255 ? rl_aux_len : rl_aux_nonzero; + } else { /* if (aux == rl_nonzero || aux == rl_aux_symbol && t.symbol != 0) */ + data[len-i-1] = (uint8_t) t.symbol; + aux = rl_aux_symbol; + } + } + + return bits; +} diff --git a/src/tANS_decode_st.h b/src/tANS_decode_st.h index 67cd065..4d9d4a9 100644 --- a/src/tANS_decode_st.h +++ b/src/tANS_decode_st.h @@ -9,3 +9,20 @@ struct tANS_decode_st { void tANS_decode_st_init(struct tANS_decode_st *self, struct tANS_symbol_tbl *symbol_tbl); uint32_t tANS_decode(struct tANS_decode_st *self, uint8_t *data, uint32_t len, uint8_t *buf, uint32_t bits); + +struct tANS_multi_decode_st { + struct tANS_decode_tbl decode_tbls[TANS_MULTI_MAX_AUX]; + uint16_t n_aux; + uint16_t x; +}; + +int tANS_multi_decode_st_init(struct tANS_multi_decode_st *self, struct tANS_symbol_tbl *symbol_tbls, uint16_t n_aux); +uint32_t tANS_multi_decode(struct tANS_multi_decode_st *self, uint8_t *data, uint32_t len, uint8_t *buf, uint32_t bits, const uint8_t *aux); + +struct tANS_rl_decode_st { + struct tANS_decode_tbl decode_tbls[3]; + uint16_t x; +}; + +void tANS_rl_decode_st_init(struct tANS_rl_decode_st *self, struct tANS_symbol_tbl symbol_tbls[static 3]); +uint32_t tANS_rl_decode(struct tANS_rl_decode_st *self, uint8_t *data, uint32_t len, uint8_t *buf, uint32_t bits); diff --git a/src/tANS_decode_tbl.c b/src/tANS_decode_tbl.c index 4f813c6..26d4de6 100644 --- a/src/tANS_decode_tbl.c +++ b/src/tANS_decode_tbl.c @@ -6,14 +6,14 @@ void tANS_decode_tbl_init(struct tANS_decode_tbl *self, struct tANS_symbol_tbl * { uint16_t i; uint16_t x; - uint8_t s; + uint16_t s; self->tblsz = symbol_tbl->tblsz; for (i = 0; i < self->tblsz; ++i) { s = self->entries[i].symbol = symbol_tbl->symbol[i]; x = symbol_tbl->entries[s].next++; - self->entries[i].nb_bits = (uint8_t) (symbol_tbl->log2_tblsz - floor_log2(x)); + self->entries[i].nb_bits = (uint16_t) (symbol_tbl->log2_tblsz - floor_log2(x)); self->entries[i].new_x = (uint16_t) ((x << self->entries[i].nb_bits) - self->tblsz); } } diff --git a/src/tANS_decode_tbl.h b/src/tANS_decode_tbl.h index 734d0f4..babfe07 100644 --- a/src/tANS_decode_tbl.h +++ b/src/tANS_decode_tbl.h @@ -3,8 +3,8 @@ #include "tANS_symbol_tbl.h" struct tANS_decode_tbl_entry { - uint8_t symbol; - uint8_t nb_bits; + uint16_t symbol; + uint16_t nb_bits; uint16_t new_x; }; diff --git a/src/tANS_encode_st.c b/src/tANS_encode_st.c index 763c284..21a0aca 100644 --- a/src/tANS_encode_st.c +++ b/src/tANS_encode_st.c @@ -22,10 +22,10 @@ static inline void set_uint32(uint8_t buf[static 4], uint32_t value) uint32_t tANS_encode(struct tANS_encode_st *self, uint8_t *data, uint32_t len, uint8_t *buf) { - uint8_t nb_bits; + uint16_t nb_bits; uint32_t i, written = 0; uint32_t bit, byte, value; - uint8_t symbol; + uint16_t symbol; for (i = 0; i < len; ++i) { symbol = data[i]; @@ -43,3 +43,121 @@ uint32_t tANS_encode(struct tANS_encode_st *self, uint8_t *data, uint32_t len, u return written; } + +int tANS_multi_encode_st_init(struct tANS_multi_encode_st *self, const struct tANS_symbol_tbl *symbol_tbls, uint16_t n_aux) +{ + uint16_t i; + + if (n_aux == 0 || n_aux > TANS_MULTI_MAX_AUX) return -1; + + for (i = 0; i < n_aux; ++i) { + self->symbol_tbls[i] = symbol_tbls[i]; + tANS_encode_tbl_init(self->encode_tbls + i, self->symbol_tbls + i); + } + self->n_aux = n_aux; + self->x = self->encode_tbls[0].tblsz; + + return 0; +} + +uint32_t tANS_multi_encode(struct tANS_multi_encode_st *self, uint8_t *data, uint32_t len, uint8_t *buf, const uint8_t *aux) +{ + uint16_t nb_bits; + uint32_t i, written = 0; + uint32_t bit, byte, value; + uint16_t symbol; + const struct tANS_symbol_tbl *sym_tbl; + const struct tANS_encode_tbl *enc_tbl; + + for (i = 0; i < len; ++i) { + symbol = data[i]; + sym_tbl = self->symbol_tbls + aux[i]; + enc_tbl = self->encode_tbls + aux[i]; + nb_bits = (uint8_t) ((self->x + sym_tbl->entries[symbol].nb) >> (sym_tbl->log2_tblsz + 1)); + bit = written & 7; + byte = (uint32_t) (written >> 3); + value = (uint32_t) self->x; + value &= (uint32_t) ((1 << nb_bits) - 1); + value <<= bit; + value |= get_uint32(buf + byte); + set_uint32(buf + byte, value); + written += nb_bits; + self->x = enc_tbl->entries[(uint16_t) (sym_tbl->entries[symbol].start + (self->x >> nb_bits))].x; + } + + return written; +} + +void tANS_rl_encode_st_init(struct tANS_rl_encode_st *self, const struct tANS_symbol_tbl symbol_tbls[static 3]) +{ + uint16_t i; + + for (i = 0; i < 3; ++i) { + self->symbol_tbls[i] = symbol_tbls[i]; + tANS_encode_tbl_init(self->encode_tbls + i, self->symbol_tbls + i); + } + self->x = self->encode_tbls[0].tblsz; +} + +enum rl_aux { + rl_aux_symbol = 0, + rl_aux_len = 1, + rl_aux_nonzero = 2, +}; + +uint32_t tANS_rl_encode(struct tANS_rl_encode_st *self, uint8_t *data, uint32_t len, uint8_t *buf) +{ + uint16_t nb_bits; + uint32_t i, written = 0, count = 0; + uint32_t bit, byte, value; + uint16_t symbol; + const struct tANS_symbol_tbl *sym_tbl; + const struct tANS_encode_tbl *enc_tbl; + enum rl_aux aux = rl_aux_symbol; + int first = 1; + + for (i = 0; i < len; ++i) { + if ((i < len - 1 && data[i] != 0 && data[i+1] != 0) || (aux == rl_aux_len && count == 0)) { + aux = rl_aux_symbol; + symbol = data[i]; + } else if (aux == rl_aux_len) { + aux = rl_aux_len; + --i; + symbol = 255; + count -= symbol; + } else if (data[i] == 0) { + aux = rl_aux_len; + count = 0; + while (++i, i < len && data[i] == 0) ++count; + i -= 2; + if (first && count % 255 == 0 && count > 0) { + symbol = 255; + } else if (first && count == 0) { + ++i; + symbol = 0; + aux = rl_aux_symbol; + } else { + symbol = (uint8_t) (count % 255); + } + count -= symbol; + } else { + aux = i < len - 1 ? rl_aux_nonzero : rl_aux_symbol; + symbol = data[i]; + } + first = 0; + sym_tbl = self->symbol_tbls + aux; + enc_tbl = self->encode_tbls + aux; + nb_bits = (uint8_t) ((self->x + sym_tbl->entries[symbol].nb) >> (sym_tbl->log2_tblsz + 1)); + bit = written & 7; + byte = (uint32_t) (written >> 3); + value = (uint32_t) self->x; + value &= (uint32_t) ((1 << nb_bits) - 1); + value <<= bit; + value |= get_uint32(buf + byte); + set_uint32(buf + byte, value); + written += nb_bits; + self->x = enc_tbl->entries[(uint16_t) (sym_tbl->entries[symbol].start + (self->x >> nb_bits))].x; + } + + return written; +} diff --git a/src/tANS_encode_st.h b/src/tANS_encode_st.h index 8e3eed9..86478fe 100644 --- a/src/tANS_encode_st.h +++ b/src/tANS_encode_st.h @@ -10,3 +10,22 @@ struct tANS_encode_st { void tANS_encode_st_init(struct tANS_encode_st *self, const struct tANS_symbol_tbl *symbol_tbl); uint32_t tANS_encode(struct tANS_encode_st *self, uint8_t *data, uint32_t len, uint8_t *buf); + +struct tANS_multi_encode_st { + struct tANS_symbol_tbl symbol_tbls[TANS_MULTI_MAX_AUX]; + struct tANS_encode_tbl encode_tbls[TANS_MULTI_MAX_AUX]; + uint16_t n_aux; + uint16_t x; +}; + +int tANS_multi_encode_st_init(struct tANS_multi_encode_st *self, const struct tANS_symbol_tbl *symbol_tbls, uint16_t n_aux); +uint32_t tANS_multi_encode(struct tANS_multi_encode_st *self, uint8_t *data, uint32_t len, uint8_t *buf, const uint8_t *aux); + +struct tANS_rl_encode_st { + struct tANS_symbol_tbl symbol_tbls[3]; + struct tANS_encode_tbl encode_tbls[3]; + uint16_t x; +}; + +void tANS_rl_encode_st_init(struct tANS_rl_encode_st *self, const struct tANS_symbol_tbl symbol_tbls[static 3]); +uint32_t tANS_rl_encode(struct tANS_rl_encode_st *self, uint8_t *data, uint32_t len, uint8_t *buf); diff --git a/src/tANS_encode_tbl.c b/src/tANS_encode_tbl.c index 7be8889..8ad4899 100644 --- a/src/tANS_encode_tbl.c +++ b/src/tANS_encode_tbl.c @@ -2,7 +2,7 @@ void tANS_encode_tbl_init(struct tANS_encode_tbl *self, struct tANS_symbol_tbl *symbol_tbl) { - uint8_t s; + uint16_t s; uint16_t i, x; self->tblsz = symbol_tbl->tblsz; diff --git a/src/tANS_symbol_tbl.c b/src/tANS_symbol_tbl.c index 22a4b8d..dfdec5f 100644 --- a/src/tANS_symbol_tbl.c +++ b/src/tANS_symbol_tbl.c @@ -22,7 +22,7 @@ int tANS_symbol_tbl_init(struct tANS_symbol_tbl *self, const struct tANS_freq_tb self->entries[s].nb = (uint32_t) ((k << (self->log2_tblsz + 1)) - (freq << k)); start += freq; for (i = 0; i < freq; ++i) { - self->symbol[x] = (uint8_t) s; + self->symbol[x] = (uint16_t) s; x += step; x &= self->tblsz - 1; } diff --git a/src/tANS_symbol_tbl.h b/src/tANS_symbol_tbl.h index 78c06e9..d1d35bf 100644 --- a/src/tANS_symbol_tbl.h +++ b/src/tANS_symbol_tbl.h @@ -13,7 +13,7 @@ struct tANS_symbol_tbl { uint16_t n_symbols; uint16_t log2_tblsz; uint16_t tblsz; - uint8_t symbol[TANS_MAX_TBLSZ]; + uint16_t symbol[TANS_MAX_TBLSZ]; struct tANS_symbol_tbl_entry entries[TANS_MAX_SYMBOLS]; }; diff --git a/test/test_stree.c b/test/test_stree.c index acae1ca..5ba1575 100644 --- a/test/test_stree.c +++ b/test/test_stree.c @@ -6,7 +6,7 @@ static enum test_result test_stree_encode_empty(void) { const uint8_t *in = (const uint8_t *) ""; uint8_t out[1]; - size_t aux[1]; + uint8_t aux[1]; ASSERT_EQ(0, stree_encode(0, in, out, aux)); return TEST_SUCCESS; } @@ -15,7 +15,7 @@ static enum test_result test_stree_encode_simple(void) { const uint8_t *in = (const uint8_t *) "abc"; uint8_t out[3]; - size_t aux[3]; + uint8_t aux[3]; ASSERT_EQ(0, stree_encode(3, in, out, aux)); ASSERT_EQ('a', out[0]); @@ -33,7 +33,7 @@ static enum test_result test_stree_encode_nontrivial(void) { const uint8_t *in = (const uint8_t *) "abaaa"; uint8_t out[5]; - size_t aux[5]; + uint8_t aux[5]; ASSERT_EQ(0, stree_encode(5, in, out, aux)); ASSERT_EQ('a', out[0]); @@ -49,7 +49,7 @@ static enum test_result test_stree_encode_so_example(void) { const uint8_t *in = (const uint8_t *) "abcabxabcd"; uint8_t out[10]; - size_t aux[10]; + uint8_t aux[10]; ASSERT_EQ(0, stree_encode(10, in, out, aux)); ASSERT_EQ('a', out[0]); @@ -70,7 +70,7 @@ static enum test_result test_stree_tricky_suffix_link(void) { const uint8_t *in = (const uint8_t *) "cdddcdc"; uint8_t out[7]; - size_t aux[7]; + uint8_t aux[7]; ASSERT_EQ(0, stree_encode(7, in, out, aux)); ASSERT_EQ('c', out[0]); @@ -88,7 +88,7 @@ static enum test_result test_stree_minimal(void) { const uint8_t *in = (const uint8_t *) "abcdeabacacabb"; uint8_t out[14]; - size_t aux[14]; + uint8_t aux[14]; ASSERT_EQ(0, stree_encode(14, in, out, aux)); return TEST_SUCCESS; @@ -98,7 +98,7 @@ static enum test_result test_stree_minimal_2(void) { const uint8_t *in = (const uint8_t *) "abcdeabacacabbabcdcccacc"; uint8_t out[24]; - size_t aux[24]; + uint8_t aux[24]; ASSERT_EQ(0, stree_encode(24, in, out, aux)); return TEST_SUCCESS; @@ -108,7 +108,7 @@ static enum test_result test_stree_minimal_3(void) { const uint8_t *in = (const uint8_t *) "abcdeabacacabbabcdcccaccccaa"; uint8_t out[28]; - size_t aux[28]; + uint8_t aux[28]; ASSERT_EQ(0, stree_encode(28, in, out, aux)); return TEST_SUCCESS; @@ -118,7 +118,7 @@ static enum test_result test_stree_minimal_4(void) { const uint8_t *in = (const uint8_t *) "dbbcaccbdbdbde"; uint8_t out[14]; - size_t aux[14]; + uint8_t aux[14]; ASSERT_EQ(0, stree_encode(14, in, out, aux)); return TEST_SUCCESS; @@ -129,8 +129,8 @@ static enum test_result test_stree_minimal_5(void) const uint8_t *in = (const uint8_t *) "acaaeabdecd"; uint8_t enc[11]; uint8_t dec[11]; - size_t aux1[11]; - size_t aux2[11]; + uint8_t aux1[11]; + uint8_t aux2[11]; ASSERT_EQ(0, stree_encode(11, in, enc, aux1)); ASSERT_EQ(0, stree_decode(11, enc, dec, aux2)); @@ -166,8 +166,8 @@ static enum test_result test_stree_minimal_6(void) const uint8_t *in = (const uint8_t *) "ccdebdaabddaeeadeeccacbecdddaaddcccdebbddaeccbdaeebbdcaaeaadadda"; uint8_t enc[64]; uint8_t dec[64]; - size_t aux1[64]; - size_t aux2[64]; + uint8_t aux1[64]; + uint8_t aux2[64]; size_t i; ASSERT_EQ(0, stree_encode(44, in, enc, aux1)); @@ -185,7 +185,7 @@ static enum test_result test_stree_repeating(void) { const uint8_t *in = (const uint8_t *) "abcabcabcabcabcdabcabcabcabcabcdababababab"; uint8_t out[42]; - size_t aux[42]; + uint8_t aux[42]; ASSERT_EQ(0, stree_encode(42, in, out, aux)); return TEST_SUCCESS; @@ -195,7 +195,7 @@ static enum test_result test_stree_long(void) { const uint8_t *in = (const uint8_t *) "abcdeabacacabbabcdcccaccccaabbbaababdadbaccabbdadbadbabaccacbbbcbadddbdababddddddabdabddddddabbbbccc"; uint8_t out[100]; - size_t aux[100]; + uint8_t aux[100]; ASSERT_EQ(0, stree_encode(100, in, out, aux)); return TEST_SUCCESS; @@ -1231,8 +1231,8 @@ static enum test_result test_stree_very_long(void) ; uint8_t enc[65536]; uint8_t dec[65536]; - size_t aux1[65536]; - size_t aux2[65536]; + uint8_t aux1[65536]; + uint8_t aux2[65536]; size_t i, j; for (i = 0; i < 1024; ++i) { @@ -1259,7 +1259,7 @@ static enum test_result test_stree_decode_empty(void) { const uint8_t *in = (const uint8_t *) ""; uint8_t out[1]; - size_t aux[1]; + uint8_t aux[1]; ASSERT_EQ(0, stree_decode(0, in, out, aux)); return TEST_SUCCESS; } @@ -1268,7 +1268,7 @@ static enum test_result test_stree_decode_simple(void) { const uint8_t *in = (const uint8_t *) "abc"; uint8_t out[3]; - size_t aux[3]; + uint8_t aux[3]; ASSERT_EQ(0, stree_decode(3, in, out, aux)); ASSERT_EQ('a', out[0]); @@ -1286,7 +1286,7 @@ static enum test_result test_stree_decode_nontrivial(void) { const uint8_t *in = (const uint8_t *) "ab\1\1\1"; uint8_t out[5]; - size_t aux[5]; + uint8_t aux[5]; ASSERT_EQ(0, stree_decode(5, in, out, aux)); ASSERT_EQ('a', out[0]); @@ -1303,8 +1303,8 @@ static enum test_result test_stree_roundtrip_so_example(void) const uint8_t *in = (const uint8_t *) "abcabxabcd"; uint8_t enc[10]; uint8_t dec[10]; - size_t aux1[10]; - size_t aux2[10]; + uint8_t aux1[10]; + uint8_t aux2[10]; ASSERT_EQ(0, stree_encode(10, in, enc, aux1)); ASSERT_EQ(0, stree_decode(10, enc, dec, aux2)); @@ -1338,8 +1338,8 @@ static enum test_result test_stree_roundtrip_iterated(void) const uint8_t *in = (const uint8_t *) "1231313223"; uint8_t enc[10]; uint8_t dec[10]; - size_t aux1[10]; - size_t aux2[10]; + uint8_t aux1[10]; + uint8_t aux2[10]; size_t i; ASSERT_EQ(0, stree_encode(10, in, enc, aux1)); diff --git a/test/test_tANS.c b/test/test_tANS.c index 6e1d919..b37356c 100644 --- a/test/test_tANS.c +++ b/test/test_tANS.c @@ -21,7 +21,7 @@ enum test_result test_tANS_encode_equal_freq(void) ASSERT_EQ(tANS_symbol_tbl_init(&symbol_tbl, &freq_tbl), 0); tANS_encode_st_init(&encode_st, &symbol_tbl); - ASSERT_EQ(tANS_encode(&encode_st, data, sizeof data, buf + 4), sizeof data * 8); + ASSERT_EQ(tANS_encode(&encode_st, data, 8, buf + 4), 8 * 8); return TEST_SUCCESS; @@ -45,7 +45,7 @@ enum test_result test_tANS_encode_high_zero_probability(void) ASSERT_EQ(tANS_symbol_tbl_init(&symbol_tbl, &freq_tbl), 0); tANS_encode_st_init(&encode_st, &symbol_tbl); - ASSERT_LT(tANS_encode(&encode_st, data, sizeof data, buf + 4), 24); + ASSERT_LT(tANS_encode(&encode_st, data, 8, buf + 4), 24); return TEST_SUCCESS; } @@ -69,13 +69,13 @@ enum test_result test_tANS_encode_decode_equal_freq(void) ASSERT_EQ(tANS_symbol_tbl_init(&symbol_tbl, &freq_tbl), 0); tANS_encode_st_init(&encode_st, &symbol_tbl); - ASSERT_EQ(tANS_encode(&encode_st, data, sizeof data, buf + 4), 64); + ASSERT_EQ(tANS_encode(&encode_st, data, 8, buf + 4), 64); tANS_decode_st_init(&decode_st, &symbol_tbl); decode_st.x = (uint16_t) (encode_st.x - (1 << log2_tblsz)); - ASSERT_EQ(tANS_decode(&decode_st, rec, sizeof rec, buf + 4, 64), 0); + ASSERT_EQ(tANS_decode(&decode_st, rec, 8, buf + 4, 64), 0); - for (i = 0; i < sizeof rec; ++i) { + for (i = 0; i < 8; ++i) { ASSERT_EQ(data[i], rec[i]); } @@ -104,13 +104,13 @@ enum test_result test_tANS_encode_decode_high_zero_probability(void) tANS_encode_st_init(&encode_st, &symbol_tbl); tANS_decode_st_init(&decode_st, &symbol_tbl); - bits = tANS_encode(&encode_st, data, sizeof data, buf + 4); + bits = tANS_encode(&encode_st, data, 8, buf + 4); decode_st.x = (uint16_t) (encode_st.x - (1 << log2_tblsz)); - ASSERT_EQ(tANS_decode(&decode_st, rec, sizeof rec, buf + 4, bits), 0); + ASSERT_EQ(tANS_decode(&decode_st, rec, 8, buf + 4, bits), 0); - for (i = 0; i < sizeof rec; ++i) { + for (i = 0; i < 8; ++i) { ASSERT_EQ(data[i], rec[i]); } @@ -143,12 +143,183 @@ enum test_result test_tANS_encode_decode_long_stream(void) data[i] = (uint8_t) (i % 4 == 3 ? i / 4 : 0); } - bits = (uint32_t) tANS_encode(&encode_st, data, sizeof data, buf + 4); + bits = (uint32_t) tANS_encode(&encode_st, data, 65536, buf + 4); decode_st.x = (uint16_t) (encode_st.x - (1 << log2_tblsz)); - ASSERT_EQ(tANS_decode(&decode_st, rec, sizeof rec, buf + 4, bits), 0); + ASSERT_EQ(tANS_decode(&decode_st, rec, 65536, buf + 4, bits), 0); - for (i = 0; i < sizeof rec; ++i) { + for (i = 0; i < 65536; ++i) { + ASSERT_EQ(data[i], rec[i]); + } + + return TEST_SUCCESS; +} + +enum test_result test_tANS_multi_encode_decode_long_stream(void) +{ + struct tANS_freq_tbl freq_tbls[3]; + struct tANS_symbol_tbl symbol_tbls[3]; + struct tANS_multi_encode_st encode_st; + struct tANS_multi_decode_st decode_st; + uint8_t data[65536]; + uint8_t aux[65536]; + double p[256]; + uint8_t rec[65536]; + uint8_t buf[4 + 32768] = {0}; + uint16_t n_symbols = 256; + uint16_t log2_tblsz = 10; + uint16_t n_aux = 3; + uint32_t i; + uint32_t bits; + + p[0] = 0.75; + for (i = 1; i < n_symbols; ++i) p[i] = 0.25 / n_symbols; + for (i = 0; i < n_aux; ++i) { + ASSERT_NE(tANS_freq_tbl_init(freq_tbls + i, n_symbols, p, log2_tblsz), -1); + ASSERT_EQ(tANS_symbol_tbl_init(symbol_tbls + i, freq_tbls + i), 0); + } + tANS_multi_encode_st_init(&encode_st, symbol_tbls, n_aux); + tANS_multi_decode_st_init(&decode_st, symbol_tbls, n_aux); + + for (i = 0; i < 65536; ++i) { + data[i] = (uint8_t) (i % 4 == 3 ? i / 4 : 0); + aux[i] = (uint8_t) (i % 3); + } + + bits = (uint32_t) tANS_multi_encode(&encode_st, data, 65536, buf + 4, aux); + decode_st.x = (uint16_t) (encode_st.x - (1 << log2_tblsz)); + + ASSERT_EQ(tANS_multi_decode(&decode_st, rec, 65536, buf + 4, bits, aux), 0); + + for (i = 0; i < 65536; ++i) { + ASSERT_EQ(data[i], rec[i]); + } + + return TEST_SUCCESS; +} + +enum test_result test_tANS_rl_encode_decode_long_stream(void) +{ + struct tANS_freq_tbl freq_tbls[3]; + struct tANS_symbol_tbl symbol_tbls[3]; + struct tANS_rl_encode_st encode_st; + struct tANS_rl_decode_st decode_st; + uint8_t data[65536]; + double p[3][256]; + uint8_t rec[65536]; + uint8_t buf[4 + 65536] = {0}; + uint16_t n_symbols = 256; + uint16_t log2_tblsz = 10; + uint32_t i; + uint32_t bits; + + p[0][0] = 0.5; + p[1][0] = 0.1; + p[2][0] = 0.0; + for (i = 1; i < n_symbols; ++i) p[0][i] = p[1][i] = p[2][i] = 0.25 / n_symbols; + for (i = 0; i < 3; ++i) { + ASSERT_NE(tANS_freq_tbl_init(freq_tbls + i, n_symbols, p[i], log2_tblsz), -1); + ASSERT_EQ(tANS_symbol_tbl_init(symbol_tbls + i, freq_tbls + i), 0); + } + tANS_rl_encode_st_init(&encode_st, symbol_tbls); + tANS_rl_decode_st_init(&decode_st, symbol_tbls); + + for (i = 0; i < 65536; ++i) { + data[i] = (uint8_t) (i % 4 == 3 ? i / 4 : 0); + } + + bits = (uint32_t) tANS_rl_encode(&encode_st, data, 65536, buf + 4); + decode_st.x = (uint16_t) (encode_st.x - (1 << log2_tblsz)); + + ASSERT_EQ(tANS_rl_decode(&decode_st, rec, 65536, buf + 4, bits), 0); + + for (i = 0; i < 65536; ++i) { + ASSERT_EQ(data[i], rec[i]); + } + + return TEST_SUCCESS; +} + +enum test_result test_tANS_rl_encode_decode_zero(void) +{ + struct tANS_freq_tbl freq_tbls[3]; + struct tANS_symbol_tbl symbol_tbls[3]; + struct tANS_rl_encode_st encode_st; + struct tANS_rl_decode_st decode_st; + uint8_t data[1024]; + double p[3][256]; + uint8_t rec[1024]; + uint8_t buf[4 + 1024] = {0}; + uint16_t n_symbols = 256; + uint16_t log2_tblsz = 10; + uint32_t i, len, bits; + + p[0][0] = 0.5; + p[1][0] = 0.1; + p[2][0] = 0.0; + for (i = 1; i < n_symbols; ++i) p[0][i] = p[1][i] = p[2][i] = 0.25 / n_symbols; + for (i = 0; i < 3; ++i) { + ASSERT_NE(tANS_freq_tbl_init(freq_tbls + i, n_symbols, p[i], log2_tblsz), -1); + ASSERT_EQ(tANS_symbol_tbl_init(symbol_tbls + i, freq_tbls + i), 0); + } + tANS_rl_encode_st_init(&encode_st, symbol_tbls); + tANS_rl_decode_st_init(&decode_st, symbol_tbls); + + for (i = 0; i < 1024; ++i) { + data[i] = (uint8_t) 0; + } + + for (len = 0; len < 1024; ++len) { + for (i = 0; i < sizeof buf; ++i) { + buf[i] = 0; + } + bits = (uint32_t) tANS_rl_encode(&encode_st, data, len, buf + 4); + decode_st.x = (uint16_t) (encode_st.x - (1 << log2_tblsz)); + + ASSERT_EQ(tANS_rl_decode(&decode_st, rec, len, buf + 4, bits), 0); + + for (i = 0; i < len; ++i) { + ASSERT_EQ(data[i], rec[i]); + } + } + + return TEST_SUCCESS; +} + +enum test_result test_tANS_rl_encode_decode_binary_example(void) +{ + struct tANS_freq_tbl freq_tbls[3]; + struct tANS_symbol_tbl symbol_tbls[3]; + struct tANS_rl_encode_st encode_st; + struct tANS_rl_decode_st decode_st; + uint8_t data[2] = { 0x00, 0x90, }; + double p[3][256]; + uint8_t rec[2]; + uint8_t buf[16] = {0}; + uint16_t n_symbols = 256; + uint16_t log2_tblsz = 10; + uint32_t i, bits; + + p[0][0] = 0.5; + p[1][0] = 0.1; + p[2][0] = 0.0; + for (i = 1; i < n_symbols; ++i) p[0][i] = p[1][i] = p[2][i] = 0.25 / n_symbols; + for (i = 0; i < 3; ++i) { + ASSERT_NE(tANS_freq_tbl_init(freq_tbls + i, n_symbols, p[i], log2_tblsz), -1); + ASSERT_EQ(tANS_symbol_tbl_init(symbol_tbls + i, freq_tbls + i), 0); + } + tANS_rl_encode_st_init(&encode_st, symbol_tbls); + tANS_rl_decode_st_init(&decode_st, symbol_tbls); + + for (i = 0; i < sizeof buf; ++i) { + buf[i] = 0; + } + bits = (uint32_t) tANS_rl_encode(&encode_st, data, 2, buf + 4); + decode_st.x = (uint16_t) (encode_st.x - (1 << log2_tblsz)); + + ASSERT_EQ(tANS_rl_decode(&decode_st, rec, 2, buf + 4, bits), 0); + + for (i = 0; i < 2; ++i) { ASSERT_EQ(data[i], rec[i]); } @@ -162,4 +333,8 @@ int main(void) RUN_TEST(test_tANS_encode_decode_equal_freq); RUN_TEST(test_tANS_encode_decode_high_zero_probability); RUN_TEST(test_tANS_encode_decode_long_stream); + RUN_TEST(test_tANS_multi_encode_decode_long_stream); + RUN_TEST(test_tANS_rl_encode_decode_long_stream); + RUN_TEST(test_tANS_rl_encode_decode_zero); + RUN_TEST(test_tANS_rl_encode_decode_binary_example); }