Make zstd implementation match more closely bzip2 one.
Encryption still does not work.
diff --git a/lib/zip_algorithm_zstd.c b/lib/zip_algorithm_zstd.c
index 6676c29..8542b64 100644
--- a/lib/zip_algorithm_zstd.c
+++ b/lib/zip_algorithm_zstd.c
@@ -41,7 +41,7 @@
struct ctx {
zip_error_t *error;
bool compress;
- zip_uint32_t compression_flags;
+ int compression_flags;
bool end_of_input;
ZSTD_DStream *zdstream;
ZSTD_CStream *zcstream;
@@ -51,10 +51,10 @@
static void *
-allocate(bool compress, int compression_flags, zip_error_t *error, zip_uint16_t method) {
+allocate(bool compress, int compression_flags, zip_error_t *error) {
struct ctx *ctx;
- if (compression_flags < 0) {
+ if (compression_flags < 0 || compression_flags > INT_MAX) {
zip_error_set(error, ZIP_ER_INVAL, 0);
return NULL;
}
@@ -66,22 +66,31 @@
ctx->error = error;
ctx->compress = compress;
- ctx->compression_flags = (zip_uint32_t)compression_flags;
+ ctx->compression_flags = compression_flags;
ctx->end_of_input = false;
+ ctx->zdstream = NULL;
+ ctx->zcstream = NULL;
+ ctx->in.src = NULL;
+ ctx->in.pos = 0;
+ ctx->in.size = 0;
+ ctx->out.dst = NULL;
+ ctx->out.pos = 0;
+ ctx->out.size = 0;
+
return ctx;
}
static void *
compress_allocate(zip_uint16_t method, int compression_flags, zip_error_t *error) {
- return allocate(true, compression_flags, error, method);
+ return allocate(true, compression_flags, error);
}
static void *
decompress_allocate(zip_uint16_t method, int compression_flags, zip_error_t *error) {
- return allocate(false, compression_flags, error, method);
+ return allocate(false, compression_flags, error);
}
@@ -103,12 +112,6 @@
switch (ret) {
case ZSTD_error_no_error:
return ZIP_ER_OK;
- case ZSTD_error_memory_allocation:
- return ZIP_ER_MEMORY;
-
- case ZSTD_error_parameter_unsupported:
- case ZSTD_error_parameter_outOfBound:
- return ZIP_ER_INVAL;
case ZSTD_error_corruption_detected:
case ZSTD_error_checksum_wrong:
@@ -116,6 +119,13 @@
case ZSTD_error_dictionary_wrong:
return ZIP_ER_COMPRESSED_DATA;
+ case ZSTD_error_memory_allocation:
+ return ZIP_ER_MEMORY;
+
+ case ZSTD_error_parameter_unsupported:
+ case ZSTD_error_parameter_outOfBound:
+ return ZIP_ER_INVAL;
+
default:
return ZIP_ER_INTERNAL;
}
@@ -133,9 +143,15 @@
ctx->out.size = 0;
if (ctx->compress) {
ctx->zcstream = ZSTD_createCStream();
+ if (ctx->zcstream == NULL) {
+ return false;
+ }
}
else {
ctx->zdstream = ZSTD_createDStream();
+ if (ctx->zdstream == NULL) {
+ return false;
+ }
}
return true;
@@ -145,15 +161,21 @@
static bool
end(void *ud) {
struct ctx *ctx = (struct ctx *)ud;
+ size_t ret;
+
if (ctx->compress) {
- ZSTD_freeCStream(ctx->zcstream);
+ ret = ZSTD_freeCStream(ctx->zcstream);
ctx->zcstream = NULL;
}
else {
- ZSTD_freeDStream(ctx->zdstream);
+ ret = ZSTD_freeDStream(ctx->zdstream);
ctx->zdstream = NULL;
}
+ if (ZSTD_isError(ret)) {
+ zip_error_set(ctx->error, map_error(ret), 0);
+ return false;
+ }
return true;
}
@@ -161,12 +183,12 @@
static bool
input(void *ud, zip_uint8_t *data, zip_uint64_t length) {
struct ctx *ctx = (struct ctx *)ud;
- if (ctx->in.pos != 0 && ctx->in.pos != ctx->in.size) {
+ if (length > SIZE_MAX || (ctx->in.pos != 0 && ctx->in.pos != ctx->in.size)) {
zip_error_set(ctx->error, ZIP_ER_INVAL, 0);
return false;
}
ctx->in.src = (const void *)data;
- ctx->in.size = length;
+ ctx->in.size = (size_t)length;
ctx->in.pos = 0;
return true;
}
@@ -188,7 +210,7 @@
ctx->out.dst = data;
ctx->out.pos = 0;
- ctx->out.size = ZIP_MIN(UINT_MAX, *length);
+ ctx->out.size = ZIP_MIN(SIZE_MAX, *length);
if (ctx->compress) {
ret = ZSTD_compressStream(ctx->zcstream, &ctx->out, &ctx->in);