427 lines
14 KiB
Python
427 lines
14 KiB
Python
# References from
|
|
# https://tools.ietf.org/html/rfc6455
|
|
# https://developer.mozilla.org/en-US/docs/Web/API/WebSockets_API/Writing_WebSocket_servers
|
|
# https://gist.github.com/SevenW/47be2f9ab74cac26bf21/ (SevenW/HTTPWebSocketsHandler.py)
|
|
|
|
import json
|
|
from base64 import b64decode, b64encode
|
|
from binascii import Error as BinasciiError
|
|
from hashlib import sha1
|
|
from http.server import BaseHTTPRequestHandler
|
|
from io import BufferedIOBase
|
|
from time import time
|
|
|
|
from gpg import Context as GPGContext
|
|
from gpg.errors import GPGMEError
|
|
|
|
_WEBSOCKET_GUID = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"
|
|
OPCODE = {
|
|
"continueation": 0x0,
|
|
"text": 0x1,
|
|
"binary": 0x2,
|
|
"close": 0x8,
|
|
"ping": 0x9,
|
|
"pong": 0xA,
|
|
}
|
|
CONTROL_OPCODES = [OPCODE["close"], OPCODE["ping"], OPCODE["pong"]]
|
|
NONCONTROL_OPCODES = [
|
|
OPCODE["continueation"],
|
|
OPCODE["text"],
|
|
OPCODE["binary"],
|
|
]
|
|
PROTOCOL_VERSION = 13
|
|
HTTP_VERSION = "HTTP/1.1"
|
|
|
|
PGP_JSON_SIGNATURE_VALIDITY = 60
|
|
|
|
|
|
class WebSocketException(Exception):
|
|
pass
|
|
|
|
|
|
class FrameError(WebSocketException):
|
|
def __init__(self, message: str, frame: dict):
|
|
assert "FIN" in frame and frame["FIN"] is not None
|
|
assert "RSV1" in frame and frame["RSV1"] is not None
|
|
assert "RSV2" in frame and frame["RSV2"] is not None
|
|
assert "RSV3" in frame and frame["RSV3"] is not None
|
|
assert "opcode" in frame and frame["opcode"] is not None
|
|
self.message = message
|
|
self.frame = frame
|
|
|
|
|
|
class HandshakeError(WebSocketException):
|
|
reason = {
|
|
"invalid_header": 0,
|
|
"missing_header": 1,
|
|
"incompatible_version": 2,
|
|
}
|
|
|
|
def __init__(self, what: str, value: str, why: int):
|
|
assert why in self.reason.values()
|
|
self.what = what
|
|
self.value = value
|
|
self.why = why
|
|
|
|
def get_reason(self, code: int):
|
|
for k, v in self.reason.items():
|
|
if v == code:
|
|
return k
|
|
|
|
|
|
class InvalidLengthError(FrameError):
|
|
def __init__(self, length: int):
|
|
self.length = length
|
|
|
|
|
|
class WebSocketCloseException(WebSocketException):
|
|
def __init__(self, code: int, reason: str):
|
|
self.code = code
|
|
self.reason = reason
|
|
|
|
|
|
class WebSocketMessage:
|
|
def __add__(self, other):
|
|
assert isinstance(other, WebSocketMessage)
|
|
assert other["opcode"] == OPCODE["continueation"]
|
|
if self.opcode == OPCODE["text"]:
|
|
# may raise UnicodeError
|
|
self.payload += other.payload.decode()
|
|
else:
|
|
self.payload += other.payload
|
|
|
|
def __init__(self, opcode, payload):
|
|
self.opcode = opcode
|
|
if opcode == OPCODE["text"]:
|
|
# may raise UnicodeError
|
|
self.payload = payload.decode()
|
|
else:
|
|
self.payload = payload
|
|
|
|
def __str__(self):
|
|
return f"WebSocket Message opcode:0x{self.opcode:x} payload:{self.payload}"
|
|
|
|
|
|
def _check_required_headers(request_handler: BaseHTTPRequestHandler):
|
|
if request_handler.request_version != HTTP_VERSION:
|
|
request_handler.send_response(400)
|
|
request_handler.end_headers()
|
|
raise HandshakeError(
|
|
"HTTP",
|
|
request_handler.request_version,
|
|
HandshakeError.reason["incompatible_version"],
|
|
)
|
|
|
|
for header in [
|
|
["Host", None],
|
|
["Upgrade", "websocket"],
|
|
["Connection", "upgrade"],
|
|
["Sec-WebSocket-Key", None],
|
|
["Sec-WebSocket-Version", None],
|
|
]:
|
|
value = request_handler.headers.get(header[0])
|
|
requisite = (
|
|
header[1] in value.lower().split(", ")
|
|
if header[1] is not None
|
|
else value is not None
|
|
)
|
|
reason = (
|
|
HandshakeError.reason["invalid_header"]
|
|
if value is not None
|
|
else HandshakeError.reason["missing_header"]
|
|
)
|
|
if requisite is False:
|
|
request_handler.send_response(400)
|
|
request_handler.end_headers()
|
|
raise HandshakeError(header[0], value, reason)
|
|
|
|
ws_key = request_handler.headers.get("Sec-WebSocket-Key")
|
|
if ws_key is None:
|
|
request_handler.send_response(400)
|
|
request_handler.end_headers()
|
|
raise HandshakeError(
|
|
"Sec-WebSocket-Key", None, HandshakeError.reason["missing_header"]
|
|
)
|
|
try:
|
|
invalid_key_error = HandshakeError(
|
|
"Sec-WebSocket-Key",
|
|
ws_key,
|
|
HandshakeError.reason["invalid_header"],
|
|
)
|
|
decoded = b64decode(s=ws_key, altchars=None, validate=True)
|
|
if len(decoded) != 16:
|
|
request_handler.send_response(400)
|
|
request_handler.end_headers()
|
|
raise invalid_key_error
|
|
except BinasciiError:
|
|
request_handler.send_response(400)
|
|
request_handler.end_headers()
|
|
raise invalid_key_error
|
|
|
|
|
|
def _decode_payload(masking_key: bytes, encoded_payload: bytes):
|
|
decoded_payload = bytearray()
|
|
for byte in encoded_payload:
|
|
decoded_payload += bytes(
|
|
[byte ^ masking_key[len(decoded_payload) % 4]]
|
|
)
|
|
return decoded_payload
|
|
|
|
|
|
def _encode_data_frame(
|
|
fin: int, opcode: int, rsv1: int, rsv2: int, rsv3: int, payload: bytes
|
|
):
|
|
assert isinstance(payload, bytes)
|
|
if opcode not in OPCODE.values():
|
|
raise ValueError(
|
|
f"Unsupported opcode {opcode}. Valid values {list(OPCODE.values())}."
|
|
)
|
|
if fin == 0 and opcode in CONTROL_OPCODES:
|
|
raise ValueError("Control frames cannot be fragmented")
|
|
if opcode in CONTROL_OPCODES and len(payload) > 125:
|
|
raise ValueError(
|
|
"Control frame cannot have a payload bigger than 125 bytes."
|
|
)
|
|
payload_length = len(payload)
|
|
length_bits = 7
|
|
if payload_length > 0x7FFFFFFFFFFFFFFF:
|
|
raise ValueError(
|
|
f"Payload maximal size exceeded (provided {payload_length} bytes)."
|
|
)
|
|
elif payload_length > 0xFFFF:
|
|
length_bits = 7 + 64
|
|
elif payload_length > 125:
|
|
length_bits = 7 + 16
|
|
frame_size = int(1 + (1 + length_bits) / 8 + payload_length)
|
|
frame = bytearray(frame_size)
|
|
frame[0] = (fin << 7) + (rsv1 << 6) + (rsv2 << 5) + (rsv3 << 4) + opcode
|
|
if length_bits == 7:
|
|
frame[1] = payload_length
|
|
frame[2:] = payload
|
|
elif length_bits == 7 + 16:
|
|
frame[1] = 126
|
|
frame[2:2] = payload_length.to_bytes(2, byteorder="big")
|
|
frame[4:] = payload
|
|
else:
|
|
frame[1] = 127
|
|
frame[2:8] = payload_length.to_bytes(8, byteorder="big")
|
|
frame[11:] = payload
|
|
|
|
return frame
|
|
|
|
|
|
def _handle_control_frame(wfile: BufferedIOBase, frame: dict):
|
|
if frame["opcode"] == OPCODE["close"]:
|
|
code = frame["status_code"] if "status_code" in frame else None
|
|
payload = (
|
|
code.to_bytes(2, byteorder="big") if code is not None else b""
|
|
)
|
|
reason = frame["close_reason"] if "close_reason" in frame else "-"
|
|
send_message(wfile, OPCODE["close"], payload)
|
|
raise WebSocketCloseException(code, reason)
|
|
elif frame["opcode"] == OPCODE["ping"]:
|
|
payload = frame["payload"] if "payload" in frame else b""
|
|
send_message(wfile, OPCODE["pong"], payload)
|
|
|
|
|
|
def _read_data_frame(rfile: BufferedIOBase):
|
|
frame = {}
|
|
net_bytes = ord(rfile.read(1))
|
|
frame["FIN"] = net_bytes >> 7
|
|
frame["RSV1"] = (net_bytes & 0x40) >> 6
|
|
frame["RSV2"] = (net_bytes & 0x20) >> 5
|
|
frame["RSV3"] = (net_bytes & 0x10) >> 4
|
|
frame["opcode"] = net_bytes & 0x0F
|
|
|
|
if frame["RSV1"] != 0 or frame["RSV2"] != 0 or frame["RSV3"] != 0:
|
|
raise FrameError(
|
|
"Unsupported feature. RSV1, RSV2 or RSV3 has a non-zero value.",
|
|
frame,
|
|
)
|
|
|
|
if not frame["opcode"] in OPCODE.values():
|
|
raise FrameError("Unsupported opcode value.", frame)
|
|
|
|
if frame["FIN"] == 0 and frame["opcode"] != OPCODE["continueation"]:
|
|
raise FrameError(
|
|
"FIN bit not set for a non-continueation frame.", frame
|
|
)
|
|
|
|
if frame["opcode"] in CONTROL_OPCODES and frame["FIN"] == 0:
|
|
raise FrameError("FIN bit not set for a control frame.", frame)
|
|
|
|
net_bytes = ord(rfile.read(1))
|
|
mask_bit = net_bytes >> 7
|
|
|
|
if mask_bit == 0:
|
|
raise FrameError("Unmasked frame from client.", frame)
|
|
|
|
length1 = net_bytes & 0x7F
|
|
|
|
if frame["opcode"] in CONTROL_OPCODES and length1 > 125:
|
|
raise FrameError("Control frame with invalid payload length.", frame)
|
|
|
|
try:
|
|
length = _read_payload_length(length1, rfile)
|
|
except InvalidLengthError as error:
|
|
raise FrameError(
|
|
f"Invalid payload length of {error.length} bytes.", frame
|
|
)
|
|
|
|
masking_key = rfile.read(4)
|
|
encoded_payload = rfile.read(length)
|
|
frame["payload"] = _decode_payload(masking_key, encoded_payload)
|
|
|
|
if frame["opcode"] == OPCODE["close"] and frame["payload"]:
|
|
frame["status_code"] = int.from_bytes(
|
|
frame["payload"][0:2], byteorder="big"
|
|
)
|
|
if length > 2:
|
|
# /!\ may raise UnicodeError /!\
|
|
frame["close_reason"] = frame["payload"][2:].decode()
|
|
|
|
return frame
|
|
|
|
|
|
def _read_payload_length(payload_length1: int, rfile: BufferedIOBase):
|
|
final_length = payload_length1
|
|
if payload_length1 == 126:
|
|
final_length = int.from_bytes(rfile.read(2), byteorder="big")
|
|
elif payload_length1 == 127:
|
|
final_length = int.from_bytes(rfile.read(8), byteorder="big")
|
|
if final_length >> 63 == 1:
|
|
raise InvalidLengthError(final_length)
|
|
return final_length
|
|
|
|
|
|
def close(wfile: BufferedIOBase, code=1000, reason=None):
|
|
code_bytes = code.to_bytes(2, byteorder="big")
|
|
payload = code_bytes if reason is None else code_bytes + reason.encode()
|
|
frame = _encode_data_frame(1, OPCODE["close"], 0, 0, 0, payload)
|
|
wfile.write(frame)
|
|
|
|
|
|
def decrypt_pgp_json(
|
|
wfile: BufferedIOBase,
|
|
gpg_context: GPGContext,
|
|
message: WebSocketMessage,
|
|
signature_key=None,
|
|
):
|
|
if message.opcode != OPCODE["text"]:
|
|
close(wfile, 1003, "Only text datatype is allowed.")
|
|
raise WebSocketCloseException(1003, "Received no-text datatype.")
|
|
try:
|
|
plaintext, result, verify_result = gpg_context.decrypt(
|
|
message.payload.encode()
|
|
)
|
|
except GPGMEError as error:
|
|
close(wfile, 1002, "Failed to decrypt message.")
|
|
raise WebSocketCloseException(
|
|
1002, f"Failed to decrypt websocket message: {error}."
|
|
)
|
|
|
|
if len(verify_result.signatures) == 0:
|
|
close(wfile, 1008, "No signature recognized.")
|
|
raise WebSocketCloseException(
|
|
1008, "Message with missing or unknown signature."
|
|
)
|
|
if (
|
|
signature_key is not None
|
|
and signature_key.fpr != verify_result.signatures[0].fpr
|
|
):
|
|
close(wfile, 1008, "Message signature check failed.")
|
|
raise WebSocketCloseException(1008, "Signing keys doesn't match.")
|
|
|
|
signature_age = int(time()) - verify_result.signatures[0].timestamp
|
|
if signature_age > PGP_JSON_SIGNATURE_VALIDITY:
|
|
close(wfile, 1008, "Signature too old.")
|
|
raise WebSocketCloseException(
|
|
1008, f"Message with a {signature_age}s old signature."
|
|
)
|
|
|
|
try:
|
|
json_ = json.loads(plaintext)
|
|
except json.JSONDecodeError:
|
|
close(wfile, 1002, "Invalid JSON data.")
|
|
raise WebSocketCloseException(1002, "Invalid JSON data.")
|
|
|
|
return json_, verify_result.signatures
|
|
|
|
|
|
def encrypt_pgp_json(obj: dict, recipient, gpg_context: GPGContext):
|
|
json_ = json.dumps(obj).encode()
|
|
ciphertext, result, sign_result = gpg_context.encrypt(
|
|
json_, recipients=[recipient], sign=True, always_trust=True
|
|
)
|
|
return ciphertext
|
|
|
|
|
|
def handshake(request_handler: BaseHTTPRequestHandler, subprotocols=[]):
|
|
request_handler.protocol_version = HTTP_VERSION
|
|
_check_required_headers(request_handler)
|
|
websocket_key = request_handler.headers.get("Sec-WebSocket-Key")
|
|
digest = b64encode(
|
|
sha1((websocket_key + _WEBSOCKET_GUID).encode()).digest()
|
|
)
|
|
websocket_version = request_handler.headers.get("Sec-WebSocket-Version")
|
|
if int(websocket_version) != PROTOCOL_VERSION:
|
|
request_handler.send_response(400)
|
|
request_handler.send_header("Sec-WebSocket-Version", PROTOCOL_VERSION)
|
|
request_handler.end_headers()
|
|
raise HandshakeError(
|
|
"Sec-WebSocket-Version",
|
|
websocket_version,
|
|
HandshakeError.reason["incompatible_version"],
|
|
)
|
|
selected_subprotocol = None
|
|
requested_subprotocols = request_handler.headers.get(
|
|
"Sec-WebSocket-Protocol"
|
|
)
|
|
if requested_subprotocols:
|
|
requested_subprotocols = requested_subprotocols.split(",")
|
|
for proto in requested_subprotocols:
|
|
for serv_proto in subprotocols:
|
|
if proto == serv_proto:
|
|
selected_subprotocol = proto
|
|
|
|
request_handler.send_response(101)
|
|
request_handler.send_header("Upgrade", "websocket")
|
|
request_handler.send_header("Connection", "Upgrade")
|
|
request_handler.send_header("Sec-WebSocket-Accept", digest.decode())
|
|
if requested_subprotocols and selected_subprotocol is not None:
|
|
request_handler.send_header(
|
|
"Sec-WebSocket-Protocol", selected_subprotocol
|
|
)
|
|
request_handler.end_headers()
|
|
|
|
return selected_subprotocol
|
|
|
|
|
|
def read_next_message(rfile: BufferedIOBase, wfile: BufferedIOBase):
|
|
frame = _read_data_frame(rfile)
|
|
message = WebSocketMessage(frame["opcode"], frame["payload"])
|
|
if frame["FIN"] == 1:
|
|
if frame["opcode"] in NONCONTROL_OPCODES:
|
|
return message
|
|
else:
|
|
_handle_control_frame(wfile, frame)
|
|
return read_next_message(rfile, wfile)
|
|
else:
|
|
return message + read_next_message(rfile, wfile)
|
|
|
|
|
|
def send_message(
|
|
wfile: BufferedIOBase,
|
|
payload: bytes,
|
|
opcode=OPCODE["text"],
|
|
rsv1=0,
|
|
rsv2=0,
|
|
rsv3=0,
|
|
):
|
|
if len(payload) > 0x7FFFFFFFFFFFFFFF:
|
|
raise ValueError(
|
|
"Payload to big. Sending fragmented messages not implemented."
|
|
)
|
|
frame = _encode_data_frame(1, opcode, rsv1, rsv2, rsv3, payload)
|
|
wfile.write(frame)
|