Skip to content

Commit

Permalink
Fix blocking issues and errno on windows (#435)
Browse files Browse the repository at this point in the history
  • Loading branch information
MisterTea authored May 3, 2021
1 parent 4459c6b commit 7fd0b19
Show file tree
Hide file tree
Showing 22 changed files with 136 additions and 94 deletions.
2 changes: 2 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,8 @@ if(CMAKE_SYSTEM_NAME MATCHES "NetBSD")
set(NETBSD TRUE)
endif(CMAKE_SYSTEM_NAME MATCHES "NetBSD")

option(DISABLE_SENTRY "Disable Sentry crash logging" OFF)

if(FREEBSD OR NETBSD OR DISABLE_SENTRY)
# Sentry doesn't work on BSD
else()
Expand Down
4 changes: 2 additions & 2 deletions src/base/BackedReader.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ int BackedReader::read(Packet* packet) {
// Connection is closed. Instead of closing the socket, set EPIPE.
// In EternalTCP, the server needs to explictly tell the client that
// the session is over.
errno = EPIPE;
SetErrno(EPIPE);
return -1;
} else if (bytesRead > 0) {
partialMessage.append(tmpBuf, bytesRead);
Expand All @@ -75,7 +75,7 @@ int BackedReader::read(Packet* packet) {
// Connection is closed. Instead of closing the socket, set EPIPE.
// In EternalTCP, the server needs to explictly tell the client that
// the session is over.
errno = EPIPE;
SetErrno(EPIPE);
return -1;
} else if (bytesRead == -1) {
VLOG(2) << "Error while reading";
Expand Down
4 changes: 2 additions & 2 deletions src/base/Connection.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ bool Connection::read(Packet* packet) {
lock_guard<std::recursive_mutex> guard(connectionMutex);
VLOG(4) << "After read get connectionMutex";
ssize_t messagesRead = reader->read(packet);
auto localErrno = errno;
auto localErrno = GetErrno();
if (messagesRead == -1) {
if (isSkippableError(localErrno)) {
// Close the socket and invalidate, then return 0 messages
Expand All @@ -178,7 +178,7 @@ bool Connection::write(const Packet& packet) {
}

BackedWriterWriteState bwws = writer->write(packet);
auto writeErrno = errno;
auto writeErrno = GetErrno();

if (bwws == BackedWriterWriteState::SKIPPED) {
VLOG(4) << "Write skipped";
Expand Down
60 changes: 49 additions & 11 deletions src/base/Headers.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -155,33 +155,71 @@ const int SERVER_KEEP_ALIVE_DURATION = 11;
#define STERROR LOG(ERROR) << "Stack Trace: " << endl << ust::generate()
#endif

inline int GetErrno() {
#ifdef WIN32
auto retval = WSAGetLastError();
if (retval >= 10000) {
// Do some translation
switch (retval) {
case WSAEWOULDBLOCK: return EWOULDBLOCK;
case WSAEADDRINUSE: return EADDRINUSE;
case WSAEINPROGRESS: return EINPROGRESS;
default:
STFATAL << "Unmapped WSA error: " << retval;
}
}
return retval;
#else
return errno;
#endif
}

inline void SetErrno(int e) {
#ifdef WIN32
WSASetLastError(e);
#else
errno = e;
#endif
}

#ifdef WIN32
inline string WindowsErrnoToString() {
const int BUFSIZE = 4096;
char buf[BUFSIZE];
FormatMessageA(FORMAT_MESSAGE_ALLOCATE_BUFFER | FORMAT_MESSAGE_FROM_SYSTEM |
FORMAT_MESSAGE_IGNORE_INSERTS,
NULL, WSAGetLastError(),
MAKELANGID(LANG_NEUTRAL, SUBLANG_DEFAULT), buf, BUFSIZE, NULL);
string s(buf, BUFSIZE);
return s;
auto charsWritten = FormatMessageA(
FORMAT_MESSAGE_FROM_SYSTEM | FORMAT_MESSAGE_IGNORE_INSERTS, NULL,
WSAGetLastError(), MAKELANGID(LANG_NEUTRAL, SUBLANG_DEFAULT), buf,
BUFSIZE, NULL);
if (charsWritten) {
string s(buf, charsWritten + 1);
return s;
}
return "Unknown Error";
}
#define FATAL_FAIL(X) \
if (((X) == -1)) \
LOG(FATAL) << "Error: (" << WSAGetLastError() \
<< "): " << WindowsErrnoToString();

#define FATAL_FAIL_UNLESS_ZERO(X) \
if (((X) != 0)) \
LOG(FATAL) << "Error: (" << WSAGetLastError() \
<< "): " << WindowsErrnoToString();

#define FATAL_FAIL_UNLESS_EINVAL(X) FATAL_FAIL(X)

#else
#define FATAL_FAIL(X) \
if (((X) == -1)) STFATAL << "Error: (" << errno << "): " << strerror(errno);
#define FATAL_FAIL(X) \
if (((X) == -1)) \
STFATAL << "Error: (" << GetErrno() \
<< "): " << strerror(GetErrno());

// On BSD/OSX we can get EINVAL if the remote side has closed the connection
// before we have initialized it.
#define FATAL_FAIL_UNLESS_EINVAL(X) \
if (((X) == -1) && errno != EINVAL) \
STFATAL << "Error: (" << errno << "): " << strerror(errno);
#define FATAL_FAIL_UNLESS_EINVAL(X) \
if (((X) == -1) && GetErrno() != EINVAL) \
STFATAL << "Error: (" << GetErrno() \
<< "): " << strerror(GetErrno());
#endif

#ifndef ET_VERSION
Expand Down
6 changes: 3 additions & 3 deletions src/base/PipeSocketHandler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ int PipeSocketHandler::connect(const SocketEndpoint& endpoint) {
VLOG(3) << "Connecting to " << endpoint << " with fd " << sockFd;
int result =
::connect(sockFd, (struct sockaddr*)&remote, sizeof(sockaddr_un));
auto localErrno = errno;
auto localErrno = GetErrno();
if (result < 0 && localErrno != EINPROGRESS) {
VLOG(3) << "Connection result: " << result << " (" << strerror(localErrno)
<< ")";
Expand All @@ -33,7 +33,7 @@ int PipeSocketHandler::connect(const SocketEndpoint& endpoint) {
FATAL_FAIL(::close(sockFd));
#endif
sockFd = -1;
errno = localErrno;
SetErrno(localErrno);
return sockFd;
}

Expand Down Expand Up @@ -72,7 +72,7 @@ int PipeSocketHandler::connect(const SocketEndpoint& endpoint) {
sockFd = -1;
}
} else {
auto localErrno = errno;
auto localErrno = GetErrno();
LOG(INFO) << "Error connecting to " << endpoint << ": " << localErrno << " "
<< strerror(localErrno);
#ifdef _MSC_VER
Expand Down
6 changes: 3 additions & 3 deletions src/base/RawSocketUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,10 @@ void RawSocketUtils::writeAll(int fd, const char* buf, size_t count) {
int rc = ::write(fd, buf + bytesWritten, count - bytesWritten);
#endif
if (rc < 0) {
auto localErrno = errno;
auto localErrno = GetErrno();
if (localErrno == EAGAIN || localErrno == EWOULDBLOCK) {
// This is fine, just keep retrying
std::this_thread::sleep_for(std::chrono::microseconds(100*1000));
std::this_thread::sleep_for(std::chrono::microseconds(100 * 1000));
continue;
}
STERROR << "Cannot write to raw socket: " << strerror(localErrno);
Expand All @@ -38,7 +38,7 @@ void RawSocketUtils::readAll(int fd, char* buf, size_t count) {
int rc = ::read(fd, buf + bytesRead, count - bytesRead);
#endif
if (rc < 0) {
auto localErrno = errno;
auto localErrno = GetErrno();
if (localErrno == EAGAIN || localErrno == EWOULDBLOCK) {
// This is fine, just keep retrying
continue;
Expand Down
11 changes: 6 additions & 5 deletions src/base/SocketHandler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,11 @@ void SocketHandler::readAll(int fd, void* buf, size_t count, bool timeout) {
// Connection is closed. Instead of closing the socket, set EPIPE.
// In EternalTCP, the server needs to explictly tell the client that
// the session is over.
errno = EPIPE;
SetErrno(EPIPE);
bytesRead = -1;
}
if (bytesRead < 0) {
auto localErrno = errno;
auto localErrno = GetErrno();
if (localErrno == EAGAIN || localErrno == EWOULDBLOCK) {
// This is fine, just keep retrying
LOG(INFO) << "Got EAGAIN, waiting...";
Expand All @@ -50,7 +50,7 @@ int SocketHandler::writeAllOrReturn(int fd, const void* buf, size_t count) {
return -1;
}
ssize_t bytesWritten = write(fd, ((const char*)buf) + pos, count - pos);
auto localErrno = errno;
auto localErrno = GetErrno();
if (bytesWritten < 0) {
if (localErrno == EAGAIN || localErrno == EWOULDBLOCK) {
LOG(INFO) << "Got EAGAIN, waiting...";
Expand Down Expand Up @@ -81,7 +81,7 @@ void SocketHandler::writeAllOrThrow(int fd, const void* buf, size_t count,
throw std::runtime_error("Socket Timeout");
}
ssize_t bytesWritten = write(fd, ((const char*)buf) + pos, count - pos);
auto localErrno = errno;
auto localErrno = GetErrno();
if (bytesWritten < 0) {
if (localErrno == EAGAIN || localErrno == EWOULDBLOCK) {
LOG(INFO) << "Got EAGAIN, waiting...";
Expand Down Expand Up @@ -119,7 +119,8 @@ void SocketHandler::readB64(int fd, char* buf, size_t count) {
}
}

void SocketHandler::readB64EncodedLength(int fd, string* out, size_t encodedLength) {
void SocketHandler::readB64EncodedLength(int fd, string* out,
size_t encodedLength) {
string s(encodedLength, '\0');
readAll(fd, &s[0], s.length(), false);
if (!Base64::Decode(s, out)) {
Expand Down
36 changes: 13 additions & 23 deletions src/base/TcpSocketHandler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -46,22 +46,27 @@ int TcpSocketHandler::connect(const SocketEndpoint &endpoint) {
// loop through all the results and connect to the first we can
for (p = results; p != NULL; p = p->ai_next) {
if ((sockFd = socket(p->ai_family, p->ai_socktype, p->ai_protocol)) == -1) {
auto localErrno = errno;
auto localErrno = GetErrno();
LOG(INFO) << "Error creating socket: " << localErrno << " "
<< strerror(localErrno);
continue;
}

// Allow non-blocking connect
setBlocking(sockFd, false);

if (::connect(sockFd, p->ai_addr, p->ai_addrlen) == -1 &&
errno != EINPROGRESS) {
auto localErrno = errno;
GetErrno() != EINPROGRESS && GetErrno() != EWOULDBLOCK
) {
auto localErrno = GetErrno();
if (p->ai_canonname) {
LOG(INFO) << "Error connecting with " << p->ai_canonname << ": "
<< localErrno << " " << strerror(localErrno);
} else {
LOG(INFO) << "Error connecting: " << localErrno << " "
<< strerror(localErrno);
}
setBlocking(sockFd, true);
#ifdef _MSC_VER
FATAL_FAIL(::closesocket(sockFd));
#else
Expand Down Expand Up @@ -96,23 +101,7 @@ int TcpSocketHandler::connect(const SocketEndpoint &endpoint) {
}
// Make sure that socket becomes blocking once it's attached to a
// server.
#ifdef WIN32
{
u_long iMode = 0;
auto result = ioctlsocket(sockFd, FIONBIO, &iMode);
if (result != NO_ERROR) {
STFATAL << result;
}
}
#else
{
int opts;
opts = fcntl(sockFd, F_GETFL);
FATAL_FAIL(opts);
opts &= (~O_NONBLOCK);
FATAL_FAIL(fcntl(sockFd, F_SETFL, opts));
}
#endif
setBlocking(sockFd, true);
break; // if we get here, we must have connected successfully
} else {
if (p->ai_canonname) {
Expand All @@ -122,6 +111,7 @@ int TcpSocketHandler::connect(const SocketEndpoint &endpoint) {
LOG(INFO) << "Error connecting to " << endpoint << ": " << so_error
<< " " << strerror(so_error);
}
setBlocking(sockFd, true);
#ifdef _MSC_VER
FATAL_FAIL(::closesocket(sockFd));
#else
Expand All @@ -131,7 +121,7 @@ int TcpSocketHandler::connect(const SocketEndpoint &endpoint) {
continue;
}
} else {
auto localErrno = errno;
auto localErrno = GetErrno();
if (p->ai_canonname) {
LOG(INFO) << "Error connecting with " << p->ai_canonname << ": "
<< localErrno << " " << strerror(localErrno);
Expand Down Expand Up @@ -188,7 +178,7 @@ set<int> TcpSocketHandler::listen(const SocketEndpoint &endpoint) {
for (p = servinfo; p != NULL; p = p->ai_next) {
int sockFd;
if ((sockFd = socket(p->ai_family, p->ai_socktype, p->ai_protocol)) == -1) {
auto localErrno = errno;
auto localErrno = GetErrno();
LOG(INFO) << "Error creating socket " << p->ai_family << "/"
<< p->ai_socktype << "/" << p->ai_protocol << ": " << localErrno
<< " " << strerror(localErrno);
Expand All @@ -207,7 +197,7 @@ set<int> TcpSocketHandler::listen(const SocketEndpoint &endpoint) {

if (::bind(sockFd, p->ai_addr, p->ai_addrlen) == -1) {
// This most often happens because the port is in use.
auto localErrno = errno;
auto localErrno = GetErrno();
LOG(WARNING) << "Error binding " << p->ai_family << "/" << p->ai_socktype
<< "/" << p->ai_protocol << ": " << localErrno << " "
<< strerror(localErrno);
Expand Down
Loading

0 comments on commit 7fd0b19

Please sign in to comment.