python-dns/0005-Ensure-asyncio-datagram-sockets-on-windows-have-had-.patch

142 lines
5.0 KiB
Diff
Raw Permalink Normal View History

2024-04-24 17:01:11 +08:00
From adfc942725bd36d28ec53f7e5480ace9eb543bd8 Mon Sep 17 00:00:00 2001
From: Bob Halley <halley@dnspython.org>
Date: Thu, 14 Dec 2023 18:04:39 -0800
Subject: [PATCH] Ensure asyncio datagram sockets on windows have had a bind()
before recvfrom().
The fix for [#637] erroneously concluded that that windows asyncio
needed connected datagram sockets, but subsequent further
investation showed that the actual problem was that windows wants
an unconnected datagram socket to be bound before recvfrom is called.
Linux autobinds in this case to the wildcard address and port, so
that's why we didn't see any problems there. We now ensure that
the source is bound.
---
dns/_asyncio_backend.py | 13 ++++++-------
tests/test_async.py | 25 +++++--------------------
2 files changed, 11 insertions(+), 27 deletions(-)
diff --git a/dns/_asyncio_backend.py b/dns/_asyncio_backend.py
index 2631228..7d4d1b5 100644
--- a/dns/_asyncio_backend.py
+++ b/dns/_asyncio_backend.py
@@ -8,6 +8,7 @@ import sys
import dns._asyncbackend
import dns.exception
+import dns.inet
_is_win32 = sys.platform == "win32"
@@ -224,14 +225,12 @@ class Backend(dns._asyncbackend.Backend):
ssl_context=None,
server_hostname=None,
):
- if destination is None and socktype == socket.SOCK_DGRAM and _is_win32:
- raise NotImplementedError(
- "destinationless datagram sockets "
- "are not supported by asyncio "
- "on Windows"
- )
loop = _get_running_loop()
if socktype == socket.SOCK_DGRAM:
+ if _is_win32 and source is None:
+ # Win32 wants explicit binding before recvfrom(). This is the
+ # proper fix for [#637].
+ source = (dns.inet.any_for_af(af), 0)
transport, protocol = await loop.create_datagram_endpoint(
_DatagramProtocol,
source,
@@ -266,7 +265,7 @@ class Backend(dns._asyncbackend.Backend):
await asyncio.sleep(interval)
def datagram_connection_required(self):
- return _is_win32
+ return False
def get_transport_class(self):
return _HTTPTransport
diff --git a/tests/test_async.py b/tests/test_async.py
index d0f977a..ac32431 100644
--- a/tests/test_async.py
+++ b/tests/test_async.py
@@ -171,8 +171,6 @@ class MiscQuery(unittest.TestCase):
@unittest.skipIf(not tests.util.is_internet_reachable(), "Internet not reachable")
class AsyncTests(unittest.TestCase):
- connect_udp = sys.platform == "win32"
-
def setUp(self):
self.backend = dns.asyncbackend.set_default_backend("asyncio")
@@ -327,12 +325,12 @@ class AsyncTests(unittest.TestCase):
qname = dns.name.from_text("dns.google.")
async def run():
- if self.connect_udp:
- dtuple = (address, 53)
- else:
- dtuple = None
async with await self.backend.make_socket(
- dns.inet.af_for_address(address), socket.SOCK_DGRAM, 0, None, dtuple
+ dns.inet.af_for_address(address),
+ socket.SOCK_DGRAM,
+ 0,
+ None,
+ None,
) as s:
q = dns.message.make_query(qname, dns.rdatatype.A)
return await dns.asyncquery.udp(q, address, sock=s, timeout=2)
@@ -485,9 +483,6 @@ class AsyncTests(unittest.TestCase):
self.assertFalse(tcp)
def testUDPReceiveQuery(self):
- if self.connect_udp:
- self.skipTest("test needs connectionless sockets")
-
async def run():
async with await self.backend.make_socket(
socket.AF_INET, socket.SOCK_DGRAM, source=("127.0.0.1", 0)
@@ -509,9 +504,6 @@ class AsyncTests(unittest.TestCase):
self.assertEqual(sender_address, recv_address)
def testUDPReceiveTimeout(self):
- if self.connect_udp:
- self.skipTest("test needs connectionless sockets")
-
async def arun():
async with await self.backend.make_socket(
socket.AF_INET, socket.SOCK_DGRAM, 0, ("127.0.0.1", 0)
@@ -616,8 +608,6 @@ class AsyncTests(unittest.TestCase):
@unittest.skipIf(not tests.util.is_internet_reachable(), "Internet not reachable")
class AsyncioOnlyTests(unittest.TestCase):
- connect_udp = sys.platform == "win32"
-
def setUp(self):
self.backend = dns.asyncbackend.set_default_backend("asyncio")
@@ -625,9 +615,6 @@ class AsyncioOnlyTests(unittest.TestCase):
return asyncio.run(afunc())
def testUseAfterTimeout(self):
- if self.connect_udp:
- self.skipTest("test needs connectionless sockets")
-
# Test #843 fix.
async def run():
qname = dns.name.from_text("dns.google")
@@ -678,8 +665,6 @@ try:
return trio.run(afunc)
class TrioAsyncTests(AsyncTests):
- connect_udp = False
-
def setUp(self):
self.backend = dns.asyncbackend.set_default_backend("trio")
--
2.39.1