diff options
Diffstat (limited to 'debian/patches/asyncio366.diff')
-rw-r--r-- | debian/patches/asyncio366.diff | 200 |
1 files changed, 200 insertions, 0 deletions
diff --git a/debian/patches/asyncio366.diff b/debian/patches/asyncio366.diff new file mode 100644 index 0000000..cd7825b --- /dev/null +++ b/debian/patches/asyncio366.diff @@ -0,0 +1,200 @@ +# DP: Fix callbacks race in SelectorLoop.sock_connect. +# DP: https://github.com/python/asyncio/pull/366 + +Index: b/Lib/asyncio/selector_events.py +=================================================================== +--- a/Lib/asyncio/selector_events.py ++++ b/Lib/asyncio/selector_events.py +@@ -382,6 +382,7 @@ class BaseSelectorEventLoop(base_events. + data = data[n:] + self.add_writer(fd, self._sock_sendall, fut, True, sock, data) + ++ @coroutine + def sock_connect(self, sock, address): + """Connect to a remote socket at address. + +@@ -390,24 +391,16 @@ class BaseSelectorEventLoop(base_events. + if self._debug and sock.gettimeout() != 0: + raise ValueError("the socket must be non-blocking") + +- fut = self.create_future() +- if hasattr(socket, 'AF_UNIX') and sock.family == socket.AF_UNIX: +- self._sock_connect(fut, sock, address) +- else: ++ if not hasattr(socket, 'AF_UNIX') or sock.family != socket.AF_UNIX: + resolved = base_events._ensure_resolved( + address, family=sock.family, proto=sock.proto, loop=self) +- resolved.add_done_callback( +- lambda resolved: self._on_resolved(fut, sock, resolved)) +- +- return fut +- +- def _on_resolved(self, fut, sock, resolved): +- try: ++ if not resolved.done(): ++ yield from resolved + _, _, _, _, address = resolved.result()[0] +- except Exception as exc: +- fut.set_exception(exc) +- else: +- self._sock_connect(fut, sock, address) ++ ++ fut = self.create_future() ++ self._sock_connect(fut, sock, address) ++ return (yield from fut) + + def _sock_connect(self, fut, sock, address): + fd = sock.fileno() +@@ -418,8 +411,8 @@ class BaseSelectorEventLoop(base_events. + # connection runs in background. We have to wait until the socket + # becomes writable to be notified when the connection succeed or + # fails. +- fut.add_done_callback(functools.partial(self._sock_connect_done, +- fd)) ++ fut.add_done_callback( ++ functools.partial(self._sock_connect_done, fd)) + self.add_writer(fd, self._sock_connect_cb, fut, sock, address) + except Exception as exc: + fut.set_exception(exc) +Index: b/Lib/test/test_asyncio/test_selector_events.py +=================================================================== +--- a/Lib/test/test_asyncio/test_selector_events.py ++++ b/Lib/test/test_asyncio/test_selector_events.py +@@ -2,6 +2,8 @@ + + import errno + import socket ++import threading ++import time + import unittest + from unittest import mock + try: +@@ -337,18 +339,6 @@ class BaseSelectorEventLoopTests(test_ut + (10, self.loop._sock_sendall, f, True, sock, b'data'), + self.loop.add_writer.call_args[0]) + +- def test_sock_connect(self): +- sock = test_utils.mock_nonblocking_socket() +- self.loop._sock_connect = mock.Mock() +- +- f = self.loop.sock_connect(sock, ('127.0.0.1', 8080)) +- self.assertIsInstance(f, asyncio.Future) +- self.loop._run_once() +- future_in, sock_in, address_in = self.loop._sock_connect.call_args[0] +- self.assertEqual(future_in, f) +- self.assertEqual(sock_in, sock) +- self.assertEqual(address_in, ('127.0.0.1', 8080)) +- + def test_sock_connect_timeout(self): + # asyncio issue #205: sock_connect() must unregister the socket on + # timeout error +@@ -360,16 +350,16 @@ class BaseSelectorEventLoopTests(test_ut + sock.connect.side_effect = BlockingIOError + + # first call to sock_connect() registers the socket +- fut = self.loop.sock_connect(sock, ('127.0.0.1', 80)) ++ fut = self.loop.create_task( ++ self.loop.sock_connect(sock, ('127.0.0.1', 80))) + self.loop._run_once() + self.assertTrue(sock.connect.called) + self.assertTrue(self.loop.add_writer.called) +- self.assertEqual(len(fut._callbacks), 1) + + # on timeout, the socket must be unregistered + sock.connect.reset_mock() +- fut.set_exception(asyncio.TimeoutError) +- with self.assertRaises(asyncio.TimeoutError): ++ fut.cancel() ++ with self.assertRaises(asyncio.CancelledError): + self.loop.run_until_complete(fut) + self.assertTrue(self.loop.remove_writer.called) + +@@ -1778,5 +1768,88 @@ class SelectorDatagramTransportTests(tes + exc_info=(ConnectionRefusedError, MOCK_ANY, MOCK_ANY)) + + ++class SelectorLoopFunctionalTests(unittest.TestCase): ++ ++ def setUp(self): ++ self.loop = asyncio.new_event_loop() ++ asyncio.set_event_loop(None) ++ ++ def tearDown(self): ++ self.loop.close() ++ ++ @asyncio.coroutine ++ def recv_all(self, sock, nbytes): ++ buf = b'' ++ while len(buf) < nbytes: ++ buf += yield from self.loop.sock_recv(sock, nbytes - len(buf)) ++ return buf ++ ++ def test_sock_connect_sock_write_race(self): ++ TIMEOUT = 3.0 ++ PAYLOAD = b'DATA' * 1024 * 1024 ++ ++ class Server(threading.Thread): ++ def __init__(self, *args, srv_sock, **kwargs): ++ super().__init__(*args, **kwargs) ++ self.srv_sock = srv_sock ++ ++ def run(self): ++ with self.srv_sock: ++ srv_sock.listen(100) ++ ++ sock, addr = self.srv_sock.accept() ++ sock.settimeout(TIMEOUT) ++ ++ with sock: ++ sock.sendall(b'helo') ++ ++ buf = bytearray() ++ while len(buf) < len(PAYLOAD): ++ pack = sock.recv(1024 * 65) ++ if not pack: ++ break ++ buf.extend(pack) ++ ++ @asyncio.coroutine ++ def client(addr): ++ sock = socket.socket() ++ with sock: ++ sock.setblocking(False) ++ ++ started = time.monotonic() ++ while True: ++ if time.monotonic() - started > TIMEOUT: ++ self.fail('unable to connect to the socket') ++ return ++ try: ++ yield from self.loop.sock_connect(sock, addr) ++ except OSError: ++ yield from asyncio.sleep(0.05, loop=self.loop) ++ else: ++ break ++ ++ # Give 'Server' thread a chance to accept and send b'helo' ++ time.sleep(0.1) ++ ++ data = yield from self.recv_all(sock, 4) ++ self.assertEqual(data, b'helo') ++ yield from self.loop.sock_sendall(sock, PAYLOAD) ++ ++ srv_sock = socket.socket() ++ srv_sock.settimeout(TIMEOUT) ++ srv_sock.bind(('127.0.0.1', 0)) ++ srv_addr = srv_sock.getsockname() ++ ++ srv = Server(srv_sock=srv_sock, daemon=True) ++ srv.start() ++ ++ try: ++ self.loop.run_until_complete( ++ asyncio.wait_for(client(srv_addr), loop=self.loop, ++ timeout=TIMEOUT)) ++ finally: ++ srv.join() ++ ++ + if __name__ == '__main__': + unittest.main() |