Skip to content

Multiple endpoints for connection. #310

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 15 commits into from
Jul 10, 2023
1 change: 1 addition & 0 deletions clickhouse/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ SET ( clickhouse-cpp-lib-src
base/platform.cpp
base/socket.cpp
base/wire_format.cpp
base/endpoints_iterator.cpp

columns/array.cpp
columns/column.cpp
Expand Down
42 changes: 42 additions & 0 deletions clickhouse/base/endpoints_iterator.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
#include "endpoints_iterator.h"
#include <clickhouse/client.h>

namespace clickhouse {

RoundRobinEndpointsIterator::RoundRobinEndpointsIterator(const std::vector<Endpoint>& _endpoints) :
endpoints (_endpoints)
, current_index (0)
, iteration_counter(0)
{

}

std::string RoundRobinEndpointsIterator::GetHostAddr() const
{
return endpoints[current_index].host;
}

unsigned int RoundRobinEndpointsIterator::GetPort() const
{
return endpoints[current_index].port;
}

void RoundRobinEndpointsIterator::ResetIterations()
{
iteration_counter = 0;
}

void RoundRobinEndpointsIterator::Next()
{
current_index = (current_index + 1) % endpoints.size();
iteration_counter++;
}

bool RoundRobinEndpointsIterator::NextIsExist() const
{
return iteration_counter + 1 < endpoints.size();
}

RoundRobinEndpointsIterator::~RoundRobinEndpointsIterator() = default;

}
48 changes: 48 additions & 0 deletions clickhouse/base/endpoints_iterator.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
#pragma once

#include "clickhouse/client.h"
#include <vector>

namespace clickhouse {

struct ClientOptions;

/**
* Base class for iterating through endpoints.
*/
class EndpointsIteratorBase
{
public:
virtual ~EndpointsIteratorBase() = default;

virtual void Next() = 0;
// Get the address of current endpoint.
virtual std::string GetHostAddr() const = 0;

// Get the port of current endpoint.
virtual unsigned int GetPort() const = 0;

// Reset iterations.
virtual void ResetIterations() = 0;
virtual bool NextIsExist() const = 0;
};

class RoundRobinEndpointsIterator : public EndpointsIteratorBase
{
public:
explicit RoundRobinEndpointsIterator(const std::vector<Endpoint>& opts);
std::string GetHostAddr() const override;
unsigned int GetPort() const override;
void ResetIterations() override;
bool NextIsExist() const override;
void Next() override;

~RoundRobinEndpointsIterator() override;

private:
const std::vector<Endpoint>& endpoints;
int current_index;
size_t iteration_counter;
};

}
4 changes: 2 additions & 2 deletions clickhouse/base/socket.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -390,9 +390,9 @@ std::unique_ptr<OutputStream> Socket::makeOutputStream() const {

NonSecureSocketFactory::~NonSecureSocketFactory() {}

std::unique_ptr<SocketBase> NonSecureSocketFactory::connect(const ClientOptions &opts) {
const auto address = NetworkAddress(opts.host, std::to_string(opts.port));
std::unique_ptr<SocketBase> NonSecureSocketFactory::connect(const ClientOptions &opts, const std::string& host, const std::string& port) {

const auto address = NetworkAddress(host, port);
auto socket = doConnect(address, opts);
setSocketOptions(*socket, opts);

Expand Down
5 changes: 3 additions & 2 deletions clickhouse/base/socket.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
#include "platform.h"
#include "input.h"
#include "output.h"
#include "endpoints_iterator.h"

#include <cstddef>
#include <string>
Expand Down Expand Up @@ -88,7 +89,7 @@ class SocketFactory {

// TODO: move connection-related options to ConnectionOptions structure.

virtual std::unique_ptr<SocketBase> connect(const ClientOptions& opts) = 0;
virtual std::unique_ptr<SocketBase> connect(const ClientOptions& opts, const std::string& host, const std::string& port) = 0;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
virtual std::unique_ptr<SocketBase> connect(const ClientOptions& opts, const std::string& host, const std::string& port) = 0;
std::unique_ptr<SocketBase> connect(const ClientOptions& opts, const Endpoint & endpoint) = 0;


virtual void sleepFor(const std::chrono::milliseconds& duration);
};
Expand Down Expand Up @@ -135,7 +136,7 @@ class NonSecureSocketFactory : public SocketFactory {
public:
~NonSecureSocketFactory() override;

std::unique_ptr<SocketBase> connect(const ClientOptions& opts) override;
std::unique_ptr<SocketBase> connect(const ClientOptions& opts, const std::string& host, const std::string& port) override;
Copy link
Contributor

@Enmk Enmk Jun 20, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
std::unique_ptr<SocketBase> connect(const ClientOptions& opts, const std::string& host, const std::string& port) override;
std::unique_ptr<SocketBase> connect(const ClientOptions& opts, const Endpoint & endpoint) override;

Also please modify the call site.


protected:
virtual std::unique_ptr<Socket> doConnect(const NetworkAddress& address, const ClientOptions& opts);
Expand Down
122 changes: 106 additions & 16 deletions clickhouse/client.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -64,8 +64,12 @@ struct ClientInfo {
};

std::ostream& operator<<(std::ostream& os, const ClientOptions& opt) {
os << "Client(" << opt.user << '@' << opt.host << ":" << opt.port
<< " ping_before_query:" << opt.ping_before_query
os << "Client(";
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That ignores opt.host and opt.port completely, and this operator<< might be used on unmodified ClientOptions (by tests or users), making output confusing.

for (size_t i = 0; i < opt.endpoints.size(); i++)
os << opt.user << '@' << opt.endpoints[i].host << ":" << opt.endpoints[i].port
<< ((i == opt.endpoints.size() - 1) ? "" : ", ");

os << " ping_before_query:" << opt.ping_before_query
<< " send_retries:" << opt.send_retries
<< " retry_timeout:" << opt.retry_timeout.count()
<< " compression_method:"
Expand Down Expand Up @@ -111,6 +115,10 @@ std::unique_ptr<SocketFactory> GetSocketFactory(const ClientOptions& opts) {
return std::make_unique<NonSecureSocketFactory>();
}

std::unique_ptr<EndpointsIteratorBase> GetEndpointsIterator(const std::vector<Endpoint>& endpoints) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
std::unique_ptr<EndpointsIteratorBase> GetEndpointsIterator(const std::vector<Endpoint>& endpoints) {
std::unique_ptr<EndpointsIteratorBase> GetEndpointsIterator(const ClientOptions& options) {

return std::make_unique<RoundRobinEndpointsIterator>(endpoints);
}

}

class Client::Impl {
Expand All @@ -130,8 +138,12 @@ class Client::Impl {

void ResetConnection();

void ResetConnectionEndpoint();

const ServerInfo& GetServerInfo() const;

const std::optional<Endpoint>& GetCurrentEndpoint() const;

private:
bool Handshake();

Expand All @@ -155,12 +167,16 @@ class Client::Impl {

void WriteBlock(const Block& block, OutputStream& output);

void CreateConnection();

void InitializeStreams(std::unique_ptr<SocketBase>&& socket);

private:
/// In case of network errors tries to reconnect to server and
/// call fuc several times.
void RetryGuard(std::function<void()> func);

void RetryConnectToTheEndpoint(std::function<void()>& func);

private:
class EnsureNull {
Expand Down Expand Up @@ -194,32 +210,36 @@ class Client::Impl {
std::unique_ptr<InputStream> input_;
std::unique_ptr<OutputStream> output_;
std::unique_ptr<SocketBase> socket_;
std::unique_ptr<EndpointsIteratorBase> endpoints_iterator;

std::optional<Endpoint> current_endpoint_;

ServerInfo server_info_;
};

ClientOptions modifyClientOptions(ClientOptions opts)
{
if (opts.host.empty())
return opts;

Endpoint endpoint_single({opts.host, opts.port});
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
Endpoint endpoint_single({opts.host, opts.port});
Endpoint default_endpoint({opts.host, opts.port});

if (std::find(opts.endpoints.begin(), opts.endpoints.end(), endpoint_single) == std::end(opts.endpoints)) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IMO it is best to make default_endpoint first endpoint unconditionally, so it is simple, easy and predictable

opts.endpoints.emplace(opts.endpoints.begin(),endpoint_single);
}
return opts;
}

Client::Impl::Impl(const ClientOptions& opts)
: Impl(opts, GetSocketFactory(opts)) {}

Client::Impl::Impl(const ClientOptions& opts,
std::unique_ptr<SocketFactory> socket_factory)
: options_(opts)
: options_(modifyClientOptions(opts))
, events_(nullptr)
, socket_factory_(std::move(socket_factory))
, endpoints_iterator(GetEndpointsIterator(options_.endpoints))
{
for (unsigned int i = 0; ; ) {
try {
ResetConnection();
break;
} catch (const std::system_error&) {
if (++i > options_.send_retries) {
throw;
}

socket_factory_->sleepFor(options_.retry_timeout);
}
}
CreateConnection();

if (options_.compression_method != CompressionMethod::None) {
compression_ = CompressionState::Enable;
Expand Down Expand Up @@ -329,17 +349,60 @@ void Client::Impl::Ping() {
}

void Client::Impl::ResetConnection() {
InitializeStreams(socket_factory_->connect(options_));
InitializeStreams(socket_factory_->connect(options_, endpoints_iterator->GetHostAddr(),
std::to_string(endpoints_iterator->GetPort())
));

if (!Handshake()) {
throw ProtocolError("fail to connect to " + options_.host);
}
}

void Client::Impl::ResetConnectionEndpoint() {
endpoints_iterator->ResetIterations();
endpoints_iterator->Next();
CreateConnection();
}

void Client::Impl::CreateConnection() {
current_endpoint_.reset();
auto try_make_connection_with_endpoint = [this]() {
for (unsigned int i = 0; ; ) {
try {
ResetConnection();
return;
} catch (const std::system_error&) {
if (++i > options_.send_retries) {
throw;
}
socket_factory_->sleepFor(options_.retry_timeout);
}
}
};

for (endpoints_iterator->ResetIterations(); ; endpoints_iterator->Next())
{
try
{
try_make_connection_with_endpoint();
current_endpoint_ = {endpoints_iterator->GetHostAddr(), endpoints_iterator->GetPort()};
break;
} catch (const std::system_error&) {
if(!endpoints_iterator->NextIsExist())
throw;
}
}
}

const ServerInfo& Client::Impl::GetServerInfo() const {
return server_info_;
}


const std::optional<Endpoint>& Client::Impl::GetCurrentEndpoint() const {
return current_endpoint_;
}

bool Client::Impl::Handshake() {
if (!SendHello()) {
return false;
Expand Down Expand Up @@ -861,6 +924,25 @@ bool Client::Impl::ReceiveHello() {
}

void Client::Impl::RetryGuard(std::function<void()> func) {
for(endpoints_iterator->ResetIterations(); ; endpoints_iterator->Next())
{
try
{
RetryConnectToTheEndpoint(func);
if (!current_endpoint_) {
current_endpoint_ = {endpoints_iterator->GetHostAddr(), endpoints_iterator->GetPort()};
}
return;
} catch (const std::system_error&) {
if (!endpoints_iterator->NextIsExist())
throw;
//If the exceptions was catched here, that's mean that we should change the current_endpoint.
current_endpoint_.reset();
}
}
}

void Client::Impl::RetryConnectToTheEndpoint(std::function<void()>& func) {
for (unsigned int i = 0; ; ++i) {
try {
func();
Expand Down Expand Up @@ -938,6 +1020,14 @@ void Client::ResetConnection() {
impl_->ResetConnection();
}

void Client::ResetConnectionEndpoint() {
impl_->ResetConnectionEndpoint();
}

const std::optional<Endpoint>& Client::GetCurrentEndpoint() const {
return impl_->GetCurrentEndpoint();
}

const ServerInfo& Client::GetServerInfo() const {
return impl_->GetServerInfo();
}
Expand Down
26 changes: 26 additions & 0 deletions clickhouse/client.h
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,19 @@ enum class CompressionMethod {
LZ4 = 1,
};

struct Endpoint {
std::string host;
unsigned int port = 9000;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
unsigned int port = 9000;
uint16_t port = 9000;

Just to make sure that nobody tires to connect to port >65535

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep, I think this is a good idea, even though this alters public interface


inline bool operator==(const Endpoint& right) const {
return host == right.host && port == right.port;
}
};

enum class EndpointsIterationAlgorithm {
RoundRobin = 0,
};

struct ClientOptions {
// Setter goes first, so it is possible to apply 'deprecated' annotation safely.
#define DECLARE_FIELD(name, type, setter, default_value) \
Expand All @@ -58,6 +71,14 @@ struct ClientOptions {
/// Service port.
DECLARE_FIELD(port, unsigned int, SetPort, 9000);

/** Set endpoints (host+port), only one is used.
* Client tries to connect to those endpoints one by one, on the round-robin basis:
* first default enpoint (set via SetHost() + SetPort()), then each of endpoints, from begin() to end(),
* the first one to establish connection is used for the rest of the session.
* If port isn't specified, default(9000) value will be used.
*/
DECLARE_FIELD(endpoints, std::vector<Endpoint>, SetEndpoints, {});

/// Default database.
DECLARE_FIELD(default_database, std::string, SetDefaultDatabase, "default");
/// User name.
Expand Down Expand Up @@ -240,6 +261,11 @@ class Client {

const ServerInfo& GetServerInfo() const;

/// Get current connected endpoint.
/// In case when client is not connected to any endpoint, nullopt will returned.
const std::optional<Endpoint>& GetCurrentEndpoint() const;

void ResetConnectionEndpoint();
private:
const ClientOptions options_;

Expand Down
Loading