Switch from gevent-websocket to gevent-ws (#2439)

* Switch from gevent-websocket to gevent-ws

* Return error handling, add gevent_ws source to lib
This commit is contained in:
Ivanq 2020-02-28 03:20:04 +03:00 committed by GitHub
parent 2862587c15
commit 219b90668f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 268 additions and 24 deletions

View File

@ -7,7 +7,7 @@ rsa
PySocks>=1.6.8
pyasn1
websocket_client
gevent-websocket
gevent-ws
coincurve
python-bitcoinlib
maxminddb

View File

@ -646,7 +646,6 @@ class Config(object):
logging.addLevelName(15, "WARNING")
logging.getLogger('').name = "-" # Remove root prefix
logging.getLogger("geventwebsocket.handler").setLevel(logging.WARNING) # Don't log ws debug messages
if console_logging:
self.initConsoleLogger()

View File

@ -814,7 +814,7 @@ class UiRequest(object):
# Remove websocket from every site (admin sites allowed to join other sites event channels)
if ui_websocket in site_check.websockets:
site_check.websockets.remove(ui_websocket)
return "Bye."
return [b"Bye."]
else: # No site found by wrapper key
ws.send(json.dumps({"error": "Wrapper key not found: %s" % wrapper_key}))
return self.error403("Wrapper key not found: %s" % wrapper_key)

View File

@ -5,8 +5,7 @@ import socket
import gevent
from gevent.pywsgi import WSGIServer
from gevent.pywsgi import WSGIHandler
from geventwebsocket.handler import WebSocketHandler
from lib.gevent_ws import WebSocketHandler
from .UiRequest import UiRequest
from Site import SiteManager
@ -27,7 +26,7 @@ class LogDb(logging.StreamHandler):
# Skip websocket handler if not necessary
class UiWSGIHandler(WSGIHandler):
class UiWSGIHandler(WebSocketHandler):
def __init__(self, *args, **kwargs):
self.server = args[2]
@ -46,24 +45,14 @@ class UiWSGIHandler(WSGIHandler):
self.write(block)
def run_application(self):
if "HTTP_UPGRADE" in self.environ: # Websocket request
try:
ws_handler = WebSocketHandler(*self.args, **self.kwargs)
ws_handler.__dict__ = self.__dict__ # Match class variables
ws_handler.run_application()
except (ConnectionAbortedError, ConnectionResetError) as err:
logging.warning("UiWSGIHandler websocket connection error: %s" % err)
except Exception as err:
logging.error("UiWSGIHandler websocket error: %s" % Debug.formatException(err))
self.handleError(err)
else: # Standard HTTP request
try:
super(UiWSGIHandler, self).run_application()
except (ConnectionAbortedError, ConnectionResetError) as err:
logging.warning("UiWSGIHandler connection error: %s" % err)
except Exception as err:
logging.error("UiWSGIHandler error: %s" % Debug.formatException(err))
self.handleError(err)
err_name = "UiWSGIHandler websocket" if "HTTP_UPGRADE" in self.environ else "UiWSGIHandler"
try:
super(UiWSGIHandler, self).run_application()
except (ConnectionAbortedError, ConnectionResetError) as err:
logging.warning("%s connection error: %s" % (err_name, err))
except Exception as err:
logging.warning("%s error: %s" % (err_name, Debug.formatException(err)))
self.handleError(err)
def handle(self):
# Save socket to be able to close them properly on exit

View File

@ -0,0 +1,256 @@
from gevent.pywsgi import WSGIHandler, _InvalidClientInput
from gevent.queue import Queue
import gevent
import hashlib
import base64
import struct
import socket
import time
import sys
SEND_PACKET_SIZE = 1300
OPCODE_TEXT = 1
OPCODE_BINARY = 2
OPCODE_CLOSE = 8
OPCODE_PING = 9
OPCODE_PONG = 10
STATUS_OK = 1000
STATUS_PROTOCOL_ERROR = 1002
STATUS_DATA_ERROR = 1007
STATUS_POLICY_VIOLATION = 1008
STATUS_TOO_LONG = 1009
class WebSocket:
def __init__(self, socket):
self.socket = socket
self.closed = False
self.status = None
self._receive_error = None
self._queue = Queue()
self.max_length = 10 * 1024 * 1024
gevent.spawn(self._listen)
def set_max_message_length(self, length):
self.max_length = length
def _listen(self):
try:
while True:
fin = False
message = bytearray()
is_first_message = True
start_opcode = None
while not fin:
payload, opcode, fin = self._get_frame(max_length=self.max_length - len(message))
# Make sure continuation frames have correct information
if not is_first_message and opcode != 0:
self._error(STATUS_PROTOCOL_ERROR)
if is_first_message:
if opcode not in (OPCODE_TEXT, OPCODE_BINARY):
self._error(STATUS_PROTOCOL_ERROR)
# Save opcode
start_opcode = opcode
message += payload
is_first_message = False
message = bytes(message)
if start_opcode == OPCODE_TEXT: # UTF-8 text
try:
message = message.decode()
except UnicodeDecodeError:
self._error(STATUS_DATA_ERROR)
self._queue.put(message)
except Exception as e:
self.closed = True
self._receive_error = e
self._queue.put(None) # To make sure the error is read
def receive(self):
if not self._queue.empty():
return self.receive_nowait()
if isinstance(self._receive_error, EOFError):
return None
if self._receive_error:
raise self._receive_error
self._queue.peek()
return self.receive_nowait()
def receive_nowait(self):
ret = self._queue.get_nowait()
if self._receive_error and not isinstance(self._receive_error, EOFError):
raise self._receive_error
return ret
def send(self, data):
if self.closed:
raise EOFError()
if isinstance(data, str):
self._send_frame(OPCODE_TEXT, data.encode())
elif isinstance(data, bytes):
self._send_frame(OPCODE_BINARY, data)
else:
raise TypeError("Expected str or bytes, got " + repr(type(data)))
# Reads a frame from the socket. Pings, pongs and close packets are handled
# automatically
def _get_frame(self, max_length):
while True:
payload, opcode, fin = self._read_frame(max_length=max_length)
if opcode == OPCODE_PING:
self._send_frame(OPCODE_PONG, payload)
elif opcode == OPCODE_PONG:
pass
elif opcode == OPCODE_CLOSE:
if len(payload) >= 2:
self.status = struct.unpack("!H", payload[:2])[0]
was_closed = self.closed
self.closed = True
if not was_closed:
# Send a close frame in response
self.close(STATUS_OK)
raise EOFError()
else:
return payload, opcode, fin
# Low-level function, use _get_frame instead
def _read_frame(self, max_length):
header = self._recv_exactly(2)
if not (header[1] & 0x80):
self._error(STATUS_POLICY_VIOLATION)
opcode = header[0] & 0xf
fin = bool(header[0] & 0x80)
payload_length = header[1] & 0x7f
if payload_length == 126:
payload_length = struct.unpack("!H", self._recv_exactly(2))[0]
elif payload_length == 127:
payload_length = struct.unpack("!Q", self._recv_exactly(8))[0]
# Control frames are handled in a special way
if opcode in (OPCODE_PING, OPCODE_PONG):
max_length = 125
if payload_length > max_length:
self._error(STATUS_TOO_LONG)
mask = self._recv_exactly(4)
payload = self._recv_exactly(payload_length)
payload = self._unmask(payload, mask)
return payload, opcode, fin
def _recv_exactly(self, length):
buf = bytearray()
while len(buf) < length:
block = self.socket.recv(min(4096, length - len(buf)))
if block == b"":
raise EOFError()
buf += block
return bytes(buf)
def _unmask(self, payload, mask):
def gen(c):
return bytes([x ^ c for x in range(256)])
payload = bytearray(payload)
payload[0::4] = payload[0::4].translate(gen(mask[0]))
payload[1::4] = payload[1::4].translate(gen(mask[1]))
payload[2::4] = payload[2::4].translate(gen(mask[2]))
payload[3::4] = payload[3::4].translate(gen(mask[3]))
return bytes(payload)
def _send_frame(self, opcode, data):
for i in range(0, len(data), SEND_PACKET_SIZE):
part = data[i:i + SEND_PACKET_SIZE]
fin = int(i == (len(data) - 1) // SEND_PACKET_SIZE * SEND_PACKET_SIZE)
header = bytes(
[
(opcode if i == 0 else 0) | (fin << 7),
min(len(part), 126)
]
)
if len(part) >= 126:
header += struct.pack("!H", len(part))
self.socket.sendall(header + part)
def _error(self, status):
self.close(status)
raise EOFError()
def close(self, status=STATUS_OK):
self.closed = True
self._send_frame(OPCODE_CLOSE, struct.pack("!H", status))
self.socket.close()
class WebSocketHandler(WSGIHandler):
def handle_one_response(self):
self.time_start = time.time()
self.status = None
self.headers_sent = False
self.result = None
self.response_use_chunked = False
self.response_length = 0
http_connection = [s.strip() for s in self.environ.get("HTTP_CONNECTION", "").split(",")]
if "Upgrade" not in http_connection or self.environ.get("HTTP_UPGRADE", "") != "websocket":
# Not my problem
return super(WebSocketHandler, self).handle_one_response()
if "HTTP_SEC_WEBSOCKET_KEY" not in self.environ:
self.start_response("400 Bad Request", [])
return
# Generate Sec-Websocket-Accept header
accept = self.environ["HTTP_SEC_WEBSOCKET_KEY"].encode()
accept += b"258EAFA5-E914-47DA-95CA-C5AB0DC85B11"
accept = base64.b64encode(hashlib.sha1(accept).digest()).decode()
# Accept
self.start_response("101 Switching Protocols", [
("Upgrade", "websocket"),
("Connection", "Upgrade"),
("Sec-Websocket-Accept", accept)
])(b"")
self.environ["wsgi.websocket"] = WebSocket(self.socket)
# Can't call super because it sets invalid flags like "status"
try:
try:
self.run_application()
finally:
try:
self.wsgi_input._discard()
except (socket.error, IOError):
pass
except _InvalidClientInput:
self._send_error_response_if_possible(400)
except socket.error as ex:
if ex.args[0] in self.ignored_socket_errors:
self.close_connection = True
else:
self.handle_error(*sys.exc_info())
except: # pylint:disable=bare-except
self.handle_error(*sys.exc_info())
finally:
self.time_finish = time.time()
self.log_request()