summaryrefslogtreecommitdiff
path: root/util/net
diff options
context:
space:
mode:
Diffstat (limited to 'util/net')
-rw-r--r--util/net/hostandport.h165
-rw-r--r--util/net/httpclient.cpp177
-rw-r--r--util/net/httpclient.h79
-rw-r--r--util/net/listen.cpp391
-rw-r--r--util/net/listen.h190
-rw-r--r--util/net/message.cpp64
-rw-r--r--util/net/message.h312
-rw-r--r--util/net/message_port.cpp298
-rw-r--r--util/net/message_port.h107
-rw-r--r--util/net/message_server.h66
-rw-r--r--util/net/message_server_asio.cpp261
-rw-r--r--util/net/message_server_port.cpp197
-rw-r--r--util/net/miniwebserver.cpp207
-rw-r--r--util/net/miniwebserver.h60
-rw-r--r--util/net/sock.cpp713
-rw-r--r--util/net/sock.h256
16 files changed, 3543 insertions, 0 deletions
diff --git a/util/net/hostandport.h b/util/net/hostandport.h
new file mode 100644
index 0000000..573e8ee
--- /dev/null
+++ b/util/net/hostandport.h
@@ -0,0 +1,165 @@
+// hostandport.h
+
+/* Copyright 2009 10gen Inc.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#pragma once
+
+#include "sock.h"
+#include "../../db/cmdline.h"
+#include "../mongoutils/str.h"
+
+namespace mongo {
+
+ using namespace mongoutils;
+
+ /** helper for manipulating host:port connection endpoints.
+ */
+ struct HostAndPort {
+ HostAndPort() : _port(-1) { }
+
+ /** From a string hostname[:portnumber]
+ Throws user assertion if bad config string or bad port #.
+ */
+ HostAndPort(string s);
+
+ /** @param p port number. -1 is ok to use default. */
+ HostAndPort(string h, int p /*= -1*/) : _host(h), _port(p) { }
+
+ HostAndPort(const SockAddr& sock )
+ : _host( sock.getAddr() ) , _port( sock.getPort() ) {
+ }
+
+ static HostAndPort me() {
+ return HostAndPort("localhost", cmdLine.port);
+ }
+
+ /* uses real hostname instead of localhost */
+ static HostAndPort Me();
+
+ bool operator<(const HostAndPort& r) const {
+ if( _host < r._host )
+ return true;
+ if( _host == r._host )
+ return port() < r.port();
+ return false;
+ }
+
+ bool operator==(const HostAndPort& r) const {
+ return _host == r._host && port() == r.port();
+ }
+
+ bool operator!=(const HostAndPort& r) const {
+ return _host != r._host || port() != r.port();
+ }
+
+ /* returns true if the host/port combo identifies this process instance. */
+ bool isSelf() const; // defined in message.cpp
+
+ bool isLocalHost() const;
+
+ /**
+ * @param includePort host:port if true, host otherwise
+ */
+ string toString( bool includePort=true ) const;
+
+ operator string() const { return toString(); }
+
+ string host() const { return _host; }
+
+ int port() const { return _port >= 0 ? _port : CmdLine::DefaultDBPort; }
+ bool hasPort() const { return _port >= 0; }
+ void setPort( int port ) { _port = port; }
+
+ private:
+ // invariant (except full obj assignment):
+ string _host;
+ int _port; // -1 indicates unspecified
+ };
+
+ inline HostAndPort HostAndPort::Me() {
+ const char* ips = cmdLine.bind_ip.c_str();
+ while(*ips) {
+ string ip;
+ const char * comma = strchr(ips, ',');
+ if (comma) {
+ ip = string(ips, comma - ips);
+ ips = comma + 1;
+ }
+ else {
+ ip = string(ips);
+ ips = "";
+ }
+ HostAndPort h = HostAndPort(ip, cmdLine.port);
+ if (!h.isLocalHost()) {
+ return h;
+ }
+ }
+
+ string h = getHostName();
+ assert( !h.empty() );
+ assert( h != "localhost" );
+ return HostAndPort(h, cmdLine.port);
+ }
+
+ inline string HostAndPort::toString( bool includePort ) const {
+ if ( ! includePort )
+ return _host;
+
+ stringstream ss;
+ ss << _host;
+ if ( _port != -1 ) {
+ ss << ':';
+#if defined(_DEBUG)
+ if( _port >= 44000 && _port < 44100 ) {
+ log() << "warning: special debug port 44xxx used" << endl;
+ ss << _port+1;
+ }
+ else
+ ss << _port;
+#else
+ ss << _port;
+#endif
+ }
+ return ss.str();
+ }
+
+ inline bool HostAndPort::isLocalHost() const {
+ return ( _host == "localhost"
+ || startsWith(_host.c_str(), "127.")
+ || _host == "::1"
+ || _host == "anonymous unix socket"
+ || _host.c_str()[0] == '/' // unix socket
+ );
+ }
+
+ inline HostAndPort::HostAndPort(string s) {
+ const char *p = s.c_str();
+ uassert(13110, "HostAndPort: bad config string", *p);
+ const char *colon = strrchr(p, ':');
+ if( colon ) {
+ int port = atoi(colon+1);
+ uassert(13095, "HostAndPort: bad port #", port > 0);
+ _host = string(p,colon-p);
+ _port = port;
+ }
+ else {
+ // no port specified.
+ _host = p;
+ _port = -1;
+ }
+ }
+
+}
diff --git a/util/net/httpclient.cpp b/util/net/httpclient.cpp
new file mode 100644
index 0000000..16eaa0a
--- /dev/null
+++ b/util/net/httpclient.cpp
@@ -0,0 +1,177 @@
+// httpclient.cpp
+
+/* Copyright 2009 10gen Inc.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "pch.h"
+#include "httpclient.h"
+#include "sock.h"
+#include "message.h"
+#include "message_port.h"
+#include "../mongoutils/str.h"
+#include "../../bson/util/builder.h"
+
+namespace mongo {
+
+ //#define HD(x) cout << x << endl;
+#define HD(x)
+
+
+ int HttpClient::get( string url , Result * result ) {
+ return _go( "GET" , url , 0 , result );
+ }
+
+ int HttpClient::post( string url , string data , Result * result ) {
+ return _go( "POST" , url , data.c_str() , result );
+ }
+
+ int HttpClient::_go( const char * command , string url , const char * body , Result * result ) {
+ bool ssl = false;
+ if ( url.find( "https://" ) == 0 ) {
+ ssl = true;
+ url = url.substr( 8 );
+ }
+ else {
+ uassert( 10271 , "invalid url" , url.find( "http://" ) == 0 );
+ url = url.substr( 7 );
+ }
+
+ string host , path;
+ if ( url.find( "/" ) == string::npos ) {
+ host = url;
+ path = "/";
+ }
+ else {
+ host = url.substr( 0 , url.find( "/" ) );
+ path = url.substr( url.find( "/" ) );
+ }
+
+
+ HD( "host [" << host << "]" );
+ HD( "path [" << path << "]" );
+
+ string server = host;
+ int port = ssl ? 443 : 80;
+
+ string::size_type idx = host.find( ":" );
+ if ( idx != string::npos ) {
+ server = host.substr( 0 , idx );
+ string t = host.substr( idx + 1 );
+ port = atoi( t.c_str() );
+ }
+
+ HD( "server [" << server << "]" );
+ HD( "port [" << port << "]" );
+
+ string req;
+ {
+ stringstream ss;
+ ss << command << " " << path << " HTTP/1.1\r\n";
+ ss << "Host: " << host << "\r\n";
+ ss << "Connection: Close\r\n";
+ ss << "User-Agent: mongodb http client\r\n";
+ if ( body ) {
+ ss << "Content-Length: " << strlen( body ) << "\r\n";
+ }
+ ss << "\r\n";
+ if ( body ) {
+ ss << body;
+ }
+
+ req = ss.str();
+ }
+
+ SockAddr addr( server.c_str() , port );
+ HD( "addr: " << addr.toString() );
+
+ Socket sock;
+ if ( ! sock.connect( addr ) )
+ return -1;
+
+ if ( ssl ) {
+#ifdef MONGO_SSL
+ _checkSSLManager();
+ sock.secure( _sslManager.get() );
+#else
+ uasserted( 15862 , "no ssl support" );
+#endif
+ }
+
+ {
+ const char * out = req.c_str();
+ int toSend = req.size();
+ sock.send( out , toSend, "_go" );
+ }
+
+ char buf[4096];
+ int got = sock.unsafe_recv( buf , 4096 );
+ buf[got] = 0;
+
+ int rc;
+ char version[32];
+ assert( sscanf( buf , "%s %d" , version , &rc ) == 2 );
+ HD( "rc: " << rc );
+
+ StringBuilder sb;
+ if ( result )
+ sb << buf;
+
+ while ( ( got = sock.unsafe_recv( buf , 4096 ) ) > 0) {
+ if ( result )
+ sb << buf;
+ }
+
+ if ( result ) {
+ result->_init( rc , sb.str() );
+ }
+
+ return rc;
+ }
+
+ void HttpClient::Result::_init( int code , string entire ) {
+ _code = code;
+ _entireResponse = entire;
+
+ while ( true ) {
+ size_t i = entire.find( '\n' );
+ if ( i == string::npos ) {
+ // invalid
+ break;
+ }
+
+ string h = entire.substr( 0 , i );
+ entire = entire.substr( i + 1 );
+
+ if ( h.size() && h[h.size()-1] == '\r' )
+ h = h.substr( 0 , h.size() - 1 );
+
+ if ( h.size() == 0 )
+ break;
+
+ i = h.find( ':' );
+ if ( i != string::npos )
+ _headers[h.substr(0,i)] = str::ltrim(h.substr(i+1));
+ }
+
+ _body = entire;
+ }
+
+#ifdef MONGO_SSL
+ void HttpClient::_checkSSLManager() {
+ _sslManager.reset( new SSLManager( true ) );
+ }
+#endif
+
+}
diff --git a/util/net/httpclient.h b/util/net/httpclient.h
new file mode 100644
index 0000000..c3f8c82
--- /dev/null
+++ b/util/net/httpclient.h
@@ -0,0 +1,79 @@
+// httpclient.h
+
+/* Copyright 2009 10gen Inc.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#pragma once
+
+#include "../../pch.h"
+#include "sock.h"
+
+namespace mongo {
+
+ class HttpClient : boost::noncopyable {
+ public:
+
+ typedef map<string,string> Headers;
+
+ class Result {
+ public:
+ Result() {}
+
+ const string& getEntireResponse() const {
+ return _entireResponse;
+ }
+
+ const Headers getHeaders() const {
+ return _headers;
+ }
+
+ const string& getBody() const {
+ return _body;
+ }
+
+ private:
+
+ void _init( int code , string entire );
+
+ int _code;
+ string _entireResponse;
+
+ Headers _headers;
+ string _body;
+
+ friend class HttpClient;
+ };
+
+ /**
+ * @return response code
+ */
+ int get( string url , Result * result = 0 );
+
+ /**
+ * @return response code
+ */
+ int post( string url , string body , Result * result = 0 );
+
+ private:
+ int _go( const char * command , string url , const char * body , Result * result );
+
+#ifdef MONGO_SSL
+ void _checkSSLManager();
+
+ scoped_ptr<SSLManager> _sslManager;
+#endif
+ };
+}
+
diff --git a/util/net/listen.cpp b/util/net/listen.cpp
new file mode 100644
index 0000000..6ee25b4
--- /dev/null
+++ b/util/net/listen.cpp
@@ -0,0 +1,391 @@
+// listen.h
+
+/* Copyright 2009 10gen Inc.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+
+#include "pch.h"
+#include "listen.h"
+#include "message_port.h"
+
+#ifndef _WIN32
+
+# ifndef __sunos__
+# include <ifaddrs.h>
+# endif
+# include <sys/resource.h>
+# include <sys/stat.h>
+
+#include <sys/types.h>
+#include <sys/socket.h>
+#include <sys/un.h>
+#include <netinet/in.h>
+#include <netinet/tcp.h>
+#include <arpa/inet.h>
+#include <errno.h>
+#include <netdb.h>
+#ifdef __openbsd__
+# include <sys/uio.h>
+#endif
+
+#else
+
+// errno doesn't work for winsock.
+#undef errno
+#define errno WSAGetLastError()
+
+#endif
+
+namespace mongo {
+
+
+ void checkTicketNumbers();
+
+
+ // ----- Listener -------
+
+ const Listener* Listener::_timeTracker;
+
+ vector<SockAddr> ipToAddrs(const char* ips, int port, bool useUnixSockets) {
+ vector<SockAddr> out;
+ if (*ips == '\0') {
+ out.push_back(SockAddr("0.0.0.0", port)); // IPv4 all
+
+ if (IPv6Enabled())
+ out.push_back(SockAddr("::", port)); // IPv6 all
+#ifndef _WIN32
+ if (useUnixSockets)
+ out.push_back(SockAddr(makeUnixSockPath(port).c_str(), port)); // Unix socket
+#endif
+ return out;
+ }
+
+ while(*ips) {
+ string ip;
+ const char * comma = strchr(ips, ',');
+ if (comma) {
+ ip = string(ips, comma - ips);
+ ips = comma + 1;
+ }
+ else {
+ ip = string(ips);
+ ips = "";
+ }
+
+ SockAddr sa(ip.c_str(), port);
+ out.push_back(sa);
+
+#ifndef _WIN32
+ if (useUnixSockets && (sa.getAddr() == "127.0.0.1" || sa.getAddr() == "0.0.0.0")) // only IPv4
+ out.push_back(SockAddr(makeUnixSockPath(port).c_str(), port));
+#endif
+ }
+ return out;
+
+ }
+
+ Listener::Listener(const string& name, const string &ip, int port, bool logConnect )
+ : _port(port), _name(name), _ip(ip), _logConnect(logConnect), _elapsedTime(0) {
+#ifdef MONGO_SSL
+ _ssl = 0;
+ _sslPort = 0;
+
+ if ( cmdLine.sslOnNormalPorts && cmdLine.sslServerManager ) {
+ secure( cmdLine.sslServerManager );
+ }
+#endif
+ }
+
+ Listener::~Listener() {
+ if ( _timeTracker == this )
+ _timeTracker = 0;
+ }
+
+#ifdef MONGO_SSL
+ void Listener::secure( SSLManager* manager ) {
+ _ssl = manager;
+ }
+
+ void Listener::addSecurePort( SSLManager* manager , int additionalPort ) {
+ _ssl = manager;
+ _sslPort = additionalPort;
+ }
+
+#endif
+
+ bool Listener::_setupSockets( const vector<SockAddr>& mine , vector<int>& socks ) {
+ for (vector<SockAddr>::const_iterator it=mine.begin(), end=mine.end(); it != end; ++it) {
+ const SockAddr& me = *it;
+
+ SOCKET sock = ::socket(me.getType(), SOCK_STREAM, 0);
+ massert( 15863 , str::stream() << "listen(): invalid socket? " << errnoWithDescription() , sock >= 0 );
+
+ if (me.getType() == AF_UNIX) {
+#if !defined(_WIN32)
+ if (unlink(me.getAddr().c_str()) == -1) {
+ int x = errno;
+ if (x != ENOENT) {
+ log() << "couldn't unlink socket file " << me << errnoWithDescription(x) << " skipping" << endl;
+ continue;
+ }
+ }
+#endif
+ }
+ else if (me.getType() == AF_INET6) {
+ // IPv6 can also accept IPv4 connections as mapped addresses (::ffff:127.0.0.1)
+ // That causes a conflict if we don't do set it to IPV6_ONLY
+ const int one = 1;
+ setsockopt(sock, IPPROTO_IPV6, IPV6_V6ONLY, (const char*) &one, sizeof(one));
+ }
+
+#if !defined(_WIN32)
+ {
+ const int one = 1;
+ if ( setsockopt( sock , SOL_SOCKET, SO_REUSEADDR, &one, sizeof(one)) < 0 )
+ out() << "Failed to set socket opt, SO_REUSEADDR" << endl;
+ }
+#endif
+
+ if ( ::bind(sock, me.raw(), me.addressSize) != 0 ) {
+ int x = errno;
+ error() << "listen(): bind() failed " << errnoWithDescription(x) << " for socket: " << me.toString() << endl;
+ if ( x == EADDRINUSE )
+ error() << " addr already in use" << endl;
+ closesocket(sock);
+ return false;
+ }
+
+#if !defined(_WIN32)
+ if (me.getType() == AF_UNIX) {
+ if (chmod(me.getAddr().c_str(), 0777) == -1) {
+ error() << "couldn't chmod socket file " << me << errnoWithDescription() << endl;
+ }
+ ListeningSockets::get()->addPath( me.getAddr() );
+ }
+#endif
+
+ if ( ::listen(sock, 128) != 0 ) {
+ error() << "listen(): listen() failed " << errnoWithDescription() << endl;
+ closesocket(sock);
+ return false;
+ }
+
+ ListeningSockets::get()->add( sock );
+
+ socks.push_back(sock);
+ }
+
+ return true;
+ }
+
+ void Listener::initAndListen() {
+ checkTicketNumbers();
+ vector<int> socks;
+ set<int> sslSocks;
+
+ { // normal sockets
+ vector<SockAddr> mine = ipToAddrs(_ip.c_str(), _port, (!cmdLine.noUnixSocket && useUnixSockets()));
+ if ( ! _setupSockets( mine , socks ) )
+ return;
+ }
+
+#ifdef MONGO_SSL
+ if ( _ssl && _sslPort > 0 ) {
+ unsigned prev = socks.size();
+
+ vector<SockAddr> mine = ipToAddrs(_ip.c_str(), _sslPort, false );
+ if ( ! _setupSockets( mine , socks ) )
+ return;
+
+ for ( unsigned i=prev; i<socks.size(); i++ ) {
+ sslSocks.insert( socks[i] );
+ }
+
+ }
+#endif
+
+ SOCKET maxfd = 0; // needed for select()
+ for ( unsigned i=0; i<socks.size(); i++ ) {
+ if ( socks[i] > maxfd )
+ maxfd = socks[i];
+ }
+
+#ifdef MONGO_SSL
+ if ( _ssl == 0 ) {
+ _logListen( _port , false );
+ }
+ else if ( _sslPort == 0 ) {
+ _logListen( _port , true );
+ }
+ else {
+ // both
+ _logListen( _port , false );
+ _logListen( _sslPort , true );
+ }
+#else
+ _logListen( _port , false );
+#endif
+
+ static long connNumber = 0;
+ struct timeval maxSelectTime;
+ while ( ! inShutdown() ) {
+ fd_set fds[1];
+ FD_ZERO(fds);
+
+ for (vector<int>::iterator it=socks.begin(), end=socks.end(); it != end; ++it) {
+ FD_SET(*it, fds);
+ }
+
+ maxSelectTime.tv_sec = 0;
+ maxSelectTime.tv_usec = 10000;
+ const int ret = select(maxfd+1, fds, NULL, NULL, &maxSelectTime);
+
+ if (ret == 0) {
+#if defined(__linux__)
+ _elapsedTime += ( 10000 - maxSelectTime.tv_usec ) / 1000;
+#else
+ _elapsedTime += 10;
+#endif
+ continue;
+ }
+
+ if (ret < 0) {
+ int x = errno;
+#ifdef EINTR
+ if ( x == EINTR ) {
+ log() << "select() signal caught, continuing" << endl;
+ continue;
+ }
+#endif
+ if ( ! inShutdown() )
+ log() << "select() failure: ret=" << ret << " " << errnoWithDescription(x) << endl;
+ return;
+ }
+
+#if defined(__linux__)
+ _elapsedTime += max(ret, (int)(( 10000 - maxSelectTime.tv_usec ) / 1000));
+#else
+ _elapsedTime += ret; // assume 1ms to grab connection. very rough
+#endif
+
+ for (vector<int>::iterator it=socks.begin(), end=socks.end(); it != end; ++it) {
+ if (! (FD_ISSET(*it, fds)))
+ continue;
+
+ SockAddr from;
+ int s = accept(*it, from.raw(), &from.addressSize);
+ if ( s < 0 ) {
+ int x = errno; // so no global issues
+ if ( x == ECONNABORTED || x == EBADF ) {
+ log() << "Listener on port " << _port << " aborted" << endl;
+ return;
+ }
+ if ( x == 0 && inShutdown() ) {
+ return; // socket closed
+ }
+ if( !inShutdown() ) {
+ log() << "Listener: accept() returns " << s << " " << errnoWithDescription(x) << endl;
+ if (x == EMFILE || x == ENFILE) {
+ // Connection still in listen queue but we can't accept it yet
+ error() << "Out of file descriptors. Waiting one second before trying to accept more connections." << warnings;
+ sleepsecs(1);
+ }
+ }
+ continue;
+ }
+ if (from.getType() != AF_UNIX)
+ disableNagle(s);
+ if ( _logConnect && ! cmdLine.quiet )
+ log() << "connection accepted from " << from.toString() << " #" << ++connNumber << endl;
+
+ Socket newSock = Socket(s, from);
+#ifdef MONGO_SSL
+ if ( _ssl && ( _sslPort == 0 || sslSocks.count(*it) ) ) {
+ newSock.secureAccepted( _ssl );
+ }
+#endif
+ accepted( newSock );
+ }
+ }
+ }
+
+ void Listener::_logListen( int port , bool ssl ) {
+ log() << _name << ( _name.size() ? " " : "" ) << "waiting for connections on port " << port << ( ssl ? " ssl" : "" ) << endl;
+ }
+
+
+ void Listener::accepted(Socket socket) {
+ accepted( new MessagingPort(socket) );
+ }
+
+ void Listener::accepted(MessagingPort *mp) {
+ assert(!"You must overwrite one of the accepted methods");
+ }
+
+ // ----- ListeningSockets -------
+
+ ListeningSockets* ListeningSockets::_instance = new ListeningSockets();
+
+ ListeningSockets* ListeningSockets::get() {
+ return _instance;
+ }
+
+ // ------ connection ticket and control ------
+
+ const int DEFAULT_MAX_CONN = 20000;
+ const int MAX_MAX_CONN = 20000;
+
+ int getMaxConnections() {
+#ifdef _WIN32
+ return DEFAULT_MAX_CONN;
+#else
+ struct rlimit limit;
+ assert( getrlimit(RLIMIT_NOFILE,&limit) == 0 );
+
+ int max = (int)(limit.rlim_cur * .8);
+
+ log(1) << "fd limit"
+ << " hard:" << limit.rlim_max
+ << " soft:" << limit.rlim_cur
+ << " max conn: " << max
+ << endl;
+
+ if ( max > MAX_MAX_CONN )
+ max = MAX_MAX_CONN;
+
+ return max;
+#endif
+ }
+
+ void checkTicketNumbers() {
+ int want = getMaxConnections();
+ int current = connTicketHolder.outof();
+ if ( current != DEFAULT_MAX_CONN ) {
+ if ( current < want ) {
+ // they want fewer than they can handle
+ // which is fine
+ log(1) << " only allowing " << current << " connections" << endl;
+ return;
+ }
+ if ( current > want ) {
+ log() << " --maxConns too high, can only handle " << want << endl;
+ }
+ }
+ connTicketHolder.resize( want );
+ }
+
+ TicketHolder connTicketHolder(DEFAULT_MAX_CONN);
+
+}
diff --git a/util/net/listen.h b/util/net/listen.h
new file mode 100644
index 0000000..415db1e
--- /dev/null
+++ b/util/net/listen.h
@@ -0,0 +1,190 @@
+// listen.h
+
+/* Copyright 2009 10gen Inc.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#pragma once
+
+#include "sock.h"
+
+namespace mongo {
+
+ class MessagingPort;
+
+ class Listener : boost::noncopyable {
+ public:
+
+ Listener(const string& name, const string &ip, int port, bool logConnect=true );
+
+ virtual ~Listener();
+
+#ifdef MONGO_SSL
+ /**
+ * make this an ssl socket
+ * ownership of SSLManager remains with the caller
+ */
+ void secure( SSLManager* manager );
+
+ void addSecurePort( SSLManager* manager , int additionalPort );
+#endif
+
+ void initAndListen(); // never returns unless error (start a thread)
+
+ /* spawn a thread, etc., then return */
+ virtual void accepted(Socket socket);
+ virtual void accepted(MessagingPort *mp);
+
+ const int _port;
+
+ /**
+ * @return a rough estimate of elapsed time since the server started
+ */
+ long long getMyElapsedTimeMillis() const { return _elapsedTime; }
+
+ void setAsTimeTracker() {
+ _timeTracker = this;
+ }
+
+ static const Listener* getTimeTracker() {
+ return _timeTracker;
+ }
+
+ static long long getElapsedTimeMillis() {
+ if ( _timeTracker )
+ return _timeTracker->getMyElapsedTimeMillis();
+
+ // should this assert or throw? seems like callers may not expect to get zero back, certainly not forever.
+ return 0;
+ }
+
+ private:
+ string _name;
+ string _ip;
+ bool _logConnect;
+ long long _elapsedTime;
+
+#ifdef MONGO_SSL
+ SSLManager* _ssl;
+ int _sslPort;
+#endif
+
+ /**
+ * @return true iff everything went ok
+ */
+ bool _setupSockets( const vector<SockAddr>& mine , vector<int>& socks );
+
+ void _logListen( int port , bool ssl );
+
+ static const Listener* _timeTracker;
+
+ virtual bool useUnixSockets() const { return false; }
+ };
+
+ /**
+ * keep track of elapsed time
+ * after a set amount of time, tells you to do something
+ * only in this file because depends on Listener
+ */
+ class ElapsedTracker {
+ public:
+ ElapsedTracker( int hitsBetweenMarks , int msBetweenMarks )
+ : _h( hitsBetweenMarks ) , _ms( msBetweenMarks ) , _pings(0) {
+ _last = Listener::getElapsedTimeMillis();
+ }
+
+ /**
+ * call this for every iteration
+ * returns true if one of the triggers has gone off
+ */
+ bool ping() {
+ if ( ( ++_pings % _h ) == 0 ) {
+ _last = Listener::getElapsedTimeMillis();
+ return true;
+ }
+
+ long long now = Listener::getElapsedTimeMillis();
+ if ( now - _last > _ms ) {
+ _last = now;
+ return true;
+ }
+
+ return false;
+ }
+
+ private:
+ int _h;
+ int _ms;
+
+ unsigned long long _pings;
+
+ long long _last;
+
+ };
+
+ class ListeningSockets {
+ public:
+ ListeningSockets()
+ : _mutex("ListeningSockets")
+ , _sockets( new set<int>() )
+ , _socketPaths( new set<string>() )
+ { }
+ void add( int sock ) {
+ scoped_lock lk( _mutex );
+ _sockets->insert( sock );
+ }
+ void addPath( string path ) {
+ scoped_lock lk( _mutex );
+ _socketPaths->insert( path );
+ }
+ void remove( int sock ) {
+ scoped_lock lk( _mutex );
+ _sockets->erase( sock );
+ }
+ void closeAll() {
+ set<int>* sockets;
+ set<string>* paths;
+
+ {
+ scoped_lock lk( _mutex );
+ sockets = _sockets;
+ _sockets = new set<int>();
+ paths = _socketPaths;
+ _socketPaths = new set<string>();
+ }
+
+ for ( set<int>::iterator i=sockets->begin(); i!=sockets->end(); i++ ) {
+ int sock = *i;
+ log() << "closing listening socket: " << sock << endl;
+ closesocket( sock );
+ }
+
+ for ( set<string>::iterator i=paths->begin(); i!=paths->end(); i++ ) {
+ string path = *i;
+ log() << "removing socket file: " << path << endl;
+ ::remove( path.c_str() );
+ }
+ }
+ static ListeningSockets* get();
+ private:
+ mongo::mutex _mutex;
+ set<int>* _sockets;
+ set<string>* _socketPaths; // for unix domain sockets
+ static ListeningSockets* _instance;
+ };
+
+
+ extern TicketHolder connTicketHolder;
+
+}
diff --git a/util/net/message.cpp b/util/net/message.cpp
new file mode 100644
index 0000000..a84e5c4
--- /dev/null
+++ b/util/net/message.cpp
@@ -0,0 +1,64 @@
+// message.cpp
+
+/* Copyright 2009 10gen Inc.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "pch.h"
+
+#include <fcntl.h>
+#include <errno.h>
+#include <time.h>
+
+#include "message.h"
+#include "message_port.h"
+#include "listen.h"
+
+#include "../goodies.h"
+#include "../../client/dbclient.h"
+
+namespace mongo {
+
+ void Message::send( MessagingPort &p, const char *context ) {
+ if ( empty() ) {
+ return;
+ }
+ if ( _buf != 0 ) {
+ p.send( (char*)_buf, _buf->len, context );
+ }
+ else {
+ p.send( _data, context );
+ }
+ }
+
+ MSGID NextMsgId;
+
+ /*struct MsgStart {
+ MsgStart() {
+ NextMsgId = (((unsigned) time(0)) << 16) ^ curTimeMillis();
+ assert(MsgDataHeaderSize == 16);
+ }
+ } msgstart;*/
+
+ MSGID nextMessageId() {
+ MSGID msgid = NextMsgId++;
+ return msgid;
+ }
+
+ bool doesOpGetAResponse( int op ) {
+ return op == dbQuery || op == dbGetMore;
+ }
+
+
+} // namespace mongo
diff --git a/util/net/message.h b/util/net/message.h
new file mode 100644
index 0000000..16da5d6
--- /dev/null
+++ b/util/net/message.h
@@ -0,0 +1,312 @@
+// message.h
+
+/* Copyright 2009 10gen Inc.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#pragma once
+
+#include "sock.h"
+#include "../../bson/util/atomic_int.h"
+#include "hostandport.h"
+
+namespace mongo {
+
+ class Message;
+ class MessagingPort;
+ class PiggyBackData;
+
+ typedef AtomicUInt MSGID;
+
+ enum Operations {
+ opReply = 1, /* reply. responseTo is set. */
+ dbMsg = 1000, /* generic msg command followed by a string */
+ dbUpdate = 2001, /* update object */
+ dbInsert = 2002,
+ //dbGetByOID = 2003,
+ dbQuery = 2004,
+ dbGetMore = 2005,
+ dbDelete = 2006,
+ dbKillCursors = 2007
+ };
+
+ bool doesOpGetAResponse( int op );
+
+ inline const char * opToString( int op ) {
+ switch ( op ) {
+ case 0: return "none";
+ case opReply: return "reply";
+ case dbMsg: return "msg";
+ case dbUpdate: return "update";
+ case dbInsert: return "insert";
+ case dbQuery: return "query";
+ case dbGetMore: return "getmore";
+ case dbDelete: return "remove";
+ case dbKillCursors: return "killcursors";
+ default:
+ PRINT(op);
+ assert(0);
+ return "";
+ }
+ }
+
+ inline bool opIsWrite( int op ) {
+ switch ( op ) {
+
+ case 0:
+ case opReply:
+ case dbMsg:
+ case dbQuery:
+ case dbGetMore:
+ case dbKillCursors:
+ return false;
+
+ case dbUpdate:
+ case dbInsert:
+ case dbDelete:
+ return false;
+
+ default:
+ PRINT(op);
+ assert(0);
+ return "";
+ }
+
+ }
+
+#pragma pack(1)
+ /* see http://www.mongodb.org/display/DOCS/Mongo+Wire+Protocol
+ */
+ struct MSGHEADER {
+ int messageLength; // total message size, including this
+ int requestID; // identifier for this message
+ int responseTo; // requestID from the original request
+ // (used in reponses from db)
+ int opCode;
+ };
+ struct OP_GETMORE : public MSGHEADER {
+ MSGHEADER header; // standard message header
+ int ZERO_or_flags; // 0 - reserved for future use
+ //cstring fullCollectionName; // "dbname.collectionname"
+ //int32 numberToReturn; // number of documents to return
+ //int64 cursorID; // cursorID from the OP_REPLY
+ };
+#pragma pack()
+
+#pragma pack(1)
+ /* todo merge this with MSGHEADER (or inherit from it). */
+ struct MsgData {
+ int len; /* len of the msg, including this field */
+ MSGID id; /* request/reply id's match... */
+ MSGID responseTo; /* id of the message we are responding to */
+ short _operation;
+ char _flags;
+ char _version;
+ int operation() const {
+ return _operation;
+ }
+ void setOperation(int o) {
+ _flags = 0;
+ _version = 0;
+ _operation = o;
+ }
+ char _data[4];
+
+ int& dataAsInt() {
+ return *((int *) _data);
+ }
+
+ bool valid() {
+ if ( len <= 0 || len > ( 4 * BSONObjMaxInternalSize ) )
+ return false;
+ if ( _operation < 0 || _operation > 30000 )
+ return false;
+ return true;
+ }
+
+ long long getCursor() {
+ assert( responseTo > 0 );
+ assert( _operation == opReply );
+ long long * l = (long long *)(_data + 4);
+ return l[0];
+ }
+
+ int dataLen(); // len without header
+ };
+ const int MsgDataHeaderSize = sizeof(MsgData) - 4;
+ inline int MsgData::dataLen() {
+ return len - MsgDataHeaderSize;
+ }
+#pragma pack()
+
+ class Message {
+ public:
+ // we assume here that a vector with initial size 0 does no allocation (0 is the default, but wanted to make it explicit).
+ Message() : _buf( 0 ), _data( 0 ), _freeIt( false ) {}
+ Message( void * data , bool freeIt ) :
+ _buf( 0 ), _data( 0 ), _freeIt( false ) {
+ _setData( reinterpret_cast< MsgData* >( data ), freeIt );
+ };
+ Message(Message& r) : _buf( 0 ), _data( 0 ), _freeIt( false ) {
+ *this = r;
+ }
+ ~Message() {
+ reset();
+ }
+
+ SockAddr _from;
+
+ MsgData *header() const {
+ assert( !empty() );
+ return _buf ? _buf : reinterpret_cast< MsgData* > ( _data[ 0 ].first );
+ }
+ int operation() const { return header()->operation(); }
+
+ MsgData *singleData() const {
+ massert( 13273, "single data buffer expected", _buf );
+ return header();
+ }
+
+ bool empty() const { return !_buf && _data.empty(); }
+
+ int size() const {
+ int res = 0;
+ if ( _buf ) {
+ res = _buf->len;
+ }
+ else {
+ for (MsgVec::const_iterator it = _data.begin(); it != _data.end(); ++it) {
+ res += it->second;
+ }
+ }
+ return res;
+ }
+
+ int dataSize() const { return size() - sizeof(MSGHEADER); }
+
+ // concat multiple buffers - noop if <2 buffers already, otherwise can be expensive copy
+ // can get rid of this if we make response handling smarter
+ void concat() {
+ if ( _buf || empty() ) {
+ return;
+ }
+
+ assert( _freeIt );
+ int totalSize = 0;
+ for( vector< pair< char *, int > >::const_iterator i = _data.begin(); i != _data.end(); ++i ) {
+ totalSize += i->second;
+ }
+ char *buf = (char*)malloc( totalSize );
+ char *p = buf;
+ for( vector< pair< char *, int > >::const_iterator i = _data.begin(); i != _data.end(); ++i ) {
+ memcpy( p, i->first, i->second );
+ p += i->second;
+ }
+ reset();
+ _setData( (MsgData*)buf, true );
+ }
+
+ // vector swap() so this is fast
+ Message& operator=(Message& r) {
+ assert( empty() );
+ assert( r._freeIt );
+ _buf = r._buf;
+ r._buf = 0;
+ if ( r._data.size() > 0 ) {
+ _data.swap( r._data );
+ }
+ r._freeIt = false;
+ _freeIt = true;
+ return *this;
+ }
+
+ void reset() {
+ if ( _freeIt ) {
+ if ( _buf ) {
+ free( _buf );
+ }
+ for( vector< pair< char *, int > >::const_iterator i = _data.begin(); i != _data.end(); ++i ) {
+ free(i->first);
+ }
+ }
+ _buf = 0;
+ _data.clear();
+ _freeIt = false;
+ }
+
+ // use to add a buffer
+ // assumes message will free everything
+ void appendData(char *d, int size) {
+ if ( size <= 0 ) {
+ return;
+ }
+ if ( empty() ) {
+ MsgData *md = (MsgData*)d;
+ md->len = size; // can be updated later if more buffers added
+ _setData( md, true );
+ return;
+ }
+ assert( _freeIt );
+ if ( _buf ) {
+ _data.push_back( make_pair( (char*)_buf, _buf->len ) );
+ _buf = 0;
+ }
+ _data.push_back( make_pair( d, size ) );
+ header()->len += size;
+ }
+
+ // use to set first buffer if empty
+ void setData(MsgData *d, bool freeIt) {
+ assert( empty() );
+ _setData( d, freeIt );
+ }
+ void setData(int operation, const char *msgtxt) {
+ setData(operation, msgtxt, strlen(msgtxt)+1);
+ }
+ void setData(int operation, const char *msgdata, size_t len) {
+ assert( empty() );
+ size_t dataLen = len + sizeof(MsgData) - 4;
+ MsgData *d = (MsgData *) malloc(dataLen);
+ memcpy(d->_data, msgdata, len);
+ d->len = fixEndian(dataLen);
+ d->setOperation(operation);
+ _setData( d, true );
+ }
+
+ bool doIFreeIt() {
+ return _freeIt;
+ }
+
+ void send( MessagingPort &p, const char *context );
+
+ string toString() const;
+
+ private:
+ void _setData( MsgData *d, bool freeIt ) {
+ _freeIt = freeIt;
+ _buf = d;
+ }
+ // if just one buffer, keep it in _buf, otherwise keep a sequence of buffers in _data
+ MsgData * _buf;
+ // byte buffer(s) - the first must contain at least a full MsgData unless using _buf for storage instead
+ typedef vector< pair< char*, int > > MsgVec;
+ MsgVec _data;
+ bool _freeIt;
+ };
+
+
+ MSGID nextMessageId();
+
+
+} // namespace mongo
diff --git a/util/net/message_port.cpp b/util/net/message_port.cpp
new file mode 100644
index 0000000..9abfaf7
--- /dev/null
+++ b/util/net/message_port.cpp
@@ -0,0 +1,298 @@
+// message_port.cpp
+
+/* Copyright 2009 10gen Inc.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "pch.h"
+
+#include <fcntl.h>
+#include <errno.h>
+#include <time.h>
+
+#include "message.h"
+#include "message_port.h"
+#include "listen.h"
+
+#include "../goodies.h"
+#include "../background.h"
+#include "../time_support.h"
+#include "../../db/cmdline.h"
+#include "../../client/dbclient.h"
+
+
+#ifndef _WIN32
+# ifndef __sunos__
+# include <ifaddrs.h>
+# endif
+# include <sys/resource.h>
+# include <sys/stat.h>
+#else
+
+// errno doesn't work for winsock.
+#undef errno
+#define errno WSAGetLastError()
+
+#endif
+
+namespace mongo {
+
+
+// if you want trace output:
+#define mmm(x)
+
+ /* messagingport -------------------------------------------------------------- */
+
+ class PiggyBackData {
+ public:
+ PiggyBackData( MessagingPort * port ) {
+ _port = port;
+ _buf = new char[1300];
+ _cur = _buf;
+ }
+
+ ~PiggyBackData() {
+ DESTRUCTOR_GUARD (
+ flush();
+ delete[]( _cur );
+ );
+ }
+
+ void append( Message& m ) {
+ assert( m.header()->len <= 1300 );
+
+ if ( len() + m.header()->len > 1300 )
+ flush();
+
+ memcpy( _cur , m.singleData() , m.header()->len );
+ _cur += m.header()->len;
+ }
+
+ void flush() {
+ if ( _buf == _cur )
+ return;
+
+ _port->send( _buf , len(), "flush" );
+ _cur = _buf;
+ }
+
+ int len() const { return _cur - _buf; }
+
+ private:
+ MessagingPort* _port;
+ char * _buf;
+ char * _cur;
+ };
+
+ class Ports {
+ set<MessagingPort*> ports;
+ mongo::mutex m;
+ public:
+ Ports() : ports(), m("Ports") {}
+ void closeAll(unsigned skip_mask) {
+ scoped_lock bl(m);
+ for ( set<MessagingPort*>::iterator i = ports.begin(); i != ports.end(); i++ ) {
+ if( (*i)->tag & skip_mask )
+ continue;
+ (*i)->shutdown();
+ }
+ }
+ void insert(MessagingPort* p) {
+ scoped_lock bl(m);
+ ports.insert(p);
+ }
+ void erase(MessagingPort* p) {
+ scoped_lock bl(m);
+ ports.erase(p);
+ }
+ };
+
+ // we "new" this so it is still be around when other automatic global vars
+ // are being destructed during termination.
+ Ports& ports = *(new Ports());
+
+ void MessagingPort::closeAllSockets(unsigned mask) {
+ ports.closeAll(mask);
+ }
+
+ MessagingPort::MessagingPort(int fd, const SockAddr& remote)
+ : Socket( fd , remote ) , piggyBackData(0) {
+ ports.insert(this);
+ }
+
+ MessagingPort::MessagingPort( double timeout, int ll )
+ : Socket( timeout, ll ) {
+ ports.insert(this);
+ piggyBackData = 0;
+ }
+
+ MessagingPort::MessagingPort( Socket& sock )
+ : Socket( sock ) , piggyBackData( 0 ) {
+ }
+
+ void MessagingPort::shutdown() {
+ close();
+ }
+
+ MessagingPort::~MessagingPort() {
+ if ( piggyBackData )
+ delete( piggyBackData );
+ shutdown();
+ ports.erase(this);
+ }
+
+ bool MessagingPort::recv(Message& m) {
+ try {
+again:
+ mmm( log() << "* recv() sock:" << this->sock << endl; )
+ int len = -1;
+
+ char *lenbuf = (char *) &len;
+ int lft = 4;
+ Socket::recv( lenbuf, lft );
+
+ if ( len < 16 || len > 48000000 ) { // messages must be large enough for headers
+ if ( len == -1 ) {
+ // Endian check from the client, after connecting, to see what mode server is running in.
+ unsigned foo = 0x10203040;
+ send( (char *) &foo, 4, "endian" );
+ goto again;
+ }
+
+ if ( len == 542393671 ) {
+ // an http GET
+ log(_logLevel) << "looks like you're trying to access db over http on native driver port. please add 1000 for webserver" << endl;
+ string msg = "You are trying to access MongoDB on the native driver port. For http diagnostic access, add 1000 to the port number\n";
+ stringstream ss;
+ ss << "HTTP/1.0 200 OK\r\nConnection: close\r\nContent-Type: text/plain\r\nContent-Length: " << msg.size() << "\r\n\r\n" << msg;
+ string s = ss.str();
+ send( s.c_str(), s.size(), "http" );
+ return false;
+ }
+ log(0) << "recv(): message len " << len << " is too large" << len << endl;
+ return false;
+ }
+
+ int z = (len+1023)&0xfffffc00;
+ assert(z>=len);
+ MsgData *md = (MsgData *) malloc(z);
+ assert(md);
+ md->len = len;
+
+ char *p = (char *) &md->id;
+ int left = len -4;
+
+ try {
+ Socket::recv( p, left );
+ }
+ catch (...) {
+ free(md);
+ throw;
+ }
+
+ m.setData(md, true);
+ return true;
+
+ }
+ catch ( const SocketException & e ) {
+ log(_logLevel + (e.shouldPrint() ? 0 : 1) ) << "SocketException: remote: " << remote() << " error: " << e << endl;
+ m.reset();
+ return false;
+ }
+ }
+
+ void MessagingPort::reply(Message& received, Message& response) {
+ say(/*received.from, */response, received.header()->id);
+ }
+
+ void MessagingPort::reply(Message& received, Message& response, MSGID responseTo) {
+ say(/*received.from, */response, responseTo);
+ }
+
+ bool MessagingPort::call(Message& toSend, Message& response) {
+ mmm( log() << "*call()" << endl; )
+ say(toSend);
+ return recv( toSend , response );
+ }
+
+ bool MessagingPort::recv( const Message& toSend , Message& response ) {
+ while ( 1 ) {
+ bool ok = recv(response);
+ if ( !ok )
+ return false;
+ //log() << "got response: " << response.data->responseTo << endl;
+ if ( response.header()->responseTo == toSend.header()->id )
+ break;
+ error() << "MessagingPort::call() wrong id got:" << hex << (unsigned)response.header()->responseTo << " expect:" << (unsigned)toSend.header()->id << '\n'
+ << dec
+ << " toSend op: " << (unsigned)toSend.operation() << '\n'
+ << " response msgid:" << (unsigned)response.header()->id << '\n'
+ << " response len: " << (unsigned)response.header()->len << '\n'
+ << " response op: " << response.operation() << '\n'
+ << " remote: " << remoteString() << endl;
+ assert(false);
+ response.reset();
+ }
+ mmm( log() << "*call() end" << endl; )
+ return true;
+ }
+
+ void MessagingPort::say(Message& toSend, int responseTo) {
+ assert( !toSend.empty() );
+ mmm( log() << "* say() sock:" << this->sock << " thr:" << GetCurrentThreadId() << endl; )
+ toSend.header()->id = nextMessageId();
+ toSend.header()->responseTo = responseTo;
+
+ if ( piggyBackData && piggyBackData->len() ) {
+ mmm( log() << "* have piggy back" << endl; )
+ if ( ( piggyBackData->len() + toSend.header()->len ) > 1300 ) {
+ // won't fit in a packet - so just send it off
+ piggyBackData->flush();
+ }
+ else {
+ piggyBackData->append( toSend );
+ piggyBackData->flush();
+ return;
+ }
+ }
+
+ toSend.send( *this, "say" );
+ }
+
+ void MessagingPort::piggyBack( Message& toSend , int responseTo ) {
+
+ if ( toSend.header()->len > 1300 ) {
+ // not worth saving because its almost an entire packet
+ say( toSend );
+ return;
+ }
+
+ // we're going to be storing this, so need to set it up
+ toSend.header()->id = nextMessageId();
+ toSend.header()->responseTo = responseTo;
+
+ if ( ! piggyBackData )
+ piggyBackData = new PiggyBackData( this );
+
+ piggyBackData->append( toSend );
+ }
+
+ HostAndPort MessagingPort::remote() const {
+ if ( ! _remoteParsed.hasPort() )
+ _remoteParsed = HostAndPort( remoteAddr() );
+ return _remoteParsed;
+ }
+
+
+} // namespace mongo
diff --git a/util/net/message_port.h b/util/net/message_port.h
new file mode 100644
index 0000000..22ecafe
--- /dev/null
+++ b/util/net/message_port.h
@@ -0,0 +1,107 @@
+// message_port.h
+
+/* Copyright 2009 10gen Inc.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#pragma once
+
+#include "sock.h"
+#include "message.h"
+
+namespace mongo {
+
+ class MessagingPort;
+ class PiggyBackData;
+
+ typedef AtomicUInt MSGID;
+
+ class AbstractMessagingPort : boost::noncopyable {
+ public:
+ AbstractMessagingPort() : tag(0) {}
+ virtual ~AbstractMessagingPort() { }
+ virtual void reply(Message& received, Message& response, MSGID responseTo) = 0; // like the reply below, but doesn't rely on received.data still being available
+ virtual void reply(Message& received, Message& response) = 0;
+
+ virtual HostAndPort remote() const = 0;
+ virtual unsigned remotePort() const = 0;
+
+ private:
+
+ public:
+ // TODO make this private with some helpers
+
+ /* ports can be tagged with various classes. see closeAllSockets(tag). defaults to 0. */
+ unsigned tag;
+
+ };
+
+ class MessagingPort : public AbstractMessagingPort , public Socket {
+ public:
+ MessagingPort(int fd, const SockAddr& remote);
+
+ // in some cases the timeout will actually be 2x this value - eg we do a partial send,
+ // then the timeout fires, then we try to send again, then the timeout fires again with
+ // no data sent, then we detect that the other side is down
+ MessagingPort(double so_timeout = 0, int logLevel = 0 );
+
+ MessagingPort(Socket& socket);
+
+ virtual ~MessagingPort();
+
+ void shutdown();
+
+ /* it's assumed if you reuse a message object, that it doesn't cross MessagingPort's.
+ also, the Message data will go out of scope on the subsequent recv call.
+ */
+ bool recv(Message& m);
+ void reply(Message& received, Message& response, MSGID responseTo);
+ void reply(Message& received, Message& response);
+ bool call(Message& toSend, Message& response);
+
+ void say(Message& toSend, int responseTo = -1);
+
+ /**
+ * this is used for doing 'async' queries
+ * instead of doing call( to , from )
+ * you would do
+ * say( to )
+ * recv( from )
+ * Note: if you fail to call recv and someone else uses this port,
+ * horrible things will happend
+ */
+ bool recv( const Message& sent , Message& response );
+
+ void piggyBack( Message& toSend , int responseTo = -1 );
+
+ unsigned remotePort() const { return Socket::remotePort(); }
+ virtual HostAndPort remote() const;
+
+
+ private:
+
+ PiggyBackData * piggyBackData;
+
+ // this is the parsed version of remote
+ // mutable because its initialized only on call to remote()
+ mutable HostAndPort _remoteParsed;
+
+ public:
+ static void closeAllSockets(unsigned tagMask = 0xffffffff);
+
+ friend class PiggyBackData;
+ };
+
+
+} // namespace mongo
diff --git a/util/net/message_server.h b/util/net/message_server.h
new file mode 100644
index 0000000..ae77b97
--- /dev/null
+++ b/util/net/message_server.h
@@ -0,0 +1,66 @@
+// message_server.h
+
+/* Copyright 2009 10gen Inc.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+/*
+ abstract database server
+ async io core, worker thread system
+ */
+
+#pragma once
+
+#include "../../pch.h"
+
+namespace mongo {
+
+ class MessageHandler {
+ public:
+ virtual ~MessageHandler() {}
+
+ /**
+ * called once when a socket is connected
+ */
+ virtual void connected( AbstractMessagingPort* p ) = 0;
+
+ /**
+ * called every time a message comes in
+ * handler is responsible for responding to client
+ */
+ virtual void process( Message& m , AbstractMessagingPort* p , LastError * err ) = 0;
+
+ /**
+ * called once when a socket is disconnected
+ */
+ virtual void disconnected( AbstractMessagingPort* p ) = 0;
+ };
+
+ class MessageServer {
+ public:
+ struct Options {
+ int port; // port to bind to
+ string ipList; // addresses to bind to
+
+ Options() : port(0), ipList("") {}
+ };
+
+ virtual ~MessageServer() {}
+ virtual void run() = 0;
+ virtual void setAsTimeTracker() = 0;
+ };
+
+ // TODO use a factory here to decide between port and asio variations
+ MessageServer * createServer( const MessageServer::Options& opts , MessageHandler * handler );
+}
diff --git a/util/net/message_server_asio.cpp b/util/net/message_server_asio.cpp
new file mode 100644
index 0000000..0c6a7d9
--- /dev/null
+++ b/util/net/message_server_asio.cpp
@@ -0,0 +1,261 @@
+// message_server_asio.cpp
+
+/* Copyright 2009 10gen Inc.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifdef USE_ASIO
+
+#include <boost/asio.hpp>
+#include <boost/bind.hpp>
+#include <boost/enable_shared_from_this.hpp>
+#include <boost/shared_ptr.hpp>
+
+#include <iostream>
+#include <vector>
+
+#include "message.h"
+#include "message_server.h"
+#include "../util/concurrency/mvar.h"
+
+using namespace boost;
+using namespace boost::asio;
+using namespace boost::asio::ip;
+
+namespace mongo {
+ class MessageServerSession;
+
+ namespace {
+ class StickyThread {
+ public:
+ StickyThread()
+ : _thread(boost::ref(*this))
+ {}
+
+ ~StickyThread() {
+ _mss.put(boost::shared_ptr<MessageServerSession>());
+ _thread.join();
+ }
+
+ void ready(boost::shared_ptr<MessageServerSession> mss) {
+ _mss.put(mss);
+ }
+
+ void operator() () {
+ boost::shared_ptr<MessageServerSession> mss;
+ while((mss = _mss.take())) { // intentionally not using ==
+ task(mss.get());
+ mss.reset();
+ }
+ }
+
+ private:
+ boost::thread _thread;
+ inline void task(MessageServerSession* mss); // must be defined after MessageServerSession
+
+ MVar<boost::shared_ptr<MessageServerSession> > _mss; // populated when given a task
+ };
+
+ vector<boost::shared_ptr<StickyThread> > thread_pool;
+ mongo::mutex tp_mutex; // this is only needed if io_service::run() is called from multiple threads
+ }
+
+ class MessageServerSession : public boost::enable_shared_from_this<MessageServerSession> , public AbstractMessagingPort {
+ public:
+ MessageServerSession( MessageHandler * handler , io_service& ioservice )
+ : _handler( handler )
+ , _socket( ioservice )
+ , _portCache(0)
+ { }
+
+ ~MessageServerSession() {
+ cout << "disconnect from: " << _socket.remote_endpoint() << endl;
+ }
+
+ tcp::socket& socket() {
+ return _socket;
+ }
+
+ void start() {
+ cout << "MessageServerSession start from:" << _socket.remote_endpoint() << endl;
+ _startHeaderRead();
+ }
+
+ void handleReadHeader( const boost::system::error_code& error ) {
+ if ( _inHeader.len == 0 )
+ return;
+
+ if ( ! _inHeader.valid() ) {
+ cout << " got invalid header from: " << _socket.remote_endpoint() << " closing connected" << endl;
+ return;
+ }
+
+ char * raw = (char*)malloc( _inHeader.len );
+
+ MsgData * data = (MsgData*)raw;
+ memcpy( data , &_inHeader , sizeof( _inHeader ) );
+ assert( data->len == _inHeader.len );
+
+ uassert( 10273 , "_cur not empty! pipelining requests not supported" , ! _cur.data );
+
+ _cur.setData( data , true );
+ async_read( _socket ,
+ buffer( raw + sizeof( _inHeader ) , _inHeader.len - sizeof( _inHeader ) ) ,
+ boost::bind( &MessageServerSession::handleReadBody , shared_from_this() , boost::asio::placeholders::error ) );
+ }
+
+ void handleReadBody( const boost::system::error_code& error ) {
+ if (!_myThread) {
+ mongo::mutex::scoped_lock(tp_mutex);
+ if (!thread_pool.empty()) {
+ _myThread = thread_pool.back();
+ thread_pool.pop_back();
+ }
+ }
+
+ if (!_myThread) // pool is empty
+ _myThread.reset(new StickyThread());
+
+ assert(_myThread);
+
+ _myThread->ready(shared_from_this());
+ }
+
+ void process() {
+ _handler->process( _cur , this );
+
+ if (_reply.data) {
+ async_write( _socket ,
+ buffer( (char*)_reply.data , _reply.data->len ) ,
+ boost::bind( &MessageServerSession::handleWriteDone , shared_from_this() , boost::asio::placeholders::error ) );
+ }
+ else {
+ _cur.reset();
+ _startHeaderRead();
+ }
+ }
+
+ void handleWriteDone( const boost::system::error_code& error ) {
+ {
+ // return thread to pool after we have sent data to the client
+ mongo::mutex::scoped_lock(tp_mutex);
+ assert(_myThread);
+ thread_pool.push_back(_myThread);
+ _myThread.reset();
+ }
+ _cur.reset();
+ _reply.reset();
+ _startHeaderRead();
+ }
+
+ virtual void reply( Message& received, Message& response ) {
+ reply( received , response , received.data->id );
+ }
+
+ virtual void reply( Message& query , Message& toSend, MSGID responseTo ) {
+ _reply = toSend;
+
+ _reply.data->id = nextMessageId();
+ _reply.data->responseTo = responseTo;
+ uassert( 10274 , "pipelining requests doesn't work yet" , query.data->id == _cur.data->id );
+ }
+
+
+ virtual unsigned remotePort() {
+ if (!_portCache)
+ _portCache = _socket.remote_endpoint().port(); //this is expensive
+ return _portCache;
+ }
+
+ private:
+
+ void _startHeaderRead() {
+ _inHeader.len = 0;
+ async_read( _socket ,
+ buffer( &_inHeader , sizeof( _inHeader ) ) ,
+ boost::bind( &MessageServerSession::handleReadHeader , shared_from_this() , boost::asio::placeholders::error ) );
+ }
+
+ MessageHandler * _handler;
+ tcp::socket _socket;
+ MsgData _inHeader;
+ Message _cur;
+ Message _reply;
+
+ unsigned _portCache;
+
+ boost::shared_ptr<StickyThread> _myThread;
+ };
+
+ void StickyThread::task(MessageServerSession* mss) {
+ mss->process();
+ }
+
+
+ class AsyncMessageServer : public MessageServer {
+ public:
+ // TODO accept an IP address to bind to
+ AsyncMessageServer( const MessageServer::Options& opts , MessageHandler * handler )
+ : _port( opts.port )
+ , _handler(handler)
+ , _endpoint( tcp::v4() , opts.port )
+ , _acceptor( _ioservice , _endpoint ) {
+ _accept();
+ }
+ virtual ~AsyncMessageServer() {
+
+ }
+
+ void run() {
+ cout << "AsyncMessageServer starting to listen on: " << _port << endl;
+ boost::thread other(boost::bind(&io_service::run, &_ioservice));
+ _ioservice.run();
+ cout << "AsyncMessageServer done listening on: " << _port << endl;
+ }
+
+ void handleAccept( shared_ptr<MessageServerSession> session ,
+ const boost::system::error_code& error ) {
+ if ( error ) {
+ cout << "handleAccept error!" << endl;
+ return;
+ }
+ session->start();
+ _accept();
+ }
+
+ void _accept( ) {
+ shared_ptr<MessageServerSession> session( new MessageServerSession( _handler , _ioservice ) );
+ _acceptor.async_accept( session->socket() ,
+ boost::bind( &AsyncMessageServer::handleAccept,
+ this,
+ session,
+ boost::asio::placeholders::error )
+ );
+ }
+
+ private:
+ int _port;
+ MessageHandler * _handler;
+ io_service _ioservice;
+ tcp::endpoint _endpoint;
+ tcp::acceptor _acceptor;
+ };
+
+ MessageServer * createServer( const MessageServer::Options& opts , MessageHandler * handler ) {
+ return new AsyncMessageServer( opts , handler );
+ }
+
+}
+
+#endif
diff --git a/util/net/message_server_port.cpp b/util/net/message_server_port.cpp
new file mode 100644
index 0000000..ca0b13d
--- /dev/null
+++ b/util/net/message_server_port.cpp
@@ -0,0 +1,197 @@
+// message_server_port.cpp
+
+/* Copyright 2009 10gen Inc.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "pch.h"
+
+#ifndef USE_ASIO
+
+#include "message.h"
+#include "message_port.h"
+#include "message_server.h"
+#include "listen.h"
+
+#include "../../db/cmdline.h"
+#include "../../db/lasterror.h"
+#include "../../db/stats/counters.h"
+
+#ifdef __linux__ // TODO: consider making this ifndef _WIN32
+# include <sys/resource.h>
+#endif
+
+namespace mongo {
+
+ namespace pms {
+
+ MessageHandler * handler;
+
+ void threadRun( MessagingPort * inPort) {
+ TicketHolderReleaser connTicketReleaser( &connTicketHolder );
+
+ setThreadName( "conn" );
+
+ assert( inPort );
+ inPort->setLogLevel(1);
+ scoped_ptr<MessagingPort> p( inPort );
+
+ p->postFork();
+
+ string otherSide;
+
+ Message m;
+ try {
+ LastError * le = new LastError();
+ lastError.reset( le ); // lastError now has ownership
+
+ otherSide = p->remoteString();
+
+ handler->connected( p.get() );
+
+ while ( ! inShutdown() ) {
+ m.reset();
+ p->clearCounters();
+
+ if ( ! p->recv(m) ) {
+ if( !cmdLine.quiet )
+ log() << "end connection " << otherSide << endl;
+ p->shutdown();
+ break;
+ }
+
+ handler->process( m , p.get() , le );
+ networkCounter.hit( p->getBytesIn() , p->getBytesOut() );
+ }
+ }
+ catch ( AssertionException& e ) {
+ log() << "AssertionException handling request, closing client connection: " << e << endl;
+ p->shutdown();
+ }
+ catch ( SocketException& e ) {
+ log() << "SocketException handling request, closing client connection: " << e << endl;
+ p->shutdown();
+ }
+ catch ( const ClockSkewException & ) {
+ log() << "ClockSkewException - shutting down" << endl;
+ exitCleanly( EXIT_CLOCK_SKEW );
+ }
+ catch ( std::exception &e ) {
+ error() << "Uncaught std::exception: " << e.what() << ", terminating" << endl;
+ dbexit( EXIT_UNCAUGHT );
+ }
+ catch ( ... ) {
+ error() << "Uncaught exception, terminating" << endl;
+ dbexit( EXIT_UNCAUGHT );
+ }
+
+ handler->disconnected( p.get() );
+ }
+
+ }
+
+ class PortMessageServer : public MessageServer , public Listener {
+ public:
+ PortMessageServer( const MessageServer::Options& opts, MessageHandler * handler ) :
+ Listener( "" , opts.ipList, opts.port ) {
+
+ uassert( 10275 , "multiple PortMessageServer not supported" , ! pms::handler );
+ pms::handler = handler;
+ }
+
+ virtual void accepted(MessagingPort * p) {
+
+ if ( ! connTicketHolder.tryAcquire() ) {
+ log() << "connection refused because too many open connections: " << connTicketHolder.used() << endl;
+
+ // TODO: would be nice if we notified them...
+ p->shutdown();
+ delete p;
+
+ sleepmillis(2); // otherwise we'll hard loop
+ return;
+ }
+
+ try {
+#ifndef __linux__ // TODO: consider making this ifdef _WIN32
+ boost::thread thr( boost::bind( &pms::threadRun , p ) );
+#else
+ pthread_attr_t attrs;
+ pthread_attr_init(&attrs);
+ pthread_attr_setdetachstate(&attrs, PTHREAD_CREATE_DETACHED);
+
+ static const size_t STACK_SIZE = 1024*1024; // if we change this we need to update the warning
+
+ struct rlimit limits;
+ verify(15887, getrlimit(RLIMIT_STACK, &limits) == 0);
+ if (limits.rlim_cur > STACK_SIZE) {
+ pthread_attr_setstacksize(&attrs, (DEBUG_BUILD
+ ? (STACK_SIZE / 2)
+ : STACK_SIZE));
+ } else if (limits.rlim_cur < 1024*1024) {
+ warning() << "Stack size set to " << (limits.rlim_cur/1024) << "KB. We suggest 1MB" << endl;
+ }
+
+
+ pthread_t thread;
+ int failed = pthread_create(&thread, &attrs, (void*(*)(void*)) &pms::threadRun, p);
+
+ pthread_attr_destroy(&attrs);
+
+ if (failed) {
+ log() << "pthread_create failed: " << errnoWithDescription(failed) << endl;
+ throw boost::thread_resource_error(); // for consistency with boost::thread
+ }
+#endif
+ }
+ catch ( boost::thread_resource_error& ) {
+ connTicketHolder.release();
+ log() << "can't create new thread, closing connection" << endl;
+
+ p->shutdown();
+ delete p;
+
+ sleepmillis(2);
+ }
+ catch ( ... ) {
+ connTicketHolder.release();
+ log() << "unknown error accepting new socket" << endl;
+
+ p->shutdown();
+ delete p;
+
+ sleepmillis(2);
+ }
+
+ }
+
+ virtual void setAsTimeTracker() {
+ Listener::setAsTimeTracker();
+ }
+
+ void run() {
+ initAndListen();
+ }
+
+ virtual bool useUnixSockets() const { return true; }
+ };
+
+
+ MessageServer * createServer( const MessageServer::Options& opts , MessageHandler * handler ) {
+ return new PortMessageServer( opts , handler );
+ }
+
+}
+
+#endif
diff --git a/util/net/miniwebserver.cpp b/util/net/miniwebserver.cpp
new file mode 100644
index 0000000..0793100
--- /dev/null
+++ b/util/net/miniwebserver.cpp
@@ -0,0 +1,207 @@
+// miniwebserver.cpp
+
+/* Copyright 2009 10gen Inc.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "pch.h"
+#include "miniwebserver.h"
+#include "../hex.h"
+
+#include "pcrecpp.h"
+
+namespace mongo {
+
+ MiniWebServer::MiniWebServer(const string& name, const string &ip, int port)
+ : Listener(name, ip, port, false)
+ {}
+
+ string MiniWebServer::parseURL( const char * buf ) {
+ const char * urlStart = strchr( buf , ' ' );
+ if ( ! urlStart )
+ return "/";
+
+ urlStart++;
+
+ const char * end = strchr( urlStart , ' ' );
+ if ( ! end ) {
+ end = strchr( urlStart , '\r' );
+ if ( ! end ) {
+ end = strchr( urlStart , '\n' );
+ }
+ }
+
+ if ( ! end )
+ return "/";
+
+ int diff = (int)(end-urlStart);
+ if ( diff < 0 || diff > 255 )
+ return "/";
+
+ return string( urlStart , (int)(end-urlStart) );
+ }
+
+ void MiniWebServer::parseParams( BSONObj & params , string query ) {
+ if ( query.size() == 0 )
+ return;
+
+ BSONObjBuilder b;
+ while ( query.size() ) {
+
+ string::size_type amp = query.find( "&" );
+
+ string cur;
+ if ( amp == string::npos ) {
+ cur = query;
+ query = "";
+ }
+ else {
+ cur = query.substr( 0 , amp );
+ query = query.substr( amp + 1 );
+ }
+
+ string::size_type eq = cur.find( "=" );
+ if ( eq == string::npos )
+ continue;
+
+ b.append( urlDecode(cur.substr(0,eq)) , urlDecode(cur.substr(eq+1) ) );
+ }
+
+ params = b.obj();
+ }
+
+ string MiniWebServer::parseMethod( const char * headers ) {
+ const char * end = strchr( headers , ' ' );
+ if ( ! end )
+ return "GET";
+ return string( headers , (int)(end-headers) );
+ }
+
+ const char *MiniWebServer::body( const char *buf ) {
+ const char *ret = strstr( buf, "\r\n\r\n" );
+ return ret ? ret + 4 : ret;
+ }
+
+ bool MiniWebServer::fullReceive( const char *buf ) {
+ const char *bod = body( buf );
+ if ( !bod )
+ return false;
+ const char *lenString = "Content-Length:";
+ const char *lengthLoc = strstr( buf, lenString );
+ if ( !lengthLoc )
+ return true;
+ lengthLoc += strlen( lenString );
+ long len = strtol( lengthLoc, 0, 10 );
+ if ( long( strlen( bod ) ) == len )
+ return true;
+ return false;
+ }
+
+ void MiniWebServer::accepted(Socket sock) {
+ sock.postFork();
+ sock.setTimeout(8);
+ char buf[4096];
+ int len = 0;
+ while ( 1 ) {
+ int left = sizeof(buf) - 1 - len;
+ if( left == 0 )
+ break;
+ int x = sock.unsafe_recv( buf + len , left );
+ if ( x <= 0 ) {
+ sock.close();
+ return;
+ }
+ len += x;
+ buf[ len ] = 0;
+ if ( fullReceive( buf ) ) {
+ break;
+ }
+ }
+ buf[len] = 0;
+
+ string responseMsg;
+ int responseCode = 599;
+ vector<string> headers;
+
+ try {
+ doRequest(buf, parseURL( buf ), responseMsg, responseCode, headers, sock.remoteAddr() );
+ }
+ catch ( std::exception& e ) {
+ responseCode = 500;
+ responseMsg = "error loading page: ";
+ responseMsg += e.what();
+ }
+ catch ( ... ) {
+ responseCode = 500;
+ responseMsg = "unknown error loading page";
+ }
+
+ stringstream ss;
+ ss << "HTTP/1.0 " << responseCode;
+ if ( responseCode == 200 ) ss << " OK";
+ ss << "\r\n";
+ if ( headers.empty() ) {
+ ss << "Content-Type: text/html\r\n";
+ }
+ else {
+ for ( vector<string>::iterator i = headers.begin(); i != headers.end(); i++ ) {
+ assert( strncmp("Content-Length", i->c_str(), 14) );
+ ss << *i << "\r\n";
+ }
+ }
+ ss << "Connection: close\r\n";
+ ss << "Content-Length: " << responseMsg.size() << "\r\n";
+ ss << "\r\n";
+ ss << responseMsg;
+ string response = ss.str();
+
+ sock.send( response.c_str(), response.size() , "http response" );
+ sock.close();
+ }
+
+ string MiniWebServer::getHeader( const char * req , string wanted ) {
+ const char * headers = strchr( req , '\n' );
+ if ( ! headers )
+ return "";
+ pcrecpp::StringPiece input( headers + 1 );
+
+ string name;
+ string val;
+ pcrecpp::RE re("([\\w\\-]+): (.*?)\r?\n");
+ while ( re.Consume( &input, &name, &val) ) {
+ if ( name == wanted )
+ return val;
+ }
+ return "";
+ }
+
+ string MiniWebServer::urlDecode(const char* s) {
+ stringstream out;
+ while(*s) {
+ if (*s == '+') {
+ out << ' ';
+ }
+ else if (*s == '%') {
+ out << fromHex(s+1);
+ s+=2;
+ }
+ else {
+ out << *s;
+ }
+ s++;
+ }
+ return out.str();
+ }
+
+} // namespace mongo
diff --git a/util/net/miniwebserver.h b/util/net/miniwebserver.h
new file mode 100644
index 0000000..1fb6b3f
--- /dev/null
+++ b/util/net/miniwebserver.h
@@ -0,0 +1,60 @@
+// miniwebserver.h
+
+/* Copyright 2009 10gen Inc.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#pragma once
+
+#include "../../pch.h"
+#include "message.h"
+#include "message_port.h"
+#include "listen.h"
+#include "../../db/jsobj.h"
+
+namespace mongo {
+
+ class MiniWebServer : public Listener {
+ public:
+ MiniWebServer(const string& name, const string &ip, int _port);
+ virtual ~MiniWebServer() {}
+
+ virtual void doRequest(
+ const char *rq, // the full request
+ string url,
+ // set these and return them:
+ string& responseMsg,
+ int& responseCode,
+ vector<string>& headers, // if completely empty, content-type: text/html will be added
+ const SockAddr &from
+ ) = 0;
+
+ // --- static helpers ----
+
+ static void parseParams( BSONObj & params , string query );
+
+ static string parseURL( const char * buf );
+ static string parseMethod( const char * headers );
+ static string getHeader( const char * headers , string name );
+ static const char *body( const char *buf );
+
+ static string urlDecode(const char* s);
+ static string urlDecode(string s) {return urlDecode(s.c_str());}
+
+ private:
+ void accepted(Socket socket);
+ static bool fullReceive( const char *buf );
+ };
+
+} // namespace mongo
diff --git a/util/net/sock.cpp b/util/net/sock.cpp
new file mode 100644
index 0000000..69c42f2
--- /dev/null
+++ b/util/net/sock.cpp
@@ -0,0 +1,713 @@
+// @file sock.cpp
+
+/* Copyright 2009 10gen Inc.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "pch.h"
+#include "sock.h"
+#include "../background.h"
+
+#if !defined(_WIN32)
+# include <sys/socket.h>
+# include <sys/types.h>
+# include <sys/socket.h>
+# include <sys/un.h>
+# include <netinet/in.h>
+# include <netinet/tcp.h>
+# include <arpa/inet.h>
+# include <errno.h>
+# include <netdb.h>
+# if defined(__openbsd__)
+# include <sys/uio.h>
+# endif
+#endif
+
+#ifdef MONGO_SSL
+#include <openssl/err.h>
+#include <openssl/ssl.h>
+#endif
+
+
+namespace mongo {
+
+ static bool ipv6 = false;
+ void enableIPv6(bool state) { ipv6 = state; }
+ bool IPv6Enabled() { return ipv6; }
+
+ void setSockTimeouts(int sock, double secs) {
+ struct timeval tv;
+ tv.tv_sec = (int)secs;
+ tv.tv_usec = (int)((long long)(secs*1000*1000) % (1000*1000));
+ bool report = logLevel > 3; // solaris doesn't provide these
+ DEV report = true;
+ bool ok = setsockopt(sock, SOL_SOCKET, SO_RCVTIMEO, (char *) &tv, sizeof(tv) ) == 0;
+ if( report && !ok ) log() << "unabled to set SO_RCVTIMEO" << endl;
+ ok = setsockopt(sock, SOL_SOCKET, SO_SNDTIMEO, (char *) &tv, sizeof(tv) ) == 0;
+ DEV if( report && !ok ) log() << "unabled to set SO_RCVTIMEO" << endl;
+ }
+
+#if defined(_WIN32)
+ void disableNagle(int sock) {
+ int x = 1;
+ if ( setsockopt(sock, IPPROTO_TCP, TCP_NODELAY, (char *) &x, sizeof(x)) )
+ error() << "disableNagle failed" << endl;
+ if ( setsockopt(sock, SOL_SOCKET, SO_KEEPALIVE, (char *) &x, sizeof(x)) )
+ error() << "SO_KEEPALIVE failed" << endl;
+ }
+#else
+
+ void disableNagle(int sock) {
+ int x = 1;
+
+#ifdef SOL_TCP
+ int level = SOL_TCP;
+#else
+ int level = SOL_SOCKET;
+#endif
+
+ if ( setsockopt(sock, level, TCP_NODELAY, (char *) &x, sizeof(x)) )
+ error() << "disableNagle failed: " << errnoWithDescription() << endl;
+
+#ifdef SO_KEEPALIVE
+ if ( setsockopt(sock, SOL_SOCKET, SO_KEEPALIVE, (char *) &x, sizeof(x)) )
+ error() << "SO_KEEPALIVE failed: " << errnoWithDescription() << endl;
+
+# ifdef __linux__
+ socklen_t len = sizeof(x);
+ if ( getsockopt(sock, level, TCP_KEEPIDLE, (char *) &x, &len) )
+ error() << "can't get TCP_KEEPIDLE: " << errnoWithDescription() << endl;
+
+ if (x > 300) {
+ x = 300;
+ if ( setsockopt(sock, level, TCP_KEEPIDLE, (char *) &x, sizeof(x)) ) {
+ error() << "can't set TCP_KEEPIDLE: " << errnoWithDescription() << endl;
+ }
+ }
+
+ len = sizeof(x); // just in case it changed
+ if ( getsockopt(sock, level, TCP_KEEPINTVL, (char *) &x, &len) )
+ error() << "can't get TCP_KEEPINTVL: " << errnoWithDescription() << endl;
+
+ if (x > 300) {
+ x = 300;
+ if ( setsockopt(sock, level, TCP_KEEPINTVL, (char *) &x, sizeof(x)) ) {
+ error() << "can't set TCP_KEEPINTVL: " << errnoWithDescription() << endl;
+ }
+ }
+# endif
+#endif
+
+ }
+
+#endif
+
+ string getAddrInfoStrError(int code) {
+#if !defined(_WIN32)
+ return gai_strerror(code);
+#else
+ /* gai_strerrorA is not threadsafe on windows. don't use it. */
+ return errnoWithDescription(code);
+#endif
+ }
+
+
+ // --- SockAddr
+
+ SockAddr::SockAddr(int sourcePort) {
+ memset(as<sockaddr_in>().sin_zero, 0, sizeof(as<sockaddr_in>().sin_zero));
+ as<sockaddr_in>().sin_family = AF_INET;
+ as<sockaddr_in>().sin_port = htons(sourcePort);
+ as<sockaddr_in>().sin_addr.s_addr = htonl(INADDR_ANY);
+ addressSize = sizeof(sockaddr_in);
+ }
+
+ SockAddr::SockAddr(const char * iporhost , int port) {
+ if (!strcmp(iporhost, "localhost"))
+ iporhost = "127.0.0.1";
+
+ if (strchr(iporhost, '/')) {
+#ifdef _WIN32
+ uassert(13080, "no unix socket support on windows", false);
+#endif
+ uassert(13079, "path to unix socket too long", strlen(iporhost) < sizeof(as<sockaddr_un>().sun_path));
+ as<sockaddr_un>().sun_family = AF_UNIX;
+ strcpy(as<sockaddr_un>().sun_path, iporhost);
+ addressSize = sizeof(sockaddr_un);
+ }
+ else {
+ addrinfo* addrs = NULL;
+ addrinfo hints;
+ memset(&hints, 0, sizeof(addrinfo));
+ hints.ai_socktype = SOCK_STREAM;
+ //hints.ai_flags = AI_ADDRCONFIG; // This is often recommended but don't do it. SERVER-1579
+ hints.ai_flags |= AI_NUMERICHOST; // first pass tries w/o DNS lookup
+ hints.ai_family = (IPv6Enabled() ? AF_UNSPEC : AF_INET);
+
+ StringBuilder ss;
+ ss << port;
+ int ret = getaddrinfo(iporhost, ss.str().c_str(), &hints, &addrs);
+
+ // old C compilers on IPv6-capable hosts return EAI_NODATA error
+#ifdef EAI_NODATA
+ int nodata = (ret == EAI_NODATA);
+#else
+ int nodata = false;
+#endif
+ if (ret == EAI_NONAME || nodata) {
+ // iporhost isn't an IP address, allow DNS lookup
+ hints.ai_flags &= ~AI_NUMERICHOST;
+ ret = getaddrinfo(iporhost, ss.str().c_str(), &hints, &addrs);
+ }
+
+ if (ret) {
+ // don't log if this as it is a CRT construction and log() may not work yet.
+ if( strcmp("0.0.0.0", iporhost) ) {
+ log() << "getaddrinfo(\"" << iporhost << "\") failed: " << gai_strerror(ret) << endl;
+ }
+ *this = SockAddr(port);
+ }
+ else {
+ //TODO: handle other addresses in linked list;
+ assert(addrs->ai_addrlen <= sizeof(sa));
+ memcpy(&sa, addrs->ai_addr, addrs->ai_addrlen);
+ addressSize = addrs->ai_addrlen;
+ freeaddrinfo(addrs);
+ }
+ }
+ }
+
+ bool SockAddr::isLocalHost() const {
+ switch (getType()) {
+ case AF_INET: return getAddr() == "127.0.0.1";
+ case AF_INET6: return getAddr() == "::1";
+ case AF_UNIX: return true;
+ default: return false;
+ }
+ assert(false);
+ return false;
+ }
+
+ string SockAddr::toString(bool includePort) const {
+ string out = getAddr();
+ if (includePort && getType() != AF_UNIX && getType() != AF_UNSPEC)
+ out += mongoutils::str::stream() << ':' << getPort();
+ return out;
+ }
+
+ sa_family_t SockAddr::getType() const {
+ return sa.ss_family;
+ }
+
+ unsigned SockAddr::getPort() const {
+ switch (getType()) {
+ case AF_INET: return ntohs(as<sockaddr_in>().sin_port);
+ case AF_INET6: return ntohs(as<sockaddr_in6>().sin6_port);
+ case AF_UNIX: return 0;
+ case AF_UNSPEC: return 0;
+ default: massert(SOCK_FAMILY_UNKNOWN_ERROR, "unsupported address family", false); return 0;
+ }
+ }
+
+ string SockAddr::getAddr() const {
+ switch (getType()) {
+ case AF_INET:
+ case AF_INET6: {
+ const int buflen=128;
+ char buffer[buflen];
+ int ret = getnameinfo(raw(), addressSize, buffer, buflen, NULL, 0, NI_NUMERICHOST);
+ massert(13082, getAddrInfoStrError(ret), ret == 0);
+ return buffer;
+ }
+
+ case AF_UNIX: return (addressSize > 2 ? as<sockaddr_un>().sun_path : "anonymous unix socket");
+ case AF_UNSPEC: return "(NONE)";
+ default: massert(SOCK_FAMILY_UNKNOWN_ERROR, "unsupported address family", false); return "";
+ }
+ }
+
+ bool SockAddr::operator==(const SockAddr& r) const {
+ if (getType() != r.getType())
+ return false;
+
+ if (getPort() != r.getPort())
+ return false;
+
+ switch (getType()) {
+ case AF_INET: return as<sockaddr_in>().sin_addr.s_addr == r.as<sockaddr_in>().sin_addr.s_addr;
+ case AF_INET6: return memcmp(as<sockaddr_in6>().sin6_addr.s6_addr, r.as<sockaddr_in6>().sin6_addr.s6_addr, sizeof(in6_addr)) == 0;
+ case AF_UNIX: return strcmp(as<sockaddr_un>().sun_path, r.as<sockaddr_un>().sun_path) == 0;
+ case AF_UNSPEC: return true; // assume all unspecified addresses are the same
+ default: massert(SOCK_FAMILY_UNKNOWN_ERROR, "unsupported address family", false);
+ }
+ return false;
+ }
+
+ bool SockAddr::operator!=(const SockAddr& r) const {
+ return !(*this == r);
+ }
+
+ bool SockAddr::operator<(const SockAddr& r) const {
+ if (getType() < r.getType())
+ return true;
+ else if (getType() > r.getType())
+ return false;
+
+ if (getPort() < r.getPort())
+ return true;
+ else if (getPort() > r.getPort())
+ return false;
+
+ switch (getType()) {
+ case AF_INET: return as<sockaddr_in>().sin_addr.s_addr < r.as<sockaddr_in>().sin_addr.s_addr;
+ case AF_INET6: return memcmp(as<sockaddr_in6>().sin6_addr.s6_addr, r.as<sockaddr_in6>().sin6_addr.s6_addr, sizeof(in6_addr)) < 0;
+ case AF_UNIX: return strcmp(as<sockaddr_un>().sun_path, r.as<sockaddr_un>().sun_path) < 0;
+ case AF_UNSPEC: return false;
+ default: massert(SOCK_FAMILY_UNKNOWN_ERROR, "unsupported address family", false);
+ }
+ return false;
+ }
+
+ SockAddr unknownAddress( "0.0.0.0", 0 );
+
+ // ------ hostname -------------------
+
+ string hostbyname(const char *hostname) {
+ string addr = SockAddr(hostname, 0).getAddr();
+ if (addr == "0.0.0.0")
+ return "";
+ else
+ return addr;
+ }
+
+ // --- my --
+
+ string getHostName() {
+ char buf[256];
+ int ec = gethostname(buf, 127);
+ if ( ec || *buf == 0 ) {
+ log() << "can't get this server's hostname " << errnoWithDescription() << endl;
+ return "";
+ }
+ return buf;
+ }
+
+
+ string _hostNameCached;
+ static void _hostNameCachedInit() {
+ _hostNameCached = getHostName();
+ }
+ boost::once_flag _hostNameCachedInitFlags = BOOST_ONCE_INIT;
+
+ string getHostNameCached() {
+ boost::call_once( _hostNameCachedInit , _hostNameCachedInitFlags );
+ return _hostNameCached;
+ }
+
+ // --------- SocketException ----------
+
+#ifdef MSG_NOSIGNAL
+ const int portSendFlags = MSG_NOSIGNAL;
+ const int portRecvFlags = MSG_NOSIGNAL;
+#else
+ const int portSendFlags = 0;
+ const int portRecvFlags = 0;
+#endif
+
+ string SocketException::toString() const {
+ stringstream ss;
+ ss << _ei.code << " socket exception [" << _type << "] ";
+
+ if ( _server.size() )
+ ss << "server [" << _server << "] ";
+
+ if ( _extra.size() )
+ ss << _extra;
+
+ return ss.str();
+ }
+
+
+ // ------------ SSLManager -----------------
+
+#ifdef MONGO_SSL
+ SSLManager::SSLManager( bool client ) {
+ _client = client;
+ SSL_library_init();
+ SSL_load_error_strings();
+ ERR_load_crypto_strings();
+
+ _context = SSL_CTX_new( client ? SSLv23_client_method() : SSLv23_server_method() );
+ massert( 15864 , mongoutils::str::stream() << "can't create SSL Context: " << ERR_error_string(ERR_get_error(), NULL) , _context );
+
+ SSL_CTX_set_options( _context, SSL_OP_ALL);
+ }
+
+ void SSLManager::setupPubPriv( const string& privateKeyFile , const string& publicKeyFile ) {
+ massert( 15865 ,
+ mongoutils::str::stream() << "Can't read SSL certificate from file "
+ << publicKeyFile << ":" << ERR_error_string(ERR_get_error(), NULL) ,
+ SSL_CTX_use_certificate_file(_context, publicKeyFile.c_str(), SSL_FILETYPE_PEM) );
+
+
+ massert( 15866 ,
+ mongoutils::str::stream() << "Can't read SSL private key from file "
+ << privateKeyFile << " : " << ERR_error_string(ERR_get_error(), NULL) ,
+ SSL_CTX_use_PrivateKey_file(_context, privateKeyFile.c_str(), SSL_FILETYPE_PEM) );
+ }
+
+
+ int SSLManager::password_cb(char *buf,int num, int rwflag,void *userdata){
+ SSLManager* sm = (SSLManager*)userdata;
+ string pass = sm->_password;
+ strcpy(buf,pass.c_str());
+ return(pass.size());
+ }
+
+ void SSLManager::setupPEM( const string& keyFile , const string& password ) {
+ _password = password;
+
+ massert( 15867 , "Can't read certificate file" , SSL_CTX_use_certificate_chain_file( _context , keyFile.c_str() ) );
+
+ SSL_CTX_set_default_passwd_cb_userdata( _context , this );
+ SSL_CTX_set_default_passwd_cb( _context, &SSLManager::password_cb );
+
+ massert( 15868 , "Can't read key file" , SSL_CTX_use_PrivateKey_file( _context , keyFile.c_str() , SSL_FILETYPE_PEM ) );
+ }
+
+ SSL * SSLManager::secure( int fd ) {
+ SSL * ssl = SSL_new( _context );
+ massert( 15861 , "can't create SSL" , ssl );
+ SSL_set_fd( ssl , fd );
+ return ssl;
+ }
+
+
+#endif
+
+ // ------------ Socket -----------------
+
+ Socket::Socket(int fd , const SockAddr& remote) :
+ _fd(fd), _remote(remote), _timeout(0) {
+ _logLevel = 0;
+ _init();
+ }
+
+ Socket::Socket( double timeout, int ll ) {
+ _logLevel = ll;
+ _fd = -1;
+ _timeout = timeout;
+ _init();
+ }
+
+ void Socket::_init() {
+ _bytesOut = 0;
+ _bytesIn = 0;
+#ifdef MONGO_SSL
+ _sslAccepted = 0;
+#endif
+ }
+
+ void Socket::close() {
+#ifdef MONGO_SSL
+ _ssl.reset();
+#endif
+ if ( _fd >= 0 ) {
+ closesocket( _fd );
+ _fd = -1;
+ }
+ }
+
+#ifdef MONGO_SSL
+ void Socket::secure( SSLManager * ssl ) {
+ assert( ssl );
+ assert( _fd >= 0 );
+ _ssl.reset( ssl->secure( _fd ) );
+ SSL_connect( _ssl.get() );
+ }
+
+ void Socket::secureAccepted( SSLManager * ssl ) {
+ _sslAccepted = ssl;
+ }
+#endif
+
+ void Socket::postFork() {
+#ifdef MONGO_SSL
+ if ( _sslAccepted ) {
+ assert( _fd );
+ _ssl.reset( _sslAccepted->secure( _fd ) );
+ SSL_accept( _ssl.get() );
+ _sslAccepted = 0;
+ }
+#endif
+ }
+
+ class ConnectBG : public BackgroundJob {
+ public:
+ ConnectBG(int sock, SockAddr remote) : _sock(sock), _remote(remote) { }
+
+ void run() { _res = ::connect(_sock, _remote.raw(), _remote.addressSize); }
+ string name() const { return "ConnectBG"; }
+ int inError() const { return _res; }
+
+ private:
+ int _sock;
+ int _res;
+ SockAddr _remote;
+ };
+
+ bool Socket::connect(SockAddr& remote) {
+ _remote = remote;
+
+ _fd = socket(remote.getType(), SOCK_STREAM, 0);
+ if ( _fd == INVALID_SOCKET ) {
+ log(_logLevel) << "ERROR: connect invalid socket " << errnoWithDescription() << endl;
+ return false;
+ }
+
+ if ( _timeout > 0 ) {
+ setTimeout( _timeout );
+ }
+
+ ConnectBG bg(_fd, remote);
+ bg.go();
+ if ( bg.wait(5000) ) {
+ if ( bg.inError() ) {
+ close();
+ return false;
+ }
+ }
+ else {
+ // time out the connect
+ close();
+ bg.wait(); // so bg stays in scope until bg thread terminates
+ return false;
+ }
+
+ if (remote.getType() != AF_UNIX)
+ disableNagle(_fd);
+
+#ifdef SO_NOSIGPIPE
+ // osx
+ const int one = 1;
+ setsockopt( _fd , SOL_SOCKET, SO_NOSIGPIPE, &one, sizeof(int));
+#endif
+
+ return true;
+ }
+
+ int Socket::_send( const char * data , int len ) {
+#ifdef MONGO_SSL
+ if ( _ssl ) {
+ return SSL_write( _ssl.get() , data , len );
+ }
+#endif
+ return ::send( _fd , data , len , portSendFlags );
+ }
+
+ // sends all data or throws an exception
+ void Socket::send( const char * data , int len, const char *context ) {
+ while( len > 0 ) {
+ int ret = _send( data , len );
+ if ( ret == -1 ) {
+
+#ifdef MONGO_SSL
+ if ( _ssl ) {
+ log() << "SSL Error ret: " << ret << " err: " << SSL_get_error( _ssl.get() , ret )
+ << " " << ERR_error_string(ERR_get_error(), NULL)
+ << endl;
+ }
+#endif
+
+#if defined(_WIN32)
+ if ( WSAGetLastError() == WSAETIMEDOUT && _timeout != 0 ) {
+#else
+ if ( ( errno == EAGAIN || errno == EWOULDBLOCK ) && _timeout != 0 ) {
+#endif
+ log(_logLevel) << "Socket " << context << " send() timed out " << _remote.toString() << endl;
+ throw SocketException( SocketException::SEND_TIMEOUT , remoteString() );
+ }
+ else {
+ SocketException::Type t = SocketException::SEND_ERROR;
+ log(_logLevel) << "Socket " << context << " send() "
+ << errnoWithDescription() << ' ' << remoteString() << endl;
+ throw SocketException( t , remoteString() );
+ }
+ }
+ else {
+ _bytesOut += ret;
+
+ assert( ret <= len );
+ len -= ret;
+ data += ret;
+ }
+ }
+ }
+
+ void Socket::_send( const vector< pair< char *, int > > &data, const char *context ) {
+ for( vector< pair< char *, int > >::const_iterator i = data.begin(); i != data.end(); ++i ) {
+ char * data = i->first;
+ int len = i->second;
+ send( data, len, context );
+ }
+ }
+
+ // sends all data or throws an exception
+ void Socket::send( const vector< pair< char *, int > > &data, const char *context ) {
+
+#ifdef MONGO_SSL
+ if ( _ssl ) {
+ _send( data , context );
+ return;
+ }
+#endif
+
+#if defined(_WIN32)
+ // TODO use scatter/gather api
+ _send( data , context );
+#else
+ vector< struct iovec > d( data.size() );
+ int i = 0;
+ for( vector< pair< char *, int > >::const_iterator j = data.begin(); j != data.end(); ++j ) {
+ if ( j->second > 0 ) {
+ d[ i ].iov_base = j->first;
+ d[ i ].iov_len = j->second;
+ ++i;
+ _bytesOut += j->second;
+ }
+ }
+ struct msghdr meta;
+ memset( &meta, 0, sizeof( meta ) );
+ meta.msg_iov = &d[ 0 ];
+ meta.msg_iovlen = d.size();
+
+ while( meta.msg_iovlen > 0 ) {
+ int ret = ::sendmsg( _fd , &meta , portSendFlags );
+ if ( ret == -1 ) {
+ if ( errno != EAGAIN || _timeout == 0 ) {
+ log(_logLevel) << "Socket " << context << " send() " << errnoWithDescription() << ' ' << remoteString() << endl;
+ throw SocketException( SocketException::SEND_ERROR , remoteString() );
+ }
+ else {
+ log(_logLevel) << "Socket " << context << " send() remote timeout " << remoteString() << endl;
+ throw SocketException( SocketException::SEND_TIMEOUT , remoteString() );
+ }
+ }
+ else {
+ struct iovec *& i = meta.msg_iov;
+ while( ret > 0 ) {
+ if ( i->iov_len > unsigned( ret ) ) {
+ i->iov_len -= ret;
+ i->iov_base = (char*)(i->iov_base) + ret;
+ ret = 0;
+ }
+ else {
+ ret -= i->iov_len;
+ ++i;
+ --(meta.msg_iovlen);
+ }
+ }
+ }
+ }
+#endif
+ }
+
+ void Socket::recv( char * buf , int len ) {
+ unsigned retries = 0;
+ while( len > 0 ) {
+ int ret = unsafe_recv( buf , len );
+ if ( ret > 0 ) {
+ if ( len <= 4 && ret != len )
+ log(_logLevel) << "Socket recv() got " << ret << " bytes wanted len=" << len << endl;
+ assert( ret <= len );
+ len -= ret;
+ buf += ret;
+ }
+ else if ( ret == 0 ) {
+ log(3) << "Socket recv() conn closed? " << remoteString() << endl;
+ throw SocketException( SocketException::CLOSED , remoteString() );
+ }
+ else { /* ret < 0 */
+#if defined(_WIN32)
+ int e = WSAGetLastError();
+#else
+ int e = errno;
+# if defined(EINTR)
+ if( e == EINTR ) {
+ if( ++retries == 1 ) {
+ log() << "EINTR retry" << endl;
+ continue;
+ }
+ }
+# endif
+#endif
+ if ( ( e == EAGAIN
+#if defined(_WIN32)
+ || e == WSAETIMEDOUT
+#endif
+ ) && _timeout > 0 )
+ {
+ // this is a timeout
+ log(_logLevel) << "Socket recv() timeout " << remoteString() <<endl;
+ throw SocketException( SocketException::RECV_TIMEOUT, remoteString() );
+ }
+
+ log(_logLevel) << "Socket recv() " << errnoWithDescription(e) << " " << remoteString() <<endl;
+ throw SocketException( SocketException::RECV_ERROR , remoteString() );
+ }
+ }
+ }
+
+ int Socket::unsafe_recv( char *buf, int max ) {
+ int x = _recv( buf , max );
+ _bytesIn += x;
+ return x;
+ }
+
+
+ int Socket::_recv( char *buf, int max ) {
+#ifdef MONGO_SSL
+ if ( _ssl ){
+ return SSL_read( _ssl.get() , buf , max );
+ }
+#endif
+ return ::recv( _fd , buf , max , portRecvFlags );
+ }
+
+ void Socket::setTimeout( double secs ) {
+ struct timeval tv;
+ tv.tv_sec = (int)secs;
+ tv.tv_usec = (int)((long long)(secs*1000*1000) % (1000*1000));
+ bool report = logLevel > 3; // solaris doesn't provide these
+ DEV report = true;
+ bool ok = setsockopt(_fd, SOL_SOCKET, SO_RCVTIMEO, (char *) &tv, sizeof(tv) ) == 0;
+ if( report && !ok ) log() << "unabled to set SO_RCVTIMEO" << endl;
+ ok = setsockopt(_fd, SOL_SOCKET, SO_SNDTIMEO, (char *) &tv, sizeof(tv) ) == 0;
+ DEV if( report && !ok ) log() << "unabled to set SO_RCVTIMEO" << endl;
+ }
+
+#if defined(_WIN32)
+ struct WinsockInit {
+ WinsockInit() {
+ WSADATA d;
+ if ( WSAStartup(MAKEWORD(2,2), &d) != 0 ) {
+ out() << "ERROR: wsastartup failed " << errnoWithDescription() << endl;
+ problem() << "ERROR: wsastartup failed " << errnoWithDescription() << endl;
+ dbexit( EXIT_NTSERVICE_ERROR );
+ }
+ }
+ } winsock_init;
+#endif
+
+} // namespace mongo
diff --git a/util/net/sock.h b/util/net/sock.h
new file mode 100644
index 0000000..1cd5133
--- /dev/null
+++ b/util/net/sock.h
@@ -0,0 +1,256 @@
+// @file sock.h
+
+/* Copyright 2009 10gen Inc.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#pragma once
+
+#include "../../pch.h"
+
+#include <stdio.h>
+#include <sstream>
+#include "../goodies.h"
+#include "../../db/cmdline.h"
+#include "../mongoutils/str.h"
+
+#ifndef _WIN32
+
+#include <sys/socket.h>
+#include <sys/types.h>
+#include <sys/socket.h>
+#include <sys/un.h>
+#include <errno.h>
+
+#ifdef __openbsd__
+# include <sys/uio.h>
+#endif
+
+#endif // _WIN32
+
+#ifdef MONGO_SSL
+#include <openssl/ssl.h>
+#endif
+
+namespace mongo {
+
+ const int SOCK_FAMILY_UNKNOWN_ERROR=13078;
+
+ void disableNagle(int sock);
+
+#if defined(_WIN32)
+
+ typedef short sa_family_t;
+ typedef int socklen_t;
+
+ // This won't actually be used on windows
+ struct sockaddr_un {
+ short sun_family;
+ char sun_path[108]; // length from unix header
+ };
+
+#else // _WIN32
+
+ inline void closesocket(int s) { close(s); }
+ const int INVALID_SOCKET = -1;
+ typedef int SOCKET;
+
+#endif // _WIN32
+
+ inline string makeUnixSockPath(int port) {
+ return mongoutils::str::stream() << cmdLine.socket << "/mongodb-" << port << ".sock";
+ }
+
+ // If an ip address is passed in, just return that. If a hostname is passed
+ // in, look up its ip and return that. Returns "" on failure.
+ string hostbyname(const char *hostname);
+
+ void enableIPv6(bool state=true);
+ bool IPv6Enabled();
+ void setSockTimeouts(int sock, double secs);
+
+ /**
+ * wrapped around os representation of network address
+ */
+ struct SockAddr {
+ SockAddr() {
+ addressSize = sizeof(sa);
+ memset(&sa, 0, sizeof(sa));
+ sa.ss_family = AF_UNSPEC;
+ }
+ SockAddr(int sourcePort); /* listener side */
+ SockAddr(const char *ip, int port); /* EndPoint (remote) side, or if you want to specify which interface locally */
+
+ template <typename T> T& as() { return *(T*)(&sa); }
+ template <typename T> const T& as() const { return *(const T*)(&sa); }
+
+ string toString(bool includePort=true) const;
+
+ /**
+ * @return one of AF_INET, AF_INET6, or AF_UNIX
+ */
+ sa_family_t getType() const;
+
+ unsigned getPort() const;
+
+ string getAddr() const;
+
+ bool isLocalHost() const;
+
+ bool operator==(const SockAddr& r) const;
+
+ bool operator!=(const SockAddr& r) const;
+
+ bool operator<(const SockAddr& r) const;
+
+ const sockaddr* raw() const {return (sockaddr*)&sa;}
+ sockaddr* raw() {return (sockaddr*)&sa;}
+
+ socklen_t addressSize;
+ private:
+ struct sockaddr_storage sa;
+ };
+
+ extern SockAddr unknownAddress; // ( "0.0.0.0", 0 )
+
+ /** this is not cache and does a syscall */
+ string getHostName();
+
+ /** this is cached, so if changes during the process lifetime
+ * will be stale */
+ string getHostNameCached();
+
+ /**
+ * thrown by Socket and SockAddr
+ */
+ class SocketException : public DBException {
+ public:
+ const enum Type { CLOSED , RECV_ERROR , SEND_ERROR, RECV_TIMEOUT, SEND_TIMEOUT, FAILED_STATE, CONNECT_ERROR } _type;
+
+ SocketException( Type t , string server , int code = 9001 , string extra="" )
+ : DBException( "socket exception" , code ) , _type(t) , _server(server), _extra(extra){ }
+ virtual ~SocketException() throw() {}
+
+ bool shouldPrint() const { return _type != CLOSED; }
+ virtual string toString() const;
+
+ private:
+ string _server;
+ string _extra;
+ };
+
+#ifdef MONGO_SSL
+ class SSLManager : boost::noncopyable {
+ public:
+ SSLManager( bool client );
+
+ void setupPEM( const string& keyFile , const string& password );
+ void setupPubPriv( const string& privateKeyFile , const string& publicKeyFile );
+
+ /**
+ * creates an SSL context to be used for this file descriptor
+ * caller should delete
+ */
+ SSL * secure( int fd );
+
+ static int password_cb( char *buf,int num, int rwflag,void *userdata );
+
+ private:
+ bool _client;
+ SSL_CTX* _context;
+ string _password;
+ };
+#endif
+
+ /**
+ * thin wrapped around file descriptor and system calls
+ * todo: ssl
+ */
+ class Socket {
+ public:
+ Socket(int sock, const SockAddr& farEnd);
+
+ /** In some cases the timeout will actually be 2x this value - eg we do a partial send,
+ then the timeout fires, then we try to send again, then the timeout fires again with
+ no data sent, then we detect that the other side is down.
+
+ Generally you don't want a timeout, you should be very prepared for errors if you set one.
+ */
+ Socket(double so_timeout = 0, int logLevel = 0 );
+
+ bool connect(SockAddr& farEnd);
+ void close();
+
+ void send( const char * data , int len, const char *context );
+ void send( const vector< pair< char *, int > > &data, const char *context );
+
+ // recv len or throw SocketException
+ void recv( char * data , int len );
+ int unsafe_recv( char *buf, int max );
+
+ int getLogLevel() const { return _logLevel; }
+ void setLogLevel( int ll ) { _logLevel = ll; }
+
+ SockAddr remoteAddr() const { return _remote; }
+ string remoteString() const { return _remote.toString(); }
+ unsigned remotePort() const { return _remote.getPort(); }
+
+ void clearCounters() { _bytesIn = 0; _bytesOut = 0; }
+ long long getBytesIn() const { return _bytesIn; }
+ long long getBytesOut() const { return _bytesOut; }
+
+ void setTimeout( double secs );
+
+#ifdef MONGO_SSL
+ /** secures inline */
+ void secure( SSLManager * ssl );
+
+ void secureAccepted( SSLManager * ssl );
+#endif
+
+ /**
+ * call this after a fork for server sockets
+ */
+ void postFork();
+
+ private:
+ void _init();
+ /** raw send, same semantics as ::send */
+ int _send( const char * data , int len );
+
+ /** sends dumbly, just each buffer at a time */
+ void _send( const vector< pair< char *, int > > &data, const char *context );
+
+ /** raw recv, same semantics as ::recv */
+ int _recv( char * buf , int max );
+
+ int _fd;
+ SockAddr _remote;
+ double _timeout;
+
+ long long _bytesIn;
+ long long _bytesOut;
+
+#ifdef MONGO_SSL
+ shared_ptr<SSL> _ssl;
+ SSLManager * _sslAccepted;
+#endif
+
+ protected:
+ int _logLevel; // passed to log() when logging errors
+
+ };
+
+
+} // namespace mongo