diff --git a/clickhouse/base/compressed.cpp b/clickhouse/base/compressed.cpp index 18d89f03..f1c3a569 100644 --- a/clickhouse/base/compressed.cpp +++ b/clickhouse/base/compressed.cpp @@ -51,20 +51,20 @@ bool CompressedInput::Decompress() { uint32_t original = 0; uint8_t method = 0; - if (!WireFormat::ReadFixed(input_, &hash)) { + if (!WireFormat::ReadFixed(*input_, &hash)) { return false; } - if (!WireFormat::ReadFixed(input_, &method)) { + if (!WireFormat::ReadFixed(*input_, &method)) { return false; } if (method != COMPRESSION_METHOD) { throw std::runtime_error("unsupported compression method " + std::to_string(int(method))); } else { - if (!WireFormat::ReadFixed(input_, &compressed)) { + if (!WireFormat::ReadFixed(*input_, &compressed)) { return false; } - if (!WireFormat::ReadFixed(input_, &original)) { + if (!WireFormat::ReadFixed(*input_, &original)) { return false; } @@ -80,9 +80,10 @@ bool CompressedInput::Decompress() { out.Write(&method, sizeof(method)); out.Write(&compressed, sizeof(compressed)); out.Write(&original, sizeof(original)); + out.Flush(); } - if (!WireFormat::ReadBytes(input_, tmp.data() + HEADER_SIZE, compressed - HEADER_SIZE)) { + if (!WireFormat::ReadBytes(*input_, tmp.data() + HEADER_SIZE, compressed - HEADER_SIZE)) { return false; } else { if (hash != CityHash128((const char*)tmp.data(), compressed)) { @@ -110,9 +111,7 @@ CompressedOutput::CompressedOutput(OutputStream * destination, size_t max_compre PreallocateCompressBuffer(max_compressed_chunk_size); } -CompressedOutput::~CompressedOutput() { - Flush(); -} +CompressedOutput::~CompressedOutput() { } size_t CompressedOutput::DoWrite(const void* data, size_t len) { const size_t original_len = len; @@ -156,9 +155,9 @@ void CompressedOutput::Compress(const void * data, size_t len) { WriteUnaligned(header + 5, static_cast(len)); } - WireFormat::WriteFixed(destination_, CityHash128( + WireFormat::WriteFixed(*destination_, CityHash128( (const char*)compressed_buffer_.data(), compressed_size + HEADER_SIZE)); - WireFormat::WriteBytes(destination_, compressed_buffer_.data(), compressed_size + HEADER_SIZE); + WireFormat::WriteBytes(*destination_, compressed_buffer_.data(), compressed_size + HEADER_SIZE); destination_->Flush(); } diff --git a/clickhouse/base/input.cpp b/clickhouse/base/input.cpp index e1c409da..e704fe58 100644 --- a/clickhouse/base/input.cpp +++ b/clickhouse/base/input.cpp @@ -56,8 +56,8 @@ size_t ArrayInput::DoNext(const void** ptr, size_t len) { } -BufferedInput::BufferedInput(InputStream* slave, size_t buflen) - : slave_(slave) +BufferedInput::BufferedInput(std::unique_ptr source, size_t buflen) + : source_(std::move(source)) , array_input_(nullptr, 0) , buffer_(buflen) { @@ -72,7 +72,7 @@ void BufferedInput::Reset() { size_t BufferedInput::DoNext(const void** ptr, size_t len) { if (array_input_.Exhausted()) { array_input_.Reset( - buffer_.data(), slave_->Read(buffer_.data(), buffer_.size()) + buffer_.data(), source_->Read(buffer_.data(), buffer_.size()) ); } @@ -82,11 +82,11 @@ size_t BufferedInput::DoNext(const void** ptr, size_t len) { size_t BufferedInput::DoRead(void* buf, size_t len) { if (array_input_.Exhausted()) { if (len > buffer_.size() / 2) { - return slave_->Read(buf, len); + return source_->Read(buf, len); } array_input_.Reset( - buffer_.data(), slave_->Read(buffer_.data(), buffer_.size()) + buffer_.data(), source_->Read(buffer_.data(), buffer_.size()) ); } diff --git a/clickhouse/base/input.h b/clickhouse/base/input.h index 9f35ddda..a8885b30 100644 --- a/clickhouse/base/input.h +++ b/clickhouse/base/input.h @@ -3,6 +3,7 @@ #include #include #include +#include namespace clickhouse { @@ -84,7 +85,7 @@ class ArrayInput : public ZeroCopyInput { class BufferedInput : public ZeroCopyInput { public: - BufferedInput(InputStream* slave, size_t buflen = 8192); + BufferedInput(std::unique_ptr source, size_t buflen = 8192); ~BufferedInput() override; void Reset(); @@ -94,7 +95,7 @@ class BufferedInput : public ZeroCopyInput { size_t DoNext(const void** ptr, size_t len) override; private: - InputStream* const slave_; + std::unique_ptr const source_; ArrayInput array_input_; std::vector buffer_; }; diff --git a/clickhouse/base/output.cpp b/clickhouse/base/output.cpp index fd751167..86b9fbdd 100644 --- a/clickhouse/base/output.cpp +++ b/clickhouse/base/output.cpp @@ -66,26 +66,14 @@ size_t BufferOutput::DoNext(void** data, size_t len) { } -BufferedOutput::BufferedOutput(OutputStream* slave, size_t buflen) - : slave_(slave) +BufferedOutput::BufferedOutput(std::unique_ptr destination, size_t buflen) + : destination_(std::move(destination)) , buffer_(buflen) , array_output_(buffer_.data(), buflen) { } -BufferedOutput::~BufferedOutput() { - try - { - Flush(); - } - catch (...) - { - // That means we've failed to flush some data e.g. to the socket, - // but there is nothing we can do at this point (can't bring the socket back), - // and throwing in destructor is really a bad idea. - // The best we can do is to log the error and ignore it, but currently there is no logging subsystem. - } -} +BufferedOutput::~BufferedOutput() { } void BufferedOutput::Reset() { array_output_.Reset(buffer_.data(), buffer_.size()); @@ -93,8 +81,8 @@ void BufferedOutput::Reset() { void BufferedOutput::DoFlush() { if (array_output_.Data() != buffer_.data()) { - slave_->Write(buffer_.data(), array_output_.Data() - buffer_.data()); - slave_->Flush(); + destination_->Write(buffer_.data(), array_output_.Data() - buffer_.data()); + destination_->Flush(); array_output_.Reset(buffer_.data(), buffer_.size()); } @@ -114,7 +102,7 @@ size_t BufferedOutput::DoWrite(const void* data, size_t len) { Flush(); if (len > buffer_.size() / 2) { - return slave_->Write(data, len); + return destination_->Write(data, len); } } diff --git a/clickhouse/base/output.h b/clickhouse/base/output.h index b23cb08b..bb804ce4 100644 --- a/clickhouse/base/output.h +++ b/clickhouse/base/output.h @@ -6,6 +6,7 @@ #include #include #include +#include namespace clickhouse { @@ -92,11 +93,13 @@ class ArrayOutput : public ZeroCopyOutput { /** * A ZeroCopyOutput stream backed by a vector. + * + * Doesn't Flush() in destructor, client must ensure to do it manually at some point. */ class BufferOutput : public ZeroCopyOutput { public: BufferOutput(Buffer* buf); - ~BufferOutput(); + ~BufferOutput() override; protected: size_t DoNext(void** data, size_t len) override; @@ -106,10 +109,16 @@ class BufferOutput : public ZeroCopyOutput { size_t pos_; }; - +/** BufferedOutput writes data to internal buffer first. + * + * Any data goes to underlying stream only if internal buffer is full + * or when client invokes Flush() on this. + * + * Doesn't Flush() in destructor, client must ensure to do it manually at some point. + */ class BufferedOutput : public ZeroCopyOutput { public: - BufferedOutput(OutputStream* slave, size_t buflen = 8192); + explicit BufferedOutput(std::unique_ptr destination, size_t buflen = 8192); ~BufferedOutput() override; void Reset(); @@ -120,7 +129,7 @@ class BufferedOutput : public ZeroCopyOutput { size_t DoWrite(const void* data, size_t len) override; private: - OutputStream* const slave_; + std::unique_ptr const destination_; Buffer buffer_; ArrayOutput array_output_; }; diff --git a/clickhouse/base/streamstack.h b/clickhouse/base/streamstack.h deleted file mode 100644 index a8bfeee5..00000000 --- a/clickhouse/base/streamstack.h +++ /dev/null @@ -1,59 +0,0 @@ -#pragma once - -#include -#include -#include - -namespace clickhouse { - -/** Collection of owned OutputStream or InputStream instances. - * Simplifies building chains or trees of streams, like: - * - * A => B => C => F - * ^ - * / - * D ====> E - * - * Streams are destroyed in LIFO order, allowing proper flushing of internal buffers. - */ -template -class Streams -{ -public: - Streams() = default; - Streams(Streams&&) = default; - Streams& operator=(Streams&&) = default; - - ~Streams() { - while (!streams_.empty()) { - streams_.pop(); - } - } - - template - inline ConcreteStreamType * Add(std::unique_ptr && stream) { - auto ret = stream.get(); - streams_.emplace(std::move(stream)); - return ret; - } - - template - inline ConcreteStreamType * AddNew(Args&&... args) { - return Add(std::make_unique(std::forward(args)...)); - } - - inline StreamType * Top() const { - return streams_.top().get(); - } - -private: - std::stack> streams_; -}; - -class OutputStream; -class InputStream; - -using OutputStreams = Streams; -using InputStreams = Streams; - -} diff --git a/clickhouse/base/wire_format.cpp b/clickhouse/base/wire_format.cpp index c0f09ecc..00a806f8 100644 --- a/clickhouse/base/wire_format.cpp +++ b/clickhouse/base/wire_format.cpp @@ -11,12 +11,12 @@ constexpr int MAX_VARINT_BYTES = 10; namespace clickhouse { -bool WireFormat::ReadAll(InputStream * input, void* buf, size_t len) { +bool WireFormat::ReadAll(InputStream& input, void* buf, size_t len) { uint8_t* p = static_cast(buf); size_t read_previously = 1; // 1 to execute loop at least once while (len > 0 && read_previously) { - read_previously = input->Read(p, len); + read_previously = input.Read(p, len); p += read_previously; len -= read_previously; @@ -25,13 +25,13 @@ bool WireFormat::ReadAll(InputStream * input, void* buf, size_t len) { return !len; } -void WireFormat::WriteAll(OutputStream* output, const void* buf, size_t len) { +void WireFormat::WriteAll(OutputStream& output, const void* buf, size_t len) { const size_t original_len = len; const uint8_t* p = static_cast(buf); size_t written_previously = 1; // 1 to execute loop at least once while (len > 0 && written_previously) { - written_previously = output->Write(p, len); + written_previously = output.Write(p, len); p += written_previously; len -= written_previously; @@ -43,13 +43,13 @@ void WireFormat::WriteAll(OutputStream* output, const void* buf, size_t len) { } } -bool WireFormat::ReadVarint64(InputStream* input, uint64_t* value) { +bool WireFormat::ReadVarint64(InputStream& input, uint64_t* value) { *value = 0; for (size_t i = 0; i < MAX_VARINT_BYTES; ++i) { uint8_t byte = 0; - if (!input->ReadByte(&byte)) { + if (!input.ReadByte(&byte)) { return false; } else { *value |= uint64_t(byte & 0x7F) << (7 * i); @@ -64,7 +64,7 @@ bool WireFormat::ReadVarint64(InputStream* input, uint64_t* value) { return false; } -void WireFormat::WriteVarint64(OutputStream* output, uint64_t value) { +void WireFormat::WriteVarint64(OutputStream& output, uint64_t value) { uint8_t bytes[MAX_VARINT_BYTES]; int size = 0; @@ -84,14 +84,14 @@ void WireFormat::WriteVarint64(OutputStream* output, uint64_t value) { WriteAll(output, bytes, size); } -bool WireFormat::SkipString(InputStream* input) { +bool WireFormat::SkipString(InputStream& input) { uint64_t len = 0; if (ReadVarint64(input, &len)) { if (len > 0x00FFFFFFULL) return false; - return input->Skip((size_t)len); + return input.Skip((size_t)len); } return false; diff --git a/clickhouse/base/wire_format.h b/clickhouse/base/wire_format.h index ade785df..9bbf7959 100644 --- a/clickhouse/base/wire_format.h +++ b/clickhouse/base/wire_format.h @@ -10,31 +10,31 @@ class OutputStream; class WireFormat { public: template - static bool ReadFixed(InputStream* input, T* value); - static bool ReadString(InputStream* input, std::string* value); - static bool SkipString(InputStream* input); - static bool ReadBytes(InputStream* input, void* buf, size_t len); - static bool ReadUInt64(InputStream* input, uint64_t* value); - static bool ReadVarint64(InputStream* output, uint64_t* value); + static bool ReadFixed(InputStream& input, T* value); + static bool ReadString(InputStream& input, std::string* value); + static bool SkipString(InputStream& input); + static bool ReadBytes(InputStream& input, void* buf, size_t len); + static bool ReadUInt64(InputStream& input, uint64_t* value); + static bool ReadVarint64(InputStream& output, uint64_t* value); template - static void WriteFixed(OutputStream* output, const T& value); - static void WriteBytes(OutputStream* output, const void* buf, size_t len); - static void WriteString(OutputStream* output, std::string_view value); - static void WriteUInt64(OutputStream* output, const uint64_t value); - static void WriteVarint64(OutputStream* output, uint64_t value); + static void WriteFixed(OutputStream& output, const T& value); + static void WriteBytes(OutputStream& output, const void* buf, size_t len); + static void WriteString(OutputStream& output, std::string_view value); + static void WriteUInt64(OutputStream& output, const uint64_t value); + static void WriteVarint64(OutputStream& output, uint64_t value); private: - static bool ReadAll(InputStream * input, void* buf, size_t len); - static void WriteAll(OutputStream* output, const void* buf, size_t len); + static bool ReadAll(InputStream& input, void* buf, size_t len); + static void WriteAll(OutputStream& output, const void* buf, size_t len); }; template -inline bool WireFormat::ReadFixed(InputStream* input, T* value) { +inline bool WireFormat::ReadFixed(InputStream& input, T* value) { return ReadAll(input, value, sizeof(T)); } -inline bool WireFormat::ReadString(InputStream* input, std::string* value) { +inline bool WireFormat::ReadString(InputStream& input, std::string* value) { uint64_t len = 0; if (ReadVarint64(input, &len)) { if (len > 0x00FFFFFFULL) { @@ -47,29 +47,29 @@ inline bool WireFormat::ReadString(InputStream* input, std::string* value) { return false; } -inline bool WireFormat::ReadBytes(InputStream* input, void* buf, size_t len) { +inline bool WireFormat::ReadBytes(InputStream& input, void* buf, size_t len) { return ReadAll(input, buf, len); } -inline bool WireFormat::ReadUInt64(InputStream* input, uint64_t* value) { +inline bool WireFormat::ReadUInt64(InputStream& input, uint64_t* value) { return ReadVarint64(input, value); } template -inline void WireFormat::WriteFixed(OutputStream* output, const T& value) { +inline void WireFormat::WriteFixed(OutputStream& output, const T& value) { WriteAll(output, &value, sizeof(T)); } -inline void WireFormat::WriteBytes(OutputStream* output, const void* buf, size_t len) { +inline void WireFormat::WriteBytes(OutputStream& output, const void* buf, size_t len) { WriteAll(output, buf, len); } -inline void WireFormat::WriteString(OutputStream* output, std::string_view value) { +inline void WireFormat::WriteString(OutputStream& output, std::string_view value) { WriteVarint64(output, value.size()); WriteAll(output, value.data(), value.size()); } -inline void WireFormat::WriteUInt64(OutputStream* output, const uint64_t value) { +inline void WireFormat::WriteUInt64(OutputStream& output, const uint64_t value) { WriteVarint64(output, value); } diff --git a/clickhouse/client.cpp b/clickhouse/client.cpp index 4603cccd..ecf93fae 100644 --- a/clickhouse/client.cpp +++ b/clickhouse/client.cpp @@ -3,7 +3,6 @@ #include "base/compressed.h" #include "base/socket.h" -#include "base/streamstack.h" #include "base/wire_format.h" #include "columns/factory.h" @@ -109,7 +108,7 @@ class Client::Impl { bool SendHello(); - bool ReadBlock(Block* block, InputStream* input); + bool ReadBlock(InputStream& input, Block* block); bool ReceiveHello(); @@ -119,7 +118,7 @@ class Client::Impl { /// Reads exception packet form input stream. bool ReceiveException(bool rethrow = false); - void WriteBlock(const Block& block, OutputStream* output); + void WriteBlock(const Block& block, OutputStream& output); private: /// In case of network errors tries to reconnect to server and @@ -153,12 +152,8 @@ class Client::Impl { QueryEvents* events_; int compression_ = CompressionState::Disable; - InputStreams input_streams_; - InputStream* input_; - - OutputStreams output_streams_; - OutputStream* output_; - + std::unique_ptr input_; + std::unique_ptr output_; std::unique_ptr socket_; #if defined(WITH_OPENSSL) @@ -283,7 +278,7 @@ void Client::Impl::Insert(const std::string& table_name, const Block& block) { } void Client::Impl::Ping() { - WireFormat::WriteUInt64(output_, ClientCodes::Ping); + WireFormat::WriteUInt64(*output_, ClientCodes::Ping); output_->Flush(); uint64_t server_packet; @@ -336,19 +331,12 @@ void Client::Impl::ResetConnection() { socket->SetTcpNoDelay(options_.tcp_nodelay); } - OutputStreams output_streams; - auto socket_output = output_streams.Add(socket->makeOutputStream()); - auto output = output_streams.AddNew(socket_output); - - InputStreams input_streams; - auto socket_input = input_streams.Add(socket->makeInputStream()); - auto input = input_streams.AddNew(socket_input); + std::unique_ptr output = std::make_unique(socket->makeOutputStream()); + std::unique_ptr input = std::make_unique(socket->makeInputStream()); - std::swap(output_streams, output_streams_); - std::swap(input_streams, input_streams_); + std::swap(input, input_); + std::swap(output, output_); std::swap(socket, socket_); - output_ = output; - input_ = input; #if defined(WITH_OPENSSL) std::swap(ssl_context_, ssl_context); @@ -376,7 +364,7 @@ bool Client::Impl::Handshake() { bool Client::Impl::ReceivePacket(uint64_t* server_packet) { uint64_t packet_type = 0; - if (!WireFormat::ReadVarint64(input_, &packet_type)) { + if (!WireFormat::ReadVarint64(*input_, &packet_type)) { return false; } if (server_packet) { @@ -399,22 +387,22 @@ bool Client::Impl::ReceivePacket(uint64_t* server_packet) { case ServerCodes::ProfileInfo: { Profile profile; - if (!WireFormat::ReadUInt64(input_, &profile.rows)) { + if (!WireFormat::ReadUInt64(*input_, &profile.rows)) { return false; } - if (!WireFormat::ReadUInt64(input_, &profile.blocks)) { + if (!WireFormat::ReadUInt64(*input_, &profile.blocks)) { return false; } - if (!WireFormat::ReadUInt64(input_, &profile.bytes)) { + if (!WireFormat::ReadUInt64(*input_, &profile.bytes)) { return false; } - if (!WireFormat::ReadFixed(input_, &profile.applied_limit)) { + if (!WireFormat::ReadFixed(*input_, &profile.applied_limit)) { return false; } - if (!WireFormat::ReadUInt64(input_, &profile.rows_before_limit)) { + if (!WireFormat::ReadUInt64(*input_, &profile.rows_before_limit)) { return false; } - if (!WireFormat::ReadFixed(input_, &profile.calculated_rows_before_limit)) { + if (!WireFormat::ReadFixed(*input_, &profile.calculated_rows_before_limit)) { return false; } @@ -428,14 +416,14 @@ bool Client::Impl::ReceivePacket(uint64_t* server_packet) { case ServerCodes::Progress: { Progress info; - if (!WireFormat::ReadUInt64(input_, &info.rows)) { + if (!WireFormat::ReadUInt64(*input_, &info.rows)) { return false; } - if (!WireFormat::ReadUInt64(input_, &info.bytes)) { + if (!WireFormat::ReadUInt64(*input_, &info.bytes)) { return false; } if (REVISION >= DBMS_MIN_REVISION_WITH_TOTAL_ROWS_IN_PROGRESS) { - if (!WireFormat::ReadUInt64(input_, &info.total_rows)) { + if (!WireFormat::ReadUInt64(*input_, &info.total_rows)) { return false; } } @@ -466,7 +454,7 @@ bool Client::Impl::ReceivePacket(uint64_t* server_packet) { return false; } -bool Client::Impl::ReadBlock(Block* block, InputStream* input) { +bool Client::Impl::ReadBlock(InputStream& input, Block* block) { // Additional information about block. if (REVISION >= DBMS_MIN_REVISION_WITH_BLOCK_INFO) { uint64_t num; @@ -516,7 +504,7 @@ bool Client::Impl::ReadBlock(Block* block, InputStream* input) { } if (ColumnRef col = CreateColumnByType(type, create_column_settings)) { - if (num_rows && !col->Load(input, num_rows)) { + if (num_rows && !col->Load(&input, num_rows)) { throw std::runtime_error("can't load"); } @@ -533,18 +521,18 @@ bool Client::Impl::ReceiveData() { Block block; if (REVISION >= DBMS_MIN_REVISION_WITH_TEMPORARY_TABLES) { - if (!WireFormat::SkipString(input_)) { + if (!WireFormat::SkipString(*input_)) { return false; } } if (compression_ == CompressionState::Enable) { - CompressedInput compressed(input_); - if (!ReadBlock(&block, &compressed)) { + CompressedInput compressed(input_.get()); + if (!ReadBlock(compressed, &block)) { return false; } } else { - if (!ReadBlock(&block, input_)) { + if (!ReadBlock(*input_, &block)) { return false; } } @@ -567,23 +555,23 @@ bool Client::Impl::ReceiveException(bool rethrow) { do { bool has_nested = false; - if (!WireFormat::ReadFixed(input_, ¤t->code)) { + if (!WireFormat::ReadFixed(*input_, ¤t->code)) { exception_received = false; break; } - if (!WireFormat::ReadString(input_, ¤t->name)) { + if (!WireFormat::ReadString(*input_, ¤t->name)) { exception_received = false; break; } - if (!WireFormat::ReadString(input_, ¤t->display_text)) { + if (!WireFormat::ReadString(*input_, ¤t->display_text)) { exception_received = false; break; } - if (!WireFormat::ReadString(input_, ¤t->stack_trace)) { + if (!WireFormat::ReadString(*input_, ¤t->stack_trace)) { exception_received = false; break; } - if (!WireFormat::ReadFixed(input_, &has_nested)) { + if (!WireFormat::ReadFixed(*input_, &has_nested)) { exception_received = false; break; } @@ -608,13 +596,13 @@ bool Client::Impl::ReceiveException(bool rethrow) { } void Client::Impl::SendCancel() { - WireFormat::WriteUInt64(output_, ClientCodes::Cancel); + WireFormat::WriteUInt64(*output_, ClientCodes::Cancel); output_->Flush(); } void Client::Impl::SendQuery(const std::string& query) { - WireFormat::WriteUInt64(output_, ClientCodes::Query); - WireFormat::WriteString(output_, std::string()); + WireFormat::WriteUInt64(*output_, ClientCodes::Query); + WireFormat::WriteString(*output_, std::string()); /// Client info. if (server_info_.revision >= DBMS_MIN_REVISION_WITH_CLIENT_INFO) { @@ -627,23 +615,23 @@ void Client::Impl::SendQuery(const std::string& query) { info.client_revision = REVISION; - WireFormat::WriteFixed(output_, info.query_kind); - WireFormat::WriteString(output_, info.initial_user); - WireFormat::WriteString(output_, info.initial_query_id); - WireFormat::WriteString(output_, info.initial_address); - WireFormat::WriteFixed(output_, info.iface_type); + WireFormat::WriteFixed(*output_, info.query_kind); + WireFormat::WriteString(*output_, info.initial_user); + WireFormat::WriteString(*output_, info.initial_query_id); + WireFormat::WriteString(*output_, info.initial_address); + WireFormat::WriteFixed(*output_, info.iface_type); - WireFormat::WriteString(output_, info.os_user); - WireFormat::WriteString(output_, info.client_hostname); - WireFormat::WriteString(output_, info.client_name); - WireFormat::WriteUInt64(output_, info.client_version_major); - WireFormat::WriteUInt64(output_, info.client_version_minor); - WireFormat::WriteUInt64(output_, info.client_revision); + WireFormat::WriteString(*output_, info.os_user); + WireFormat::WriteString(*output_, info.client_hostname); + WireFormat::WriteString(*output_, info.client_name); + WireFormat::WriteUInt64(*output_, info.client_version_major); + WireFormat::WriteUInt64(*output_, info.client_version_minor); + WireFormat::WriteUInt64(*output_, info.client_revision); if (server_info_.revision >= DBMS_MIN_REVISION_WITH_QUOTA_KEY_IN_CLIENT_INFO) - WireFormat::WriteString(output_, info.quota_key); + WireFormat::WriteString(*output_, info.quota_key); if (server_info_.revision >= DBMS_MIN_REVISION_WITH_VERSION_PATCH) { - WireFormat::WriteUInt64(output_, info.client_version_patch); + WireFormat::WriteUInt64(*output_, info.client_version_patch); } } @@ -651,11 +639,11 @@ void Client::Impl::SendQuery(const std::string& query) { //if (settings) // settings->serialize(*out); //else - WireFormat::WriteString(output_, std::string()); + WireFormat::WriteString(*output_, std::string()); - WireFormat::WriteUInt64(output_, Stages::Complete); - WireFormat::WriteUInt64(output_, compression_); - WireFormat::WriteString(output_, query); + WireFormat::WriteUInt64(*output_, Stages::Complete); + WireFormat::WriteUInt64(*output_, compression_); + WireFormat::WriteString(*output_, query); // Send empty block as marker of // end of data SendData(Block()); @@ -664,7 +652,7 @@ void Client::Impl::SendQuery(const std::string& query) { } -void Client::Impl::WriteBlock(const Block& block, OutputStream* output) { +void Client::Impl::WriteBlock(const Block& block, OutputStream& output) { // Additional information about block. if (server_info_.revision >= DBMS_MIN_REVISION_WITH_BLOCK_INFO) { WireFormat::WriteUInt64(output, 1); @@ -681,39 +669,41 @@ void Client::Impl::WriteBlock(const Block& block, OutputStream* output) { WireFormat::WriteString(output, bi.Name()); WireFormat::WriteString(output, bi.Type()->GetName()); - bi.Column()->Save(output); + bi.Column()->Save(&output); } - output->Flush(); + output.Flush(); } void Client::Impl::SendData(const Block& block) { - WireFormat::WriteUInt64(output_, ClientCodes::Data); + WireFormat::WriteUInt64(*output_, ClientCodes::Data); if (server_info_.revision >= DBMS_MIN_REVISION_WITH_TEMPORARY_TABLES) { - WireFormat::WriteString(output_, std::string()); + WireFormat::WriteString(*output_, std::string()); } if (compression_ == CompressionState::Enable) { assert(options_.compression_method == CompressionMethod::LZ4); - CompressedOutput compressed_ouput(output_, options_.max_compression_chunk_size); - BufferedOutput buffered(&compressed_ouput, options_.max_compression_chunk_size); - WriteBlock(block, &buffered); + + std::unique_ptr compressed_ouput = std::make_unique(output_.get(), options_.max_compression_chunk_size); + BufferedOutput buffered(std::move(compressed_ouput), options_.max_compression_chunk_size); + + WriteBlock(block, buffered); } else { - WriteBlock(block, output_); + WriteBlock(block, *output_); } output_->Flush(); } bool Client::Impl::SendHello() { - WireFormat::WriteUInt64(output_, ClientCodes::Hello); - WireFormat::WriteString(output_, std::string(DBMS_NAME) + " client"); - WireFormat::WriteUInt64(output_, DBMS_VERSION_MAJOR); - WireFormat::WriteUInt64(output_, DBMS_VERSION_MINOR); - WireFormat::WriteUInt64(output_, REVISION); - WireFormat::WriteString(output_, options_.default_database); - WireFormat::WriteString(output_, options_.user); - WireFormat::WriteString(output_, options_.password); + WireFormat::WriteUInt64(*output_, ClientCodes::Hello); + WireFormat::WriteString(*output_, std::string(DBMS_NAME) + " client"); + WireFormat::WriteUInt64(*output_, DBMS_VERSION_MAJOR); + WireFormat::WriteUInt64(*output_, DBMS_VERSION_MINOR); + WireFormat::WriteUInt64(*output_, REVISION); + WireFormat::WriteString(*output_, options_.default_database); + WireFormat::WriteString(*output_, options_.user); + WireFormat::WriteString(*output_, options_.password); output_->Flush(); @@ -723,38 +713,38 @@ bool Client::Impl::SendHello() { bool Client::Impl::ReceiveHello() { uint64_t packet_type = 0; - if (!WireFormat::ReadVarint64(input_, &packet_type)) { + if (!WireFormat::ReadVarint64(*input_, &packet_type)) { return false; } if (packet_type == ServerCodes::Hello) { - if (!WireFormat::ReadString(input_, &server_info_.name)) { + if (!WireFormat::ReadString(*input_, &server_info_.name)) { return false; } - if (!WireFormat::ReadUInt64(input_, &server_info_.version_major)) { + if (!WireFormat::ReadUInt64(*input_, &server_info_.version_major)) { return false; } - if (!WireFormat::ReadUInt64(input_, &server_info_.version_minor)) { + if (!WireFormat::ReadUInt64(*input_, &server_info_.version_minor)) { return false; } - if (!WireFormat::ReadUInt64(input_, &server_info_.revision)) { + if (!WireFormat::ReadUInt64(*input_, &server_info_.revision)) { return false; } if (server_info_.revision >= DBMS_MIN_REVISION_WITH_SERVER_TIMEZONE) { - if (!WireFormat::ReadString(input_, &server_info_.timezone)) { + if (!WireFormat::ReadString(*input_, &server_info_.timezone)) { return false; } } if (server_info_.revision >= DBMS_MIN_REVISION_WITH_SERVER_DISPLAY_NAME) { - if (!WireFormat::ReadString(input_, &server_info_.display_name)) { + if (!WireFormat::ReadString(*input_, &server_info_.display_name)) { return false; } } if (server_info_.revision >= DBMS_MIN_REVISION_WITH_VERSION_PATCH) { - if (!WireFormat::ReadUInt64(input_, &server_info_.version_patch)) { + if (!WireFormat::ReadUInt64(*input_, &server_info_.version_patch)) { return false; } } diff --git a/clickhouse/columns/enum.cpp b/clickhouse/columns/enum.cpp index bc908dda..fc07e629 100644 --- a/clickhouse/columns/enum.cpp +++ b/clickhouse/columns/enum.cpp @@ -76,12 +76,12 @@ void ColumnEnum::Append(ColumnRef column) { template bool ColumnEnum::Load(InputStream* input, size_t rows) { data_.resize(rows); - return WireFormat::ReadBytes(input, data_.data(), data_.size() * sizeof(T)); + return WireFormat::ReadBytes(*input, data_.data(), data_.size() * sizeof(T)); } template void ColumnEnum::Save(OutputStream* output) { - WireFormat::WriteBytes(output, data_.data(), data_.size() * sizeof(T)); + WireFormat::WriteBytes(*output, data_.data(), data_.size() * sizeof(T)); } template diff --git a/clickhouse/columns/lowcardinality.cpp b/clickhouse/columns/lowcardinality.cpp index bcbbeaaf..1cb21fef 100644 --- a/clickhouse/columns/lowcardinality.cpp +++ b/clickhouse/columns/lowcardinality.cpp @@ -191,7 +191,7 @@ void ColumnLowCardinality::Append(ColumnRef col) { namespace { -auto Load(ColumnRef new_dictionary_column, InputStream* input, size_t rows) { +auto Load(ColumnRef new_dictionary_column, InputStream& input, size_t rows) { // This code tries to follow original implementation of ClickHouse's LowCardinality serialization with // NativeBlockOutputStream::writeData() for DataTypeLowCardinality // (see corresponding serializeBinaryBulkStateSuffix, serializeBinaryBulkStatePrefix, and serializeBinaryBulkWithMultipleStreams), @@ -224,7 +224,7 @@ auto Load(ColumnRef new_dictionary_column, InputStream* input, size_t rows) { if (!WireFormat::ReadFixed(input, &number_of_keys)) throw std::runtime_error("Failed to read number of rows in dictionary column."); - if (!new_dictionary_column->Load(input, number_of_keys)) + if (!new_dictionary_column->Load(&input, number_of_keys)) throw std::runtime_error("Failed to read values of dictionary column."); uint64_t number_of_rows; @@ -234,7 +234,7 @@ auto Load(ColumnRef new_dictionary_column, InputStream* input, size_t rows) { if (number_of_rows != rows) throw std::runtime_error("LowCardinality column must be read in full."); - new_index_column->Load(input, number_of_rows); + new_index_column->Load(&input, number_of_rows); ColumnLowCardinality::UniqueItems new_unique_items_map; for (size_t i = 0; i < new_dictionary_column->Size(); ++i) { @@ -252,7 +252,7 @@ auto Load(ColumnRef new_dictionary_column, InputStream* input, size_t rows) { bool ColumnLowCardinality::Load(InputStream* input, size_t rows) { try { - auto [new_dictionary, new_index, new_unique_items_map] = ::Load(dictionary_column_->Slice(0, 0), input, rows); + auto [new_dictionary, new_index, new_unique_items_map] = ::Load(dictionary_column_->Slice(0, 0), *input, rows); dictionary_column_->Swap(*new_dictionary); index_column_.swap(new_index); @@ -267,18 +267,18 @@ bool ColumnLowCardinality::Load(InputStream* input, size_t rows) { void ColumnLowCardinality::Save(OutputStream* output) { // prefix const uint64_t version = static_cast(KeySerializationVersion::SharedDictionariesWithAdditionalKeys); - WireFormat::WriteFixed(output, version); + WireFormat::WriteFixed(*output, version); // body const uint64_t index_serialization_type = indexTypeFromIndexColumn(*index_column_) | IndexFlag::HasAdditionalKeysBit; - WireFormat::WriteFixed(output, index_serialization_type); + WireFormat::WriteFixed(*output, index_serialization_type); const uint64_t number_of_keys = dictionary_column_->Size(); - WireFormat::WriteFixed(output, number_of_keys); + WireFormat::WriteFixed(*output, number_of_keys); dictionary_column_->Save(output); const uint64_t number_of_rows = index_column_->Size(); - WireFormat::WriteFixed(output, number_of_rows); + WireFormat::WriteFixed(*output, number_of_rows); index_column_->Save(output); // suffix diff --git a/clickhouse/columns/numeric.cpp b/clickhouse/columns/numeric.cpp index b101281c..479d1e79 100644 --- a/clickhouse/columns/numeric.cpp +++ b/clickhouse/columns/numeric.cpp @@ -64,12 +64,12 @@ template bool ColumnVector::Load(InputStream* input, size_t rows) { data_.resize(rows); - return WireFormat::ReadBytes(input, data_.data(), data_.size() * sizeof(T)); + return WireFormat::ReadBytes(*input, data_.data(), data_.size() * sizeof(T)); } template void ColumnVector::Save(OutputStream* output) { - WireFormat::WriteBytes(output, data_.data(), data_.size() * sizeof(T)); + WireFormat::WriteBytes(*output, data_.data(), data_.size() * sizeof(T)); } template diff --git a/clickhouse/columns/string.cpp b/clickhouse/columns/string.cpp index ff042ce5..a2138b96 100644 --- a/clickhouse/columns/string.cpp +++ b/clickhouse/columns/string.cpp @@ -79,7 +79,7 @@ void ColumnFixedString::Append(ColumnRef column) { bool ColumnFixedString::Load(InputStream * input, size_t rows) { data_.resize(string_size_ * rows); - if (!WireFormat::ReadBytes(input, &data_[0], data_.size())) { + if (!WireFormat::ReadBytes(*input, &data_[0], data_.size())) { return false; } @@ -87,7 +87,7 @@ bool ColumnFixedString::Load(InputStream * input, size_t rows) { } void ColumnFixedString::Save(OutputStream* output) { - WireFormat::WriteBytes(output, data_.data(), data_.size()); + WireFormat::WriteBytes(*output, data_.data(), data_.size()); } size_t ColumnFixedString::Size() const { @@ -230,13 +230,13 @@ bool ColumnString::Load(InputStream* input, size_t rows) { // TODO(performance): unroll a loop to a first row (to get rid of `blocks_.size() == 0` check) and the rest. for (size_t i = 0; i < rows; ++i) { uint64_t len; - if (!WireFormat::ReadUInt64(input, &len)) + if (!WireFormat::ReadUInt64(*input, &len)) return false; if (blocks_.size() == 0 || len > block->GetAvailble()) block = &blocks_.emplace_back(std::max(DEFAULT_BLOCK_SIZE, len)); - if (!WireFormat::ReadBytes(input, block->GetCurrentWritePos(), len)) + if (!WireFormat::ReadBytes(*input, block->GetCurrentWritePos(), len)) return false; items_.emplace_back(block->ConsumeTailAsStringViewUnsafe(len)); @@ -247,7 +247,7 @@ bool ColumnString::Load(InputStream* input, size_t rows) { void ColumnString::Save(OutputStream* output) { for (const auto & item : items_) { - WireFormat::WriteString(output, item); + WireFormat::WriteString(*output, item); } } diff --git a/ut/stream_ut.cpp b/ut/stream_ut.cpp index 997f0768..e139d8bc 100644 --- a/ut/stream_ut.cpp +++ b/ut/stream_ut.cpp @@ -11,13 +11,14 @@ TEST(CodedStreamCase, Varint64) { { BufferOutput output(&buf); - WireFormat::WriteVarint64(&output, 18446744071965638648ULL); + WireFormat::WriteVarint64(output, 18446744071965638648ULL); + output.Flush(); } { ArrayInput input(buf.data(), buf.size()); uint64_t value = 0; - ASSERT_TRUE(WireFormat::ReadVarint64(&input, &value)); + ASSERT_TRUE(WireFormat::ReadVarint64(input, &value)); ASSERT_EQ(value, 18446744071965638648ULL); } }