// @file sock.cpp /* Copyright 2009 10gen Inc. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include "pch.h" #include "sock.h" #include "../background.h" #if !defined(_WIN32) # include # include # include # include # include # include # include # include # include # if defined(__openbsd__) # include # endif #endif #ifdef MONGO_SSL #include #include #endif namespace mongo { static bool ipv6 = false; void enableIPv6(bool state) { ipv6 = state; } bool IPv6Enabled() { return ipv6; } void setSockTimeouts(int sock, double secs) { struct timeval tv; tv.tv_sec = (int)secs; tv.tv_usec = (int)((long long)(secs*1000*1000) % (1000*1000)); bool report = logLevel > 3; // solaris doesn't provide these DEV report = true; bool ok = setsockopt(sock, SOL_SOCKET, SO_RCVTIMEO, (char *) &tv, sizeof(tv) ) == 0; if( report && !ok ) log() << "unabled to set SO_RCVTIMEO" << endl; ok = setsockopt(sock, SOL_SOCKET, SO_SNDTIMEO, (char *) &tv, sizeof(tv) ) == 0; DEV if( report && !ok ) log() << "unabled to set SO_RCVTIMEO" << endl; } #if defined(_WIN32) void disableNagle(int sock) { int x = 1; if ( setsockopt(sock, IPPROTO_TCP, TCP_NODELAY, (char *) &x, sizeof(x)) ) error() << "disableNagle failed" << endl; if ( setsockopt(sock, SOL_SOCKET, SO_KEEPALIVE, (char *) &x, sizeof(x)) ) error() << "SO_KEEPALIVE failed" << endl; } #else void disableNagle(int sock) { int x = 1; #ifdef SOL_TCP int level = SOL_TCP; #else int level = SOL_SOCKET; #endif if ( setsockopt(sock, level, TCP_NODELAY, (char *) &x, sizeof(x)) ) error() << "disableNagle failed: " << errnoWithDescription() << endl; #ifdef SO_KEEPALIVE if ( setsockopt(sock, SOL_SOCKET, SO_KEEPALIVE, (char *) &x, sizeof(x)) ) error() << "SO_KEEPALIVE failed: " << errnoWithDescription() << endl; # ifdef __linux__ socklen_t len = sizeof(x); if ( getsockopt(sock, level, TCP_KEEPIDLE, (char *) &x, &len) ) error() << "can't get TCP_KEEPIDLE: " << errnoWithDescription() << endl; if (x > 300) { x = 300; if ( setsockopt(sock, level, TCP_KEEPIDLE, (char *) &x, sizeof(x)) ) { error() << "can't set TCP_KEEPIDLE: " << errnoWithDescription() << endl; } } len = sizeof(x); // just in case it changed if ( getsockopt(sock, level, TCP_KEEPINTVL, (char *) &x, &len) ) error() << "can't get TCP_KEEPINTVL: " << errnoWithDescription() << endl; if (x > 300) { x = 300; if ( setsockopt(sock, level, TCP_KEEPINTVL, (char *) &x, sizeof(x)) ) { error() << "can't set TCP_KEEPINTVL: " << errnoWithDescription() << endl; } } # endif #endif } #endif string getAddrInfoStrError(int code) { #if !defined(_WIN32) return gai_strerror(code); #else /* gai_strerrorA is not threadsafe on windows. don't use it. */ return errnoWithDescription(code); #endif } // --- SockAddr SockAddr::SockAddr(int sourcePort) { memset(as().sin_zero, 0, sizeof(as().sin_zero)); as().sin_family = AF_INET; as().sin_port = htons(sourcePort); as().sin_addr.s_addr = htonl(INADDR_ANY); addressSize = sizeof(sockaddr_in); } SockAddr::SockAddr(const char * iporhost , int port) { if (!strcmp(iporhost, "localhost")) iporhost = "127.0.0.1"; if (strchr(iporhost, '/')) { #ifdef _WIN32 uassert(13080, "no unix socket support on windows", false); #endif uassert(13079, "path to unix socket too long", strlen(iporhost) < sizeof(as().sun_path)); as().sun_family = AF_UNIX; strcpy(as().sun_path, iporhost); addressSize = sizeof(sockaddr_un); } else { addrinfo* addrs = NULL; addrinfo hints; memset(&hints, 0, sizeof(addrinfo)); hints.ai_socktype = SOCK_STREAM; //hints.ai_flags = AI_ADDRCONFIG; // This is often recommended but don't do it. SERVER-1579 hints.ai_flags |= AI_NUMERICHOST; // first pass tries w/o DNS lookup hints.ai_family = (IPv6Enabled() ? AF_UNSPEC : AF_INET); StringBuilder ss; ss << port; int ret = getaddrinfo(iporhost, ss.str().c_str(), &hints, &addrs); // old C compilers on IPv6-capable hosts return EAI_NODATA error #ifdef EAI_NODATA int nodata = (ret == EAI_NODATA); #else int nodata = false; #endif if (ret == EAI_NONAME || nodata) { // iporhost isn't an IP address, allow DNS lookup hints.ai_flags &= ~AI_NUMERICHOST; ret = getaddrinfo(iporhost, ss.str().c_str(), &hints, &addrs); } if (ret) { // don't log if this as it is a CRT construction and log() may not work yet. if( strcmp("0.0.0.0", iporhost) ) { log() << "getaddrinfo(\"" << iporhost << "\") failed: " << gai_strerror(ret) << endl; } *this = SockAddr(port); } else { //TODO: handle other addresses in linked list; assert(addrs->ai_addrlen <= sizeof(sa)); memcpy(&sa, addrs->ai_addr, addrs->ai_addrlen); addressSize = addrs->ai_addrlen; freeaddrinfo(addrs); } } } bool SockAddr::isLocalHost() const { switch (getType()) { case AF_INET: return getAddr() == "127.0.0.1"; case AF_INET6: return getAddr() == "::1"; case AF_UNIX: return true; default: return false; } assert(false); return false; } string SockAddr::toString(bool includePort) const { string out = getAddr(); if (includePort && getType() != AF_UNIX && getType() != AF_UNSPEC) out += mongoutils::str::stream() << ':' << getPort(); return out; } sa_family_t SockAddr::getType() const { return sa.ss_family; } unsigned SockAddr::getPort() const { switch (getType()) { case AF_INET: return ntohs(as().sin_port); case AF_INET6: return ntohs(as().sin6_port); case AF_UNIX: return 0; case AF_UNSPEC: return 0; default: massert(SOCK_FAMILY_UNKNOWN_ERROR, "unsupported address family", false); return 0; } } string SockAddr::getAddr() const { switch (getType()) { case AF_INET: case AF_INET6: { const int buflen=128; char buffer[buflen]; int ret = getnameinfo(raw(), addressSize, buffer, buflen, NULL, 0, NI_NUMERICHOST); massert(13082, getAddrInfoStrError(ret), ret == 0); return buffer; } case AF_UNIX: return (addressSize > 2 ? as().sun_path : "anonymous unix socket"); case AF_UNSPEC: return "(NONE)"; default: massert(SOCK_FAMILY_UNKNOWN_ERROR, "unsupported address family", false); return ""; } } bool SockAddr::operator==(const SockAddr& r) const { if (getType() != r.getType()) return false; if (getPort() != r.getPort()) return false; switch (getType()) { case AF_INET: return as().sin_addr.s_addr == r.as().sin_addr.s_addr; case AF_INET6: return memcmp(as().sin6_addr.s6_addr, r.as().sin6_addr.s6_addr, sizeof(in6_addr)) == 0; case AF_UNIX: return strcmp(as().sun_path, r.as().sun_path) == 0; case AF_UNSPEC: return true; // assume all unspecified addresses are the same default: massert(SOCK_FAMILY_UNKNOWN_ERROR, "unsupported address family", false); } return false; } bool SockAddr::operator!=(const SockAddr& r) const { return !(*this == r); } bool SockAddr::operator<(const SockAddr& r) const { if (getType() < r.getType()) return true; else if (getType() > r.getType()) return false; if (getPort() < r.getPort()) return true; else if (getPort() > r.getPort()) return false; switch (getType()) { case AF_INET: return as().sin_addr.s_addr < r.as().sin_addr.s_addr; case AF_INET6: return memcmp(as().sin6_addr.s6_addr, r.as().sin6_addr.s6_addr, sizeof(in6_addr)) < 0; case AF_UNIX: return strcmp(as().sun_path, r.as().sun_path) < 0; case AF_UNSPEC: return false; default: massert(SOCK_FAMILY_UNKNOWN_ERROR, "unsupported address family", false); } return false; } SockAddr unknownAddress( "0.0.0.0", 0 ); // ------ hostname ------------------- string hostbyname(const char *hostname) { string addr = SockAddr(hostname, 0).getAddr(); if (addr == "0.0.0.0") return ""; else return addr; } // --- my -- string getHostName() { char buf[256]; int ec = gethostname(buf, 127); if ( ec || *buf == 0 ) { log() << "can't get this server's hostname " << errnoWithDescription() << endl; return ""; } return buf; } string _hostNameCached; static void _hostNameCachedInit() { _hostNameCached = getHostName(); } boost::once_flag _hostNameCachedInitFlags = BOOST_ONCE_INIT; string getHostNameCached() { boost::call_once( _hostNameCachedInit , _hostNameCachedInitFlags ); return _hostNameCached; } // --------- SocketException ---------- #ifdef MSG_NOSIGNAL const int portSendFlags = MSG_NOSIGNAL; const int portRecvFlags = MSG_NOSIGNAL; #else const int portSendFlags = 0; const int portRecvFlags = 0; #endif string SocketException::toString() const { stringstream ss; ss << _ei.code << " socket exception [" << _type << "] "; if ( _server.size() ) ss << "server [" << _server << "] "; if ( _extra.size() ) ss << _extra; return ss.str(); } // ------------ SSLManager ----------------- #ifdef MONGO_SSL SSLManager::SSLManager( bool client ) { _client = client; SSL_library_init(); SSL_load_error_strings(); ERR_load_crypto_strings(); _context = SSL_CTX_new( client ? SSLv23_client_method() : SSLv23_server_method() ); massert( 15864 , mongoutils::str::stream() << "can't create SSL Context: " << ERR_error_string(ERR_get_error(), NULL) , _context ); SSL_CTX_set_options( _context, SSL_OP_ALL); } void SSLManager::setupPubPriv( const string& privateKeyFile , const string& publicKeyFile ) { massert( 15865 , mongoutils::str::stream() << "Can't read SSL certificate from file " << publicKeyFile << ":" << ERR_error_string(ERR_get_error(), NULL) , SSL_CTX_use_certificate_file(_context, publicKeyFile.c_str(), SSL_FILETYPE_PEM) ); massert( 15866 , mongoutils::str::stream() << "Can't read SSL private key from file " << privateKeyFile << " : " << ERR_error_string(ERR_get_error(), NULL) , SSL_CTX_use_PrivateKey_file(_context, privateKeyFile.c_str(), SSL_FILETYPE_PEM) ); } int SSLManager::password_cb(char *buf,int num, int rwflag,void *userdata){ SSLManager* sm = (SSLManager*)userdata; string pass = sm->_password; strcpy(buf,pass.c_str()); return(pass.size()); } void SSLManager::setupPEM( const string& keyFile , const string& password ) { _password = password; massert( 15867 , "Can't read certificate file" , SSL_CTX_use_certificate_chain_file( _context , keyFile.c_str() ) ); SSL_CTX_set_default_passwd_cb_userdata( _context , this ); SSL_CTX_set_default_passwd_cb( _context, &SSLManager::password_cb ); massert( 15868 , "Can't read key file" , SSL_CTX_use_PrivateKey_file( _context , keyFile.c_str() , SSL_FILETYPE_PEM ) ); } SSL * SSLManager::secure( int fd ) { SSL * ssl = SSL_new( _context ); massert( 15861 , "can't create SSL" , ssl ); SSL_set_fd( ssl , fd ); return ssl; } #endif // ------------ Socket ----------------- Socket::Socket(int fd , const SockAddr& remote) : _fd(fd), _remote(remote), _timeout(0) { _logLevel = 0; _init(); } Socket::Socket( double timeout, int ll ) { _logLevel = ll; _fd = -1; _timeout = timeout; _init(); } void Socket::_init() { _bytesOut = 0; _bytesIn = 0; #ifdef MONGO_SSL _sslAccepted = 0; #endif } void Socket::close() { #ifdef MONGO_SSL _ssl.reset(); #endif if ( _fd >= 0 ) { closesocket( _fd ); _fd = -1; } } #ifdef MONGO_SSL void Socket::secure( SSLManager * ssl ) { assert( ssl ); assert( _fd >= 0 ); _ssl.reset( ssl->secure( _fd ) ); SSL_connect( _ssl.get() ); } void Socket::secureAccepted( SSLManager * ssl ) { _sslAccepted = ssl; } #endif void Socket::postFork() { #ifdef MONGO_SSL if ( _sslAccepted ) { assert( _fd ); _ssl.reset( _sslAccepted->secure( _fd ) ); SSL_accept( _ssl.get() ); _sslAccepted = 0; } #endif } class ConnectBG : public BackgroundJob { public: ConnectBG(int sock, SockAddr remote) : _sock(sock), _remote(remote) { } void run() { _res = ::connect(_sock, _remote.raw(), _remote.addressSize); } string name() const { return "ConnectBG"; } int inError() const { return _res; } private: int _sock; int _res; SockAddr _remote; }; bool Socket::connect(SockAddr& remote) { _remote = remote; _fd = socket(remote.getType(), SOCK_STREAM, 0); if ( _fd == INVALID_SOCKET ) { log(_logLevel) << "ERROR: connect invalid socket " << errnoWithDescription() << endl; return false; } if ( _timeout > 0 ) { setTimeout( _timeout ); } ConnectBG bg(_fd, remote); bg.go(); if ( bg.wait(5000) ) { if ( bg.inError() ) { close(); return false; } } else { // time out the connect close(); bg.wait(); // so bg stays in scope until bg thread terminates return false; } if (remote.getType() != AF_UNIX) disableNagle(_fd); #ifdef SO_NOSIGPIPE // osx const int one = 1; setsockopt( _fd , SOL_SOCKET, SO_NOSIGPIPE, &one, sizeof(int)); #endif return true; } int Socket::_send( const char * data , int len ) { #ifdef MONGO_SSL if ( _ssl ) { return SSL_write( _ssl.get() , data , len ); } #endif return ::send( _fd , data , len , portSendFlags ); } // sends all data or throws an exception void Socket::send( const char * data , int len, const char *context ) { while( len > 0 ) { int ret = _send( data , len ); if ( ret == -1 ) { #ifdef MONGO_SSL if ( _ssl ) { log() << "SSL Error ret: " << ret << " err: " << SSL_get_error( _ssl.get() , ret ) << " " << ERR_error_string(ERR_get_error(), NULL) << endl; } #endif #if defined(_WIN32) if ( WSAGetLastError() == WSAETIMEDOUT && _timeout != 0 ) { #else if ( ( errno == EAGAIN || errno == EWOULDBLOCK ) && _timeout != 0 ) { #endif log(_logLevel) << "Socket " << context << " send() timed out " << _remote.toString() << endl; throw SocketException( SocketException::SEND_TIMEOUT , remoteString() ); } else { SocketException::Type t = SocketException::SEND_ERROR; log(_logLevel) << "Socket " << context << " send() " << errnoWithDescription() << ' ' << remoteString() << endl; throw SocketException( t , remoteString() ); } } else { _bytesOut += ret; assert( ret <= len ); len -= ret; data += ret; } } } void Socket::_send( const vector< pair< char *, int > > &data, const char *context ) { for( vector< pair< char *, int > >::const_iterator i = data.begin(); i != data.end(); ++i ) { char * data = i->first; int len = i->second; send( data, len, context ); } } // sends all data or throws an exception void Socket::send( const vector< pair< char *, int > > &data, const char *context ) { #ifdef MONGO_SSL if ( _ssl ) { _send( data , context ); return; } #endif #if defined(_WIN32) // TODO use scatter/gather api _send( data , context ); #else vector< struct iovec > d( data.size() ); int i = 0; for( vector< pair< char *, int > >::const_iterator j = data.begin(); j != data.end(); ++j ) { if ( j->second > 0 ) { d[ i ].iov_base = j->first; d[ i ].iov_len = j->second; ++i; _bytesOut += j->second; } } struct msghdr meta; memset( &meta, 0, sizeof( meta ) ); meta.msg_iov = &d[ 0 ]; meta.msg_iovlen = d.size(); while( meta.msg_iovlen > 0 ) { int ret = ::sendmsg( _fd , &meta , portSendFlags ); if ( ret == -1 ) { if ( errno != EAGAIN || _timeout == 0 ) { log(_logLevel) << "Socket " << context << " send() " << errnoWithDescription() << ' ' << remoteString() << endl; throw SocketException( SocketException::SEND_ERROR , remoteString() ); } else { log(_logLevel) << "Socket " << context << " send() remote timeout " << remoteString() << endl; throw SocketException( SocketException::SEND_TIMEOUT , remoteString() ); } } else { struct iovec *& i = meta.msg_iov; while( ret > 0 ) { if ( i->iov_len > unsigned( ret ) ) { i->iov_len -= ret; i->iov_base = (char*)(i->iov_base) + ret; ret = 0; } else { ret -= i->iov_len; ++i; --(meta.msg_iovlen); } } } } #endif } void Socket::recv( char * buf , int len ) { unsigned retries = 0; while( len > 0 ) { int ret = unsafe_recv( buf , len ); if ( ret > 0 ) { if ( len <= 4 && ret != len ) log(_logLevel) << "Socket recv() got " << ret << " bytes wanted len=" << len << endl; assert( ret <= len ); len -= ret; buf += ret; } else if ( ret == 0 ) { log(3) << "Socket recv() conn closed? " << remoteString() << endl; throw SocketException( SocketException::CLOSED , remoteString() ); } else { /* ret < 0 */ #if defined(_WIN32) int e = WSAGetLastError(); #else int e = errno; # if defined(EINTR) if( e == EINTR ) { if( ++retries == 1 ) { log() << "EINTR retry" << endl; continue; } } # endif #endif if ( ( e == EAGAIN #if defined(_WIN32) || e == WSAETIMEDOUT #endif ) && _timeout > 0 ) { // this is a timeout log(_logLevel) << "Socket recv() timeout " << remoteString() < 3; // solaris doesn't provide these DEV report = true; bool ok = setsockopt(_fd, SOL_SOCKET, SO_RCVTIMEO, (char *) &tv, sizeof(tv) ) == 0; if( report && !ok ) log() << "unabled to set SO_RCVTIMEO" << endl; ok = setsockopt(_fd, SOL_SOCKET, SO_SNDTIMEO, (char *) &tv, sizeof(tv) ) == 0; DEV if( report && !ok ) log() << "unabled to set SO_RCVTIMEO" << endl; } #if defined(_WIN32) struct WinsockInit { WinsockInit() { WSADATA d; if ( WSAStartup(MAKEWORD(2,2), &d) != 0 ) { out() << "ERROR: wsastartup failed " << errnoWithDescription() << endl; problem() << "ERROR: wsastartup failed " << errnoWithDescription() << endl; dbexit( EXIT_NTSERVICE_ERROR ); } } } winsock_init; #endif } // namespace mongo