Skip to content

Commit 0ef73e2

Browse files
authored
Merge pull request #228 from Enmk/fix_leaking_sockets
Socket RAII wrapper to prevent leaking socket
2 parents c1bf280 + 038548e commit 0ef73e2

File tree

1 file changed

+49
-21
lines changed

1 file changed

+49
-21
lines changed

clickhouse/base/socket.cpp

Lines changed: 49 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -144,19 +144,53 @@ ssize_t Poll(struct pollfd* fds, int nfds, int timeout) noexcept {
144144
#endif
145145
}
146146

147+
#ifndef INVALID_SOCKET
148+
const SOCKET INVALID_SOCKET = -1;
149+
#endif
150+
151+
void CloseSocket(SOCKET socket) {
152+
if (socket == INVALID_SOCKET)
153+
return;
154+
155+
#if defined(_win_)
156+
closesocket(socket);
157+
#else
158+
close(socket);
159+
#endif
160+
}
161+
162+
struct SocketRAIIWrapper {
163+
SOCKET socket = INVALID_SOCKET;
164+
165+
~SocketRAIIWrapper() {
166+
CloseSocket(socket);
167+
}
168+
169+
SOCKET operator*() const {
170+
return socket;
171+
}
172+
173+
SOCKET release() {
174+
auto result = socket;
175+
socket = INVALID_SOCKET;
176+
177+
return result;
178+
}
179+
};
180+
147181
SOCKET SocketConnect(const NetworkAddress& addr, const SocketTimeoutParams& timeout_params) {
148182
int last_err = 0;
149183
for (auto res = addr.Info(); res != nullptr; res = res->ai_next) {
150-
SOCKET s(socket(res->ai_family, res->ai_socktype, res->ai_protocol));
184+
SocketRAIIWrapper s{socket(res->ai_family, res->ai_socktype, res->ai_protocol)};
151185

152-
if (s == -1) {
186+
if (*s == INVALID_SOCKET) {
153187
continue;
154188
}
155189

156-
SetNonBlock(s, true);
157-
SetTimeout(s, timeout_params);
190+
SetNonBlock(*s, true);
191+
SetTimeout(*s, timeout_params);
158192

159-
if (connect(s, res->ai_addr, (int)res->ai_addrlen) != 0) {
193+
if (connect(*s, res->ai_addr, (int)res->ai_addrlen) != 0) {
160194
int err = getSocketErrorCode();
161195
if (
162196
err == EINPROGRESS || err == EAGAIN || err == EWOULDBLOCK
@@ -165,7 +199,7 @@ SOCKET SocketConnect(const NetworkAddress& addr, const SocketTimeoutParams& time
165199
#endif
166200
) {
167201
pollfd fd;
168-
fd.fd = s;
202+
fd.fd = *s;
169203
fd.events = POLLOUT;
170204
fd.revents = 0;
171205
ssize_t rval = Poll(&fd, 1, 5000);
@@ -175,18 +209,18 @@ SOCKET SocketConnect(const NetworkAddress& addr, const SocketTimeoutParams& time
175209
}
176210
if (rval > 0) {
177211
socklen_t len = sizeof(err);
178-
getsockopt(s, SOL_SOCKET, SO_ERROR, (char*)&err, &len);
212+
getsockopt(*s, SOL_SOCKET, SO_ERROR, (char*)&err, &len);
179213

180214
if (!err) {
181-
SetNonBlock(s, false);
182-
return s;
215+
SetNonBlock(*s, false);
216+
return s.release();
183217
}
184218
last_err = err;
185219
}
186220
}
187221
} else {
188-
SetNonBlock(s, false);
189-
return s;
222+
SetNonBlock(*s, false);
223+
return s.release();
190224
}
191225
}
192226
if (last_err > 0) {
@@ -265,15 +299,15 @@ Socket::Socket(const NetworkAddress & addr)
265299
Socket::Socket(Socket&& other) noexcept
266300
: handle_(other.handle_)
267301
{
268-
other.handle_ = -1;
302+
other.handle_ = INVALID_SOCKET;
269303
}
270304

271305
Socket& Socket::operator=(Socket&& other) noexcept {
272306
if (this != &other) {
273307
Close();
274308

275309
handle_ = other.handle_;
276-
other.handle_ = -1;
310+
other.handle_ = INVALID_SOCKET;
277311
}
278312

279313
return *this;
@@ -284,14 +318,8 @@ Socket::~Socket() {
284318
}
285319

286320
void Socket::Close() {
287-
if (handle_ != -1) {
288-
#if defined(_win_)
289-
closesocket(handle_);
290-
#else
291-
close(handle_);
292-
#endif
293-
handle_ = -1;
294-
}
321+
CloseSocket(handle_);
322+
handle_ = INVALID_SOCKET;
295323
}
296324

297325
void Socket::SetTcpKeepAlive(int idle, int intvl, int cnt) noexcept {

0 commit comments

Comments
 (0)