Skip to content

Commit

Permalink
Use custom exceptions for internal errors
Browse files Browse the repository at this point in the history
Fixes #44
  • Loading branch information
miguelgrinberg committed Apr 7, 2017
1 parent 7dda70c commit 36814a4
Show file tree
Hide file tree
Showing 9 changed files with 35 additions and 15 deletions.
3 changes: 2 additions & 1 deletion engineio/asyncio_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import six
from six.moves import urllib

from .exceptions import EngineIOError
from . import packet
from . import server
from . import asyncio_socket
Expand Down Expand Up @@ -152,7 +153,7 @@ async def handle_request(self, *args, **kwargs):
try:
await socket.handle_post_request(environ)
r = self._ok()
except ValueError:
except EngineIOError:
r = self._bad_request()
else:
self.logger.warning('Method %s not supported', method)
Expand Down
7 changes: 4 additions & 3 deletions engineio/asyncio_socket.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import time
import six

from . import exceptions
from . import packet
from . import payload
from . import socket
Expand Down Expand Up @@ -44,7 +45,7 @@ async def receive(self, pkt):
elif pkt.packet_type == packet.CLOSE:
await self.close(wait=False, abort=True)
else:
raise ValueError
raise exceptions.UnknownPacketError()

async def send(self, pkt):
"""Send a packet to the client."""
Expand Down Expand Up @@ -81,7 +82,7 @@ async def handle_post_request(self, environ):
"""Handle a long-polling POST request from the client."""
length = int(environ.get('CONTENT_LENGTH', '0'))
if length > self.server.max_http_buffer_size:
raise ValueError()
raise exceptions.ContentTooLongError()
else:
body = await environ['wsgi.input'].read(length)
p = payload.Payload(encoded_payload=body)
Expand Down Expand Up @@ -179,7 +180,7 @@ async def writer():
pkt = packet.Packet(encoded_packet=p)
try:
await self.receive(pkt)
except ValueError:
except exceptions.UnknownPacketError:
pass

await self.queue.put(None) # unlock the writer task so it can exit
Expand Down
10 changes: 10 additions & 0 deletions engineio/exceptions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
class EngineIOError(Exception):
pass


class ContentTooLongError(EngineIOError):
pass


class UnknownPacketError(EngineIOError):
pass
3 changes: 2 additions & 1 deletion engineio/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import six
from six.moves import urllib

from .exceptions import EngineIOError
from . import packet
from . import payload
from . import socket
Expand Down Expand Up @@ -277,7 +278,7 @@ def handle_request(self, environ, start_response):
try:
socket.handle_post_request(environ)
r = self._ok()
except ValueError:
except EngineIOError:
r = self._bad_request()
else:
self.logger.warning('Method %s not supported', method)
Expand Down
7 changes: 4 additions & 3 deletions engineio/socket.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import time
import six

from . import exceptions
from . import packet
from . import payload

Expand Down Expand Up @@ -58,7 +59,7 @@ def receive(self, pkt):
elif pkt.packet_type == packet.CLOSE:
self.close(wait=False, abort=True)
else:
raise ValueError
raise exceptions.UnknownPacketError()

def send(self, pkt):
"""Send a packet to the client."""
Expand Down Expand Up @@ -97,7 +98,7 @@ def handle_post_request(self, environ):
"""Handle a long-polling POST request from the client."""
length = int(environ.get('CONTENT_LENGTH', '0'))
if length > self.server.max_http_buffer_size:
raise ValueError()
raise exceptions.ContentTooLongError()
else:
body = environ['wsgi.input'].read(length)
p = payload.Payload(encoded_payload=body)
Expand Down Expand Up @@ -194,7 +195,7 @@ def writer():
pkt = packet.Packet(encoded_packet=p)
try:
self.receive(pkt)
except ValueError:
except exceptions.UnknownPacketError:
pass

self.queue.put(None) # unlock the writer task so that it can exit
Expand Down
3 changes: 2 additions & 1 deletion tests/test_asyncio_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
else:
import mock

from engineio import exceptions
from engineio import packet
from engineio import payload
if sys.version_info >= (3, 5):
Expand Down Expand Up @@ -522,7 +523,7 @@ def test_post_request_error(self, import_module):

@asyncio.coroutine
def mock_post_request(*args, **kwargs):
raise ValueError()
raise exceptions.ContentTooLongError()

mock_socket.handle_post_request.mock.return_value = mock_post_request()
_run(s.handle_request('request'))
Expand Down
5 changes: 3 additions & 2 deletions tests/test_asyncio_socket.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
else:
import mock

from engineio import exceptions
from engineio import packet
from engineio import payload
if sys.version_info >= (3, 5):
Expand Down Expand Up @@ -111,7 +112,7 @@ def test_message_handler(self):
def test_invalid_packet(self):
mock_server = self._get_mock_server()
s = asyncio_socket.AsyncSocket(mock_server, 'sid')
self.assertRaises(ValueError, _run,
self.assertRaises(exceptions.UnknownPacketError, _run,
s.receive(packet.Packet(packet.OPEN)))

def test_timeout(self):
Expand Down Expand Up @@ -165,7 +166,7 @@ def test_polling_write_too_large(self):
environ = {'REQUEST_METHOD': 'POST', 'QUERY_STRING': 'sid=foo',
'CONTENT_LENGTH': len(p),
'wsgi.input': self._get_read_mock_coro(p)}
self.assertRaises(ValueError, _run,
self.assertRaises(exceptions.ContentTooLongError, _run,
s.handle_post_request(environ))

def test_upgrade_handshake(self):
Expand Down
3 changes: 2 additions & 1 deletion tests/test_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
else:
import mock

from engineio import exceptions
from engineio import packet
from engineio import payload
from engineio import server
Expand Down Expand Up @@ -626,7 +627,7 @@ def test_post_request_error(self):
s = server.Server()
mock_socket = self._get_mock_socket()
mock_socket.handle_post_request = mock.MagicMock(
side_effect=[ValueError])
side_effect=[exceptions.EngineIOError])
s.sockets['foo'] = mock_socket
environ = {'REQUEST_METHOD': 'POST', 'QUERY_STRING': 'sid=foo'}
start_response = mock.MagicMock()
Expand Down
9 changes: 6 additions & 3 deletions tests/test_socket.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
else:
import mock

from engineio import exceptions
from engineio import packet
from engineio import payload
from engineio import socket
Expand Down Expand Up @@ -98,7 +99,8 @@ def test_message_sync_handler(self):
def test_invalid_packet(self):
mock_server = self._get_mock_server()
s = socket.Socket(mock_server, 'sid')
self.assertRaises(ValueError, s.receive, packet.Packet(packet.OPEN))
self.assertRaises(exceptions.UnknownPacketError, s.receive,
packet.Packet(packet.OPEN))

def test_timeout(self):
mock_server = self._get_mock_server()
Expand Down Expand Up @@ -152,7 +154,8 @@ def test_polling_write_too_large(self):
s.receive = mock.MagicMock()
environ = {'REQUEST_METHOD': 'POST', 'QUERY_STRING': 'sid=foo',
'CONTENT_LENGTH': len(p), 'wsgi.input': six.BytesIO(p)}
self.assertRaises(ValueError, s.handle_post_request, environ)
self.assertRaises(exceptions.ContentTooLongError,
s.handle_post_request, environ)

def test_upgrade_handshake(self):
mock_server = self._get_mock_server()
Expand Down Expand Up @@ -239,7 +242,7 @@ def test_invalid_packet_type(self):
mock_server = self._get_mock_server()
s = socket.Socket(mock_server, 'sid')
pkt = packet.Packet(packet_type=99)
self.assertRaises(ValueError, lambda: s.receive(pkt))
self.assertRaises(exceptions.UnknownPacketError, s.receive, pkt)

def test_upgrade_not_supported(self):
mock_server = self._get_mock_server()
Expand Down

0 comments on commit 36814a4

Please sign in to comment.