Skip to content

Commit

Permalink
test IgnoreErrors
Browse files Browse the repository at this point in the history
(cherry picked from commit ac6763f)
  • Loading branch information
rthalley committed Feb 16, 2024
1 parent e093299 commit 7952e31
Showing 1 changed file with 140 additions and 0 deletions.
140 changes: 140 additions & 0 deletions tests/test_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT
# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.

import contextlib
import socket
import sys
import time
Expand All @@ -32,6 +33,7 @@
import dns.message
import dns.name
import dns.query
import dns.rcode
import dns.rdataclass
import dns.rdatatype
import dns.tsigkeyring
Expand Down Expand Up @@ -659,3 +661,141 @@ def test_matches_destination(self):
dns.query._matches_destination(
socket.AF_INET, ("10.0.0.1", 1234), ("10.0.0.1", 1235), False
)


@contextlib.contextmanager
def mock_udp_recv(wire1, from1, wire2, from2):
saved = dns.query._udp_recv
first_time = True

def mock(sock, max_size, expiration):
nonlocal first_time
if first_time:
first_time = False
return wire1, from1
else:
return wire2, from2

try:
dns.query._udp_recv = mock
yield None
finally:
dns.query._udp_recv = saved


class IgnoreErrors(unittest.TestCase):
def setUp(self):
self.q = dns.message.make_query("example.", "A")
self.good_r = dns.message.make_response(self.q)
self.good_r.set_rcode(dns.rcode.NXDOMAIN)
self.good_r_wire = self.good_r.to_wire()

def mock_receive(
self,
wire1,
from1,
wire2,
from2,
ignore_unexpected=True,
ignore_errors=True,
):
s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
try:
with mock_udp_recv(wire1, from1, wire2, from2):
(r, when) = dns.query.receive_udp(
s,
("127.0.0.1", 53),
time.time() + 2,
ignore_unexpected=ignore_unexpected,
ignore_errors=ignore_errors,
query=self.q,
)
self.assertEqual(r, self.good_r)
finally:
s.close()

def test_good_mock(self):
self.mock_receive(self.good_r_wire, ("127.0.0.1", 53), None, None)

def test_bad_address(self):
self.mock_receive(
self.good_r_wire, ("127.0.0.2", 53), self.good_r_wire, ("127.0.0.1", 53)
)

def test_bad_address_not_ignored(self):
def bad():
self.mock_receive(
self.good_r_wire,
("127.0.0.2", 53),
self.good_r_wire,
("127.0.0.1", 53),
ignore_unexpected=False,
)

self.assertRaises(dns.query.UnexpectedSource, bad)

def test_bad_id(self):
bad_r = dns.message.make_response(self.q)
bad_r.id += 1
bad_r_wire = bad_r.to_wire()
self.mock_receive(
bad_r_wire, ("127.0.0.1", 53), self.good_r_wire, ("127.0.0.1", 53)
)

def test_bad_id_not_ignored(self):
bad_r = dns.message.make_response(self.q)
bad_r.id += 1
bad_r_wire = bad_r.to_wire()

def bad():
(r, wire) = self.mock_receive(
bad_r_wire,
("127.0.0.1", 53),
self.good_r_wire,
("127.0.0.1", 53),
ignore_errors=False,
)

self.assertRaises(AssertionError, bad)

def test_bad_wire(self):
bad_r = dns.message.make_response(self.q)
bad_r.id += 1
bad_r_wire = bad_r.to_wire()
self.mock_receive(
bad_r_wire[:10], ("127.0.0.1", 53), self.good_r_wire, ("127.0.0.1", 53)
)

def test_bad_wire_not_ignored(self):
bad_r = dns.message.make_response(self.q)
bad_r.id += 1
bad_r_wire = bad_r.to_wire()

def bad():
self.mock_receive(
bad_r_wire[:10],
("127.0.0.1", 53),
self.good_r_wire,
("127.0.0.1", 53),
ignore_errors=False,
)

self.assertRaises(dns.message.ShortHeader, bad)

def test_trailing_wire(self):
wire = self.good_r_wire + b"abcd"
self.mock_receive(wire, ("127.0.0.1", 53), self.good_r_wire, ("127.0.0.1", 53))

def test_trailing_wire_not_ignored(self):
wire = self.good_r_wire + b"abcd"

def bad():
self.mock_receive(
wire,
("127.0.0.1", 53),
self.good_r_wire,
("127.0.0.1", 53),
ignore_errors=False,
)

self.assertRaises(dns.message.TrailingJunk, bad)

0 comments on commit 7952e31

Please sign in to comment.