poc-websocket-pgp-json/websocket.py

427 lines
14 KiB
Python
Raw Permalink Normal View History

2021-01-01 12:40:42 +01:00
# References from
# https://tools.ietf.org/html/rfc6455
# https://developer.mozilla.org/en-US/docs/Web/API/WebSockets_API/Writing_WebSocket_servers
# https://gist.github.com/SevenW/47be2f9ab74cac26bf21/ (SevenW/HTTPWebSocketsHandler.py)
import json
from base64 import b64decode, b64encode
2021-01-01 12:40:42 +01:00
from binascii import Error as BinasciiError
from hashlib import sha1
2021-01-01 12:40:42 +01:00
from http.server import BaseHTTPRequestHandler
from io import BufferedIOBase
from time import time
from gpg import Context as GPGContext
from gpg.errors import GPGMEError
2021-01-01 12:40:42 +01:00
_WEBSOCKET_GUID = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"
OPCODE = {
"continueation": 0x0,
"text": 0x1,
"binary": 0x2,
"close": 0x8,
"ping": 0x9,
"pong": 0xA,
}
2021-01-01 12:40:42 +01:00
CONTROL_OPCODES = [OPCODE["close"], OPCODE["ping"], OPCODE["pong"]]
NONCONTROL_OPCODES = [
OPCODE["continueation"],
OPCODE["text"],
OPCODE["binary"],
]
2021-01-01 12:40:42 +01:00
PROTOCOL_VERSION = 13
HTTP_VERSION = "HTTP/1.1"
PGP_JSON_SIGNATURE_VALIDITY = 60
2021-01-01 12:40:42 +01:00
class WebSocketException(Exception):
pass
2021-01-01 12:40:42 +01:00
class FrameError(WebSocketException):
def __init__(self, message: str, frame: dict):
2021-01-01 12:40:42 +01:00
assert "FIN" in frame and frame["FIN"] is not None
assert "RSV1" in frame and frame["RSV1"] is not None
assert "RSV2" in frame and frame["RSV2"] is not None
assert "RSV3" in frame and frame["RSV3"] is not None
assert "opcode" in frame and frame["opcode"] is not None
self.message = message
self.frame = frame
2021-01-01 12:40:42 +01:00
class HandshakeError(WebSocketException):
reason = {
"invalid_header": 0,
"missing_header": 1,
"incompatible_version": 2,
}
2021-01-01 12:40:42 +01:00
def __init__(self, what: str, value: str, why: int):
assert why in self.reason.values()
self.what = what
self.value = value
self.why = why
def get_reason(self, code: int):
for k, v in self.reason.items():
if v == code:
return k
2021-01-01 12:40:42 +01:00
class InvalidLengthError(FrameError):
def __init__(self, length: int):
2021-01-01 12:40:42 +01:00
self.length = length
2021-01-01 12:40:42 +01:00
class WebSocketCloseException(WebSocketException):
def __init__(self, code: int, reason: str):
self.code = code
self.reason = reason
class WebSocketMessage:
2021-01-01 12:40:42 +01:00
def __add__(self, other):
assert isinstance(other, WebSocketMessage)
assert other["opcode"] == OPCODE["continueation"]
if self.opcode == OPCODE["text"]:
2021-01-01 12:40:42 +01:00
# may raise UnicodeError
self.payload += other.payload.decode()
else:
self.payload += other.payload
def __init__(self, opcode, payload):
self.opcode = opcode
if opcode == OPCODE["text"]:
# may raise UnicodeError
self.payload = payload.decode()
else:
self.payload = payload
def __str__(self):
return f"WebSocket Message opcode:0x{self.opcode:x} payload:{self.payload}"
2021-01-01 12:40:42 +01:00
def _check_required_headers(request_handler: BaseHTTPRequestHandler):
if request_handler.request_version != HTTP_VERSION:
request_handler.send_response(400)
request_handler.end_headers()
raise HandshakeError(
"HTTP",
request_handler.request_version,
HandshakeError.reason["incompatible_version"],
)
for header in [
["Host", None],
["Upgrade", "websocket"],
["Connection", "upgrade"],
["Sec-WebSocket-Key", None],
["Sec-WebSocket-Version", None],
]:
2021-01-01 12:40:42 +01:00
value = request_handler.headers.get(header[0])
requisite = (
header[1] in value.lower().split(", ")
if header[1] is not None
else value is not None
)
reason = (
HandshakeError.reason["invalid_header"]
if value is not None
else HandshakeError.reason["missing_header"]
)
2021-01-01 12:40:42 +01:00
if requisite is False:
request_handler.send_response(400)
request_handler.end_headers()
raise HandshakeError(header[0], value, reason)
ws_key = request_handler.headers.get("Sec-WebSocket-Key")
if ws_key is None:
request_handler.send_response(400)
request_handler.end_headers()
raise HandshakeError(
"Sec-WebSocket-Key", None, HandshakeError.reason["missing_header"]
)
2021-01-01 12:40:42 +01:00
try:
invalid_key_error = HandshakeError(
"Sec-WebSocket-Key",
ws_key,
HandshakeError.reason["invalid_header"],
)
2021-01-01 12:40:42 +01:00
decoded = b64decode(s=ws_key, altchars=None, validate=True)
if len(decoded) != 16:
request_handler.send_response(400)
request_handler.end_headers()
raise invalid_key_error
except BinasciiError:
request_handler.send_response(400)
request_handler.end_headers()
raise invalid_key_error
2021-01-01 12:40:42 +01:00
def _decode_payload(masking_key: bytes, encoded_payload: bytes):
decoded_payload = bytearray()
for byte in encoded_payload:
decoded_payload += bytes(
[byte ^ masking_key[len(decoded_payload) % 4]]
)
2021-01-01 12:40:42 +01:00
return decoded_payload
def _encode_data_frame(
fin: int, opcode: int, rsv1: int, rsv2: int, rsv3: int, payload: bytes
):
2021-01-01 12:40:42 +01:00
assert isinstance(payload, bytes)
if opcode not in OPCODE.values():
raise ValueError(
f"Unsupported opcode {opcode}. Valid values {list(OPCODE.values())}."
)
2021-01-01 12:40:42 +01:00
if fin == 0 and opcode in CONTROL_OPCODES:
raise ValueError("Control frames cannot be fragmented")
2021-01-01 12:40:42 +01:00
if opcode in CONTROL_OPCODES and len(payload) > 125:
raise ValueError(
"Control frame cannot have a payload bigger than 125 bytes."
)
2021-01-01 12:40:42 +01:00
payload_length = len(payload)
length_bits = 7
if payload_length > 0x7FFFFFFFFFFFFFFF:
raise ValueError(
f"Payload maximal size exceeded (provided {payload_length} bytes)."
)
elif payload_length > 0xFFFF:
2021-01-01 12:40:42 +01:00
length_bits = 7 + 64
2021-01-08 19:26:38 +01:00
elif payload_length > 125:
length_bits = 7 + 16
2021-01-01 12:40:42 +01:00
frame_size = int(1 + (1 + length_bits) / 8 + payload_length)
frame = bytearray(frame_size)
frame[0] = (fin << 7) + (rsv1 << 6) + (rsv2 << 5) + (rsv3 << 4) + opcode
if length_bits == 7:
frame[1] = payload_length
frame[2:] = payload
elif length_bits == 7 + 16:
frame[1] = 126
frame[2:2] = payload_length.to_bytes(2, byteorder="big")
frame[4:] = payload
else:
frame[1] = 127
frame[2:8] = payload_length.to_bytes(8, byteorder="big")
frame[11:] = payload
return frame
2021-01-01 12:40:42 +01:00
def _handle_control_frame(wfile: BufferedIOBase, frame: dict):
if frame["opcode"] == OPCODE["close"]:
code = frame["status_code"] if "status_code" in frame else None
payload = (
code.to_bytes(2, byteorder="big") if code is not None else b""
)
2021-01-01 12:40:42 +01:00
reason = frame["close_reason"] if "close_reason" in frame else "-"
send_message(wfile, OPCODE["close"], payload)
raise WebSocketCloseException(code, reason)
elif frame["opcode"] == OPCODE["ping"]:
payload = frame["payload"] if "payload" in frame else b""
send_message(wfile, OPCODE["pong"], payload)
2021-01-01 12:40:42 +01:00
def _read_data_frame(rfile: BufferedIOBase):
frame = {}
net_bytes = ord(rfile.read(1))
frame["FIN"] = net_bytes >> 7
frame["RSV1"] = (net_bytes & 0x40) >> 6
frame["RSV2"] = (net_bytes & 0x20) >> 5
frame["RSV3"] = (net_bytes & 0x10) >> 4
frame["opcode"] = net_bytes & 0x0F
2021-01-01 12:40:42 +01:00
if frame["RSV1"] != 0 or frame["RSV2"] != 0 or frame["RSV3"] != 0:
raise FrameError(
"Unsupported feature. RSV1, RSV2 or RSV3 has a non-zero value.",
frame,
)
2021-01-01 12:40:42 +01:00
if not frame["opcode"] in OPCODE.values():
raise FrameError("Unsupported opcode value.", frame)
if frame["FIN"] == 0 and frame["opcode"] != OPCODE["continueation"]:
raise FrameError(
"FIN bit not set for a non-continueation frame.", frame
)
2021-01-01 12:40:42 +01:00
if frame["opcode"] in CONTROL_OPCODES and frame["FIN"] == 0:
raise FrameError("FIN bit not set for a control frame.", frame)
net_bytes = ord(rfile.read(1))
mask_bit = net_bytes >> 7
if mask_bit == 0:
raise FrameError("Unmasked frame from client.", frame)
length1 = net_bytes & 0x7F
2021-01-01 12:40:42 +01:00
if frame["opcode"] in CONTROL_OPCODES and length1 > 125:
raise FrameError("Control frame with invalid payload length.", frame)
try:
length = _read_payload_length(length1, rfile)
except InvalidLengthError as error:
raise FrameError(
f"Invalid payload length of {error.length} bytes.", frame
)
2021-01-01 12:40:42 +01:00
masking_key = rfile.read(4)
encoded_payload = rfile.read(length)
frame["payload"] = _decode_payload(masking_key, encoded_payload)
if frame["opcode"] == OPCODE["close"] and frame["payload"]:
frame["status_code"] = int.from_bytes(
frame["payload"][0:2], byteorder="big"
)
2021-01-01 12:40:42 +01:00
if length > 2:
# /!\ may raise UnicodeError /!\
frame["close_reason"] = frame["payload"][2:].decode()
return frame
2021-01-01 12:40:42 +01:00
def _read_payload_length(payload_length1: int, rfile: BufferedIOBase):
final_length = payload_length1
if payload_length1 == 126:
final_length = int.from_bytes(rfile.read(2), byteorder="big")
elif payload_length1 == 127:
final_length = int.from_bytes(rfile.read(8), byteorder="big")
if final_length >> 63 == 1:
raise InvalidLengthError(final_length)
return final_length
2021-01-01 12:40:42 +01:00
def close(wfile: BufferedIOBase, code=1000, reason=None):
code_bytes = code.to_bytes(2, byteorder="big")
payload = code_bytes if reason is None else code_bytes + reason.encode()
frame = _encode_data_frame(1, OPCODE["close"], 0, 0, 0, payload)
wfile.write(frame)
def decrypt_pgp_json(
wfile: BufferedIOBase,
gpg_context: GPGContext,
message: WebSocketMessage,
signature_key=None,
):
2021-01-01 12:40:42 +01:00
if message.opcode != OPCODE["text"]:
close(wfile, 1003, "Only text datatype is allowed.")
raise WebSocketCloseException(1003, "Received no-text datatype.")
try:
plaintext, result, verify_result = gpg_context.decrypt(
message.payload.encode()
)
2021-01-01 12:40:42 +01:00
except GPGMEError as error:
close(wfile, 1002, "Failed to decrypt message.")
raise WebSocketCloseException(
1002, f"Failed to decrypt websocket message: {error}."
)
2021-01-01 12:40:42 +01:00
if len(verify_result.signatures) == 0:
close(wfile, 1008, "No signature recognized.")
raise WebSocketCloseException(
1008, "Message with missing or unknown signature."
)
if (
signature_key is not None
and signature_key.fpr != verify_result.signatures[0].fpr
):
close(wfile, 1008, "Message signature check failed.")
raise WebSocketCloseException(1008, "Signing keys doesn't match.")
2021-01-01 12:40:42 +01:00
signature_age = int(time()) - verify_result.signatures[0].timestamp
if signature_age > PGP_JSON_SIGNATURE_VALIDITY:
close(wfile, 1008, "Signature too old.")
raise WebSocketCloseException(
1008, f"Message with a {signature_age}s old signature."
)
2021-01-01 12:40:42 +01:00
try:
json_ = json.loads(plaintext)
except json.JSONDecodeError:
close(wfile, 1002, "Invalid JSON data.")
raise WebSocketCloseException(1002, "Invalid JSON data.")
return json_, verify_result.signatures
def encrypt_pgp_json(obj: dict, recipient, gpg_context: GPGContext):
2021-01-01 12:40:42 +01:00
json_ = json.dumps(obj).encode()
ciphertext, result, sign_result = gpg_context.encrypt(
json_, recipients=[recipient], sign=True, always_trust=True
)
2021-01-01 12:40:42 +01:00
return ciphertext
2021-01-01 12:40:42 +01:00
def handshake(request_handler: BaseHTTPRequestHandler, subprotocols=[]):
request_handler.protocol_version = HTTP_VERSION
_check_required_headers(request_handler)
websocket_key = request_handler.headers.get("Sec-WebSocket-Key")
digest = b64encode(
sha1((websocket_key + _WEBSOCKET_GUID).encode()).digest()
)
2021-01-01 12:40:42 +01:00
websocket_version = request_handler.headers.get("Sec-WebSocket-Version")
if int(websocket_version) != PROTOCOL_VERSION:
request_handler.send_response(400)
request_handler.send_header("Sec-WebSocket-Version", PROTOCOL_VERSION)
request_handler.end_headers()
raise HandshakeError(
"Sec-WebSocket-Version",
websocket_version,
HandshakeError.reason["incompatible_version"],
)
2021-01-01 12:40:42 +01:00
selected_subprotocol = None
requested_subprotocols = request_handler.headers.get(
"Sec-WebSocket-Protocol"
)
2021-01-01 12:40:42 +01:00
if requested_subprotocols:
requested_subprotocols = requested_subprotocols.split(",")
for proto in requested_subprotocols:
for serv_proto in subprotocols:
if proto == serv_proto:
selected_subprotocol = proto
request_handler.send_response(101)
request_handler.send_header("Upgrade", "websocket")
request_handler.send_header("Connection", "Upgrade")
request_handler.send_header("Sec-WebSocket-Accept", digest.decode())
if requested_subprotocols and selected_subprotocol is not None:
request_handler.send_header(
"Sec-WebSocket-Protocol", selected_subprotocol
)
2021-01-01 12:40:42 +01:00
request_handler.end_headers()
return selected_subprotocol
2021-01-01 12:40:42 +01:00
def read_next_message(rfile: BufferedIOBase, wfile: BufferedIOBase):
frame = _read_data_frame(rfile)
message = WebSocketMessage(frame["opcode"], frame["payload"])
if frame["FIN"] == 1:
if frame["opcode"] in NONCONTROL_OPCODES:
return message
else:
_handle_control_frame(wfile, frame)
return read_next_message(rfile, wfile)
else:
return message + read_next_message(rfile, wfile)
def send_message(
wfile: BufferedIOBase,
payload: bytes,
opcode=OPCODE["text"],
rsv1=0,
rsv2=0,
rsv3=0,
):
if len(payload) > 0x7FFFFFFFFFFFFFFF:
raise ValueError(
"Payload to big. Sending fragmented messages not implemented."
)
2021-01-01 12:40:42 +01:00
frame = _encode_data_frame(1, opcode, rsv1, rsv2, rsv3, payload)
wfile.write(frame)