diff options
Diffstat (limited to 'util/net')
-rw-r--r-- | util/net/hostandport.h | 165 | ||||
-rw-r--r-- | util/net/httpclient.cpp | 177 | ||||
-rw-r--r-- | util/net/httpclient.h | 79 | ||||
-rw-r--r-- | util/net/listen.cpp | 391 | ||||
-rw-r--r-- | util/net/listen.h | 190 | ||||
-rw-r--r-- | util/net/message.cpp | 64 | ||||
-rw-r--r-- | util/net/message.h | 312 | ||||
-rw-r--r-- | util/net/message_port.cpp | 298 | ||||
-rw-r--r-- | util/net/message_port.h | 107 | ||||
-rw-r--r-- | util/net/message_server.h | 66 | ||||
-rw-r--r-- | util/net/message_server_asio.cpp | 261 | ||||
-rw-r--r-- | util/net/message_server_port.cpp | 197 | ||||
-rw-r--r-- | util/net/miniwebserver.cpp | 207 | ||||
-rw-r--r-- | util/net/miniwebserver.h | 60 | ||||
-rw-r--r-- | util/net/sock.cpp | 713 | ||||
-rw-r--r-- | util/net/sock.h | 256 |
16 files changed, 3543 insertions, 0 deletions
diff --git a/util/net/hostandport.h b/util/net/hostandport.h new file mode 100644 index 0000000..573e8ee --- /dev/null +++ b/util/net/hostandport.h @@ -0,0 +1,165 @@ +// hostandport.h + +/* Copyright 2009 10gen Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include "sock.h" +#include "../../db/cmdline.h" +#include "../mongoutils/str.h" + +namespace mongo { + + using namespace mongoutils; + + /** helper for manipulating host:port connection endpoints. + */ + struct HostAndPort { + HostAndPort() : _port(-1) { } + + /** From a string hostname[:portnumber] + Throws user assertion if bad config string or bad port #. + */ + HostAndPort(string s); + + /** @param p port number. -1 is ok to use default. */ + HostAndPort(string h, int p /*= -1*/) : _host(h), _port(p) { } + + HostAndPort(const SockAddr& sock ) + : _host( sock.getAddr() ) , _port( sock.getPort() ) { + } + + static HostAndPort me() { + return HostAndPort("localhost", cmdLine.port); + } + + /* uses real hostname instead of localhost */ + static HostAndPort Me(); + + bool operator<(const HostAndPort& r) const { + if( _host < r._host ) + return true; + if( _host == r._host ) + return port() < r.port(); + return false; + } + + bool operator==(const HostAndPort& r) const { + return _host == r._host && port() == r.port(); + } + + bool operator!=(const HostAndPort& r) const { + return _host != r._host || port() != r.port(); + } + + /* returns true if the host/port combo identifies this process instance. */ + bool isSelf() const; // defined in message.cpp + + bool isLocalHost() const; + + /** + * @param includePort host:port if true, host otherwise + */ + string toString( bool includePort=true ) const; + + operator string() const { return toString(); } + + string host() const { return _host; } + + int port() const { return _port >= 0 ? _port : CmdLine::DefaultDBPort; } + bool hasPort() const { return _port >= 0; } + void setPort( int port ) { _port = port; } + + private: + // invariant (except full obj assignment): + string _host; + int _port; // -1 indicates unspecified + }; + + inline HostAndPort HostAndPort::Me() { + const char* ips = cmdLine.bind_ip.c_str(); + while(*ips) { + string ip; + const char * comma = strchr(ips, ','); + if (comma) { + ip = string(ips, comma - ips); + ips = comma + 1; + } + else { + ip = string(ips); + ips = ""; + } + HostAndPort h = HostAndPort(ip, cmdLine.port); + if (!h.isLocalHost()) { + return h; + } + } + + string h = getHostName(); + assert( !h.empty() ); + assert( h != "localhost" ); + return HostAndPort(h, cmdLine.port); + } + + inline string HostAndPort::toString( bool includePort ) const { + if ( ! includePort ) + return _host; + + stringstream ss; + ss << _host; + if ( _port != -1 ) { + ss << ':'; +#if defined(_DEBUG) + if( _port >= 44000 && _port < 44100 ) { + log() << "warning: special debug port 44xxx used" << endl; + ss << _port+1; + } + else + ss << _port; +#else + ss << _port; +#endif + } + return ss.str(); + } + + inline bool HostAndPort::isLocalHost() const { + return ( _host == "localhost" + || startsWith(_host.c_str(), "127.") + || _host == "::1" + || _host == "anonymous unix socket" + || _host.c_str()[0] == '/' // unix socket + ); + } + + inline HostAndPort::HostAndPort(string s) { + const char *p = s.c_str(); + uassert(13110, "HostAndPort: bad config string", *p); + const char *colon = strrchr(p, ':'); + if( colon ) { + int port = atoi(colon+1); + uassert(13095, "HostAndPort: bad port #", port > 0); + _host = string(p,colon-p); + _port = port; + } + else { + // no port specified. + _host = p; + _port = -1; + } + } + +} diff --git a/util/net/httpclient.cpp b/util/net/httpclient.cpp new file mode 100644 index 0000000..16eaa0a --- /dev/null +++ b/util/net/httpclient.cpp @@ -0,0 +1,177 @@ +// httpclient.cpp + +/* Copyright 2009 10gen Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "pch.h" +#include "httpclient.h" +#include "sock.h" +#include "message.h" +#include "message_port.h" +#include "../mongoutils/str.h" +#include "../../bson/util/builder.h" + +namespace mongo { + + //#define HD(x) cout << x << endl; +#define HD(x) + + + int HttpClient::get( string url , Result * result ) { + return _go( "GET" , url , 0 , result ); + } + + int HttpClient::post( string url , string data , Result * result ) { + return _go( "POST" , url , data.c_str() , result ); + } + + int HttpClient::_go( const char * command , string url , const char * body , Result * result ) { + bool ssl = false; + if ( url.find( "https://" ) == 0 ) { + ssl = true; + url = url.substr( 8 ); + } + else { + uassert( 10271 , "invalid url" , url.find( "http://" ) == 0 ); + url = url.substr( 7 ); + } + + string host , path; + if ( url.find( "/" ) == string::npos ) { + host = url; + path = "/"; + } + else { + host = url.substr( 0 , url.find( "/" ) ); + path = url.substr( url.find( "/" ) ); + } + + + HD( "host [" << host << "]" ); + HD( "path [" << path << "]" ); + + string server = host; + int port = ssl ? 443 : 80; + + string::size_type idx = host.find( ":" ); + if ( idx != string::npos ) { + server = host.substr( 0 , idx ); + string t = host.substr( idx + 1 ); + port = atoi( t.c_str() ); + } + + HD( "server [" << server << "]" ); + HD( "port [" << port << "]" ); + + string req; + { + stringstream ss; + ss << command << " " << path << " HTTP/1.1\r\n"; + ss << "Host: " << host << "\r\n"; + ss << "Connection: Close\r\n"; + ss << "User-Agent: mongodb http client\r\n"; + if ( body ) { + ss << "Content-Length: " << strlen( body ) << "\r\n"; + } + ss << "\r\n"; + if ( body ) { + ss << body; + } + + req = ss.str(); + } + + SockAddr addr( server.c_str() , port ); + HD( "addr: " << addr.toString() ); + + Socket sock; + if ( ! sock.connect( addr ) ) + return -1; + + if ( ssl ) { +#ifdef MONGO_SSL + _checkSSLManager(); + sock.secure( _sslManager.get() ); +#else + uasserted( 15862 , "no ssl support" ); +#endif + } + + { + const char * out = req.c_str(); + int toSend = req.size(); + sock.send( out , toSend, "_go" ); + } + + char buf[4096]; + int got = sock.unsafe_recv( buf , 4096 ); + buf[got] = 0; + + int rc; + char version[32]; + assert( sscanf( buf , "%s %d" , version , &rc ) == 2 ); + HD( "rc: " << rc ); + + StringBuilder sb; + if ( result ) + sb << buf; + + while ( ( got = sock.unsafe_recv( buf , 4096 ) ) > 0) { + if ( result ) + sb << buf; + } + + if ( result ) { + result->_init( rc , sb.str() ); + } + + return rc; + } + + void HttpClient::Result::_init( int code , string entire ) { + _code = code; + _entireResponse = entire; + + while ( true ) { + size_t i = entire.find( '\n' ); + if ( i == string::npos ) { + // invalid + break; + } + + string h = entire.substr( 0 , i ); + entire = entire.substr( i + 1 ); + + if ( h.size() && h[h.size()-1] == '\r' ) + h = h.substr( 0 , h.size() - 1 ); + + if ( h.size() == 0 ) + break; + + i = h.find( ':' ); + if ( i != string::npos ) + _headers[h.substr(0,i)] = str::ltrim(h.substr(i+1)); + } + + _body = entire; + } + +#ifdef MONGO_SSL + void HttpClient::_checkSSLManager() { + _sslManager.reset( new SSLManager( true ) ); + } +#endif + +} diff --git a/util/net/httpclient.h b/util/net/httpclient.h new file mode 100644 index 0000000..c3f8c82 --- /dev/null +++ b/util/net/httpclient.h @@ -0,0 +1,79 @@ +// httpclient.h + +/* Copyright 2009 10gen Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include "../../pch.h" +#include "sock.h" + +namespace mongo { + + class HttpClient : boost::noncopyable { + public: + + typedef map<string,string> Headers; + + class Result { + public: + Result() {} + + const string& getEntireResponse() const { + return _entireResponse; + } + + const Headers getHeaders() const { + return _headers; + } + + const string& getBody() const { + return _body; + } + + private: + + void _init( int code , string entire ); + + int _code; + string _entireResponse; + + Headers _headers; + string _body; + + friend class HttpClient; + }; + + /** + * @return response code + */ + int get( string url , Result * result = 0 ); + + /** + * @return response code + */ + int post( string url , string body , Result * result = 0 ); + + private: + int _go( const char * command , string url , const char * body , Result * result ); + +#ifdef MONGO_SSL + void _checkSSLManager(); + + scoped_ptr<SSLManager> _sslManager; +#endif + }; +} + diff --git a/util/net/listen.cpp b/util/net/listen.cpp new file mode 100644 index 0000000..6ee25b4 --- /dev/null +++ b/util/net/listen.cpp @@ -0,0 +1,391 @@ +// listen.h + +/* Copyright 2009 10gen Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + + +#include "pch.h" +#include "listen.h" +#include "message_port.h" + +#ifndef _WIN32 + +# ifndef __sunos__ +# include <ifaddrs.h> +# endif +# include <sys/resource.h> +# include <sys/stat.h> + +#include <sys/types.h> +#include <sys/socket.h> +#include <sys/un.h> +#include <netinet/in.h> +#include <netinet/tcp.h> +#include <arpa/inet.h> +#include <errno.h> +#include <netdb.h> +#ifdef __openbsd__ +# include <sys/uio.h> +#endif + +#else + +// errno doesn't work for winsock. +#undef errno +#define errno WSAGetLastError() + +#endif + +namespace mongo { + + + void checkTicketNumbers(); + + + // ----- Listener ------- + + const Listener* Listener::_timeTracker; + + vector<SockAddr> ipToAddrs(const char* ips, int port, bool useUnixSockets) { + vector<SockAddr> out; + if (*ips == '\0') { + out.push_back(SockAddr("0.0.0.0", port)); // IPv4 all + + if (IPv6Enabled()) + out.push_back(SockAddr("::", port)); // IPv6 all +#ifndef _WIN32 + if (useUnixSockets) + out.push_back(SockAddr(makeUnixSockPath(port).c_str(), port)); // Unix socket +#endif + return out; + } + + while(*ips) { + string ip; + const char * comma = strchr(ips, ','); + if (comma) { + ip = string(ips, comma - ips); + ips = comma + 1; + } + else { + ip = string(ips); + ips = ""; + } + + SockAddr sa(ip.c_str(), port); + out.push_back(sa); + +#ifndef _WIN32 + if (useUnixSockets && (sa.getAddr() == "127.0.0.1" || sa.getAddr() == "0.0.0.0")) // only IPv4 + out.push_back(SockAddr(makeUnixSockPath(port).c_str(), port)); +#endif + } + return out; + + } + + Listener::Listener(const string& name, const string &ip, int port, bool logConnect ) + : _port(port), _name(name), _ip(ip), _logConnect(logConnect), _elapsedTime(0) { +#ifdef MONGO_SSL + _ssl = 0; + _sslPort = 0; + + if ( cmdLine.sslOnNormalPorts && cmdLine.sslServerManager ) { + secure( cmdLine.sslServerManager ); + } +#endif + } + + Listener::~Listener() { + if ( _timeTracker == this ) + _timeTracker = 0; + } + +#ifdef MONGO_SSL + void Listener::secure( SSLManager* manager ) { + _ssl = manager; + } + + void Listener::addSecurePort( SSLManager* manager , int additionalPort ) { + _ssl = manager; + _sslPort = additionalPort; + } + +#endif + + bool Listener::_setupSockets( const vector<SockAddr>& mine , vector<int>& socks ) { + for (vector<SockAddr>::const_iterator it=mine.begin(), end=mine.end(); it != end; ++it) { + const SockAddr& me = *it; + + SOCKET sock = ::socket(me.getType(), SOCK_STREAM, 0); + massert( 15863 , str::stream() << "listen(): invalid socket? " << errnoWithDescription() , sock >= 0 ); + + if (me.getType() == AF_UNIX) { +#if !defined(_WIN32) + if (unlink(me.getAddr().c_str()) == -1) { + int x = errno; + if (x != ENOENT) { + log() << "couldn't unlink socket file " << me << errnoWithDescription(x) << " skipping" << endl; + continue; + } + } +#endif + } + else if (me.getType() == AF_INET6) { + // IPv6 can also accept IPv4 connections as mapped addresses (::ffff:127.0.0.1) + // That causes a conflict if we don't do set it to IPV6_ONLY + const int one = 1; + setsockopt(sock, IPPROTO_IPV6, IPV6_V6ONLY, (const char*) &one, sizeof(one)); + } + +#if !defined(_WIN32) + { + const int one = 1; + if ( setsockopt( sock , SOL_SOCKET, SO_REUSEADDR, &one, sizeof(one)) < 0 ) + out() << "Failed to set socket opt, SO_REUSEADDR" << endl; + } +#endif + + if ( ::bind(sock, me.raw(), me.addressSize) != 0 ) { + int x = errno; + error() << "listen(): bind() failed " << errnoWithDescription(x) << " for socket: " << me.toString() << endl; + if ( x == EADDRINUSE ) + error() << " addr already in use" << endl; + closesocket(sock); + return false; + } + +#if !defined(_WIN32) + if (me.getType() == AF_UNIX) { + if (chmod(me.getAddr().c_str(), 0777) == -1) { + error() << "couldn't chmod socket file " << me << errnoWithDescription() << endl; + } + ListeningSockets::get()->addPath( me.getAddr() ); + } +#endif + + if ( ::listen(sock, 128) != 0 ) { + error() << "listen(): listen() failed " << errnoWithDescription() << endl; + closesocket(sock); + return false; + } + + ListeningSockets::get()->add( sock ); + + socks.push_back(sock); + } + + return true; + } + + void Listener::initAndListen() { + checkTicketNumbers(); + vector<int> socks; + set<int> sslSocks; + + { // normal sockets + vector<SockAddr> mine = ipToAddrs(_ip.c_str(), _port, (!cmdLine.noUnixSocket && useUnixSockets())); + if ( ! _setupSockets( mine , socks ) ) + return; + } + +#ifdef MONGO_SSL + if ( _ssl && _sslPort > 0 ) { + unsigned prev = socks.size(); + + vector<SockAddr> mine = ipToAddrs(_ip.c_str(), _sslPort, false ); + if ( ! _setupSockets( mine , socks ) ) + return; + + for ( unsigned i=prev; i<socks.size(); i++ ) { + sslSocks.insert( socks[i] ); + } + + } +#endif + + SOCKET maxfd = 0; // needed for select() + for ( unsigned i=0; i<socks.size(); i++ ) { + if ( socks[i] > maxfd ) + maxfd = socks[i]; + } + +#ifdef MONGO_SSL + if ( _ssl == 0 ) { + _logListen( _port , false ); + } + else if ( _sslPort == 0 ) { + _logListen( _port , true ); + } + else { + // both + _logListen( _port , false ); + _logListen( _sslPort , true ); + } +#else + _logListen( _port , false ); +#endif + + static long connNumber = 0; + struct timeval maxSelectTime; + while ( ! inShutdown() ) { + fd_set fds[1]; + FD_ZERO(fds); + + for (vector<int>::iterator it=socks.begin(), end=socks.end(); it != end; ++it) { + FD_SET(*it, fds); + } + + maxSelectTime.tv_sec = 0; + maxSelectTime.tv_usec = 10000; + const int ret = select(maxfd+1, fds, NULL, NULL, &maxSelectTime); + + if (ret == 0) { +#if defined(__linux__) + _elapsedTime += ( 10000 - maxSelectTime.tv_usec ) / 1000; +#else + _elapsedTime += 10; +#endif + continue; + } + + if (ret < 0) { + int x = errno; +#ifdef EINTR + if ( x == EINTR ) { + log() << "select() signal caught, continuing" << endl; + continue; + } +#endif + if ( ! inShutdown() ) + log() << "select() failure: ret=" << ret << " " << errnoWithDescription(x) << endl; + return; + } + +#if defined(__linux__) + _elapsedTime += max(ret, (int)(( 10000 - maxSelectTime.tv_usec ) / 1000)); +#else + _elapsedTime += ret; // assume 1ms to grab connection. very rough +#endif + + for (vector<int>::iterator it=socks.begin(), end=socks.end(); it != end; ++it) { + if (! (FD_ISSET(*it, fds))) + continue; + + SockAddr from; + int s = accept(*it, from.raw(), &from.addressSize); + if ( s < 0 ) { + int x = errno; // so no global issues + if ( x == ECONNABORTED || x == EBADF ) { + log() << "Listener on port " << _port << " aborted" << endl; + return; + } + if ( x == 0 && inShutdown() ) { + return; // socket closed + } + if( !inShutdown() ) { + log() << "Listener: accept() returns " << s << " " << errnoWithDescription(x) << endl; + if (x == EMFILE || x == ENFILE) { + // Connection still in listen queue but we can't accept it yet + error() << "Out of file descriptors. Waiting one second before trying to accept more connections." << warnings; + sleepsecs(1); + } + } + continue; + } + if (from.getType() != AF_UNIX) + disableNagle(s); + if ( _logConnect && ! cmdLine.quiet ) + log() << "connection accepted from " << from.toString() << " #" << ++connNumber << endl; + + Socket newSock = Socket(s, from); +#ifdef MONGO_SSL + if ( _ssl && ( _sslPort == 0 || sslSocks.count(*it) ) ) { + newSock.secureAccepted( _ssl ); + } +#endif + accepted( newSock ); + } + } + } + + void Listener::_logListen( int port , bool ssl ) { + log() << _name << ( _name.size() ? " " : "" ) << "waiting for connections on port " << port << ( ssl ? " ssl" : "" ) << endl; + } + + + void Listener::accepted(Socket socket) { + accepted( new MessagingPort(socket) ); + } + + void Listener::accepted(MessagingPort *mp) { + assert(!"You must overwrite one of the accepted methods"); + } + + // ----- ListeningSockets ------- + + ListeningSockets* ListeningSockets::_instance = new ListeningSockets(); + + ListeningSockets* ListeningSockets::get() { + return _instance; + } + + // ------ connection ticket and control ------ + + const int DEFAULT_MAX_CONN = 20000; + const int MAX_MAX_CONN = 20000; + + int getMaxConnections() { +#ifdef _WIN32 + return DEFAULT_MAX_CONN; +#else + struct rlimit limit; + assert( getrlimit(RLIMIT_NOFILE,&limit) == 0 ); + + int max = (int)(limit.rlim_cur * .8); + + log(1) << "fd limit" + << " hard:" << limit.rlim_max + << " soft:" << limit.rlim_cur + << " max conn: " << max + << endl; + + if ( max > MAX_MAX_CONN ) + max = MAX_MAX_CONN; + + return max; +#endif + } + + void checkTicketNumbers() { + int want = getMaxConnections(); + int current = connTicketHolder.outof(); + if ( current != DEFAULT_MAX_CONN ) { + if ( current < want ) { + // they want fewer than they can handle + // which is fine + log(1) << " only allowing " << current << " connections" << endl; + return; + } + if ( current > want ) { + log() << " --maxConns too high, can only handle " << want << endl; + } + } + connTicketHolder.resize( want ); + } + + TicketHolder connTicketHolder(DEFAULT_MAX_CONN); + +} diff --git a/util/net/listen.h b/util/net/listen.h new file mode 100644 index 0000000..415db1e --- /dev/null +++ b/util/net/listen.h @@ -0,0 +1,190 @@ +// listen.h + +/* Copyright 2009 10gen Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include "sock.h" + +namespace mongo { + + class MessagingPort; + + class Listener : boost::noncopyable { + public: + + Listener(const string& name, const string &ip, int port, bool logConnect=true ); + + virtual ~Listener(); + +#ifdef MONGO_SSL + /** + * make this an ssl socket + * ownership of SSLManager remains with the caller + */ + void secure( SSLManager* manager ); + + void addSecurePort( SSLManager* manager , int additionalPort ); +#endif + + void initAndListen(); // never returns unless error (start a thread) + + /* spawn a thread, etc., then return */ + virtual void accepted(Socket socket); + virtual void accepted(MessagingPort *mp); + + const int _port; + + /** + * @return a rough estimate of elapsed time since the server started + */ + long long getMyElapsedTimeMillis() const { return _elapsedTime; } + + void setAsTimeTracker() { + _timeTracker = this; + } + + static const Listener* getTimeTracker() { + return _timeTracker; + } + + static long long getElapsedTimeMillis() { + if ( _timeTracker ) + return _timeTracker->getMyElapsedTimeMillis(); + + // should this assert or throw? seems like callers may not expect to get zero back, certainly not forever. + return 0; + } + + private: + string _name; + string _ip; + bool _logConnect; + long long _elapsedTime; + +#ifdef MONGO_SSL + SSLManager* _ssl; + int _sslPort; +#endif + + /** + * @return true iff everything went ok + */ + bool _setupSockets( const vector<SockAddr>& mine , vector<int>& socks ); + + void _logListen( int port , bool ssl ); + + static const Listener* _timeTracker; + + virtual bool useUnixSockets() const { return false; } + }; + + /** + * keep track of elapsed time + * after a set amount of time, tells you to do something + * only in this file because depends on Listener + */ + class ElapsedTracker { + public: + ElapsedTracker( int hitsBetweenMarks , int msBetweenMarks ) + : _h( hitsBetweenMarks ) , _ms( msBetweenMarks ) , _pings(0) { + _last = Listener::getElapsedTimeMillis(); + } + + /** + * call this for every iteration + * returns true if one of the triggers has gone off + */ + bool ping() { + if ( ( ++_pings % _h ) == 0 ) { + _last = Listener::getElapsedTimeMillis(); + return true; + } + + long long now = Listener::getElapsedTimeMillis(); + if ( now - _last > _ms ) { + _last = now; + return true; + } + + return false; + } + + private: + int _h; + int _ms; + + unsigned long long _pings; + + long long _last; + + }; + + class ListeningSockets { + public: + ListeningSockets() + : _mutex("ListeningSockets") + , _sockets( new set<int>() ) + , _socketPaths( new set<string>() ) + { } + void add( int sock ) { + scoped_lock lk( _mutex ); + _sockets->insert( sock ); + } + void addPath( string path ) { + scoped_lock lk( _mutex ); + _socketPaths->insert( path ); + } + void remove( int sock ) { + scoped_lock lk( _mutex ); + _sockets->erase( sock ); + } + void closeAll() { + set<int>* sockets; + set<string>* paths; + + { + scoped_lock lk( _mutex ); + sockets = _sockets; + _sockets = new set<int>(); + paths = _socketPaths; + _socketPaths = new set<string>(); + } + + for ( set<int>::iterator i=sockets->begin(); i!=sockets->end(); i++ ) { + int sock = *i; + log() << "closing listening socket: " << sock << endl; + closesocket( sock ); + } + + for ( set<string>::iterator i=paths->begin(); i!=paths->end(); i++ ) { + string path = *i; + log() << "removing socket file: " << path << endl; + ::remove( path.c_str() ); + } + } + static ListeningSockets* get(); + private: + mongo::mutex _mutex; + set<int>* _sockets; + set<string>* _socketPaths; // for unix domain sockets + static ListeningSockets* _instance; + }; + + + extern TicketHolder connTicketHolder; + +} diff --git a/util/net/message.cpp b/util/net/message.cpp new file mode 100644 index 0000000..a84e5c4 --- /dev/null +++ b/util/net/message.cpp @@ -0,0 +1,64 @@ +// message.cpp + +/* Copyright 2009 10gen Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "pch.h" + +#include <fcntl.h> +#include <errno.h> +#include <time.h> + +#include "message.h" +#include "message_port.h" +#include "listen.h" + +#include "../goodies.h" +#include "../../client/dbclient.h" + +namespace mongo { + + void Message::send( MessagingPort &p, const char *context ) { + if ( empty() ) { + return; + } + if ( _buf != 0 ) { + p.send( (char*)_buf, _buf->len, context ); + } + else { + p.send( _data, context ); + } + } + + MSGID NextMsgId; + + /*struct MsgStart { + MsgStart() { + NextMsgId = (((unsigned) time(0)) << 16) ^ curTimeMillis(); + assert(MsgDataHeaderSize == 16); + } + } msgstart;*/ + + MSGID nextMessageId() { + MSGID msgid = NextMsgId++; + return msgid; + } + + bool doesOpGetAResponse( int op ) { + return op == dbQuery || op == dbGetMore; + } + + +} // namespace mongo diff --git a/util/net/message.h b/util/net/message.h new file mode 100644 index 0000000..16da5d6 --- /dev/null +++ b/util/net/message.h @@ -0,0 +1,312 @@ +// message.h + +/* Copyright 2009 10gen Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include "sock.h" +#include "../../bson/util/atomic_int.h" +#include "hostandport.h" + +namespace mongo { + + class Message; + class MessagingPort; + class PiggyBackData; + + typedef AtomicUInt MSGID; + + enum Operations { + opReply = 1, /* reply. responseTo is set. */ + dbMsg = 1000, /* generic msg command followed by a string */ + dbUpdate = 2001, /* update object */ + dbInsert = 2002, + //dbGetByOID = 2003, + dbQuery = 2004, + dbGetMore = 2005, + dbDelete = 2006, + dbKillCursors = 2007 + }; + + bool doesOpGetAResponse( int op ); + + inline const char * opToString( int op ) { + switch ( op ) { + case 0: return "none"; + case opReply: return "reply"; + case dbMsg: return "msg"; + case dbUpdate: return "update"; + case dbInsert: return "insert"; + case dbQuery: return "query"; + case dbGetMore: return "getmore"; + case dbDelete: return "remove"; + case dbKillCursors: return "killcursors"; + default: + PRINT(op); + assert(0); + return ""; + } + } + + inline bool opIsWrite( int op ) { + switch ( op ) { + + case 0: + case opReply: + case dbMsg: + case dbQuery: + case dbGetMore: + case dbKillCursors: + return false; + + case dbUpdate: + case dbInsert: + case dbDelete: + return false; + + default: + PRINT(op); + assert(0); + return ""; + } + + } + +#pragma pack(1) + /* see http://www.mongodb.org/display/DOCS/Mongo+Wire+Protocol + */ + struct MSGHEADER { + int messageLength; // total message size, including this + int requestID; // identifier for this message + int responseTo; // requestID from the original request + // (used in reponses from db) + int opCode; + }; + struct OP_GETMORE : public MSGHEADER { + MSGHEADER header; // standard message header + int ZERO_or_flags; // 0 - reserved for future use + //cstring fullCollectionName; // "dbname.collectionname" + //int32 numberToReturn; // number of documents to return + //int64 cursorID; // cursorID from the OP_REPLY + }; +#pragma pack() + +#pragma pack(1) + /* todo merge this with MSGHEADER (or inherit from it). */ + struct MsgData { + int len; /* len of the msg, including this field */ + MSGID id; /* request/reply id's match... */ + MSGID responseTo; /* id of the message we are responding to */ + short _operation; + char _flags; + char _version; + int operation() const { + return _operation; + } + void setOperation(int o) { + _flags = 0; + _version = 0; + _operation = o; + } + char _data[4]; + + int& dataAsInt() { + return *((int *) _data); + } + + bool valid() { + if ( len <= 0 || len > ( 4 * BSONObjMaxInternalSize ) ) + return false; + if ( _operation < 0 || _operation > 30000 ) + return false; + return true; + } + + long long getCursor() { + assert( responseTo > 0 ); + assert( _operation == opReply ); + long long * l = (long long *)(_data + 4); + return l[0]; + } + + int dataLen(); // len without header + }; + const int MsgDataHeaderSize = sizeof(MsgData) - 4; + inline int MsgData::dataLen() { + return len - MsgDataHeaderSize; + } +#pragma pack() + + class Message { + public: + // we assume here that a vector with initial size 0 does no allocation (0 is the default, but wanted to make it explicit). + Message() : _buf( 0 ), _data( 0 ), _freeIt( false ) {} + Message( void * data , bool freeIt ) : + _buf( 0 ), _data( 0 ), _freeIt( false ) { + _setData( reinterpret_cast< MsgData* >( data ), freeIt ); + }; + Message(Message& r) : _buf( 0 ), _data( 0 ), _freeIt( false ) { + *this = r; + } + ~Message() { + reset(); + } + + SockAddr _from; + + MsgData *header() const { + assert( !empty() ); + return _buf ? _buf : reinterpret_cast< MsgData* > ( _data[ 0 ].first ); + } + int operation() const { return header()->operation(); } + + MsgData *singleData() const { + massert( 13273, "single data buffer expected", _buf ); + return header(); + } + + bool empty() const { return !_buf && _data.empty(); } + + int size() const { + int res = 0; + if ( _buf ) { + res = _buf->len; + } + else { + for (MsgVec::const_iterator it = _data.begin(); it != _data.end(); ++it) { + res += it->second; + } + } + return res; + } + + int dataSize() const { return size() - sizeof(MSGHEADER); } + + // concat multiple buffers - noop if <2 buffers already, otherwise can be expensive copy + // can get rid of this if we make response handling smarter + void concat() { + if ( _buf || empty() ) { + return; + } + + assert( _freeIt ); + int totalSize = 0; + for( vector< pair< char *, int > >::const_iterator i = _data.begin(); i != _data.end(); ++i ) { + totalSize += i->second; + } + char *buf = (char*)malloc( totalSize ); + char *p = buf; + for( vector< pair< char *, int > >::const_iterator i = _data.begin(); i != _data.end(); ++i ) { + memcpy( p, i->first, i->second ); + p += i->second; + } + reset(); + _setData( (MsgData*)buf, true ); + } + + // vector swap() so this is fast + Message& operator=(Message& r) { + assert( empty() ); + assert( r._freeIt ); + _buf = r._buf; + r._buf = 0; + if ( r._data.size() > 0 ) { + _data.swap( r._data ); + } + r._freeIt = false; + _freeIt = true; + return *this; + } + + void reset() { + if ( _freeIt ) { + if ( _buf ) { + free( _buf ); + } + for( vector< pair< char *, int > >::const_iterator i = _data.begin(); i != _data.end(); ++i ) { + free(i->first); + } + } + _buf = 0; + _data.clear(); + _freeIt = false; + } + + // use to add a buffer + // assumes message will free everything + void appendData(char *d, int size) { + if ( size <= 0 ) { + return; + } + if ( empty() ) { + MsgData *md = (MsgData*)d; + md->len = size; // can be updated later if more buffers added + _setData( md, true ); + return; + } + assert( _freeIt ); + if ( _buf ) { + _data.push_back( make_pair( (char*)_buf, _buf->len ) ); + _buf = 0; + } + _data.push_back( make_pair( d, size ) ); + header()->len += size; + } + + // use to set first buffer if empty + void setData(MsgData *d, bool freeIt) { + assert( empty() ); + _setData( d, freeIt ); + } + void setData(int operation, const char *msgtxt) { + setData(operation, msgtxt, strlen(msgtxt)+1); + } + void setData(int operation, const char *msgdata, size_t len) { + assert( empty() ); + size_t dataLen = len + sizeof(MsgData) - 4; + MsgData *d = (MsgData *) malloc(dataLen); + memcpy(d->_data, msgdata, len); + d->len = fixEndian(dataLen); + d->setOperation(operation); + _setData( d, true ); + } + + bool doIFreeIt() { + return _freeIt; + } + + void send( MessagingPort &p, const char *context ); + + string toString() const; + + private: + void _setData( MsgData *d, bool freeIt ) { + _freeIt = freeIt; + _buf = d; + } + // if just one buffer, keep it in _buf, otherwise keep a sequence of buffers in _data + MsgData * _buf; + // byte buffer(s) - the first must contain at least a full MsgData unless using _buf for storage instead + typedef vector< pair< char*, int > > MsgVec; + MsgVec _data; + bool _freeIt; + }; + + + MSGID nextMessageId(); + + +} // namespace mongo diff --git a/util/net/message_port.cpp b/util/net/message_port.cpp new file mode 100644 index 0000000..9abfaf7 --- /dev/null +++ b/util/net/message_port.cpp @@ -0,0 +1,298 @@ +// message_port.cpp + +/* Copyright 2009 10gen Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "pch.h" + +#include <fcntl.h> +#include <errno.h> +#include <time.h> + +#include "message.h" +#include "message_port.h" +#include "listen.h" + +#include "../goodies.h" +#include "../background.h" +#include "../time_support.h" +#include "../../db/cmdline.h" +#include "../../client/dbclient.h" + + +#ifndef _WIN32 +# ifndef __sunos__ +# include <ifaddrs.h> +# endif +# include <sys/resource.h> +# include <sys/stat.h> +#else + +// errno doesn't work for winsock. +#undef errno +#define errno WSAGetLastError() + +#endif + +namespace mongo { + + +// if you want trace output: +#define mmm(x) + + /* messagingport -------------------------------------------------------------- */ + + class PiggyBackData { + public: + PiggyBackData( MessagingPort * port ) { + _port = port; + _buf = new char[1300]; + _cur = _buf; + } + + ~PiggyBackData() { + DESTRUCTOR_GUARD ( + flush(); + delete[]( _cur ); + ); + } + + void append( Message& m ) { + assert( m.header()->len <= 1300 ); + + if ( len() + m.header()->len > 1300 ) + flush(); + + memcpy( _cur , m.singleData() , m.header()->len ); + _cur += m.header()->len; + } + + void flush() { + if ( _buf == _cur ) + return; + + _port->send( _buf , len(), "flush" ); + _cur = _buf; + } + + int len() const { return _cur - _buf; } + + private: + MessagingPort* _port; + char * _buf; + char * _cur; + }; + + class Ports { + set<MessagingPort*> ports; + mongo::mutex m; + public: + Ports() : ports(), m("Ports") {} + void closeAll(unsigned skip_mask) { + scoped_lock bl(m); + for ( set<MessagingPort*>::iterator i = ports.begin(); i != ports.end(); i++ ) { + if( (*i)->tag & skip_mask ) + continue; + (*i)->shutdown(); + } + } + void insert(MessagingPort* p) { + scoped_lock bl(m); + ports.insert(p); + } + void erase(MessagingPort* p) { + scoped_lock bl(m); + ports.erase(p); + } + }; + + // we "new" this so it is still be around when other automatic global vars + // are being destructed during termination. + Ports& ports = *(new Ports()); + + void MessagingPort::closeAllSockets(unsigned mask) { + ports.closeAll(mask); + } + + MessagingPort::MessagingPort(int fd, const SockAddr& remote) + : Socket( fd , remote ) , piggyBackData(0) { + ports.insert(this); + } + + MessagingPort::MessagingPort( double timeout, int ll ) + : Socket( timeout, ll ) { + ports.insert(this); + piggyBackData = 0; + } + + MessagingPort::MessagingPort( Socket& sock ) + : Socket( sock ) , piggyBackData( 0 ) { + } + + void MessagingPort::shutdown() { + close(); + } + + MessagingPort::~MessagingPort() { + if ( piggyBackData ) + delete( piggyBackData ); + shutdown(); + ports.erase(this); + } + + bool MessagingPort::recv(Message& m) { + try { +again: + mmm( log() << "* recv() sock:" << this->sock << endl; ) + int len = -1; + + char *lenbuf = (char *) &len; + int lft = 4; + Socket::recv( lenbuf, lft ); + + if ( len < 16 || len > 48000000 ) { // messages must be large enough for headers + if ( len == -1 ) { + // Endian check from the client, after connecting, to see what mode server is running in. + unsigned foo = 0x10203040; + send( (char *) &foo, 4, "endian" ); + goto again; + } + + if ( len == 542393671 ) { + // an http GET + log(_logLevel) << "looks like you're trying to access db over http on native driver port. please add 1000 for webserver" << endl; + string msg = "You are trying to access MongoDB on the native driver port. For http diagnostic access, add 1000 to the port number\n"; + stringstream ss; + ss << "HTTP/1.0 200 OK\r\nConnection: close\r\nContent-Type: text/plain\r\nContent-Length: " << msg.size() << "\r\n\r\n" << msg; + string s = ss.str(); + send( s.c_str(), s.size(), "http" ); + return false; + } + log(0) << "recv(): message len " << len << " is too large" << len << endl; + return false; + } + + int z = (len+1023)&0xfffffc00; + assert(z>=len); + MsgData *md = (MsgData *) malloc(z); + assert(md); + md->len = len; + + char *p = (char *) &md->id; + int left = len -4; + + try { + Socket::recv( p, left ); + } + catch (...) { + free(md); + throw; + } + + m.setData(md, true); + return true; + + } + catch ( const SocketException & e ) { + log(_logLevel + (e.shouldPrint() ? 0 : 1) ) << "SocketException: remote: " << remote() << " error: " << e << endl; + m.reset(); + return false; + } + } + + void MessagingPort::reply(Message& received, Message& response) { + say(/*received.from, */response, received.header()->id); + } + + void MessagingPort::reply(Message& received, Message& response, MSGID responseTo) { + say(/*received.from, */response, responseTo); + } + + bool MessagingPort::call(Message& toSend, Message& response) { + mmm( log() << "*call()" << endl; ) + say(toSend); + return recv( toSend , response ); + } + + bool MessagingPort::recv( const Message& toSend , Message& response ) { + while ( 1 ) { + bool ok = recv(response); + if ( !ok ) + return false; + //log() << "got response: " << response.data->responseTo << endl; + if ( response.header()->responseTo == toSend.header()->id ) + break; + error() << "MessagingPort::call() wrong id got:" << hex << (unsigned)response.header()->responseTo << " expect:" << (unsigned)toSend.header()->id << '\n' + << dec + << " toSend op: " << (unsigned)toSend.operation() << '\n' + << " response msgid:" << (unsigned)response.header()->id << '\n' + << " response len: " << (unsigned)response.header()->len << '\n' + << " response op: " << response.operation() << '\n' + << " remote: " << remoteString() << endl; + assert(false); + response.reset(); + } + mmm( log() << "*call() end" << endl; ) + return true; + } + + void MessagingPort::say(Message& toSend, int responseTo) { + assert( !toSend.empty() ); + mmm( log() << "* say() sock:" << this->sock << " thr:" << GetCurrentThreadId() << endl; ) + toSend.header()->id = nextMessageId(); + toSend.header()->responseTo = responseTo; + + if ( piggyBackData && piggyBackData->len() ) { + mmm( log() << "* have piggy back" << endl; ) + if ( ( piggyBackData->len() + toSend.header()->len ) > 1300 ) { + // won't fit in a packet - so just send it off + piggyBackData->flush(); + } + else { + piggyBackData->append( toSend ); + piggyBackData->flush(); + return; + } + } + + toSend.send( *this, "say" ); + } + + void MessagingPort::piggyBack( Message& toSend , int responseTo ) { + + if ( toSend.header()->len > 1300 ) { + // not worth saving because its almost an entire packet + say( toSend ); + return; + } + + // we're going to be storing this, so need to set it up + toSend.header()->id = nextMessageId(); + toSend.header()->responseTo = responseTo; + + if ( ! piggyBackData ) + piggyBackData = new PiggyBackData( this ); + + piggyBackData->append( toSend ); + } + + HostAndPort MessagingPort::remote() const { + if ( ! _remoteParsed.hasPort() ) + _remoteParsed = HostAndPort( remoteAddr() ); + return _remoteParsed; + } + + +} // namespace mongo diff --git a/util/net/message_port.h b/util/net/message_port.h new file mode 100644 index 0000000..22ecafe --- /dev/null +++ b/util/net/message_port.h @@ -0,0 +1,107 @@ +// message_port.h + +/* Copyright 2009 10gen Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include "sock.h" +#include "message.h" + +namespace mongo { + + class MessagingPort; + class PiggyBackData; + + typedef AtomicUInt MSGID; + + class AbstractMessagingPort : boost::noncopyable { + public: + AbstractMessagingPort() : tag(0) {} + virtual ~AbstractMessagingPort() { } + virtual void reply(Message& received, Message& response, MSGID responseTo) = 0; // like the reply below, but doesn't rely on received.data still being available + virtual void reply(Message& received, Message& response) = 0; + + virtual HostAndPort remote() const = 0; + virtual unsigned remotePort() const = 0; + + private: + + public: + // TODO make this private with some helpers + + /* ports can be tagged with various classes. see closeAllSockets(tag). defaults to 0. */ + unsigned tag; + + }; + + class MessagingPort : public AbstractMessagingPort , public Socket { + public: + MessagingPort(int fd, const SockAddr& remote); + + // in some cases the timeout will actually be 2x this value - eg we do a partial send, + // then the timeout fires, then we try to send again, then the timeout fires again with + // no data sent, then we detect that the other side is down + MessagingPort(double so_timeout = 0, int logLevel = 0 ); + + MessagingPort(Socket& socket); + + virtual ~MessagingPort(); + + void shutdown(); + + /* it's assumed if you reuse a message object, that it doesn't cross MessagingPort's. + also, the Message data will go out of scope on the subsequent recv call. + */ + bool recv(Message& m); + void reply(Message& received, Message& response, MSGID responseTo); + void reply(Message& received, Message& response); + bool call(Message& toSend, Message& response); + + void say(Message& toSend, int responseTo = -1); + + /** + * this is used for doing 'async' queries + * instead of doing call( to , from ) + * you would do + * say( to ) + * recv( from ) + * Note: if you fail to call recv and someone else uses this port, + * horrible things will happend + */ + bool recv( const Message& sent , Message& response ); + + void piggyBack( Message& toSend , int responseTo = -1 ); + + unsigned remotePort() const { return Socket::remotePort(); } + virtual HostAndPort remote() const; + + + private: + + PiggyBackData * piggyBackData; + + // this is the parsed version of remote + // mutable because its initialized only on call to remote() + mutable HostAndPort _remoteParsed; + + public: + static void closeAllSockets(unsigned tagMask = 0xffffffff); + + friend class PiggyBackData; + }; + + +} // namespace mongo diff --git a/util/net/message_server.h b/util/net/message_server.h new file mode 100644 index 0000000..ae77b97 --- /dev/null +++ b/util/net/message_server.h @@ -0,0 +1,66 @@ +// message_server.h + +/* Copyright 2009 10gen Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + abstract database server + async io core, worker thread system + */ + +#pragma once + +#include "../../pch.h" + +namespace mongo { + + class MessageHandler { + public: + virtual ~MessageHandler() {} + + /** + * called once when a socket is connected + */ + virtual void connected( AbstractMessagingPort* p ) = 0; + + /** + * called every time a message comes in + * handler is responsible for responding to client + */ + virtual void process( Message& m , AbstractMessagingPort* p , LastError * err ) = 0; + + /** + * called once when a socket is disconnected + */ + virtual void disconnected( AbstractMessagingPort* p ) = 0; + }; + + class MessageServer { + public: + struct Options { + int port; // port to bind to + string ipList; // addresses to bind to + + Options() : port(0), ipList("") {} + }; + + virtual ~MessageServer() {} + virtual void run() = 0; + virtual void setAsTimeTracker() = 0; + }; + + // TODO use a factory here to decide between port and asio variations + MessageServer * createServer( const MessageServer::Options& opts , MessageHandler * handler ); +} diff --git a/util/net/message_server_asio.cpp b/util/net/message_server_asio.cpp new file mode 100644 index 0000000..0c6a7d9 --- /dev/null +++ b/util/net/message_server_asio.cpp @@ -0,0 +1,261 @@ +// message_server_asio.cpp + +/* Copyright 2009 10gen Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifdef USE_ASIO + +#include <boost/asio.hpp> +#include <boost/bind.hpp> +#include <boost/enable_shared_from_this.hpp> +#include <boost/shared_ptr.hpp> + +#include <iostream> +#include <vector> + +#include "message.h" +#include "message_server.h" +#include "../util/concurrency/mvar.h" + +using namespace boost; +using namespace boost::asio; +using namespace boost::asio::ip; + +namespace mongo { + class MessageServerSession; + + namespace { + class StickyThread { + public: + StickyThread() + : _thread(boost::ref(*this)) + {} + + ~StickyThread() { + _mss.put(boost::shared_ptr<MessageServerSession>()); + _thread.join(); + } + + void ready(boost::shared_ptr<MessageServerSession> mss) { + _mss.put(mss); + } + + void operator() () { + boost::shared_ptr<MessageServerSession> mss; + while((mss = _mss.take())) { // intentionally not using == + task(mss.get()); + mss.reset(); + } + } + + private: + boost::thread _thread; + inline void task(MessageServerSession* mss); // must be defined after MessageServerSession + + MVar<boost::shared_ptr<MessageServerSession> > _mss; // populated when given a task + }; + + vector<boost::shared_ptr<StickyThread> > thread_pool; + mongo::mutex tp_mutex; // this is only needed if io_service::run() is called from multiple threads + } + + class MessageServerSession : public boost::enable_shared_from_this<MessageServerSession> , public AbstractMessagingPort { + public: + MessageServerSession( MessageHandler * handler , io_service& ioservice ) + : _handler( handler ) + , _socket( ioservice ) + , _portCache(0) + { } + + ~MessageServerSession() { + cout << "disconnect from: " << _socket.remote_endpoint() << endl; + } + + tcp::socket& socket() { + return _socket; + } + + void start() { + cout << "MessageServerSession start from:" << _socket.remote_endpoint() << endl; + _startHeaderRead(); + } + + void handleReadHeader( const boost::system::error_code& error ) { + if ( _inHeader.len == 0 ) + return; + + if ( ! _inHeader.valid() ) { + cout << " got invalid header from: " << _socket.remote_endpoint() << " closing connected" << endl; + return; + } + + char * raw = (char*)malloc( _inHeader.len ); + + MsgData * data = (MsgData*)raw; + memcpy( data , &_inHeader , sizeof( _inHeader ) ); + assert( data->len == _inHeader.len ); + + uassert( 10273 , "_cur not empty! pipelining requests not supported" , ! _cur.data ); + + _cur.setData( data , true ); + async_read( _socket , + buffer( raw + sizeof( _inHeader ) , _inHeader.len - sizeof( _inHeader ) ) , + boost::bind( &MessageServerSession::handleReadBody , shared_from_this() , boost::asio::placeholders::error ) ); + } + + void handleReadBody( const boost::system::error_code& error ) { + if (!_myThread) { + mongo::mutex::scoped_lock(tp_mutex); + if (!thread_pool.empty()) { + _myThread = thread_pool.back(); + thread_pool.pop_back(); + } + } + + if (!_myThread) // pool is empty + _myThread.reset(new StickyThread()); + + assert(_myThread); + + _myThread->ready(shared_from_this()); + } + + void process() { + _handler->process( _cur , this ); + + if (_reply.data) { + async_write( _socket , + buffer( (char*)_reply.data , _reply.data->len ) , + boost::bind( &MessageServerSession::handleWriteDone , shared_from_this() , boost::asio::placeholders::error ) ); + } + else { + _cur.reset(); + _startHeaderRead(); + } + } + + void handleWriteDone( const boost::system::error_code& error ) { + { + // return thread to pool after we have sent data to the client + mongo::mutex::scoped_lock(tp_mutex); + assert(_myThread); + thread_pool.push_back(_myThread); + _myThread.reset(); + } + _cur.reset(); + _reply.reset(); + _startHeaderRead(); + } + + virtual void reply( Message& received, Message& response ) { + reply( received , response , received.data->id ); + } + + virtual void reply( Message& query , Message& toSend, MSGID responseTo ) { + _reply = toSend; + + _reply.data->id = nextMessageId(); + _reply.data->responseTo = responseTo; + uassert( 10274 , "pipelining requests doesn't work yet" , query.data->id == _cur.data->id ); + } + + + virtual unsigned remotePort() { + if (!_portCache) + _portCache = _socket.remote_endpoint().port(); //this is expensive + return _portCache; + } + + private: + + void _startHeaderRead() { + _inHeader.len = 0; + async_read( _socket , + buffer( &_inHeader , sizeof( _inHeader ) ) , + boost::bind( &MessageServerSession::handleReadHeader , shared_from_this() , boost::asio::placeholders::error ) ); + } + + MessageHandler * _handler; + tcp::socket _socket; + MsgData _inHeader; + Message _cur; + Message _reply; + + unsigned _portCache; + + boost::shared_ptr<StickyThread> _myThread; + }; + + void StickyThread::task(MessageServerSession* mss) { + mss->process(); + } + + + class AsyncMessageServer : public MessageServer { + public: + // TODO accept an IP address to bind to + AsyncMessageServer( const MessageServer::Options& opts , MessageHandler * handler ) + : _port( opts.port ) + , _handler(handler) + , _endpoint( tcp::v4() , opts.port ) + , _acceptor( _ioservice , _endpoint ) { + _accept(); + } + virtual ~AsyncMessageServer() { + + } + + void run() { + cout << "AsyncMessageServer starting to listen on: " << _port << endl; + boost::thread other(boost::bind(&io_service::run, &_ioservice)); + _ioservice.run(); + cout << "AsyncMessageServer done listening on: " << _port << endl; + } + + void handleAccept( shared_ptr<MessageServerSession> session , + const boost::system::error_code& error ) { + if ( error ) { + cout << "handleAccept error!" << endl; + return; + } + session->start(); + _accept(); + } + + void _accept( ) { + shared_ptr<MessageServerSession> session( new MessageServerSession( _handler , _ioservice ) ); + _acceptor.async_accept( session->socket() , + boost::bind( &AsyncMessageServer::handleAccept, + this, + session, + boost::asio::placeholders::error ) + ); + } + + private: + int _port; + MessageHandler * _handler; + io_service _ioservice; + tcp::endpoint _endpoint; + tcp::acceptor _acceptor; + }; + + MessageServer * createServer( const MessageServer::Options& opts , MessageHandler * handler ) { + return new AsyncMessageServer( opts , handler ); + } + +} + +#endif diff --git a/util/net/message_server_port.cpp b/util/net/message_server_port.cpp new file mode 100644 index 0000000..ca0b13d --- /dev/null +++ b/util/net/message_server_port.cpp @@ -0,0 +1,197 @@ +// message_server_port.cpp + +/* Copyright 2009 10gen Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "pch.h" + +#ifndef USE_ASIO + +#include "message.h" +#include "message_port.h" +#include "message_server.h" +#include "listen.h" + +#include "../../db/cmdline.h" +#include "../../db/lasterror.h" +#include "../../db/stats/counters.h" + +#ifdef __linux__ // TODO: consider making this ifndef _WIN32 +# include <sys/resource.h> +#endif + +namespace mongo { + + namespace pms { + + MessageHandler * handler; + + void threadRun( MessagingPort * inPort) { + TicketHolderReleaser connTicketReleaser( &connTicketHolder ); + + setThreadName( "conn" ); + + assert( inPort ); + inPort->setLogLevel(1); + scoped_ptr<MessagingPort> p( inPort ); + + p->postFork(); + + string otherSide; + + Message m; + try { + LastError * le = new LastError(); + lastError.reset( le ); // lastError now has ownership + + otherSide = p->remoteString(); + + handler->connected( p.get() ); + + while ( ! inShutdown() ) { + m.reset(); + p->clearCounters(); + + if ( ! p->recv(m) ) { + if( !cmdLine.quiet ) + log() << "end connection " << otherSide << endl; + p->shutdown(); + break; + } + + handler->process( m , p.get() , le ); + networkCounter.hit( p->getBytesIn() , p->getBytesOut() ); + } + } + catch ( AssertionException& e ) { + log() << "AssertionException handling request, closing client connection: " << e << endl; + p->shutdown(); + } + catch ( SocketException& e ) { + log() << "SocketException handling request, closing client connection: " << e << endl; + p->shutdown(); + } + catch ( const ClockSkewException & ) { + log() << "ClockSkewException - shutting down" << endl; + exitCleanly( EXIT_CLOCK_SKEW ); + } + catch ( std::exception &e ) { + error() << "Uncaught std::exception: " << e.what() << ", terminating" << endl; + dbexit( EXIT_UNCAUGHT ); + } + catch ( ... ) { + error() << "Uncaught exception, terminating" << endl; + dbexit( EXIT_UNCAUGHT ); + } + + handler->disconnected( p.get() ); + } + + } + + class PortMessageServer : public MessageServer , public Listener { + public: + PortMessageServer( const MessageServer::Options& opts, MessageHandler * handler ) : + Listener( "" , opts.ipList, opts.port ) { + + uassert( 10275 , "multiple PortMessageServer not supported" , ! pms::handler ); + pms::handler = handler; + } + + virtual void accepted(MessagingPort * p) { + + if ( ! connTicketHolder.tryAcquire() ) { + log() << "connection refused because too many open connections: " << connTicketHolder.used() << endl; + + // TODO: would be nice if we notified them... + p->shutdown(); + delete p; + + sleepmillis(2); // otherwise we'll hard loop + return; + } + + try { +#ifndef __linux__ // TODO: consider making this ifdef _WIN32 + boost::thread thr( boost::bind( &pms::threadRun , p ) ); +#else + pthread_attr_t attrs; + pthread_attr_init(&attrs); + pthread_attr_setdetachstate(&attrs, PTHREAD_CREATE_DETACHED); + + static const size_t STACK_SIZE = 1024*1024; // if we change this we need to update the warning + + struct rlimit limits; + verify(15887, getrlimit(RLIMIT_STACK, &limits) == 0); + if (limits.rlim_cur > STACK_SIZE) { + pthread_attr_setstacksize(&attrs, (DEBUG_BUILD + ? (STACK_SIZE / 2) + : STACK_SIZE)); + } else if (limits.rlim_cur < 1024*1024) { + warning() << "Stack size set to " << (limits.rlim_cur/1024) << "KB. We suggest 1MB" << endl; + } + + + pthread_t thread; + int failed = pthread_create(&thread, &attrs, (void*(*)(void*)) &pms::threadRun, p); + + pthread_attr_destroy(&attrs); + + if (failed) { + log() << "pthread_create failed: " << errnoWithDescription(failed) << endl; + throw boost::thread_resource_error(); // for consistency with boost::thread + } +#endif + } + catch ( boost::thread_resource_error& ) { + connTicketHolder.release(); + log() << "can't create new thread, closing connection" << endl; + + p->shutdown(); + delete p; + + sleepmillis(2); + } + catch ( ... ) { + connTicketHolder.release(); + log() << "unknown error accepting new socket" << endl; + + p->shutdown(); + delete p; + + sleepmillis(2); + } + + } + + virtual void setAsTimeTracker() { + Listener::setAsTimeTracker(); + } + + void run() { + initAndListen(); + } + + virtual bool useUnixSockets() const { return true; } + }; + + + MessageServer * createServer( const MessageServer::Options& opts , MessageHandler * handler ) { + return new PortMessageServer( opts , handler ); + } + +} + +#endif diff --git a/util/net/miniwebserver.cpp b/util/net/miniwebserver.cpp new file mode 100644 index 0000000..0793100 --- /dev/null +++ b/util/net/miniwebserver.cpp @@ -0,0 +1,207 @@ +// miniwebserver.cpp + +/* Copyright 2009 10gen Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "pch.h" +#include "miniwebserver.h" +#include "../hex.h" + +#include "pcrecpp.h" + +namespace mongo { + + MiniWebServer::MiniWebServer(const string& name, const string &ip, int port) + : Listener(name, ip, port, false) + {} + + string MiniWebServer::parseURL( const char * buf ) { + const char * urlStart = strchr( buf , ' ' ); + if ( ! urlStart ) + return "/"; + + urlStart++; + + const char * end = strchr( urlStart , ' ' ); + if ( ! end ) { + end = strchr( urlStart , '\r' ); + if ( ! end ) { + end = strchr( urlStart , '\n' ); + } + } + + if ( ! end ) + return "/"; + + int diff = (int)(end-urlStart); + if ( diff < 0 || diff > 255 ) + return "/"; + + return string( urlStart , (int)(end-urlStart) ); + } + + void MiniWebServer::parseParams( BSONObj & params , string query ) { + if ( query.size() == 0 ) + return; + + BSONObjBuilder b; + while ( query.size() ) { + + string::size_type amp = query.find( "&" ); + + string cur; + if ( amp == string::npos ) { + cur = query; + query = ""; + } + else { + cur = query.substr( 0 , amp ); + query = query.substr( amp + 1 ); + } + + string::size_type eq = cur.find( "=" ); + if ( eq == string::npos ) + continue; + + b.append( urlDecode(cur.substr(0,eq)) , urlDecode(cur.substr(eq+1) ) ); + } + + params = b.obj(); + } + + string MiniWebServer::parseMethod( const char * headers ) { + const char * end = strchr( headers , ' ' ); + if ( ! end ) + return "GET"; + return string( headers , (int)(end-headers) ); + } + + const char *MiniWebServer::body( const char *buf ) { + const char *ret = strstr( buf, "\r\n\r\n" ); + return ret ? ret + 4 : ret; + } + + bool MiniWebServer::fullReceive( const char *buf ) { + const char *bod = body( buf ); + if ( !bod ) + return false; + const char *lenString = "Content-Length:"; + const char *lengthLoc = strstr( buf, lenString ); + if ( !lengthLoc ) + return true; + lengthLoc += strlen( lenString ); + long len = strtol( lengthLoc, 0, 10 ); + if ( long( strlen( bod ) ) == len ) + return true; + return false; + } + + void MiniWebServer::accepted(Socket sock) { + sock.postFork(); + sock.setTimeout(8); + char buf[4096]; + int len = 0; + while ( 1 ) { + int left = sizeof(buf) - 1 - len; + if( left == 0 ) + break; + int x = sock.unsafe_recv( buf + len , left ); + if ( x <= 0 ) { + sock.close(); + return; + } + len += x; + buf[ len ] = 0; + if ( fullReceive( buf ) ) { + break; + } + } + buf[len] = 0; + + string responseMsg; + int responseCode = 599; + vector<string> headers; + + try { + doRequest(buf, parseURL( buf ), responseMsg, responseCode, headers, sock.remoteAddr() ); + } + catch ( std::exception& e ) { + responseCode = 500; + responseMsg = "error loading page: "; + responseMsg += e.what(); + } + catch ( ... ) { + responseCode = 500; + responseMsg = "unknown error loading page"; + } + + stringstream ss; + ss << "HTTP/1.0 " << responseCode; + if ( responseCode == 200 ) ss << " OK"; + ss << "\r\n"; + if ( headers.empty() ) { + ss << "Content-Type: text/html\r\n"; + } + else { + for ( vector<string>::iterator i = headers.begin(); i != headers.end(); i++ ) { + assert( strncmp("Content-Length", i->c_str(), 14) ); + ss << *i << "\r\n"; + } + } + ss << "Connection: close\r\n"; + ss << "Content-Length: " << responseMsg.size() << "\r\n"; + ss << "\r\n"; + ss << responseMsg; + string response = ss.str(); + + sock.send( response.c_str(), response.size() , "http response" ); + sock.close(); + } + + string MiniWebServer::getHeader( const char * req , string wanted ) { + const char * headers = strchr( req , '\n' ); + if ( ! headers ) + return ""; + pcrecpp::StringPiece input( headers + 1 ); + + string name; + string val; + pcrecpp::RE re("([\\w\\-]+): (.*?)\r?\n"); + while ( re.Consume( &input, &name, &val) ) { + if ( name == wanted ) + return val; + } + return ""; + } + + string MiniWebServer::urlDecode(const char* s) { + stringstream out; + while(*s) { + if (*s == '+') { + out << ' '; + } + else if (*s == '%') { + out << fromHex(s+1); + s+=2; + } + else { + out << *s; + } + s++; + } + return out.str(); + } + +} // namespace mongo diff --git a/util/net/miniwebserver.h b/util/net/miniwebserver.h new file mode 100644 index 0000000..1fb6b3f --- /dev/null +++ b/util/net/miniwebserver.h @@ -0,0 +1,60 @@ +// miniwebserver.h + +/* Copyright 2009 10gen Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include "../../pch.h" +#include "message.h" +#include "message_port.h" +#include "listen.h" +#include "../../db/jsobj.h" + +namespace mongo { + + class MiniWebServer : public Listener { + public: + MiniWebServer(const string& name, const string &ip, int _port); + virtual ~MiniWebServer() {} + + virtual void doRequest( + const char *rq, // the full request + string url, + // set these and return them: + string& responseMsg, + int& responseCode, + vector<string>& headers, // if completely empty, content-type: text/html will be added + const SockAddr &from + ) = 0; + + // --- static helpers ---- + + static void parseParams( BSONObj & params , string query ); + + static string parseURL( const char * buf ); + static string parseMethod( const char * headers ); + static string getHeader( const char * headers , string name ); + static const char *body( const char *buf ); + + static string urlDecode(const char* s); + static string urlDecode(string s) {return urlDecode(s.c_str());} + + private: + void accepted(Socket socket); + static bool fullReceive( const char *buf ); + }; + +} // namespace mongo diff --git a/util/net/sock.cpp b/util/net/sock.cpp new file mode 100644 index 0000000..69c42f2 --- /dev/null +++ b/util/net/sock.cpp @@ -0,0 +1,713 @@ +// @file sock.cpp + +/* Copyright 2009 10gen Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "pch.h" +#include "sock.h" +#include "../background.h" + +#if !defined(_WIN32) +# include <sys/socket.h> +# include <sys/types.h> +# include <sys/socket.h> +# include <sys/un.h> +# include <netinet/in.h> +# include <netinet/tcp.h> +# include <arpa/inet.h> +# include <errno.h> +# include <netdb.h> +# if defined(__openbsd__) +# include <sys/uio.h> +# endif +#endif + +#ifdef MONGO_SSL +#include <openssl/err.h> +#include <openssl/ssl.h> +#endif + + +namespace mongo { + + static bool ipv6 = false; + void enableIPv6(bool state) { ipv6 = state; } + bool IPv6Enabled() { return ipv6; } + + void setSockTimeouts(int sock, double secs) { + struct timeval tv; + tv.tv_sec = (int)secs; + tv.tv_usec = (int)((long long)(secs*1000*1000) % (1000*1000)); + bool report = logLevel > 3; // solaris doesn't provide these + DEV report = true; + bool ok = setsockopt(sock, SOL_SOCKET, SO_RCVTIMEO, (char *) &tv, sizeof(tv) ) == 0; + if( report && !ok ) log() << "unabled to set SO_RCVTIMEO" << endl; + ok = setsockopt(sock, SOL_SOCKET, SO_SNDTIMEO, (char *) &tv, sizeof(tv) ) == 0; + DEV if( report && !ok ) log() << "unabled to set SO_RCVTIMEO" << endl; + } + +#if defined(_WIN32) + void disableNagle(int sock) { + int x = 1; + if ( setsockopt(sock, IPPROTO_TCP, TCP_NODELAY, (char *) &x, sizeof(x)) ) + error() << "disableNagle failed" << endl; + if ( setsockopt(sock, SOL_SOCKET, SO_KEEPALIVE, (char *) &x, sizeof(x)) ) + error() << "SO_KEEPALIVE failed" << endl; + } +#else + + void disableNagle(int sock) { + int x = 1; + +#ifdef SOL_TCP + int level = SOL_TCP; +#else + int level = SOL_SOCKET; +#endif + + if ( setsockopt(sock, level, TCP_NODELAY, (char *) &x, sizeof(x)) ) + error() << "disableNagle failed: " << errnoWithDescription() << endl; + +#ifdef SO_KEEPALIVE + if ( setsockopt(sock, SOL_SOCKET, SO_KEEPALIVE, (char *) &x, sizeof(x)) ) + error() << "SO_KEEPALIVE failed: " << errnoWithDescription() << endl; + +# ifdef __linux__ + socklen_t len = sizeof(x); + if ( getsockopt(sock, level, TCP_KEEPIDLE, (char *) &x, &len) ) + error() << "can't get TCP_KEEPIDLE: " << errnoWithDescription() << endl; + + if (x > 300) { + x = 300; + if ( setsockopt(sock, level, TCP_KEEPIDLE, (char *) &x, sizeof(x)) ) { + error() << "can't set TCP_KEEPIDLE: " << errnoWithDescription() << endl; + } + } + + len = sizeof(x); // just in case it changed + if ( getsockopt(sock, level, TCP_KEEPINTVL, (char *) &x, &len) ) + error() << "can't get TCP_KEEPINTVL: " << errnoWithDescription() << endl; + + if (x > 300) { + x = 300; + if ( setsockopt(sock, level, TCP_KEEPINTVL, (char *) &x, sizeof(x)) ) { + error() << "can't set TCP_KEEPINTVL: " << errnoWithDescription() << endl; + } + } +# endif +#endif + + } + +#endif + + string getAddrInfoStrError(int code) { +#if !defined(_WIN32) + return gai_strerror(code); +#else + /* gai_strerrorA is not threadsafe on windows. don't use it. */ + return errnoWithDescription(code); +#endif + } + + + // --- SockAddr + + SockAddr::SockAddr(int sourcePort) { + memset(as<sockaddr_in>().sin_zero, 0, sizeof(as<sockaddr_in>().sin_zero)); + as<sockaddr_in>().sin_family = AF_INET; + as<sockaddr_in>().sin_port = htons(sourcePort); + as<sockaddr_in>().sin_addr.s_addr = htonl(INADDR_ANY); + addressSize = sizeof(sockaddr_in); + } + + SockAddr::SockAddr(const char * iporhost , int port) { + if (!strcmp(iporhost, "localhost")) + iporhost = "127.0.0.1"; + + if (strchr(iporhost, '/')) { +#ifdef _WIN32 + uassert(13080, "no unix socket support on windows", false); +#endif + uassert(13079, "path to unix socket too long", strlen(iporhost) < sizeof(as<sockaddr_un>().sun_path)); + as<sockaddr_un>().sun_family = AF_UNIX; + strcpy(as<sockaddr_un>().sun_path, iporhost); + addressSize = sizeof(sockaddr_un); + } + else { + addrinfo* addrs = NULL; + addrinfo hints; + memset(&hints, 0, sizeof(addrinfo)); + hints.ai_socktype = SOCK_STREAM; + //hints.ai_flags = AI_ADDRCONFIG; // This is often recommended but don't do it. SERVER-1579 + hints.ai_flags |= AI_NUMERICHOST; // first pass tries w/o DNS lookup + hints.ai_family = (IPv6Enabled() ? AF_UNSPEC : AF_INET); + + StringBuilder ss; + ss << port; + int ret = getaddrinfo(iporhost, ss.str().c_str(), &hints, &addrs); + + // old C compilers on IPv6-capable hosts return EAI_NODATA error +#ifdef EAI_NODATA + int nodata = (ret == EAI_NODATA); +#else + int nodata = false; +#endif + if (ret == EAI_NONAME || nodata) { + // iporhost isn't an IP address, allow DNS lookup + hints.ai_flags &= ~AI_NUMERICHOST; + ret = getaddrinfo(iporhost, ss.str().c_str(), &hints, &addrs); + } + + if (ret) { + // don't log if this as it is a CRT construction and log() may not work yet. + if( strcmp("0.0.0.0", iporhost) ) { + log() << "getaddrinfo(\"" << iporhost << "\") failed: " << gai_strerror(ret) << endl; + } + *this = SockAddr(port); + } + else { + //TODO: handle other addresses in linked list; + assert(addrs->ai_addrlen <= sizeof(sa)); + memcpy(&sa, addrs->ai_addr, addrs->ai_addrlen); + addressSize = addrs->ai_addrlen; + freeaddrinfo(addrs); + } + } + } + + bool SockAddr::isLocalHost() const { + switch (getType()) { + case AF_INET: return getAddr() == "127.0.0.1"; + case AF_INET6: return getAddr() == "::1"; + case AF_UNIX: return true; + default: return false; + } + assert(false); + return false; + } + + string SockAddr::toString(bool includePort) const { + string out = getAddr(); + if (includePort && getType() != AF_UNIX && getType() != AF_UNSPEC) + out += mongoutils::str::stream() << ':' << getPort(); + return out; + } + + sa_family_t SockAddr::getType() const { + return sa.ss_family; + } + + unsigned SockAddr::getPort() const { + switch (getType()) { + case AF_INET: return ntohs(as<sockaddr_in>().sin_port); + case AF_INET6: return ntohs(as<sockaddr_in6>().sin6_port); + case AF_UNIX: return 0; + case AF_UNSPEC: return 0; + default: massert(SOCK_FAMILY_UNKNOWN_ERROR, "unsupported address family", false); return 0; + } + } + + string SockAddr::getAddr() const { + switch (getType()) { + case AF_INET: + case AF_INET6: { + const int buflen=128; + char buffer[buflen]; + int ret = getnameinfo(raw(), addressSize, buffer, buflen, NULL, 0, NI_NUMERICHOST); + massert(13082, getAddrInfoStrError(ret), ret == 0); + return buffer; + } + + case AF_UNIX: return (addressSize > 2 ? as<sockaddr_un>().sun_path : "anonymous unix socket"); + case AF_UNSPEC: return "(NONE)"; + default: massert(SOCK_FAMILY_UNKNOWN_ERROR, "unsupported address family", false); return ""; + } + } + + bool SockAddr::operator==(const SockAddr& r) const { + if (getType() != r.getType()) + return false; + + if (getPort() != r.getPort()) + return false; + + switch (getType()) { + case AF_INET: return as<sockaddr_in>().sin_addr.s_addr == r.as<sockaddr_in>().sin_addr.s_addr; + case AF_INET6: return memcmp(as<sockaddr_in6>().sin6_addr.s6_addr, r.as<sockaddr_in6>().sin6_addr.s6_addr, sizeof(in6_addr)) == 0; + case AF_UNIX: return strcmp(as<sockaddr_un>().sun_path, r.as<sockaddr_un>().sun_path) == 0; + case AF_UNSPEC: return true; // assume all unspecified addresses are the same + default: massert(SOCK_FAMILY_UNKNOWN_ERROR, "unsupported address family", false); + } + return false; + } + + bool SockAddr::operator!=(const SockAddr& r) const { + return !(*this == r); + } + + bool SockAddr::operator<(const SockAddr& r) const { + if (getType() < r.getType()) + return true; + else if (getType() > r.getType()) + return false; + + if (getPort() < r.getPort()) + return true; + else if (getPort() > r.getPort()) + return false; + + switch (getType()) { + case AF_INET: return as<sockaddr_in>().sin_addr.s_addr < r.as<sockaddr_in>().sin_addr.s_addr; + case AF_INET6: return memcmp(as<sockaddr_in6>().sin6_addr.s6_addr, r.as<sockaddr_in6>().sin6_addr.s6_addr, sizeof(in6_addr)) < 0; + case AF_UNIX: return strcmp(as<sockaddr_un>().sun_path, r.as<sockaddr_un>().sun_path) < 0; + case AF_UNSPEC: return false; + default: massert(SOCK_FAMILY_UNKNOWN_ERROR, "unsupported address family", false); + } + return false; + } + + SockAddr unknownAddress( "0.0.0.0", 0 ); + + // ------ hostname ------------------- + + string hostbyname(const char *hostname) { + string addr = SockAddr(hostname, 0).getAddr(); + if (addr == "0.0.0.0") + return ""; + else + return addr; + } + + // --- my -- + + string getHostName() { + char buf[256]; + int ec = gethostname(buf, 127); + if ( ec || *buf == 0 ) { + log() << "can't get this server's hostname " << errnoWithDescription() << endl; + return ""; + } + return buf; + } + + + string _hostNameCached; + static void _hostNameCachedInit() { + _hostNameCached = getHostName(); + } + boost::once_flag _hostNameCachedInitFlags = BOOST_ONCE_INIT; + + string getHostNameCached() { + boost::call_once( _hostNameCachedInit , _hostNameCachedInitFlags ); + return _hostNameCached; + } + + // --------- SocketException ---------- + +#ifdef MSG_NOSIGNAL + const int portSendFlags = MSG_NOSIGNAL; + const int portRecvFlags = MSG_NOSIGNAL; +#else + const int portSendFlags = 0; + const int portRecvFlags = 0; +#endif + + string SocketException::toString() const { + stringstream ss; + ss << _ei.code << " socket exception [" << _type << "] "; + + if ( _server.size() ) + ss << "server [" << _server << "] "; + + if ( _extra.size() ) + ss << _extra; + + return ss.str(); + } + + + // ------------ SSLManager ----------------- + +#ifdef MONGO_SSL + SSLManager::SSLManager( bool client ) { + _client = client; + SSL_library_init(); + SSL_load_error_strings(); + ERR_load_crypto_strings(); + + _context = SSL_CTX_new( client ? SSLv23_client_method() : SSLv23_server_method() ); + massert( 15864 , mongoutils::str::stream() << "can't create SSL Context: " << ERR_error_string(ERR_get_error(), NULL) , _context ); + + SSL_CTX_set_options( _context, SSL_OP_ALL); + } + + void SSLManager::setupPubPriv( const string& privateKeyFile , const string& publicKeyFile ) { + massert( 15865 , + mongoutils::str::stream() << "Can't read SSL certificate from file " + << publicKeyFile << ":" << ERR_error_string(ERR_get_error(), NULL) , + SSL_CTX_use_certificate_file(_context, publicKeyFile.c_str(), SSL_FILETYPE_PEM) ); + + + massert( 15866 , + mongoutils::str::stream() << "Can't read SSL private key from file " + << privateKeyFile << " : " << ERR_error_string(ERR_get_error(), NULL) , + SSL_CTX_use_PrivateKey_file(_context, privateKeyFile.c_str(), SSL_FILETYPE_PEM) ); + } + + + int SSLManager::password_cb(char *buf,int num, int rwflag,void *userdata){ + SSLManager* sm = (SSLManager*)userdata; + string pass = sm->_password; + strcpy(buf,pass.c_str()); + return(pass.size()); + } + + void SSLManager::setupPEM( const string& keyFile , const string& password ) { + _password = password; + + massert( 15867 , "Can't read certificate file" , SSL_CTX_use_certificate_chain_file( _context , keyFile.c_str() ) ); + + SSL_CTX_set_default_passwd_cb_userdata( _context , this ); + SSL_CTX_set_default_passwd_cb( _context, &SSLManager::password_cb ); + + massert( 15868 , "Can't read key file" , SSL_CTX_use_PrivateKey_file( _context , keyFile.c_str() , SSL_FILETYPE_PEM ) ); + } + + SSL * SSLManager::secure( int fd ) { + SSL * ssl = SSL_new( _context ); + massert( 15861 , "can't create SSL" , ssl ); + SSL_set_fd( ssl , fd ); + return ssl; + } + + +#endif + + // ------------ Socket ----------------- + + Socket::Socket(int fd , const SockAddr& remote) : + _fd(fd), _remote(remote), _timeout(0) { + _logLevel = 0; + _init(); + } + + Socket::Socket( double timeout, int ll ) { + _logLevel = ll; + _fd = -1; + _timeout = timeout; + _init(); + } + + void Socket::_init() { + _bytesOut = 0; + _bytesIn = 0; +#ifdef MONGO_SSL + _sslAccepted = 0; +#endif + } + + void Socket::close() { +#ifdef MONGO_SSL + _ssl.reset(); +#endif + if ( _fd >= 0 ) { + closesocket( _fd ); + _fd = -1; + } + } + +#ifdef MONGO_SSL + void Socket::secure( SSLManager * ssl ) { + assert( ssl ); + assert( _fd >= 0 ); + _ssl.reset( ssl->secure( _fd ) ); + SSL_connect( _ssl.get() ); + } + + void Socket::secureAccepted( SSLManager * ssl ) { + _sslAccepted = ssl; + } +#endif + + void Socket::postFork() { +#ifdef MONGO_SSL + if ( _sslAccepted ) { + assert( _fd ); + _ssl.reset( _sslAccepted->secure( _fd ) ); + SSL_accept( _ssl.get() ); + _sslAccepted = 0; + } +#endif + } + + class ConnectBG : public BackgroundJob { + public: + ConnectBG(int sock, SockAddr remote) : _sock(sock), _remote(remote) { } + + void run() { _res = ::connect(_sock, _remote.raw(), _remote.addressSize); } + string name() const { return "ConnectBG"; } + int inError() const { return _res; } + + private: + int _sock; + int _res; + SockAddr _remote; + }; + + bool Socket::connect(SockAddr& remote) { + _remote = remote; + + _fd = socket(remote.getType(), SOCK_STREAM, 0); + if ( _fd == INVALID_SOCKET ) { + log(_logLevel) << "ERROR: connect invalid socket " << errnoWithDescription() << endl; + return false; + } + + if ( _timeout > 0 ) { + setTimeout( _timeout ); + } + + ConnectBG bg(_fd, remote); + bg.go(); + if ( bg.wait(5000) ) { + if ( bg.inError() ) { + close(); + return false; + } + } + else { + // time out the connect + close(); + bg.wait(); // so bg stays in scope until bg thread terminates + return false; + } + + if (remote.getType() != AF_UNIX) + disableNagle(_fd); + +#ifdef SO_NOSIGPIPE + // osx + const int one = 1; + setsockopt( _fd , SOL_SOCKET, SO_NOSIGPIPE, &one, sizeof(int)); +#endif + + return true; + } + + int Socket::_send( const char * data , int len ) { +#ifdef MONGO_SSL + if ( _ssl ) { + return SSL_write( _ssl.get() , data , len ); + } +#endif + return ::send( _fd , data , len , portSendFlags ); + } + + // sends all data or throws an exception + void Socket::send( const char * data , int len, const char *context ) { + while( len > 0 ) { + int ret = _send( data , len ); + if ( ret == -1 ) { + +#ifdef MONGO_SSL + if ( _ssl ) { + log() << "SSL Error ret: " << ret << " err: " << SSL_get_error( _ssl.get() , ret ) + << " " << ERR_error_string(ERR_get_error(), NULL) + << endl; + } +#endif + +#if defined(_WIN32) + if ( WSAGetLastError() == WSAETIMEDOUT && _timeout != 0 ) { +#else + if ( ( errno == EAGAIN || errno == EWOULDBLOCK ) && _timeout != 0 ) { +#endif + log(_logLevel) << "Socket " << context << " send() timed out " << _remote.toString() << endl; + throw SocketException( SocketException::SEND_TIMEOUT , remoteString() ); + } + else { + SocketException::Type t = SocketException::SEND_ERROR; + log(_logLevel) << "Socket " << context << " send() " + << errnoWithDescription() << ' ' << remoteString() << endl; + throw SocketException( t , remoteString() ); + } + } + else { + _bytesOut += ret; + + assert( ret <= len ); + len -= ret; + data += ret; + } + } + } + + void Socket::_send( const vector< pair< char *, int > > &data, const char *context ) { + for( vector< pair< char *, int > >::const_iterator i = data.begin(); i != data.end(); ++i ) { + char * data = i->first; + int len = i->second; + send( data, len, context ); + } + } + + // sends all data or throws an exception + void Socket::send( const vector< pair< char *, int > > &data, const char *context ) { + +#ifdef MONGO_SSL + if ( _ssl ) { + _send( data , context ); + return; + } +#endif + +#if defined(_WIN32) + // TODO use scatter/gather api + _send( data , context ); +#else + vector< struct iovec > d( data.size() ); + int i = 0; + for( vector< pair< char *, int > >::const_iterator j = data.begin(); j != data.end(); ++j ) { + if ( j->second > 0 ) { + d[ i ].iov_base = j->first; + d[ i ].iov_len = j->second; + ++i; + _bytesOut += j->second; + } + } + struct msghdr meta; + memset( &meta, 0, sizeof( meta ) ); + meta.msg_iov = &d[ 0 ]; + meta.msg_iovlen = d.size(); + + while( meta.msg_iovlen > 0 ) { + int ret = ::sendmsg( _fd , &meta , portSendFlags ); + if ( ret == -1 ) { + if ( errno != EAGAIN || _timeout == 0 ) { + log(_logLevel) << "Socket " << context << " send() " << errnoWithDescription() << ' ' << remoteString() << endl; + throw SocketException( SocketException::SEND_ERROR , remoteString() ); + } + else { + log(_logLevel) << "Socket " << context << " send() remote timeout " << remoteString() << endl; + throw SocketException( SocketException::SEND_TIMEOUT , remoteString() ); + } + } + else { + struct iovec *& i = meta.msg_iov; + while( ret > 0 ) { + if ( i->iov_len > unsigned( ret ) ) { + i->iov_len -= ret; + i->iov_base = (char*)(i->iov_base) + ret; + ret = 0; + } + else { + ret -= i->iov_len; + ++i; + --(meta.msg_iovlen); + } + } + } + } +#endif + } + + void Socket::recv( char * buf , int len ) { + unsigned retries = 0; + while( len > 0 ) { + int ret = unsafe_recv( buf , len ); + if ( ret > 0 ) { + if ( len <= 4 && ret != len ) + log(_logLevel) << "Socket recv() got " << ret << " bytes wanted len=" << len << endl; + assert( ret <= len ); + len -= ret; + buf += ret; + } + else if ( ret == 0 ) { + log(3) << "Socket recv() conn closed? " << remoteString() << endl; + throw SocketException( SocketException::CLOSED , remoteString() ); + } + else { /* ret < 0 */ +#if defined(_WIN32) + int e = WSAGetLastError(); +#else + int e = errno; +# if defined(EINTR) + if( e == EINTR ) { + if( ++retries == 1 ) { + log() << "EINTR retry" << endl; + continue; + } + } +# endif +#endif + if ( ( e == EAGAIN +#if defined(_WIN32) + || e == WSAETIMEDOUT +#endif + ) && _timeout > 0 ) + { + // this is a timeout + log(_logLevel) << "Socket recv() timeout " << remoteString() <<endl; + throw SocketException( SocketException::RECV_TIMEOUT, remoteString() ); + } + + log(_logLevel) << "Socket recv() " << errnoWithDescription(e) << " " << remoteString() <<endl; + throw SocketException( SocketException::RECV_ERROR , remoteString() ); + } + } + } + + int Socket::unsafe_recv( char *buf, int max ) { + int x = _recv( buf , max ); + _bytesIn += x; + return x; + } + + + int Socket::_recv( char *buf, int max ) { +#ifdef MONGO_SSL + if ( _ssl ){ + return SSL_read( _ssl.get() , buf , max ); + } +#endif + return ::recv( _fd , buf , max , portRecvFlags ); + } + + void Socket::setTimeout( double secs ) { + struct timeval tv; + tv.tv_sec = (int)secs; + tv.tv_usec = (int)((long long)(secs*1000*1000) % (1000*1000)); + bool report = logLevel > 3; // solaris doesn't provide these + DEV report = true; + bool ok = setsockopt(_fd, SOL_SOCKET, SO_RCVTIMEO, (char *) &tv, sizeof(tv) ) == 0; + if( report && !ok ) log() << "unabled to set SO_RCVTIMEO" << endl; + ok = setsockopt(_fd, SOL_SOCKET, SO_SNDTIMEO, (char *) &tv, sizeof(tv) ) == 0; + DEV if( report && !ok ) log() << "unabled to set SO_RCVTIMEO" << endl; + } + +#if defined(_WIN32) + struct WinsockInit { + WinsockInit() { + WSADATA d; + if ( WSAStartup(MAKEWORD(2,2), &d) != 0 ) { + out() << "ERROR: wsastartup failed " << errnoWithDescription() << endl; + problem() << "ERROR: wsastartup failed " << errnoWithDescription() << endl; + dbexit( EXIT_NTSERVICE_ERROR ); + } + } + } winsock_init; +#endif + +} // namespace mongo diff --git a/util/net/sock.h b/util/net/sock.h new file mode 100644 index 0000000..1cd5133 --- /dev/null +++ b/util/net/sock.h @@ -0,0 +1,256 @@ +// @file sock.h + +/* Copyright 2009 10gen Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include "../../pch.h" + +#include <stdio.h> +#include <sstream> +#include "../goodies.h" +#include "../../db/cmdline.h" +#include "../mongoutils/str.h" + +#ifndef _WIN32 + +#include <sys/socket.h> +#include <sys/types.h> +#include <sys/socket.h> +#include <sys/un.h> +#include <errno.h> + +#ifdef __openbsd__ +# include <sys/uio.h> +#endif + +#endif // _WIN32 + +#ifdef MONGO_SSL +#include <openssl/ssl.h> +#endif + +namespace mongo { + + const int SOCK_FAMILY_UNKNOWN_ERROR=13078; + + void disableNagle(int sock); + +#if defined(_WIN32) + + typedef short sa_family_t; + typedef int socklen_t; + + // This won't actually be used on windows + struct sockaddr_un { + short sun_family; + char sun_path[108]; // length from unix header + }; + +#else // _WIN32 + + inline void closesocket(int s) { close(s); } + const int INVALID_SOCKET = -1; + typedef int SOCKET; + +#endif // _WIN32 + + inline string makeUnixSockPath(int port) { + return mongoutils::str::stream() << cmdLine.socket << "/mongodb-" << port << ".sock"; + } + + // If an ip address is passed in, just return that. If a hostname is passed + // in, look up its ip and return that. Returns "" on failure. + string hostbyname(const char *hostname); + + void enableIPv6(bool state=true); + bool IPv6Enabled(); + void setSockTimeouts(int sock, double secs); + + /** + * wrapped around os representation of network address + */ + struct SockAddr { + SockAddr() { + addressSize = sizeof(sa); + memset(&sa, 0, sizeof(sa)); + sa.ss_family = AF_UNSPEC; + } + SockAddr(int sourcePort); /* listener side */ + SockAddr(const char *ip, int port); /* EndPoint (remote) side, or if you want to specify which interface locally */ + + template <typename T> T& as() { return *(T*)(&sa); } + template <typename T> const T& as() const { return *(const T*)(&sa); } + + string toString(bool includePort=true) const; + + /** + * @return one of AF_INET, AF_INET6, or AF_UNIX + */ + sa_family_t getType() const; + + unsigned getPort() const; + + string getAddr() const; + + bool isLocalHost() const; + + bool operator==(const SockAddr& r) const; + + bool operator!=(const SockAddr& r) const; + + bool operator<(const SockAddr& r) const; + + const sockaddr* raw() const {return (sockaddr*)&sa;} + sockaddr* raw() {return (sockaddr*)&sa;} + + socklen_t addressSize; + private: + struct sockaddr_storage sa; + }; + + extern SockAddr unknownAddress; // ( "0.0.0.0", 0 ) + + /** this is not cache and does a syscall */ + string getHostName(); + + /** this is cached, so if changes during the process lifetime + * will be stale */ + string getHostNameCached(); + + /** + * thrown by Socket and SockAddr + */ + class SocketException : public DBException { + public: + const enum Type { CLOSED , RECV_ERROR , SEND_ERROR, RECV_TIMEOUT, SEND_TIMEOUT, FAILED_STATE, CONNECT_ERROR } _type; + + SocketException( Type t , string server , int code = 9001 , string extra="" ) + : DBException( "socket exception" , code ) , _type(t) , _server(server), _extra(extra){ } + virtual ~SocketException() throw() {} + + bool shouldPrint() const { return _type != CLOSED; } + virtual string toString() const; + + private: + string _server; + string _extra; + }; + +#ifdef MONGO_SSL + class SSLManager : boost::noncopyable { + public: + SSLManager( bool client ); + + void setupPEM( const string& keyFile , const string& password ); + void setupPubPriv( const string& privateKeyFile , const string& publicKeyFile ); + + /** + * creates an SSL context to be used for this file descriptor + * caller should delete + */ + SSL * secure( int fd ); + + static int password_cb( char *buf,int num, int rwflag,void *userdata ); + + private: + bool _client; + SSL_CTX* _context; + string _password; + }; +#endif + + /** + * thin wrapped around file descriptor and system calls + * todo: ssl + */ + class Socket { + public: + Socket(int sock, const SockAddr& farEnd); + + /** In some cases the timeout will actually be 2x this value - eg we do a partial send, + then the timeout fires, then we try to send again, then the timeout fires again with + no data sent, then we detect that the other side is down. + + Generally you don't want a timeout, you should be very prepared for errors if you set one. + */ + Socket(double so_timeout = 0, int logLevel = 0 ); + + bool connect(SockAddr& farEnd); + void close(); + + void send( const char * data , int len, const char *context ); + void send( const vector< pair< char *, int > > &data, const char *context ); + + // recv len or throw SocketException + void recv( char * data , int len ); + int unsafe_recv( char *buf, int max ); + + int getLogLevel() const { return _logLevel; } + void setLogLevel( int ll ) { _logLevel = ll; } + + SockAddr remoteAddr() const { return _remote; } + string remoteString() const { return _remote.toString(); } + unsigned remotePort() const { return _remote.getPort(); } + + void clearCounters() { _bytesIn = 0; _bytesOut = 0; } + long long getBytesIn() const { return _bytesIn; } + long long getBytesOut() const { return _bytesOut; } + + void setTimeout( double secs ); + +#ifdef MONGO_SSL + /** secures inline */ + void secure( SSLManager * ssl ); + + void secureAccepted( SSLManager * ssl ); +#endif + + /** + * call this after a fork for server sockets + */ + void postFork(); + + private: + void _init(); + /** raw send, same semantics as ::send */ + int _send( const char * data , int len ); + + /** sends dumbly, just each buffer at a time */ + void _send( const vector< pair< char *, int > > &data, const char *context ); + + /** raw recv, same semantics as ::recv */ + int _recv( char * buf , int max ); + + int _fd; + SockAddr _remote; + double _timeout; + + long long _bytesIn; + long long _bytesOut; + +#ifdef MONGO_SSL + shared_ptr<SSL> _ssl; + SSLManager * _sslAccepted; +#endif + + protected: + int _logLevel; // passed to log() when logging errors + + }; + + +} // namespace mongo |