Skip to content

Add ZSTD compression #367

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
May 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ OPTION (WITH_OPENSSL "Use OpenSSL for TLS connections" OFF)
OPTION (WITH_SYSTEM_ABSEIL "Use system ABSEIL" OFF)
OPTION (WITH_SYSTEM_LZ4 "Use system LZ4" OFF)
OPTION (WITH_SYSTEM_CITYHASH "Use system cityhash" OFF)
OPTION (WITH_SYSTEM_ZSTD "Use system ZSTD" OFF)
OPTION (DEBUG_DEPENDENCIES "Print debug info about dependencies duting build" ON)
OPTION (CHECK_VERSION "Check that version number corresponds to git tag, usefull in CI/CD to validate that new version published on GitHub has same version in sources" OFF)

Expand Down Expand Up @@ -93,6 +94,13 @@ ELSE ()
SUBDIRS (contrib/cityhash/cityhash)
ENDIF ()

IF (WITH_SYSTEM_ZSTD)
FIND_PACKAGE(zstd REQUIRED)
ELSE ()
INCLUDE_DIRECTORIES (contrib/zstd/zstd)
SUBDIRS (contrib/zstd/zstd)
ENDIF ()

SUBDIRS (
clickhouse
)
Expand Down Expand Up @@ -141,4 +149,5 @@ if(DEBUG_DEPENDENCIES)
print_target_debug_info(absl::int128)
print_target_debug_info(cityhash::cityhash)
print_target_debug_info(lz4::lz4)
print_target_debug_info(zstd::zstd)
endif()
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ Optional dependencies:
- openssl
- liblz4
- libabsl
- libzstd

## Building

Expand Down
1 change: 1 addition & 0 deletions clickhouse/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,7 @@ TARGET_LINK_LIBRARIES (clickhouse-cpp-lib
absl::int128
cityhash::cityhash
lz4::lz4
zstd::zstd
)
TARGET_INCLUDE_DIRECTORIES (clickhouse-cpp-lib
PUBLIC ${PROJECT_SOURCE_DIR}
Expand Down
195 changes: 139 additions & 56 deletions clickhouse/base/compressed.cpp
Original file line number Diff line number Diff line change
@@ -1,19 +1,26 @@
#include "compressed.h"
#include "wire_format.h"
#include "output.h"
#include "../exceptions.h"
#include "clickhouse/exceptions.h"

#include <city.h>
#include <lz4.h>
#include <exception>
#include <zstd.h>
#include <stdexcept>
#include <system_error>

namespace {
constexpr size_t HEADER_SIZE = 9;
// see DB::CompressionMethodByte::LZ4 from src/Compression/CompressionInfo.h of ClickHouse project
constexpr uint8_t COMPRESSION_METHOD = 0x82;
// Documentation says that compression is faster when output buffer is larger than LZ4_compressBound estimation.

// see DB::CompressionMethodByte from src/Compression/CompressionInfo.h of ClickHouse project
enum class CompressionMethodByte : uint8_t {
NONE = 0x02,
LZ4 = 0x82,
ZSTD = 0x90,
};

// Documentation says that compression is faster when output buffer is larger than LZ4_compressBound/ZSTD_compressBound estimation.
constexpr size_t EXTRA_COMPRESS_BUFFER_SIZE = 4096;
constexpr size_t DBMS_MAX_COMPRESSED_SIZE = 0x40000000ULL; // 1GB
}
Expand All @@ -32,7 +39,7 @@ CompressedInput::~CompressedInput() {
#else
if (!std::uncaught_exceptions()) {
#endif
throw LZ4Error("some data was not read");
throw CompressionError("some data was not read");
}
}
}
Expand Down Expand Up @@ -60,55 +67,79 @@ bool CompressedInput::Decompress() {
return false;
}

if (method != COMPRESSION_METHOD) {
throw LZ4Error("unsupported compression method " + std::to_string(int(method)));
} else {
if (!WireFormat::ReadFixed(*input_, &compressed)) {
return false;
}
if (!WireFormat::ReadFixed(*input_, &original)) {
return false;
}
if (method != static_cast<uint8_t>(CompressionMethodByte::LZ4) && method != static_cast<uint8_t>(CompressionMethodByte::ZSTD)) {
throw CompressionError("unsupported compression method " + std::to_string((method)));
}

if (compressed > DBMS_MAX_COMPRESSED_SIZE) {
throw LZ4Error("compressed data too big");
}
if (!WireFormat::ReadFixed(*input_, &compressed)) {
return false;
}
if (!WireFormat::ReadFixed(*input_, &original)) {
return false;
}

if (compressed > DBMS_MAX_COMPRESSED_SIZE) {
throw CompressionError("compressed data too big");
}

Buffer tmp(compressed);
Buffer tmp(compressed);

// Data header
{
BufferOutput out(&tmp);
out.Write(&method, sizeof(method));
out.Write(&compressed, sizeof(compressed));
out.Write(&original, sizeof(original));
out.Flush();
// Data header
{
BufferOutput out(&tmp);
out.Write(&method, sizeof(method));
out.Write(&compressed, sizeof(compressed));
out.Write(&original, sizeof(original));
out.Flush();
}

if (!WireFormat::ReadBytes(*input_, tmp.data() + HEADER_SIZE, compressed - HEADER_SIZE)) {
return false;
} else {
if (hash != CityHash128((const char*)tmp.data(), compressed)) {
throw CompressionError("data was corrupted");
}
}

data_ = Buffer(original);

if (!WireFormat::ReadBytes(*input_, tmp.data() + HEADER_SIZE, compressed - HEADER_SIZE)) {
return false;
switch (method) {
case static_cast<uint8_t>(CompressionMethodByte::LZ4): {
if (LZ4_decompress_safe((const char*)tmp.data() + HEADER_SIZE, (char*)data_.data(), static_cast<int>(compressed - HEADER_SIZE), original) < 0) {
throw CompressionError("can't decompress LZ4-encoded data");
} else {
if (hash != CityHash128((const char*)tmp.data(), compressed)) {
throw LZ4Error("data was corrupted");
}
mem_.Reset(data_.data(), original);
}
return true;
}

data_ = Buffer(original);
case static_cast<uint8_t>(CompressionMethodByte::ZSTD): {
size_t res = ZSTD_decompress((char*)data_.data(), original, (const char*)tmp.data() + HEADER_SIZE, static_cast<int>(compressed - HEADER_SIZE));

if (LZ4_decompress_safe((const char*)tmp.data() + HEADER_SIZE, (char*)data_.data(), static_cast<int>(compressed - HEADER_SIZE), original) < 0) {
throw LZ4Error("can't decompress data");
if (ZSTD_isError(res)) {
throw CompressionError("can't decompress ZSTD-encoded data, ZSTD error: " + std::string(ZSTD_getErrorName(res)));
} else {
mem_.Reset(data_.data(), original);
}
return true;
}

case static_cast<uint8_t>(CompressionMethodByte::NONE): {
throw CompressionError("compression method not defined" + std::to_string((method)));
}
Copy link
Contributor

@Enmk Enmk Mar 7, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please handle invalid\unsupported compression methods too:

Suggested change
}
}
default:
throw CompressionError("Unknown or unsupported compression method " + std::to_string((method)));

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

default: {
throw CompressionError("Unknown or unsupported compression method " + std::to_string((method)));
}
}

return true;
}


CompressedOutput::CompressedOutput(OutputStream * destination, size_t max_compressed_chunk_size)
CompressedOutput::CompressedOutput(OutputStream * destination, size_t max_compressed_chunk_size, CompressionMethod method)
: destination_(destination)
, max_compressed_chunk_size_(max_compressed_chunk_size)
, method_(method)
{
PreallocateCompressBuffer(max_compressed_chunk_size);
}
Expand Down Expand Up @@ -139,37 +170,89 @@ void CompressedOutput::DoFlush() {
}

void CompressedOutput::Compress(const void * data, size_t len) {
const auto compressed_size = LZ4_compress_default(
(const char*)data,
(char*)compressed_buffer_.data() + HEADER_SIZE,
static_cast<int>(len),
static_cast<int>(compressed_buffer_.size() - HEADER_SIZE));
if (compressed_size <= 0)
throw LZ4Error("Failed to compress chunk of " + std::to_string(len) + " bytes, "
"LZ4 error: " + std::to_string(compressed_size));
switch (method_) {
case clickhouse::CompressionMethod::LZ4: {
const auto compressed_size = LZ4_compress_default(
(const char*)data,
(char*)compressed_buffer_.data() + HEADER_SIZE,
static_cast<int>(len),
static_cast<int>(compressed_buffer_.size() - HEADER_SIZE));
if (compressed_size <= 0)
throw CompressionError("Failed to compress chunk of " + std::to_string(len) + " bytes, "
"LZ4 error: " + std::to_string(compressed_size));

{
auto header = compressed_buffer_.data();
WriteUnaligned(header, COMPRESSION_METHOD);
// Compressed data size with header
WriteUnaligned(header + 1, static_cast<uint32_t>(compressed_size + HEADER_SIZE));
// Original data size
WriteUnaligned(header + 5, static_cast<uint32_t>(len));
{
auto header = compressed_buffer_.data();
WriteUnaligned(header, CompressionMethodByte::LZ4);
// Compressed data size with header
WriteUnaligned(header + 1, static_cast<uint32_t>(compressed_size + HEADER_SIZE));
// Original data size
WriteUnaligned(header + 5, static_cast<uint32_t>(len));
}

WireFormat::WriteFixed(*destination_, CityHash128((const char*)compressed_buffer_.data(), compressed_size + HEADER_SIZE));
WireFormat::WriteBytes(*destination_, compressed_buffer_.data(), compressed_size + HEADER_SIZE);
break;
}

WireFormat::WriteFixed(*destination_, CityHash128(
(const char*)compressed_buffer_.data(), compressed_size + HEADER_SIZE));
WireFormat::WriteBytes(*destination_, compressed_buffer_.data(), compressed_size + HEADER_SIZE);
case clickhouse::CompressionMethod::ZSTD: {
const size_t compressed_size = ZSTD_compress(
(char*)compressed_buffer_.data() + HEADER_SIZE,
static_cast<int>(compressed_buffer_.size() - HEADER_SIZE),
(const char*)data,
static_cast<int>(len),
ZSTD_fast);
if (ZSTD_isError(compressed_size))
throw CompressionError("Failed to compress chunk of " + std::to_string(len) + " bytes, "
"ZSTD error: " + std::string(ZSTD_getErrorName(compressed_size)));

{
auto header = compressed_buffer_.data();
WriteUnaligned(header, CompressionMethodByte::ZSTD);
// Compressed data size with header
WriteUnaligned(header + 1, static_cast<uint32_t>(compressed_size + HEADER_SIZE));
// Original data size
WriteUnaligned(header + 5, static_cast<uint32_t>(len));
}

WireFormat::WriteFixed(*destination_, CityHash128((const char*)compressed_buffer_.data(), compressed_size + HEADER_SIZE));
WireFormat::WriteBytes(*destination_, compressed_buffer_.data(), compressed_size + HEADER_SIZE);
break;
}

case clickhouse::CompressionMethod::None: {
throw CompressionError("no compression defined");
}
}

destination_->Flush();
}

void CompressedOutput::PreallocateCompressBuffer(size_t input_size) {
const auto estimated_compressed_buffer_size = LZ4_compressBound(static_cast<int>(input_size));
if (estimated_compressed_buffer_size <= 0)
throw LZ4Error("Failed to estimate compressed buffer size, LZ4 error: " + std::to_string(estimated_compressed_buffer_size));
switch (method_) {
case clickhouse::CompressionMethod::LZ4: {
const auto estimated_compressed_buffer_size = LZ4_compressBound(static_cast<int>(input_size));
if (estimated_compressed_buffer_size <= 0)
throw CompressionError("Failed to estimate compressed buffer size, LZ4 error: " + std::to_string(estimated_compressed_buffer_size));

compressed_buffer_.resize(estimated_compressed_buffer_size + HEADER_SIZE + EXTRA_COMPRESS_BUFFER_SIZE);
break;
}

compressed_buffer_.resize(estimated_compressed_buffer_size + HEADER_SIZE + EXTRA_COMPRESS_BUFFER_SIZE);
case clickhouse::CompressionMethod::ZSTD: {
const size_t estimated_compressed_buffer_size = ZSTD_compressBound(static_cast<int>(input_size));
if (ZSTD_isError(estimated_compressed_buffer_size))
throw CompressionError("Failed to estimate compressed buffer size, ZSTD error: " + std::string(ZSTD_getErrorName(estimated_compressed_buffer_size)));

compressed_buffer_.resize(estimated_compressed_buffer_size + HEADER_SIZE + EXTRA_COMPRESS_BUFFER_SIZE);
break;
}

case clickhouse::CompressionMethod::None: {
/// do nothing
break;
}
}
}

}
5 changes: 4 additions & 1 deletion clickhouse/base/compressed.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
#include "output.h"
#include "buffer.h"

#include "clickhouse/client.h"

namespace clickhouse {

class CompressedInput : public ZeroCopyInput {
Expand All @@ -25,7 +27,7 @@ class CompressedInput : public ZeroCopyInput {

class CompressedOutput : public OutputStream {
public:
explicit CompressedOutput(OutputStream * destination, size_t max_compressed_chunk_size = 0);
explicit CompressedOutput(OutputStream* destination, size_t max_compressed_chunk_size = 0, CompressionMethod method = CompressionMethod::LZ4);
~CompressedOutput() override;

protected:
Expand All @@ -40,6 +42,7 @@ class CompressedOutput : public OutputStream {
OutputStream * destination_;
const size_t max_compressed_chunk_size_;
Buffer compressed_buffer_;
CompressionMethod method_;
};

}
7 changes: 4 additions & 3 deletions clickhouse/client.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,9 @@ std::ostream& operator<<(std::ostream& os, const ClientOptions& opt) {
<< " send_retries:" << opt.send_retries
<< " retry_timeout:" << opt.retry_timeout.count()
<< " compression_method:"
<< (opt.compression_method == CompressionMethod::LZ4 ? "LZ4" : "None");
<< (opt.compression_method == CompressionMethod::LZ4 ? "LZ4"
: opt.compression_method == CompressionMethod::ZSTD ? "ZSTD"
: "None");
#if defined(WITH_OPENSSL)
if (opt.ssl_options) {
const auto & ssl_options = *opt.ssl_options;
Expand Down Expand Up @@ -858,9 +860,8 @@ void Client::Impl::SendData(const Block& block) {
}

if (compression_ == CompressionState::Enable) {
assert(options_.compression_method == CompressionMethod::LZ4);

std::unique_ptr<OutputStream> compressed_output = std::make_unique<CompressedOutput>(output_.get(), options_.max_compression_chunk_size);
std::unique_ptr<OutputStream> compressed_output = std::make_unique<CompressedOutput>(output_.get(), options_.max_compression_chunk_size, options_.compression_method);
BufferedOutput buffered(std::move(compressed_output), options_.max_compression_chunk_size);

WriteBlock(block, buffered);
Expand Down
7 changes: 4 additions & 3 deletions clickhouse/client.h
Original file line number Diff line number Diff line change
Expand Up @@ -40,9 +40,10 @@ struct ServerInfo {
};

/// Methods of block compression.
enum class CompressionMethod {
None = -1,
LZ4 = 1,
enum class CompressionMethod : int8_t {
None = -1,
LZ4 = 1,
ZSTD = 2,
};

struct Endpoint {
Expand Down
2 changes: 1 addition & 1 deletion clickhouse/exceptions.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ class OpenSSLError : public Error {
using Error::Error;
};

class LZ4Error : public Error {
class CompressionError : public Error {
using Error::Error;
};

Expand Down
Loading