python-dns/0006-The-Tudoor-fix-should-not-eat-valid-Truncated-except.patch

230 lines
7.9 KiB
Diff
Raw Normal View History

2024-04-24 17:01:11 +08:00
From 2ab3d1628c9ae0545e225522b3b445c3478dc6ad Mon Sep 17 00:00:00 2001
From: Bob Halley <halley@dnspython.org>
Date: Sun, 18 Feb 2024 10:27:43 -0800
Subject: [PATCH] The Tudoor fix should not eat valid Truncated exceptions
[#1053] (#1054)
* The Tudoor fix should not eat valid Truncated exceptions [##1053]
* Make logic more readable
---
dns/asyncquery.py | 10 ++++++++
dns/query.py | 14 +++++++++++
tests/test_async.py | 60 ++++++++++++++++++++++++++++++++++++++++++++-
tests/test_query.py | 44 ++++++++++++++++++++++++++++++++-
4 files changed, 126 insertions(+), 2 deletions(-)
diff --git a/dns/asyncquery.py b/dns/asyncquery.py
index 94cb2413..4d9ab9ae 100644
--- a/dns/asyncquery.py
+++ b/dns/asyncquery.py
@@ -151,6 +151,16 @@ async def receive_udp(
ignore_trailing=ignore_trailing,
raise_on_truncation=raise_on_truncation,
)
+ except dns.message.Truncated as e:
+ # See the comment in query.py for details.
+ if (
+ ignore_errors
+ and query is not None
+ and not query.is_response(e.message())
+ ):
+ continue
+ else:
+ raise
except Exception:
if ignore_errors:
continue
diff --git a/dns/query.py b/dns/query.py
index 06d186c7..384bf31e 100644
--- a/dns/query.py
+++ b/dns/query.py
@@ -618,6 +618,20 @@ def receive_udp(
ignore_trailing=ignore_trailing,
raise_on_truncation=raise_on_truncation,
)
+ except dns.message.Truncated as e:
+ # If we got Truncated and not FORMERR, we at least got the header with TC
+ # set, and very likely the question section, so we'll re-raise if the
+ # message seems to be a response as we need to know when truncation happens.
+ # We need to check that it seems to be a response as we don't want a random
+ # injected message with TC set to cause us to bail out.
+ if (
+ ignore_errors
+ and query is not None
+ and not query.is_response(e.message())
+ ):
+ continue
+ else:
+ raise
except Exception:
if ignore_errors:
continue
diff --git a/tests/test_async.py b/tests/test_async.py
index ba2078cd..9373548d 100644
--- a/tests/test_async.py
+++ b/tests/test_async.py
@@ -705,7 +705,11 @@ async def mock_receive(
from2,
ignore_unexpected=True,
ignore_errors=True,
+ raise_on_truncation=False,
+ good_r=None,
):
+ if good_r is None:
+ good_r = self.good_r
s = MockSock(wire1, from1, wire2, from2)
(r, when, _) = await dns.asyncquery.receive_udp(
s,
@@ -713,9 +717,10 @@ async def mock_receive(
time.time() + 2,
ignore_unexpected=ignore_unexpected,
ignore_errors=ignore_errors,
+ raise_on_truncation=raise_on_truncation,
query=self.q,
)
- self.assertEqual(r, self.good_r)
+ self.assertEqual(r, good_r)
def test_good_mock(self):
async def run():
@@ -802,6 +807,59 @@ async def run():
self.async_run(run)
+ def test_good_wire_with_truncation_flag_and_no_truncation_raise(self):
+ async def run():
+ tc_r = dns.message.make_response(self.q)
+ tc_r.flags |= dns.flags.TC
+ tc_r_wire = tc_r.to_wire()
+ await self.mock_receive(
+ tc_r_wire, ("127.0.0.1", 53), None, None, good_r=tc_r
+ )
+
+ self.async_run(run)
+
+ def test_good_wire_with_truncation_flag_and_truncation_raise(self):
+ async def agood():
+ tc_r = dns.message.make_response(self.q)
+ tc_r.flags |= dns.flags.TC
+ tc_r_wire = tc_r.to_wire()
+ await self.mock_receive(
+ tc_r_wire, ("127.0.0.1", 53), None, None, raise_on_truncation=True
+ )
+
+ def good():
+ self.async_run(agood)
+
+ self.assertRaises(dns.message.Truncated, good)
+
+ def test_wrong_id_wire_with_truncation_flag_and_no_truncation_raise(self):
+ async def run():
+ bad_r = dns.message.make_response(self.q)
+ bad_r.id += 1
+ bad_r.flags |= dns.flags.TC
+ bad_r_wire = bad_r.to_wire()
+ await self.mock_receive(
+ bad_r_wire, ("127.0.0.1", 53), self.good_r_wire, ("127.0.0.1", 53)
+ )
+
+ self.async_run(run)
+
+ def test_wrong_id_wire_with_truncation_flag_and_truncation_raise(self):
+ async def run():
+ bad_r = dns.message.make_response(self.q)
+ bad_r.id += 1
+ bad_r.flags |= dns.flags.TC
+ bad_r_wire = bad_r.to_wire()
+ await self.mock_receive(
+ bad_r_wire,
+ ("127.0.0.1", 53),
+ self.good_r_wire,
+ ("127.0.0.1", 53),
+ raise_on_truncation=True,
+ )
+
+ self.async_run(run)
+
def test_bad_wire_not_ignored(self):
bad_r = dns.message.make_response(self.q)
bad_r.id += 1
diff --git a/tests/test_query.py b/tests/test_query.py
index 1039a14e..62007e85 100644
--- a/tests/test_query.py
+++ b/tests/test_query.py
@@ -29,6 +29,7 @@
have_ssl = False
import dns.exception
+import dns.flags
import dns.inet
import dns.message
import dns.name
@@ -706,7 +707,11 @@ def mock_receive(
from2,
ignore_unexpected=True,
ignore_errors=True,
+ raise_on_truncation=False,
+ good_r=None,
):
+ if good_r is None:
+ good_r = self.good_r
s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
try:
with mock_udp_recv(wire1, from1, wire2, from2):
@@ -716,9 +721,10 @@ def mock_receive(
time.time() + 2,
ignore_unexpected=ignore_unexpected,
ignore_errors=ignore_errors,
+ raise_on_truncation=raise_on_truncation,
query=self.q,
)
- self.assertEqual(r, self.good_r)
+ self.assertEqual(r, good_r)
finally:
s.close()
@@ -787,6 +793,42 @@ def test_bad_wire(self):
bad_r_wire[:10], ("127.0.0.1", 53), self.good_r_wire, ("127.0.0.1", 53)
)
+ def test_good_wire_with_truncation_flag_and_no_truncation_raise(self):
+ tc_r = dns.message.make_response(self.q)
+ tc_r.flags |= dns.flags.TC
+ tc_r_wire = tc_r.to_wire()
+ self.mock_receive(tc_r_wire, ("127.0.0.1", 53), None, None, good_r=tc_r)
+
+ def test_good_wire_with_truncation_flag_and_truncation_raise(self):
+ def good():
+ tc_r = dns.message.make_response(self.q)
+ tc_r.flags |= dns.flags.TC
+ tc_r_wire = tc_r.to_wire()
+ self.mock_receive(
+ tc_r_wire, ("127.0.0.1", 53), None, None, raise_on_truncation=True
+ )
+
+ self.assertRaises(dns.message.Truncated, good)
+
+ def test_wrong_id_wire_with_truncation_flag_and_no_truncation_raise(self):
+ bad_r = dns.message.make_response(self.q)
+ bad_r.id += 1
+ bad_r.flags |= dns.flags.TC
+ bad_r_wire = bad_r.to_wire()
+ self.mock_receive(
+ bad_r_wire, ("127.0.0.1", 53), self.good_r_wire, ("127.0.0.1", 53)
+ )
+
+ def test_wrong_id_wire_with_truncation_flag_and_truncation_raise(self):
+ bad_r = dns.message.make_response(self.q)
+ bad_r.id += 1
+ bad_r.flags |= dns.flags.TC
+ bad_r_wire = bad_r.to_wire()
+ self.mock_receive(
+ bad_r_wire, ("127.0.0.1", 53), self.good_r_wire, ("127.0.0.1", 53),
+ raise_on_truncation=True
+ )
+
def test_bad_wire_not_ignored(self):
bad_r = dns.message.make_response(self.q)
bad_r.id += 1