summaryrefslogtreecommitdiff
path: root/debian/patches/asyncio366.diff
diff options
context:
space:
mode:
Diffstat (limited to 'debian/patches/asyncio366.diff')
-rw-r--r--debian/patches/asyncio366.diff200
1 files changed, 200 insertions, 0 deletions
diff --git a/debian/patches/asyncio366.diff b/debian/patches/asyncio366.diff
new file mode 100644
index 0000000..cd7825b
--- /dev/null
+++ b/debian/patches/asyncio366.diff
@@ -0,0 +1,200 @@
+# DP: Fix callbacks race in SelectorLoop.sock_connect.
+# DP: https://github.com/python/asyncio/pull/366
+
+Index: b/Lib/asyncio/selector_events.py
+===================================================================
+--- a/Lib/asyncio/selector_events.py
++++ b/Lib/asyncio/selector_events.py
+@@ -382,6 +382,7 @@ class BaseSelectorEventLoop(base_events.
+ data = data[n:]
+ self.add_writer(fd, self._sock_sendall, fut, True, sock, data)
+
++ @coroutine
+ def sock_connect(self, sock, address):
+ """Connect to a remote socket at address.
+
+@@ -390,24 +391,16 @@ class BaseSelectorEventLoop(base_events.
+ if self._debug and sock.gettimeout() != 0:
+ raise ValueError("the socket must be non-blocking")
+
+- fut = self.create_future()
+- if hasattr(socket, 'AF_UNIX') and sock.family == socket.AF_UNIX:
+- self._sock_connect(fut, sock, address)
+- else:
++ if not hasattr(socket, 'AF_UNIX') or sock.family != socket.AF_UNIX:
+ resolved = base_events._ensure_resolved(
+ address, family=sock.family, proto=sock.proto, loop=self)
+- resolved.add_done_callback(
+- lambda resolved: self._on_resolved(fut, sock, resolved))
+-
+- return fut
+-
+- def _on_resolved(self, fut, sock, resolved):
+- try:
++ if not resolved.done():
++ yield from resolved
+ _, _, _, _, address = resolved.result()[0]
+- except Exception as exc:
+- fut.set_exception(exc)
+- else:
+- self._sock_connect(fut, sock, address)
++
++ fut = self.create_future()
++ self._sock_connect(fut, sock, address)
++ return (yield from fut)
+
+ def _sock_connect(self, fut, sock, address):
+ fd = sock.fileno()
+@@ -418,8 +411,8 @@ class BaseSelectorEventLoop(base_events.
+ # connection runs in background. We have to wait until the socket
+ # becomes writable to be notified when the connection succeed or
+ # fails.
+- fut.add_done_callback(functools.partial(self._sock_connect_done,
+- fd))
++ fut.add_done_callback(
++ functools.partial(self._sock_connect_done, fd))
+ self.add_writer(fd, self._sock_connect_cb, fut, sock, address)
+ except Exception as exc:
+ fut.set_exception(exc)
+Index: b/Lib/test/test_asyncio/test_selector_events.py
+===================================================================
+--- a/Lib/test/test_asyncio/test_selector_events.py
++++ b/Lib/test/test_asyncio/test_selector_events.py
+@@ -2,6 +2,8 @@
+
+ import errno
+ import socket
++import threading
++import time
+ import unittest
+ from unittest import mock
+ try:
+@@ -337,18 +339,6 @@ class BaseSelectorEventLoopTests(test_ut
+ (10, self.loop._sock_sendall, f, True, sock, b'data'),
+ self.loop.add_writer.call_args[0])
+
+- def test_sock_connect(self):
+- sock = test_utils.mock_nonblocking_socket()
+- self.loop._sock_connect = mock.Mock()
+-
+- f = self.loop.sock_connect(sock, ('127.0.0.1', 8080))
+- self.assertIsInstance(f, asyncio.Future)
+- self.loop._run_once()
+- future_in, sock_in, address_in = self.loop._sock_connect.call_args[0]
+- self.assertEqual(future_in, f)
+- self.assertEqual(sock_in, sock)
+- self.assertEqual(address_in, ('127.0.0.1', 8080))
+-
+ def test_sock_connect_timeout(self):
+ # asyncio issue #205: sock_connect() must unregister the socket on
+ # timeout error
+@@ -360,16 +350,16 @@ class BaseSelectorEventLoopTests(test_ut
+ sock.connect.side_effect = BlockingIOError
+
+ # first call to sock_connect() registers the socket
+- fut = self.loop.sock_connect(sock, ('127.0.0.1', 80))
++ fut = self.loop.create_task(
++ self.loop.sock_connect(sock, ('127.0.0.1', 80)))
+ self.loop._run_once()
+ self.assertTrue(sock.connect.called)
+ self.assertTrue(self.loop.add_writer.called)
+- self.assertEqual(len(fut._callbacks), 1)
+
+ # on timeout, the socket must be unregistered
+ sock.connect.reset_mock()
+- fut.set_exception(asyncio.TimeoutError)
+- with self.assertRaises(asyncio.TimeoutError):
++ fut.cancel()
++ with self.assertRaises(asyncio.CancelledError):
+ self.loop.run_until_complete(fut)
+ self.assertTrue(self.loop.remove_writer.called)
+
+@@ -1778,5 +1768,88 @@ class SelectorDatagramTransportTests(tes
+ exc_info=(ConnectionRefusedError, MOCK_ANY, MOCK_ANY))
+
+
++class SelectorLoopFunctionalTests(unittest.TestCase):
++
++ def setUp(self):
++ self.loop = asyncio.new_event_loop()
++ asyncio.set_event_loop(None)
++
++ def tearDown(self):
++ self.loop.close()
++
++ @asyncio.coroutine
++ def recv_all(self, sock, nbytes):
++ buf = b''
++ while len(buf) < nbytes:
++ buf += yield from self.loop.sock_recv(sock, nbytes - len(buf))
++ return buf
++
++ def test_sock_connect_sock_write_race(self):
++ TIMEOUT = 3.0
++ PAYLOAD = b'DATA' * 1024 * 1024
++
++ class Server(threading.Thread):
++ def __init__(self, *args, srv_sock, **kwargs):
++ super().__init__(*args, **kwargs)
++ self.srv_sock = srv_sock
++
++ def run(self):
++ with self.srv_sock:
++ srv_sock.listen(100)
++
++ sock, addr = self.srv_sock.accept()
++ sock.settimeout(TIMEOUT)
++
++ with sock:
++ sock.sendall(b'helo')
++
++ buf = bytearray()
++ while len(buf) < len(PAYLOAD):
++ pack = sock.recv(1024 * 65)
++ if not pack:
++ break
++ buf.extend(pack)
++
++ @asyncio.coroutine
++ def client(addr):
++ sock = socket.socket()
++ with sock:
++ sock.setblocking(False)
++
++ started = time.monotonic()
++ while True:
++ if time.monotonic() - started > TIMEOUT:
++ self.fail('unable to connect to the socket')
++ return
++ try:
++ yield from self.loop.sock_connect(sock, addr)
++ except OSError:
++ yield from asyncio.sleep(0.05, loop=self.loop)
++ else:
++ break
++
++ # Give 'Server' thread a chance to accept and send b'helo'
++ time.sleep(0.1)
++
++ data = yield from self.recv_all(sock, 4)
++ self.assertEqual(data, b'helo')
++ yield from self.loop.sock_sendall(sock, PAYLOAD)
++
++ srv_sock = socket.socket()
++ srv_sock.settimeout(TIMEOUT)
++ srv_sock.bind(('127.0.0.1', 0))
++ srv_addr = srv_sock.getsockname()
++
++ srv = Server(srv_sock=srv_sock, daemon=True)
++ srv.start()
++
++ try:
++ self.loop.run_until_complete(
++ asyncio.wait_for(client(srv_addr), loop=self.loop,
++ timeout=TIMEOUT))
++ finally:
++ srv.join()
++
++
+ if __name__ == '__main__':
+ unittest.main()