diff --git a/CMakeLists.txt b/CMakeLists.txt index efe1f04a5..7b9e7e3ba 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -8,7 +8,7 @@ option(CODE_COVERAGE "Enable code coverage" OFF) SET(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DET_VERSION='\"${PROJECT_VERSION}\"'") # For easylogging, disable default log file, enable crash log, ensure thread safe, and catch c++ exceptions -SET(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DELPP_NO_DEFAULT_LOG_FILE -DELPP_FEATURE_CRASH_LOG -DELPP_THREAD_SAFE -DELPP_HANDLE_SIGABRT") +SET(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DELPP_NO_DEFAULT_LOG_FILE -DELPP_FEATURE_CRASH_LOG -DELPP_THREAD_SAFE") IF(CODE_COVERAGE) if(UNIX) SET(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fprofile-arcs -ftest-coverage") @@ -97,6 +97,7 @@ include_directories( external/msgpack-c/include src/base src/terminal + src/terminal/forwarding src/htm ${PROTOBUF_INCLUDE_DIRS} ${CMAKE_CURRENT_BINARY_DIR} @@ -111,22 +112,52 @@ add_library( et-lib STATIC + external/easyloggingpp/src/easylogging++.h external/easyloggingpp/src/easylogging++.cc + src/base/BackedReader.hpp src/base/BackedReader.cpp + + src/base/BackedWriter.hpp src/base/BackedWriter.cpp + + src/base/ClientConnection.hpp src/base/ClientConnection.cpp + + src/base/Connection.hpp src/base/Connection.cpp + + src/base/CryptoHandler.hpp src/base/CryptoHandler.cpp + + src/base/ServerClientConnection.hpp src/base/ServerClientConnection.cpp + + src/base/ServerConnection.hpp src/base/ServerConnection.cpp + + src/base/SocketHandler.hpp src/base/SocketHandler.cpp + + src/base/PipeSocketHandler.hpp src/base/PipeSocketHandler.cpp + + src/base/TcpSocketHandler.hpp src/base/TcpSocketHandler.cpp + + src/base/UnixSocketHandler.hpp src/base/UnixSocketHandler.cpp + + src/base/LogHandler.hpp src/base/LogHandler.cpp + + src/base/DaemonCreator.hpp src/base/DaemonCreator.cpp + + src/base/SystemUtils.hpp src/base/SystemUtils.cpp + + src/base/RawSocketUtils.hpp src/base/RawSocketUtils.cpp ${ET_SRCS} @@ -141,14 +172,31 @@ add_library( TerminalCommon STATIC - src/terminal/PortForwardHandler.cpp - src/terminal/PortForwardSourceHandler.cpp - src/terminal/PortForwardDestinationHandler.cpp + src/terminal/forwarding/PortForwardHandler.hpp + src/terminal/forwarding/PortForwardHandler.cpp + + src/terminal/forwarding/ForwardSourceHandler.hpp + src/terminal/forwarding/ForwardSourceHandler.cpp + + src/terminal/forwarding/ForwardDestinationHandler.hpp + src/terminal/forwarding/ForwardDestinationHandler.cpp + + src/terminal/TerminalServer.hpp src/terminal/TerminalServer.cpp + + src/terminal/UserTerminalRouter.hpp src/terminal/UserTerminalRouter.cpp + + src/terminal/TerminalClient.hpp src/terminal/TerminalClient.cpp + + src/terminal/SshSetupHandler.hpp src/terminal/SshSetupHandler.cpp + + src/terminal/UserTerminalHandler.hpp src/terminal/UserTerminalHandler.cpp + + src/terminal/UserJumphostHandler.hpp src/terminal/UserJumphostHandler.cpp ${ETERMINAL_SRCS} ${ETERMINAL_HDRS} diff --git a/proto/ET.proto b/proto/ET.proto index 59c37db44..72113d241 100644 --- a/proto/ET.proto +++ b/proto/ET.proto @@ -6,6 +6,7 @@ enum EtPacketType { // Count down from 254 to avoid collisions HEARTBEAT = 254; INITIAL_PAYLOAD = 253; + INITIAL_RESPONSE = 252; } message ConnectRequest { @@ -25,10 +26,11 @@ message ConnectResponse { optional string error = 2; } -message SequenceHeader { - optional int32 sequenceNumber = 1; -} +message SequenceHeader { optional int32 sequenceNumber = 1; } + +message CatchupBuffer { repeated bytes buffer = 1; } -message CatchupBuffer { - repeated bytes buffer = 1; +message SocketEndpoint { + optional string name = 1; + optional int32 port = 2; } diff --git a/proto/ETerminal.proto b/proto/ETerminal.proto index 5bad1065f..c9092349c 100644 --- a/proto/ETerminal.proto +++ b/proto/ETerminal.proto @@ -2,16 +2,17 @@ syntax = "proto2"; package et; option optimize_for = LITE_RUNTIME; +import "ET.proto"; + enum TerminalPacketType { KEEP_ALIVE = 0; TERMINAL_BUFFER = 1; TERMINAL_INFO = 2; - PORT_FORWARD_SOURCE_REQUEST = 3; - PORT_FORWARD_SOURCE_RESPONSE = 4; PORT_FORWARD_DESTINATION_REQUEST = 5; PORT_FORWARD_DESTINATION_RESPONSE = 6; PORT_FORWARD_DATA = 7; - IDPASSKEY = 8; + TERMINAL_USER_INFO = 8; + TERMINAL_INIT = 9; } message TerminalBuffer { @@ -27,8 +28,9 @@ message TerminalInfo { } message PortForwardSourceRequest { - optional int32 sourceport = 1; - optional int32 destinationport = 2; + optional SocketEndpoint source = 1; + optional SocketEndpoint destination = 2; + optional string environmentvariable = 3; } message PortForwardSourceResponse { @@ -36,7 +38,7 @@ message PortForwardSourceResponse { } message PortForwardDestinationRequest { - optional int32 port = 1; + optional SocketEndpoint destination = 1; optional int32 fd = 2; } @@ -56,9 +58,27 @@ message PortForwardData { message InitialPayload { optional bool jumphost = 1 [default = false]; + repeated PortForwardSourceRequest reversetunnels = 2; +} + +message InitialResponse { + optional string error = 1; } message ConfigParams { optional int32 vlevel = 1; optional int32 minloglevel = 2; } + +message TermInit { + repeated string environmentnames = 1; + repeated string environmentvalues = 2; +} + +message TerminalUserInfo { + optional string id = 1; + optional string passkey = 2; + optional int64 uid = 3; + optional int64 gid = 4; + optional int64 fd = 5; +} \ No newline at end of file diff --git a/src/base/ClientConnection.hpp b/src/base/ClientConnection.hpp index 2a5c7222e..5c79e8131 100644 --- a/src/base/ClientConnection.hpp +++ b/src/base/ClientConnection.hpp @@ -4,7 +4,6 @@ #include "Headers.hpp" #include "Connection.hpp" -#include "SocketEndpoint.hpp" namespace et { extern const int NULL_CLIENT_ID; diff --git a/src/base/Headers.hpp b/src/base/Headers.hpp index 1f15c146f..22aaa1b43 100644 --- a/src/base/Headers.hpp +++ b/src/base/Headers.hpp @@ -57,6 +57,18 @@ #include #include "ET.pb.h" +#include "ETerminal.pb.h" +inline std::ostream& operator<<(std::ostream& os, + const et::SocketEndpoint& se) { + if (se.has_name()) { + os << se.name(); + } + if (se.has_port()) { + os << ":" << se.port(); + } + return os; +} + #include "easylogging++.h" #include diff --git a/src/base/PipeSocketHandler.cpp b/src/base/PipeSocketHandler.cpp index 95a9a5093..c591b7899 100644 --- a/src/base/PipeSocketHandler.cpp +++ b/src/base/PipeSocketHandler.cpp @@ -14,7 +14,7 @@ PipeSocketHandler::PipeSocketHandler() {} int PipeSocketHandler::connect(const SocketEndpoint& endpoint) { lock_guard guard(mutex); - string pipePath = endpoint.getName(); + string pipePath = endpoint.name(); sockaddr_un remote; int sockFd = ::socket(AF_UNIX, SOCK_STREAM, 0); @@ -81,9 +81,9 @@ int PipeSocketHandler::connect(const SocketEndpoint& endpoint) { set PipeSocketHandler::listen(const SocketEndpoint& endpoint) { lock_guard guard(mutex); - string pipePath = endpoint.getName(); + string pipePath = endpoint.name(); if (pipeServerSockets.find(pipePath) != pipeServerSockets.end()) { - LOG(FATAL) << "Tried to listen twice on the same path"; + throw runtime_error("Tried to listen twice on the same path"); } sockaddr_un local; @@ -97,7 +97,7 @@ set PipeSocketHandler::listen(const SocketEndpoint& endpoint) { FATAL_FAIL(::bind(fd, (struct sockaddr*)&local, sizeof(sockaddr_un))); ::listen(fd, 5); - chmod(local.sun_path, 0777); + FATAL_FAIL(::chmod(local.sun_path, S_IRUSR | S_IWUSR | S_IXUSR)); pipeServerSockets[pipePath] = set({fd}); return pipeServerSockets[pipePath]; @@ -106,7 +106,7 @@ set PipeSocketHandler::listen(const SocketEndpoint& endpoint) { set PipeSocketHandler::getEndpointFds(const SocketEndpoint& endpoint) { lock_guard guard(mutex); - string pipePath = endpoint.getName(); + string pipePath = endpoint.name(); if (pipeServerSockets.find(pipePath) == pipeServerSockets.end()) { LOG(FATAL) << "Tried to getPipeFd on a pipe without calling listen() first: " @@ -118,7 +118,7 @@ set PipeSocketHandler::getEndpointFds(const SocketEndpoint& endpoint) { void PipeSocketHandler::stopListening(const SocketEndpoint& endpoint) { lock_guard guard(mutex); - string pipePath = endpoint.getName(); + string pipePath = endpoint.name(); auto it = pipeServerSockets.find(pipePath); if (it == pipeServerSockets.end()) { LOG(FATAL) diff --git a/src/base/ServerConnection.hpp b/src/base/ServerConnection.hpp index 79b34b1b9..0288792c1 100644 --- a/src/base/ServerConnection.hpp +++ b/src/base/ServerConnection.hpp @@ -4,7 +4,6 @@ #include "Headers.hpp" #include "ServerClientConnection.hpp" -#include "SocketEndpoint.hpp" #include "SocketHandler.hpp" namespace et { diff --git a/src/base/SocketEndpoint.hpp b/src/base/SocketEndpoint.hpp deleted file mode 100644 index 80a209cdb..000000000 --- a/src/base/SocketEndpoint.hpp +++ /dev/null @@ -1,51 +0,0 @@ -#ifndef __ET_SOCKET_ENDPOINT__ -#define __ET_SOCKET_ENDPOINT__ - -#include "Headers.hpp" - -namespace et { -class SocketEndpoint { - public: - SocketEndpoint() : name(""), port(-1), is_jumphost(false) {} - - explicit SocketEndpoint(const string &_name) - : name(_name), port(-1), is_jumphost(false) {} - - explicit SocketEndpoint(const string &_name, bool _is_jumphost) - : name(_name), port(-1), is_jumphost(_is_jumphost) {} - - explicit SocketEndpoint(int _port) - : name(""), port(_port), is_jumphost(false) {} - - explicit SocketEndpoint(int _port, bool _is_jumphost) - : name(""), port(_port), is_jumphost(_is_jumphost) {} - - SocketEndpoint(const string &_name, int _port) - : name(_name), port(_port), is_jumphost(false) {} - - SocketEndpoint(const string &_name, int _port, bool _is_jumphost) - : name(_name), port(_port), is_jumphost(_is_jumphost) {} - - const string &getName() const { return name; } - - int getPort() const { return port; } - - bool isJumphost() const { return is_jumphost; } - - protected: - string name; - int port; - bool is_jumphost; -}; - -inline ostream &operator<<(ostream &os, const SocketEndpoint &self) { - if (self.getPort() >= 0) { - os << self.getName() << ":" << self.getPort(); - } else { - os << self.getName(); - } - return os; -} -} // namespace et - -#endif // __ET_SOCKET_ENDPOINT__ diff --git a/src/base/SocketHandler.hpp b/src/base/SocketHandler.hpp index 8ebed2a5c..1864ecec0 100644 --- a/src/base/SocketHandler.hpp +++ b/src/base/SocketHandler.hpp @@ -4,7 +4,6 @@ #include "Headers.hpp" #include "Packet.hpp" -#include "SocketEndpoint.hpp" namespace et { class SocketHandler { diff --git a/src/base/TcpSocketHandler.cpp b/src/base/TcpSocketHandler.cpp index de8ab2ec8..26e0280d0 100644 --- a/src/base/TcpSocketHandler.cpp +++ b/src/base/TcpSocketHandler.cpp @@ -25,8 +25,8 @@ int TcpSocketHandler::connect(const SocketEndpoint &endpoint) { #else hints.ai_flags = (AI_CANONNAME | AI_V4MAPPED | AI_ADDRCONFIG | AI_ALL); #endif - std::string portname = std::to_string(endpoint.getPort()); - std::string hostname = endpoint.getName(); + std::string portname = std::to_string(endpoint.port()); + std::string hostname = endpoint.name(); // (re)initialize the DNS system ::res_init(); @@ -140,7 +140,7 @@ int TcpSocketHandler::connect(const SocketEndpoint &endpoint) { set TcpSocketHandler::listen(const SocketEndpoint &endpoint) { lock_guard guard(mutex); - int port = endpoint.getPort(); + int port = endpoint.port(); if (portServerSockets.find(port) != portServerSockets.end()) { LOG(FATAL) << "Tried to listen twice on the same port"; } @@ -219,7 +219,7 @@ set TcpSocketHandler::listen(const SocketEndpoint &endpoint) { set TcpSocketHandler::getEndpointFds(const SocketEndpoint &endpoint) { lock_guard guard(mutex); - int port = endpoint.getPort(); + int port = endpoint.port(); if (portServerSockets.find(port) == portServerSockets.end()) { LOG(FATAL) << "Tried to getEndpointFds on a port without calling listen() first"; @@ -230,7 +230,7 @@ set TcpSocketHandler::getEndpointFds(const SocketEndpoint &endpoint) { void TcpSocketHandler::stopListening(const SocketEndpoint &endpoint) { lock_guard guard(mutex); - int port = endpoint.getPort(); + int port = endpoint.port(); auto it = portServerSockets.find(port); if (it == portServerSockets.end()) { LOG(FATAL) diff --git a/src/htm/HtmClientMain.cpp b/src/htm/HtmClientMain.cpp index 0af9cca47..3496e2f26 100644 --- a/src/htm/HtmClientMain.cpp +++ b/src/htm/HtmClientMain.cpp @@ -6,7 +6,6 @@ #include "MultiplexerState.hpp" #include "PipeSocketHandler.hpp" #include "RawSocketUtils.hpp" -#include "SocketEndpoint.hpp" using namespace et; @@ -96,7 +95,9 @@ int main(int argc, char** argv) { // This means we are the client to the daemon usleep(10 * 1000); // Sleep for 10ms to let the daemon come alive shared_ptr socketHandler(new PipeSocketHandler()); - HtmClient htmClient(socketHandler, SocketEndpoint(HtmServer::getPipeName())); + SocketEndpoint pipeEndpoint; + pipeEndpoint.set_name(HtmServer::getPipeName()); + HtmClient htmClient(socketHandler, pipeEndpoint); htmClient.run(); char buf[] = { diff --git a/src/htm/HtmServerMain.cpp b/src/htm/HtmServerMain.cpp index 73d93588b..e6ab21f37 100644 --- a/src/htm/HtmServerMain.cpp +++ b/src/htm/HtmServerMain.cpp @@ -26,7 +26,9 @@ int main(int argc, char **argv) { el::Loggers::reconfigureLogger("default", defaultConf); shared_ptr socketHandler(new PipeSocketHandler()); - HtmServer htm(socketHandler, SocketEndpoint(HtmServer::getPipeName())); + SocketEndpoint endpoint; + endpoint.set_name(HtmServer::getPipeName()); + HtmServer htm(socketHandler, endpoint); htm.run(); LOG(INFO) << "Server is shutting down"; diff --git a/src/htm/IpcPairClient.hpp b/src/htm/IpcPairClient.hpp index a44617578..154598b6e 100644 --- a/src/htm/IpcPairClient.hpp +++ b/src/htm/IpcPairClient.hpp @@ -4,7 +4,6 @@ #include "Headers.hpp" #include "IpcPairEndpoint.hpp" -#include "SocketEndpoint.hpp" #include "SocketHandler.hpp" namespace et { diff --git a/src/terminal/PsuedoUserTerminal.hpp b/src/terminal/PsuedoUserTerminal.hpp index 69f130fef..6d57b5388 100644 --- a/src/terminal/PsuedoUserTerminal.hpp +++ b/src/terminal/PsuedoUserTerminal.hpp @@ -58,7 +58,7 @@ class PsuedoUserTerminal : public UserTerminal { string terminal = string(::getenv("SHELL")); VLOG(1) << "Child process launching terminal " << terminal; setenv("ET_VERSION", ET_VERSION, 1); - execl(terminal.c_str(), terminal.c_str(), "--login", NULL); + FATAL_FAIL(execl(terminal.c_str(), terminal.c_str(), "--login", NULL)); } virtual void cleanup() { diff --git a/src/terminal/TerminalClient.cpp b/src/terminal/TerminalClient.cpp index 12af47e30..50737f2de 100644 --- a/src/terminal/TerminalClient.cpp +++ b/src/terminal/TerminalClient.cpp @@ -1,53 +1,6 @@ #include "TerminalClient.hpp" namespace et { -TerminalClient::TerminalClient(std::shared_ptr _socketHandler, - const SocketEndpoint& _socketEndpoint, - const string& id, const string& passkey, - shared_ptr _console) - : console(_console), shuttingDown(false) { - portForwardHandler = - shared_ptr(new PortForwardHandler(_socketHandler)); - InitialPayload payload; - if (_socketEndpoint.isJumphost()) { - payload.set_jumphost(true); - } - - connection = shared_ptr( - new ClientConnection(_socketHandler, _socketEndpoint, id, passkey)); - - int connectFailCount = 0; - while (true) { - try { - if (connection->connect()) { - connection->writePacket( - Packet(EtPacketType::INITIAL_PAYLOAD, protoToString(payload))); - break; - } else { - LOG(ERROR) << "Connecting to server failed: Connect timeout"; - connectFailCount++; - if (connectFailCount == 3) { - throw std::runtime_error("Connect Timeout"); - } - } - } catch (const runtime_error& err) { - LOG(INFO) << "Could not make initial connection to server"; - cout << "Could not make initial connection to " << _socketEndpoint << ": " - << err.what() << endl; - exit(1); - } - break; - } - VLOG(1) << "Client created with id: " << connection->getId(); -}; - -TerminalClient::~TerminalClient() { - connection->shutdown(); - console.reset(); - portForwardHandler.reset(); - connection.reset(); -} - vector> parseRangesToPairs(const string& input) { vector> pairs; auto j = split(input, ','); @@ -92,36 +45,29 @@ vector> parseRangesToPairs(const string& input) { return pairs; } -void TerminalClient::run(const string& command, const string& tunnels, - const string& reverseTunnels) { - if (console) { - console->setup(); - } - -// TE sends/receives data to/from the shell one char at a time. -#define BUF_SIZE (16 * 1024) - char b[BUF_SIZE]; - - time_t keepaliveTime = time(NULL) + CLIENT_KEEP_ALIVE_DURATION; - bool waitingOnKeepalive = false; - - if (command.length()) { - LOG(INFO) << "Got command: " << command; - et::TerminalBuffer tb; - tb.set_buffer(command + "; exit\n"); - - connection->writePacket( - Packet(TerminalPacketType::TERMINAL_BUFFER, protoToString(tb))); - } +TerminalClient::TerminalClient(shared_ptr _socketHandler, + shared_ptr _pipeSocketHandler, + const SocketEndpoint& _socketEndpoint, + const string& id, const string& passkey, + shared_ptr _console, bool jumphost, + const string& tunnels, + const string& reverseTunnels, + bool forwardSshAgent) + : console(_console), shuttingDown(false) { + portForwardHandler = shared_ptr( + new PortForwardHandler(_socketHandler, _pipeSocketHandler)); + InitialPayload payload; + payload.set_jumphost(jumphost); try { if (tunnels.length()) { auto pairs = parseRangesToPairs(tunnels); for (auto& pair : pairs) { PortForwardSourceRequest pfsr; - pfsr.set_sourceport(pair.first); - pfsr.set_destinationport(pair.second); - auto pfsresponse = portForwardHandler->createSource(pfsr); + pfsr.mutable_source()->set_port(pair.first); + pfsr.mutable_destination()->set_port(pair.second); + auto pfsresponse = + portForwardHandler->createSource(pfsr, nullptr, -1, -1); if (pfsresponse.has_error()) { throw std::runtime_error(pfsresponse.error()); } @@ -131,19 +77,116 @@ void TerminalClient::run(const string& command, const string& tunnels, auto pairs = parseRangesToPairs(reverseTunnels); for (auto& pair : pairs) { PortForwardSourceRequest pfsr; - pfsr.set_sourceport(pair.first); - pfsr.set_destinationport(pair.second); - - connection->writePacket( - Packet(TerminalPacketType::PORT_FORWARD_SOURCE_REQUEST, - protoToString(pfsr))); + pfsr.mutable_source()->set_port(pair.first); + pfsr.mutable_destination()->set_port(pair.second); + *(payload.add_reversetunnels()) = pfsr; } } + if (forwardSshAgent) { + PortForwardSourceRequest pfsr; + auto authSockEnv = getenv("SSH_AUTH_SOCK"); + if (!authSockEnv) { + cerr << "Missing environment variable SSH_AUTH_SOCK. Are you sure you " + "ran ssh-agent first?" + << endl; + exit(1); + } + string authSock = string(authSockEnv); + pfsr.mutable_destination()->set_name(authSock); + pfsr.set_environmentvariable("SSH_AUTH_SOCK"); + *(payload.add_reversetunnels()) = pfsr; + } } catch (const std::runtime_error& ex) { cout << "Error establishing port forward: " << ex.what() << endl; LOG(FATAL) << "Error establishing port forward: " << ex.what(); } + connection = shared_ptr( + new ClientConnection(_socketHandler, _socketEndpoint, id, passkey)); + + int connectFailCount = 0; + while (true) { + try { + bool fail = true; + if (connection->connect()) { + connection->writePacket( + Packet(EtPacketType::INITIAL_PAYLOAD, protoToString(payload))); + fd_set rfd; + timeval tv; + + for (int a = 0; a < 3; a++) { + FD_ZERO(&rfd); + int clientFd = connection->getSocketFd(); + FD_SET(clientFd, &rfd); + tv.tv_sec = 1; + tv.tv_usec = 0; + select(clientFd + 1, &rfd, NULL, NULL, &tv); + if (FD_ISSET(clientFd, &rfd)) { + Packet initialResponsePacket; + if (connection->readPacket(&initialResponsePacket)) { + if (initialResponsePacket.getHeader() != + EtPacketType::INITIAL_RESPONSE) { + LOG(FATAL) << "Missing initial response!"; + } + auto initialResponse = stringToProto( + initialResponsePacket.getPayload()); + if (initialResponse.has_error()) { + cout << "Error initializing connection: " + << initialResponse.error() << endl; + exit(1); + } + fail = false; + break; + } + } + } + } + if (fail) { + LOG(ERROR) << "Connecting to server failed: Connect timeout"; + connectFailCount++; + if (connectFailCount == 3) { + throw std::runtime_error("Connect Timeout"); + } + } + } catch (const runtime_error& err) { + LOG(INFO) << "Could not make initial connection to server"; + cout << "Could not make initial connection to " << _socketEndpoint << ": " + << err.what() << endl; + exit(1); + } + break; + } + VLOG(1) << "Client created with id: " << connection->getId(); +}; + +TerminalClient::~TerminalClient() { + connection->shutdown(); + console.reset(); + portForwardHandler.reset(); + connection.reset(); +} + +void TerminalClient::run(const string& command) { + if (console) { + console->setup(); + } + +// TE sends/receives data to/from the shell one char at a time. +#define BUF_SIZE (16 * 1024) + char b[BUF_SIZE]; + + time_t keepaliveTime = time(NULL) + CLIENT_KEEP_ALIVE_DURATION; + bool waitingOnKeepalive = false; + + if (command.length()) { + LOG(INFO) << "Got command: " << command; + et::TerminalBuffer tb; + tb.set_buffer(command + "; exit\n"); + + connection->writePacket( + Packet(TerminalPacketType::TERMINAL_BUFFER, protoToString(tb))); + } + TerminalInfo lastTerminalInfo; if (!console.get()) { @@ -207,10 +250,6 @@ void TerminalClient::run(const string& command, const string& tunnels, } char packetType = packet.getHeader(); if (packetType == et::TerminalPacketType::PORT_FORWARD_DATA || - packetType == - et::TerminalPacketType::PORT_FORWARD_SOURCE_REQUEST || - packetType == - et::TerminalPacketType::PORT_FORWARD_SOURCE_RESPONSE || packetType == et::TerminalPacketType::PORT_FORWARD_DESTINATION_REQUEST || packetType == diff --git a/src/terminal/TerminalClient.hpp b/src/terminal/TerminalClient.hpp index 0eee21100..f01daf603 100644 --- a/src/terminal/TerminalClient.hpp +++ b/src/terminal/TerminalClient.hpp @@ -4,10 +4,10 @@ #include "ClientConnection.hpp" #include "Console.hpp" #include "CryptoHandler.hpp" +#include "ForwardSourceHandler.hpp" #include "Headers.hpp" #include "LogHandler.hpp" #include "PortForwardHandler.hpp" -#include "PortForwardSourceHandler.hpp" #include "RawSocketUtils.hpp" #include "ServerConnection.hpp" #include "SshSetupHandler.hpp" @@ -25,18 +25,18 @@ namespace et { class TerminalClient { public: TerminalClient(std::shared_ptr _socketHandler, + std::shared_ptr _pipeSocketHandler, const SocketEndpoint& _socketEndpoint, const string& id, - const string& passkey, shared_ptr _console); + const string& passkey, shared_ptr _console, + bool jumphost, const string& tunnels, + const string& reverseTunnels, bool forwardSshAgent); virtual ~TerminalClient(); void setUpTunnel(const string& tunnels); void setUpReverseTunnels(const string& reverseTunnels); void handleWindowChanged(winsize* win); // void handlePfwPacket(char packetType); - void run(const string& command, const string& tunnels, - const string& reverseTunnels); - void shutdown() { - shuttingDown = true; - } + void run(const string& command); + void shutdown() { shuttingDown = true; } protected: shared_ptr console; diff --git a/src/terminal/TerminalClientMain.cpp b/src/terminal/TerminalClientMain.cpp index c2cd6ecf7..cffb72c48 100644 --- a/src/terminal/TerminalClientMain.cpp +++ b/src/terminal/TerminalClientMain.cpp @@ -1,6 +1,7 @@ #include "TerminalClient.hpp" #include "ParseConfigFile.hpp" +#include "PipeSocketHandler.hpp" #include "PsuedoTerminalConsole.hpp" using namespace et; @@ -55,10 +56,11 @@ int main(int argc, char** argv) { ("x,kill-other-sessions", "kill all old sessions belonging to the user") // ("v,verbose", "Enable verbose logging", - cxxopts::value()->default_value("0")) // - ("logtostdout", "Write log to stdout") // - ("silent", "Disable logging") // - ("N,no-terminal", "Do not create a terminal") // + cxxopts::value()->default_value("0")) // + ("logtostdout", "Write log to stdout") // + ("silent", "Disable logging") // + ("N,no-terminal", "Do not create a terminal") // + ("f,forward-ssh-agent", "Forward ssh-agent socket") // ; options.parse_positional({"host", "positional"}); @@ -188,8 +190,11 @@ int main(int argc, char** argv) { port = result["jport"].as(); LOG(INFO) << "Setting port to jumphost port"; } - SocketEndpoint socketEndpoint = SocketEndpoint(host, port, is_jumphost); + SocketEndpoint socketEndpoint; + socketEndpoint.set_name(host); + socketEndpoint.set_port(port); shared_ptr clientSocket(new TcpSocketHandler()); + shared_ptr clientPipeSocket(new PipeSocketHandler()); if (!ping(socketEndpoint, clientSocket)) { cout << "Could not reach the ET server: " << host << ":" << port << endl; @@ -222,12 +227,12 @@ int main(int argc, char** argv) { console.reset(new PsuedoTerminalConsole()); } - TerminalClient terminalClient = - TerminalClient(clientSocket, socketEndpoint, id, passkey, console); - terminalClient.run( - result.count("command") ? result["command"].as() : "", - result.count("t") ? result["t"].as() : "", - result.count("r") ? result["r"].as() : ""); + TerminalClient terminalClient = TerminalClient( + clientSocket, clientPipeSocket, socketEndpoint, id, passkey, console, + is_jumphost, result.count("t") ? result["t"].as() : "", + result.count("r") ? result["r"].as() : "", result.count("f")); + terminalClient.run(result.count("command") ? result["command"].as() + : ""); } catch (cxxopts::OptionException& oe) { cout << "Exception: " << oe.what() << "\n" << endl; cout << options.help({}) << endl; diff --git a/src/terminal/TerminalMain.cpp b/src/terminal/TerminalMain.cpp index 646f06629..d3203f239 100644 --- a/src/terminal/TerminalMain.cpp +++ b/src/terminal/TerminalMain.cpp @@ -134,12 +134,15 @@ int main(int argc, char** argv) { if (DaemonCreator::createSessionLeader() == -1) { LOG(FATAL) << "Error creating daemon: " << strerror(errno); } + SocketEndpoint routerFifoEndpoint; + routerFifoEndpoint.set_name(ROUTER_FIFO_NAME); + SocketEndpoint destinationEndpoint; + destinationEndpoint.set_name(result["dsthost"].as()); + destinationEndpoint.set_port(result["dstport"].as()); shared_ptr jumpClientSocketHandler(new TcpSocketHandler()); UserJumphostHandler ujh(jumpClientSocketHandler, idpasskey, - SocketEndpoint(result["dsthost"].as(), - result["dstport"].as()), - ipcSocketHandler, - SocketEndpoint(ROUTER_FIFO_NAME)); + destinationEndpoint, ipcSocketHandler, + routerFifoEndpoint); ujh.run(); // Uninstall log rotation callback @@ -160,8 +163,10 @@ int main(int argc, char** argv) { // Install log rotation callback el::Helpers::installPreRollOutCallback(LogHandler::rolloutHandler); - UserTerminalHandler uth(ipcSocketHandler, term, true, - SocketEndpoint(ROUTER_FIFO_NAME), idpasskey); + SocketEndpoint routerEndpoint; + routerEndpoint.set_name(ROUTER_FIFO_NAME); + UserTerminalHandler uth(ipcSocketHandler, term, true, routerEndpoint, + idpasskey); cout << "IDPASSKEY:" << idpasskey << endl; if (DaemonCreator::createSessionLeader() == -1) { LOG(FATAL) << "Error creating daemon: " << strerror(errno); diff --git a/src/terminal/TerminalServer.cpp b/src/terminal/TerminalServer.cpp index e49a3f289..49db117d8 100644 --- a/src/terminal/TerminalServer.cpp +++ b/src/terminal/TerminalServer.cpp @@ -77,12 +77,16 @@ void TerminalServer::run() { void TerminalServer::runJumpHost( shared_ptr serverClientState) { + InitialResponse response; + serverClientState->writePacket( + Packet(uint8_t(EtPacketType::INITIAL_RESPONSE), protoToString(response))); // set thread name el::Helpers::setThreadName(serverClientState->getId()); bool run = true; bool b[BUF_SIZE]; - int terminalFd = terminalRouter->getFd(serverClientState->getId()); + int terminalFd = + terminalRouter->getInfoForId(serverClientState->getId()).fd(); shared_ptr terminalSocketHandler = terminalRouter->getSocketHandler(); @@ -150,7 +154,35 @@ void TerminalServer::runJumpHost( } void TerminalServer::runTerminal( - shared_ptr serverClientState) { + shared_ptr serverClientState, + const InitialPayload &payload) { + auto userInfo = terminalRouter->getInfoForId(serverClientState->getId()); + InitialResponse response; + shared_ptr serverSocketHandler = getSocketHandler(); + shared_ptr pipeSocketHandler(new PipeSocketHandler()); + shared_ptr portForwardHandler( + new PortForwardHandler(serverSocketHandler, pipeSocketHandler)); + map environmentVariables; + vector pipePaths; + for (const PortForwardSourceRequest &pfsr : payload.reversetunnels()) { + string sourceName; + PortForwardSourceResponse pfsresponse = portForwardHandler->createSource( + pfsr, &sourceName, userInfo.uid(), userInfo.gid()); + if (pfsresponse.has_error()) { + InitialResponse response; + response.set_error(pfsresponse.error()); + serverClientState->writePacket(Packet( + uint8_t(EtPacketType::INITIAL_RESPONSE), protoToString(response))); + return; + } + if (pfsr.has_environmentvariable()) { + environmentVariables[pfsr.environmentvariable()] = sourceName; + pipePaths.push_back(sourceName); + } + } + serverClientState->writePacket( + Packet(uint8_t(EtPacketType::INITIAL_RESPONSE), protoToString(response))); + // Set thread name el::Helpers::setThreadName(serverClientState->getId()); // Whether the TE should keep running. @@ -159,13 +191,19 @@ void TerminalServer::runTerminal( // TE sends/receives data to/from the shell one char at a time. char b[BUF_SIZE]; - shared_ptr serverSocketHandler = getSocketHandler(); - PortForwardHandler portForwardHandler(serverSocketHandler); - - int terminalFd = terminalRouter->getFd(serverClientState->getId()); + int terminalFd = userInfo.fd(); shared_ptr terminalSocketHandler = terminalRouter->getSocketHandler(); + TermInit termInit; + for (auto &it : environmentVariables) { + *(termInit.add_environmentnames()) = it.first; + *(termInit.add_environmentvalues()) = it.second; + } + terminalSocketHandler->writePacket( + terminalFd, + Packet(TerminalPacketType::TERMINAL_INIT, protoToString(termInit))); + while (run) { { lock_guard guard(terminalThreadMutex); @@ -217,7 +255,7 @@ void TerminalServer::runTerminal( vector requests; vector dataToSend; - portForwardHandler.update(&requests, &dataToSend); + portForwardHandler->update(&requests, &dataToSend); for (auto &pfr : requests) { serverClientState->writePacket( Packet(TerminalPacketType::PORT_FORWARD_DESTINATION_REQUEST, @@ -238,15 +276,11 @@ void TerminalServer::runTerminal( } char packetType = packet.getHeader(); if (packetType == et::TerminalPacketType::PORT_FORWARD_DATA || - packetType == - et::TerminalPacketType::PORT_FORWARD_SOURCE_REQUEST || - packetType == - et::TerminalPacketType::PORT_FORWARD_SOURCE_RESPONSE || packetType == et::TerminalPacketType::PORT_FORWARD_DESTINATION_REQUEST || packetType == et::TerminalPacketType::PORT_FORWARD_DESTINATION_RESPONSE) { - portForwardHandler.handlePacket(packet, serverClientState); + portForwardHandler->handlePacket(packet, serverClientState); continue; } switch (packetType) { @@ -301,7 +335,7 @@ void TerminalServer::runTerminal( serverClientState.reset(); removeClient(id); } -} // namespace et +} void TerminalServer::handleConnection( shared_ptr serverClientState) { @@ -318,7 +352,7 @@ void TerminalServer::handleConnection( if (payload.jumphost()) { runJumpHost(serverClientState); } else { - runTerminal(serverClientState); + runTerminal(serverClientState, payload); } } diff --git a/src/terminal/TerminalServer.hpp b/src/terminal/TerminalServer.hpp index 7445a246d..54ace111a 100644 --- a/src/terminal/TerminalServer.hpp +++ b/src/terminal/TerminalServer.hpp @@ -52,7 +52,8 @@ class TerminalServer : public ServerConnection { const SocketEndpoint &_routerEndpoint); virtual ~TerminalServer(); void runJumpHost(shared_ptr serverClientState); - void runTerminal(shared_ptr serverClientState); + void runTerminal(shared_ptr serverClientState, + const InitialPayload& payload); void handleConnection(shared_ptr serverClientState); virtual bool newClient(shared_ptr serverClientState); diff --git a/src/terminal/TerminalServerMain.cpp b/src/terminal/TerminalServerMain.cpp index 2fbb11115..8a0861228 100644 --- a/src/terminal/TerminalServerMain.cpp +++ b/src/terminal/TerminalServerMain.cpp @@ -21,12 +21,13 @@ int main(int argc, char **argv) { ("version", "Print version") // ("port", "Port to listen on", cxxopts::value()->default_value("0")) // - ("daemon", "Daemonize the server") // + ("daemon", "Daemonize the server") // ("cfgfile", "Location of the config file", cxxopts::value()->default_value("")) // - ("logtostdout", "log to stdout") // + ("logtostdout", "log to stdout") // ("pidfile", "Location of the pid file", - cxxopts::value()->default_value("/var/run/etserver.pid")) // + cxxopts::value()->default_value( + "/var/run/etserver.pid")) // ("v,verbose", "Enable verbose logging", cxxopts::value()->default_value("0"), "LEVEL") // ; @@ -114,16 +115,20 @@ int main(int argc, char **argv) { // Install log rotation callback el::Helpers::installPreRollOutCallback(LogHandler::rolloutHandler); std::shared_ptr tcpSocketHandler(new TcpSocketHandler()); - std::shared_ptr pipeSocketHandler(new PipeSocketHandler()); + std::shared_ptr pipeSocketHandler( + new PipeSocketHandler()); LOG(INFO) << "In child, about to start server."; - TerminalServer terminalServer(tcpSocketHandler, SocketEndpoint(port), - pipeSocketHandler, - SocketEndpoint(ROUTER_FIFO_NAME)); + SocketEndpoint serverEndpoint; + serverEndpoint.set_port(port); + SocketEndpoint routerFifo; + routerFifo.set_name(ROUTER_FIFO_NAME); + TerminalServer terminalServer(tcpSocketHandler, serverEndpoint, + pipeSocketHandler, routerFifo); terminalServer.run(); - } catch (cxxopts::OptionException& oe) { + } catch (cxxopts::OptionException &oe) { cout << "Exception: " << oe.what() << "\n" << endl; cout << options.help({}) << endl; exit(1); diff --git a/src/terminal/UserJumphostHandler.cpp b/src/terminal/UserJumphostHandler.cpp index 2908efccf..4b2a48fe5 100644 --- a/src/terminal/UserJumphostHandler.cpp +++ b/src/terminal/UserJumphostHandler.cpp @@ -12,6 +12,11 @@ UserJumphostHandler::UserJumphostHandler( auto idpasskey_splited = split(idpasskey, '/'); string id = idpasskey_splited[0]; string passkey = idpasskey_splited[1]; + TerminalUserInfo tui; + tui.set_id(id); + tui.set_passkey(passkey); + tui.set_uid(getuid()); + tui.set_gid(getgid()); routerFd = routerSocketHandler->connect(routerEndpoint); @@ -29,7 +34,8 @@ UserJumphostHandler::UserJumphostHandler( try { routerSocketHandler->writePacket( - routerFd, Packet(TerminalPacketType::IDPASSKEY, idpasskey)); + routerFd, + Packet(TerminalPacketType::TERMINAL_USER_INFO, protoToString(tui))); } catch (const std::runtime_error &re) { LOG(FATAL) << "Cannot send idpasskey to router: " << re.what(); } diff --git a/src/terminal/UserJumphostHandler.hpp b/src/terminal/UserJumphostHandler.hpp index 67b7758b0..0f0483b1b 100644 --- a/src/terminal/UserJumphostHandler.hpp +++ b/src/terminal/UserJumphostHandler.hpp @@ -1,7 +1,6 @@ #include "Headers.hpp" #include "ClientConnection.hpp" -#include "SocketEndpoint.hpp" #include "SocketHandler.hpp" namespace et { diff --git a/src/terminal/UserTerminalHandler.cpp b/src/terminal/UserTerminalHandler.cpp index 73bd022e5..3916bdd21 100644 --- a/src/terminal/UserTerminalHandler.cpp +++ b/src/terminal/UserTerminalHandler.cpp @@ -37,6 +37,14 @@ UserTerminalHandler::UserTerminalHandler( routerEndpoint(_routerEndpoint), shuttingDown(false) { routerFd = socketHandler->connect(routerEndpoint); + auto idpasskey_splited = split(idPasskey, '/'); + string id = idpasskey_splited[0]; + string passkey = idpasskey_splited[1]; + TerminalUserInfo tui; + tui.set_id(id); + tui.set_passkey(passkey); + tui.set_uid(getuid()); + tui.set_gid(getgid()); if (routerFd < 0) { if (errno == ECONNREFUSED) { @@ -52,13 +60,26 @@ UserTerminalHandler::UserTerminalHandler( try { socketHandler->writePacket( - routerFd, Packet(TerminalPacketType::IDPASSKEY, idPasskey)); + routerFd, + Packet(TerminalPacketType::TERMINAL_USER_INFO, protoToString(tui))); + } catch (const std::runtime_error &re) { LOG(FATAL) << "Error connecting to router: " << re.what(); } } void UserTerminalHandler::run() { + Packet termInitPacket = socketHandler->readPacket(routerFd); + if (termInitPacket.getHeader() != TerminalPacketType::TERMINAL_INIT) { + LOG(FATAL) << "Invalid terminal init packet header: " + << termInitPacket.getHeader(); + } + TermInit ti = stringToProto(termInitPacket.getPayload()); + for (int a = 0; a < ti.environmentnames_size(); a++) { + setenv(ti.environmentnames(a).c_str(), ti.environmentvalues(a).c_str(), + true); + } + int masterfd = term->setup(routerFd); VLOG(1) << "pty opened " << masterfd; runUserTerminal(masterfd); diff --git a/src/terminal/UserTerminalRouter.cpp b/src/terminal/UserTerminalRouter.cpp index 6fd788f15..7b91936a8 100644 --- a/src/terminal/UserTerminalRouter.cpp +++ b/src/terminal/UserTerminalRouter.cpp @@ -8,6 +8,10 @@ UserTerminalRouter::UserTerminalRouter( const SocketEndpoint &_routerEndpoint) : socketHandler(_socketHandler) { serverFd = *(socketHandler->listen(_routerEndpoint).begin()); + FATAL_FAIL(::chown(_routerEndpoint.name().c_str(), getuid(), getgid())); + FATAL_FAIL(::chmod(_routerEndpoint.name().c_str(), + S_IRUSR | S_IWUSR | S_IXUSR | S_IRGRP | S_IWGRP | S_IXGRP | + S_IROTH | S_IWOTH | S_IXOTH)); } IdKeyPair UserTerminalRouter::acceptNewConnection() { @@ -25,21 +29,14 @@ IdKeyPair UserTerminalRouter::acceptNewConnection() { try { Packet packet = socketHandler->readPacket(terminalFd); - if (packet.getHeader() != TerminalPacketType::IDPASSKEY) { + if (packet.getHeader() != TerminalPacketType::TERMINAL_USER_INFO) { LOG(FATAL) << "Got an invalid packet header: " << int(packet.getHeader()); } - string buf = packet.getPayload(); - VLOG(1) << "Got passkey: " << buf; - size_t slashIndex = buf.find("/"); - if (slashIndex == string::npos) { - LOG(ERROR) << "Invalid idPasskey id/key pair: " << buf; - close(terminalFd); - } else { - string id = buf.substr(0, slashIndex); - string key = buf.substr(slashIndex + 1); - idFdMap[id] = terminalFd; - return IdKeyPair({id, key}); - } + TerminalUserInfo tui = stringToProto(packet.getPayload()); + VLOG(1) << "Got id/passkey: " << tui.id() << "/" << tui.passkey(); + tui.set_fd(terminalFd); + idInfoMap[tui.id()] = tui; + return IdKeyPair({tui.id(), tui.passkey()}); } catch (const std::runtime_error &re) { LOG(FATAL) << "Router can't talk to terminal: " << re.what(); } @@ -47,9 +44,9 @@ IdKeyPair UserTerminalRouter::acceptNewConnection() { return IdKeyPair({"", ""}); } -int UserTerminalRouter::getFd(const string &id) { - auto it = idFdMap.find(id); - if (it == idFdMap.end()) { +TerminalUserInfo UserTerminalRouter::getInfoForId(const string &id) { + auto it = idInfoMap.find(id); + if (it == idInfoMap.end()) { LOG(FATAL) << " Tried to read from an id that no longer exists"; } return it->second; diff --git a/src/terminal/UserTerminalRouter.hpp b/src/terminal/UserTerminalRouter.hpp index 4707b01a6..511e207ae 100644 --- a/src/terminal/UserTerminalRouter.hpp +++ b/src/terminal/UserTerminalRouter.hpp @@ -15,14 +15,14 @@ class UserTerminalRouter { const SocketEndpoint& _routerEndpoint); inline int getServerFd() { return serverFd; } IdKeyPair acceptNewConnection(); - int getFd(const string& id); + TerminalUserInfo getInfoForId(const string& id); inline shared_ptr getSocketHandler() { return socketHandler; } protected: int serverFd; - unordered_map idFdMap; + unordered_map idInfoMap; shared_ptr socketHandler; }; } // namespace et diff --git a/src/terminal/PortForwardDestinationHandler.cpp b/src/terminal/forwarding/ForwardDestinationHandler.cpp similarity index 81% rename from src/terminal/PortForwardDestinationHandler.cpp rename to src/terminal/forwarding/ForwardDestinationHandler.cpp index d0fcdee59..c9cd52dff 100644 --- a/src/terminal/PortForwardDestinationHandler.cpp +++ b/src/terminal/forwarding/ForwardDestinationHandler.cpp @@ -1,18 +1,18 @@ -#include "PortForwardDestinationHandler.hpp" +#include "ForwardDestinationHandler.hpp" namespace et { -PortForwardDestinationHandler::PortForwardDestinationHandler( +ForwardDestinationHandler::ForwardDestinationHandler( shared_ptr _socketHandler, int _fd, int _socketId) : socketHandler(_socketHandler), fd(_fd), socketId(_socketId) {} -void PortForwardDestinationHandler::close() { socketHandler->close(fd); } +void ForwardDestinationHandler::close() { socketHandler->close(fd); } -void PortForwardDestinationHandler::write(const string& s) { +void ForwardDestinationHandler::write(const string& s) { VLOG(1) << "Writing " << s.length() << " bytes to port destination"; socketHandler->writeAllOrReturn(fd, s.c_str(), s.length()); } -void PortForwardDestinationHandler::update(vector* retval) { +void ForwardDestinationHandler::update(vector* retval) { if (fd == -1) { return; } diff --git a/src/terminal/PortForwardDestinationHandler.hpp b/src/terminal/forwarding/ForwardDestinationHandler.hpp similarity index 74% rename from src/terminal/PortForwardDestinationHandler.hpp rename to src/terminal/forwarding/ForwardDestinationHandler.hpp index 0b01ee9fc..154cc2b57 100644 --- a/src/terminal/PortForwardDestinationHandler.hpp +++ b/src/terminal/forwarding/ForwardDestinationHandler.hpp @@ -8,11 +8,10 @@ #include "ETerminal.pb.h" namespace et { -class PortForwardDestinationHandler { +class ForwardDestinationHandler { public: - PortForwardDestinationHandler(shared_ptr _socketHandler, - int _fd, int _socketId); - + ForwardDestinationHandler(shared_ptr _socketHandler, int _fd, + int _socketId); void write(const string& s); void update(vector* retval); diff --git a/src/terminal/PortForwardSourceHandler.cpp b/src/terminal/forwarding/ForwardSourceHandler.cpp similarity index 72% rename from src/terminal/PortForwardSourceHandler.cpp rename to src/terminal/forwarding/ForwardSourceHandler.cpp index 319d7ed0b..8818ae147 100644 --- a/src/terminal/PortForwardSourceHandler.cpp +++ b/src/terminal/forwarding/ForwardSourceHandler.cpp @@ -1,21 +1,21 @@ -#include "PortForwardSourceHandler.hpp" +#include "ForwardSourceHandler.hpp" namespace et { -PortForwardSourceHandler::PortForwardSourceHandler( - shared_ptr _socketHandler, int _sourcePort, - int _destinationPort) +ForwardSourceHandler::ForwardSourceHandler( + shared_ptr _socketHandler, const SocketEndpoint& _source, + const SocketEndpoint& _destination) : socketHandler(_socketHandler), - sourcePort(_sourcePort), - destinationPort(_destinationPort) { - socketHandler->listen(SocketEndpoint(sourcePort)); + source(_source), + destination(_destination) { + socketHandler->listen(source); } -int PortForwardSourceHandler::listen() { +int ForwardSourceHandler::listen() { // TODO: Replace with select - for (int i : socketHandler->getEndpointFds(SocketEndpoint(sourcePort))) { + for (int i : socketHandler->getEndpointFds(source)) { int fd = socketHandler->accept(i); if (fd > -1) { - LOG(INFO) << "Tunnel " << sourcePort << " -> " << destinationPort + LOG(INFO) << "Tunnel " << source << " -> " << destination << " socket created with fd " << fd; unassignedFds.insert(fd); return fd; @@ -24,7 +24,7 @@ int PortForwardSourceHandler::listen() { return -1; } -void PortForwardSourceHandler::update(vector* data) { +void ForwardSourceHandler::update(vector* data) { vector socketsToRemove; for (auto& it : socketFdMap) { @@ -65,11 +65,11 @@ void PortForwardSourceHandler::update(vector* data) { } } -bool PortForwardSourceHandler::hasUnassignedFd(int fd) { +bool ForwardSourceHandler::hasUnassignedFd(int fd) { return unassignedFds.find(fd) != unassignedFds.end(); } -void PortForwardSourceHandler::closeUnassignedFd(int fd) { +void ForwardSourceHandler::closeUnassignedFd(int fd) { if (unassignedFds.find(fd) == unassignedFds.end()) { LOG(ERROR) << "Tried to close an unassigned fd that doesn't exist"; return; @@ -78,7 +78,7 @@ void PortForwardSourceHandler::closeUnassignedFd(int fd) { unassignedFds.erase(fd); } -void PortForwardSourceHandler::addSocket(int socketId, int sourceFd) { +void ForwardSourceHandler::addSocket(int socketId, int sourceFd) { if (unassignedFds.find(sourceFd) == unassignedFds.end()) { LOG(ERROR) << "Tried to close an unassigned fd that doesn't exist " << sourceFd; @@ -88,9 +88,7 @@ void PortForwardSourceHandler::addSocket(int socketId, int sourceFd) { unassignedFds.erase(sourceFd); socketFdMap[socketId] = sourceFd; } - -void PortForwardSourceHandler::sendDataOnSocket(int socketId, - const string& data) { +void ForwardSourceHandler::sendDataOnSocket(int socketId, const string& data) { if (socketFdMap.find(socketId) == socketFdMap.end()) { LOG(ERROR) << "Tried to write to a socket that no longer exists!"; return; @@ -102,7 +100,7 @@ void PortForwardSourceHandler::sendDataOnSocket(int socketId, socketHandler->writeAllOrReturn(fd, buf, count); } -void PortForwardSourceHandler::closeSocket(int socketId) { +void ForwardSourceHandler::closeSocket(int socketId) { auto it = socketFdMap.find(socketId); if (it == socketFdMap.end()) { LOG(ERROR) << "Tried to remove a socket that no longer exists!"; diff --git a/src/terminal/PortForwardSourceHandler.hpp b/src/terminal/forwarding/ForwardSourceHandler.hpp similarity index 53% rename from src/terminal/PortForwardSourceHandler.hpp rename to src/terminal/forwarding/ForwardSourceHandler.hpp index f4618d81f..755c1f48f 100644 --- a/src/terminal/PortForwardSourceHandler.hpp +++ b/src/terminal/forwarding/ForwardSourceHandler.hpp @@ -1,16 +1,16 @@ -#ifndef __PORT_FORWARD_SOURCE_LISTENER_H__ -#define __PORT_FORWARD_SOURCE_LISTENER_H__ +#ifndef __FORWARD_SOURCE_HANDLER_H__ +#define __FORWARD_SOURCE_HANDLER_H__ #include "Headers.hpp" -#include "ETerminal.pb.h" #include "SocketHandler.hpp" namespace et { -class PortForwardSourceHandler { +class ForwardSourceHandler { public: - PortForwardSourceHandler(shared_ptr _socketHandler, - int _sourcePort, int _destinationPort); + ForwardSourceHandler(shared_ptr _socketHandler, + const SocketEndpoint& _source, + const SocketEndpoint& _destination); int listen(); @@ -26,15 +26,15 @@ class PortForwardSourceHandler { void sendDataOnSocket(int socketId, const string& data); - inline int getDestinationPort() { return destinationPort; } + inline SocketEndpoint getDestination() { return destination; } protected: shared_ptr socketHandler; - int sourcePort; - int destinationPort; + SocketEndpoint source; + SocketEndpoint destination; unordered_set unassignedFds; unordered_map socketFdMap; }; } // namespace et -#endif // __PORT_FORWARD_SOURCE_LISTENER_H__ +#endif // __FORWARD_SOURCE_HANDLER_H__ diff --git a/src/terminal/PortForwardHandler.cpp b/src/terminal/forwarding/PortForwardHandler.cpp similarity index 65% rename from src/terminal/PortForwardHandler.cpp rename to src/terminal/forwarding/PortForwardHandler.cpp index b1158585b..2f308dd85 100644 --- a/src/terminal/PortForwardHandler.cpp +++ b/src/terminal/forwarding/PortForwardHandler.cpp @@ -1,8 +1,11 @@ #include "PortForwardHandler.hpp" namespace et { -PortForwardHandler::PortForwardHandler(shared_ptr _socketHandler) - : socketHandler(_socketHandler) {} +PortForwardHandler::PortForwardHandler( + shared_ptr _networkSocketHandler, + shared_ptr _pipeSocketHandler) + : networkSocketHandler(_networkSocketHandler), + pipeSocketHandler(_pipeSocketHandler) {} void PortForwardHandler::update(vector* requests, vector* dataToSend) { @@ -11,7 +14,7 @@ void PortForwardHandler::update(vector* requests, int fd = it->listen(); if (fd >= 0) { PortForwardDestinationRequest pfr; - pfr.set_port(it->getDestinationPort()); + *(pfr.mutable_destination()) = it->getDestination(); pfr.set_fd(fd); requests->push_back(pfr); } @@ -29,13 +32,52 @@ void PortForwardHandler::update(vector* requests, } PortForwardSourceResponse PortForwardHandler::createSource( - const PortForwardSourceRequest& pfsr) { + const PortForwardSourceRequest& pfsr, string* sourceName, uid_t userid, + gid_t groupid) { try { - auto handler = - shared_ptr(new PortForwardSourceHandler( - socketHandler, pfsr.sourceport(), pfsr.destinationport())); - sourceHandlers.push_back(handler); - return PortForwardSourceResponse(); + if (pfsr.has_source() && !pfsr.source().has_port()) { + throw runtime_error("Do not set a source when forwarding named pipes"); + } + SocketEndpoint source; + if (pfsr.has_source()) { + source = pfsr.source(); + } else { + // Make a random file to forward the pipe + string sourcePattern = string("/tmp/et_forward_sock_XXXXXX"); + string sourceDirectory = string(mkdtemp(&sourcePattern[0])); + FATAL_FAIL(::chmod(sourceDirectory.c_str(), S_IRUSR | S_IWUSR | S_IXUSR)); + FATAL_FAIL(::chown(sourceDirectory.c_str(), userid, groupid)); + string sourcePath = string(sourceDirectory) + "/sock"; + + source.set_name(sourcePath); + if (sourceName == nullptr) { + LOG(FATAL) + << "Tried to create a pipe but without a place to put the name!"; + } + *sourceName = sourcePath; + LOG(INFO) << "Creating pipe at " << sourcePath; + } + if (pfsr.source().has_port()) { + if (sourceName != nullptr) { + LOG(FATAL) << "Tried to create a port forward but with a place to put " + "the name!"; + } + auto handler = shared_ptr(new ForwardSourceHandler( + networkSocketHandler, source, pfsr.destination())); + sourceHandlers.push_back(handler); + return PortForwardSourceResponse(); + } else { + if (userid < 0 || groupid < 0) { + LOG(FATAL) + << "Tried to create a unix socket forward with no userid/groupid"; + } + auto handler = shared_ptr(new ForwardSourceHandler( + pipeSocketHandler, source, pfsr.destination())); + FATAL_FAIL(::chmod(source.name().c_str(), S_IRUSR | S_IWUSR | S_IXUSR)); + FATAL_FAIL(::chown(source.name().c_str(), userid, groupid)); + sourceHandlers.push_back(handler); + return PortForwardSourceResponse(); + } } catch (const std::runtime_error& ex) { PortForwardSourceResponse pfsr; pfsr.set_error(ex.what()); @@ -45,11 +87,24 @@ PortForwardSourceResponse PortForwardHandler::createSource( PortForwardDestinationResponse PortForwardHandler::createDestination( const PortForwardDestinationRequest& pfdr) { - // Try ipv6 first - int fd = socketHandler->connect(SocketEndpoint("::1", pfdr.port())); - if (fd == -1) { - // Try ipv4 next - fd = socketHandler->connect(SocketEndpoint("127.0.0.1", pfdr.port())); + int fd = -1; + bool isTcp = pfdr.destination().has_port(); + if (pfdr.destination().has_port()) { + // Try ipv6 first + SocketEndpoint ipv6Localhost; + ipv6Localhost.set_name("::1"); + ipv6Localhost.set_port(pfdr.destination().port()); + + fd = networkSocketHandler->connect(ipv6Localhost); + if (fd == -1) { + SocketEndpoint ipv4Localhost; + ipv4Localhost.set_name("127.0.0.1"); + ipv4Localhost.set_port(pfdr.destination().port()); + // Try ipv4 next + fd = networkSocketHandler->connect(ipv4Localhost); + } + } else { + fd = pipeSocketHandler->connect(pfdr.destination()); } PortForwardDestinationResponse pfdresponse; pfdresponse.set_clientfd(pfdr.fd()); @@ -68,8 +123,9 @@ PortForwardDestinationResponse PortForwardHandler::createDestination( } if (!pfdresponse.has_error()) { LOG(INFO) << "Created socket/fd pair: " << socketId << ' ' << fd; - destinationHandlers[socketId] = shared_ptr( - new PortForwardDestinationHandler(socketHandler, fd, socketId)); + destinationHandlers[socketId] = + shared_ptr(new ForwardDestinationHandler( + isTcp ? networkSocketHandler : pipeSocketHandler, fd, socketId)); pfdresponse.set_socketid(socketId); } } @@ -115,33 +171,11 @@ void PortForwardHandler::handlePacket(const Packet& packet, } break; } - case TerminalPacketType::PORT_FORWARD_SOURCE_REQUEST: { - LOG(INFO) << "Got new port source request"; - PortForwardSourceRequest pfsr = - stringToProto(packet.getPayload()); - PortForwardSourceResponse pfsresponse = createSource(pfsr); - Packet sendPacket( - uint8_t(TerminalPacketType::PORT_FORWARD_SOURCE_RESPONSE), - protoToString(pfsresponse)); - connection->writePacket(sendPacket); - break; - } - case TerminalPacketType::PORT_FORWARD_SOURCE_RESPONSE: { - LOG(INFO) << "Got port source response"; - PortForwardSourceResponse pfsresponse = - stringToProto(packet.getPayload()); - if (pfsresponse.has_error()) { - cout << "FATAL: A reverse tunnel has failed (probably because someone " - "else is already using that port on the destination server" - << endl; - LOG(FATAL) << "Reverse tunnel request failed: " << pfsresponse.error(); - } - break; - } case TerminalPacketType::PORT_FORWARD_DESTINATION_REQUEST: { PortForwardDestinationRequest pfdr = stringToProto(packet.getPayload()); - LOG(INFO) << "Got new port destination request for port " << pfdr.port(); + LOG(INFO) << "Got new port destination request for " + << pfdr.destination(); PortForwardDestinationResponse pfdresponse = createDestination(pfdr); Packet sendPacket( uint8_t(TerminalPacketType::PORT_FORWARD_DESTINATION_RESPONSE), diff --git a/src/terminal/PortForwardHandler.hpp b/src/terminal/forwarding/PortForwardHandler.hpp similarity index 54% rename from src/terminal/PortForwardHandler.hpp rename to src/terminal/forwarding/PortForwardHandler.hpp index cf335aee0..c7b2350e9 100644 --- a/src/terminal/PortForwardHandler.hpp +++ b/src/terminal/forwarding/PortForwardHandler.hpp @@ -4,18 +4,21 @@ #include "ETerminal.pb.h" #include "Connection.hpp" -#include "PortForwardDestinationHandler.hpp" -#include "PortForwardSourceHandler.hpp" +#include "ForwardDestinationHandler.hpp" +#include "ForwardSourceHandler.hpp" #include "SocketHandler.hpp" namespace et { class PortForwardHandler { public: - explicit PortForwardHandler(shared_ptr _socketHandler); + explicit PortForwardHandler(shared_ptr _networkSocketHandler, + shared_ptr _pipeSocketHandler); void update(vector* requests, vector* dataToSend); void handlePacket(const Packet& packet, shared_ptr connection); - PortForwardSourceResponse createSource(const PortForwardSourceRequest& pfsr); + PortForwardSourceResponse createSource(const PortForwardSourceRequest& pfsr, + string* sourceName, uid_t userid, + gid_t groupid); PortForwardDestinationResponse createDestination( const PortForwardDestinationRequest& pfdr); @@ -25,13 +28,12 @@ class PortForwardHandler { void sendDataToSourceOnSocket(int socketId, const string& data); protected: - shared_ptr socketHandler; - unordered_map> - destinationHandlers; + shared_ptr networkSocketHandler; + shared_ptr pipeSocketHandler; + unordered_map> destinationHandlers; - vector> sourceHandlers; - unordered_map> - socketIdSourceHandlerMap; + vector> sourceHandlers; + unordered_map> socketIdSourceHandlerMap; }; } // namespace et diff --git a/test/BackedTest.cpp b/test/BackedTest.cpp index f8892d5e3..59c5bf9c8 100644 --- a/test/BackedTest.cpp +++ b/test/BackedTest.cpp @@ -113,7 +113,8 @@ TEST_CASE("BackedTest", "[BackedTest]") { string tmpPath = string("/tmp/et_test_XXXXXXXX"); pipeDirectory = string(mkdtemp(&tmpPath[0])); pipePath = string(pipeDirectory) + "/pipe"; - SocketEndpoint endpoint(pipePath); + SocketEndpoint endpoint; + endpoint.set_name(pipePath); int serverClientFd = -1; std::thread serverListenThread(listenFn, serverSocketHandler, endpoint, &serverClientFd); diff --git a/test/ConnectionTest.cpp b/test/ConnectionTest.cpp index 561354dc9..ed86cf74f 100644 --- a/test/ConnectionTest.cpp +++ b/test/ConnectionTest.cpp @@ -262,7 +262,8 @@ TEST_CASE("ConnectionTest", "[ConnectionTest]") { string tmpPath = string("/tmp/et_test_XXXXXXXX"); pipeDirectory = string(mkdtemp(&tmpPath[0])); pipePath = string(pipeDirectory) + "/pipe"; - endpoint = SocketEndpoint(pipePath); + endpoint = SocketEndpoint(); + endpoint.set_name(pipePath); serverConnection.reset( new TestServerConnection(serverSocketHandler, endpoint)); diff --git a/test/FakeConsole.hpp b/test/FakeConsole.hpp index b0c26267f..b54edee47 100644 --- a/test/FakeConsole.hpp +++ b/test/FakeConsole.hpp @@ -47,7 +47,8 @@ class FakeConsole : public Console { string tmpPath = string("/tmp/et_test_console_XXXXXXXX"); pipeDirectory = string(mkdtemp(&tmpPath[0])); pipePath = string(pipeDirectory) + "/pipe"; - SocketEndpoint endpoint(pipePath); + SocketEndpoint endpoint; + endpoint.set_name(pipePath); { lock_guard lock(_mutex); serverClientFd = -1; @@ -141,7 +142,8 @@ class FakeUserTerminal : public UserTerminal { string tmpPath = string("/tmp/et_test_userterminal_XXXXXXXX"); pipeDirectory = string(mkdtemp(&tmpPath[0])); pipePath = string(pipeDirectory) + "/pipe"; - SocketEndpoint endpoint(pipePath); + SocketEndpoint endpoint; + endpoint.set_name(pipePath); serverClientFd = -1; std::thread serverListenThread(&FakeUserTerminal::listenFn, this, socketHandler, endpoint, &serverClientFd); diff --git a/test/TerminalTest.cpp b/test/TerminalTest.cpp index 199f0eaf6..0252916dd 100644 --- a/test/TerminalTest.cpp +++ b/test/TerminalTest.cpp @@ -90,6 +90,7 @@ void readWriteTest(const string& clientId, shared_ptr fakeUserTerminal, SocketEndpoint serverEndpoint, shared_ptr clientSocketHandler, + shared_ptr clientPipeSocketHandler, shared_ptr fakeConsole, const SocketEndpoint& routerEndpoint) { auto uth = shared_ptr( @@ -99,9 +100,10 @@ void readWriteTest(const string& clientId, sleep(1); shared_ptr terminalClient(new TerminalClient( - clientSocketHandler, serverEndpoint, clientId, CRYPTO_KEY, fakeConsole)); + clientSocketHandler, clientPipeSocketHandler, serverEndpoint, clientId, + CRYPTO_KEY, fakeConsole, false, "", "", false)); thread terminalClientThread( - [terminalClient]() { terminalClient->run("", "", ""); }); + [terminalClient]() { terminalClient->run(""); }); sleep(3); string s(1024, '\0'); @@ -143,6 +145,7 @@ TEST_CASE("EndToEndTest", "[EndToEndTest]") { shared_ptr serverSocketHandler; shared_ptr clientSocketHandler; + shared_ptr clientPipeSocketHandler; SocketEndpoint serverEndpoint; @@ -153,6 +156,7 @@ TEST_CASE("EndToEndTest", "[EndToEndTest]") { srand(1); clientSocketHandler.reset(new PipeSocketHandler()); + clientPipeSocketHandler.reset(new PipeSocketHandler()); serverSocketHandler.reset(new PipeSocketHandler()); routerSocketHandler.reset(new PipeSocketHandler()); el::Helpers::setThreadName("Main"); @@ -167,10 +171,11 @@ TEST_CASE("EndToEndTest", "[EndToEndTest]") { pipeDirectory = string(mkdtemp(&tmpPath[0])); string routerPipePath = string(pipeDirectory) + "/pipe_router"; - auto routerEndpoint = SocketEndpoint(routerPipePath); + SocketEndpoint routerEndpoint; + routerEndpoint.set_name(routerPipePath); string serverPipePath = string(pipeDirectory) + "/pipe_server"; - serverEndpoint = SocketEndpoint(serverPipePath); + serverEndpoint.set_name(serverPipePath); auto server = shared_ptr( new TerminalServer(serverSocketHandler, serverEndpoint, @@ -179,7 +184,7 @@ TEST_CASE("EndToEndTest", "[EndToEndTest]") { sleep(1); readWriteTest("1234567890123456", routerSocketHandler, fakeUserTerminal, - serverEndpoint, clientSocketHandler, fakeConsole, + serverEndpoint, clientSocketHandler, clientPipeSocketHandler, fakeConsole, routerEndpoint); server->shutdown(); t_server.join(); @@ -188,6 +193,7 @@ TEST_CASE("EndToEndTest", "[EndToEndTest]") { userTerminalSocketHandler.reset(); serverSocketHandler.reset(); clientSocketHandler.reset(); + clientPipeSocketHandler.reset(); routerSocketHandler.reset(); FATAL_FAIL(::remove(routerPipePath.c_str())); FATAL_FAIL(::remove(serverPipePath.c_str()));