diff --git a/server.py b/server.py index c0aafad..2bd2c08 100755 --- a/server.py +++ b/server.py @@ -1,43 +1,51 @@ #!/usr/bin/env python3 -import gpg import json import os import re -from gpg.gpgme import GPG_ERR_NO_ERROR, GPGME_DELETE_FORCE, gpgme_op_delete_ext from http.server import BaseHTTPRequestHandler, HTTPServer from secrets import token_urlsafe -from sys import argv, exit as sysexit +from sys import argv +from sys import exit as sysexit from tempfile import TemporaryDirectory +import gpg +from gpg.gpgme import GPG_ERR_NO_ERROR, GPGME_DELETE_FORCE, gpgme_op_delete_ext + import websocket from database import DataBase NONCE_BYTES = 128 ADMIN_ACCESS = 100 + class JsonMissingFieldException(Exception): def __init__(self, missing): self.missing = missing - + class RequestHandler(BaseHTTPRequestHandler): - _RE_FILES = re.compile(r"^(/|/admin\.html|/data\.html|/forms\.html|/config\.js|/public\.js|/registered\.js|/openpgp\.min\.js)$") + _RE_FILES = re.compile( + r"^(/|/admin\.html|/data\.html|/forms\.html|/config\.js|/public\.js|/registered\.js|/openpgp\.min\.js)$" + ) _RE_FORM_NAME = re.compile(r"^([-_0-9a-zA-Z]{1,64})$") _RE_KEY_FINGERPRINT = re.compile(r"^([0-9a-fA-F]{40})$") _RE_PGP_JSON_REQUEST_FIELD = re.compile(r"^[0-9a-zA-Z_/]{3,100}$") _VALID_CONTENT_SUBTYPES = {"html", "javascript", "json", "plain"} def _api_answer_pgp_json(self, payload): - request_answer = websocket.encrypt_pgp_json({ - "request": self._pgpjson_pending_request, - "payload": payload - }, self._pgpjson_client_key, self._gpg_context) + request_answer = websocket.encrypt_pgp_json( + {"request": self._pgpjson_pending_request, "payload": payload}, + self._pgpjson_client_key, + self._gpg_context, + ) websocket.send_message(self.wfile, request_answer) self._close_websocket() def _api_collect_data(self): - form_name = re.match(self._RE_FORM_NAME, self.path.replace("/post/", "")) + form_name = re.match( + self._RE_FORM_NAME, self.path.replace("/post/", "") + ) if form_name: secret = self._read().decode() self.server.db.store_form_data(form_name.group(1), secret) @@ -45,11 +53,17 @@ class RequestHandler(BaseHTTPRequestHandler): self.end_headers() def _api_get_collected_data(self): - form_name = re.match(self._RE_FORM_NAME, self._pgpjson_pending_request.replace("data/get/", "")) + form_name = re.match( + self._RE_FORM_NAME, + self._pgpjson_pending_request.replace("data/get/", ""), + ) if form_name: form_name = form_name.group(1) form_users = self.server.db.get_form_keys_fingerprints(form_name) - if self._pgpjson_user["fingerprint"] not in form_users["fingerprints"]: + if ( + self._pgpjson_user["fingerprint"] + not in form_users["fingerprints"] + ): self._close_websocket(1008, "Access denied.") else: collected_data = self.server.db.get_collected_data(form_name) @@ -62,18 +76,29 @@ class RequestHandler(BaseHTTPRequestHandler): self._close_websocket(1008, "Access denied.") return - form_name = re.match(self._RE_FORM_NAME, self._pgpjson_pending_request.replace("form/get/", "")) + form_name = re.match( + self._RE_FORM_NAME, + self._pgpjson_pending_request.replace("form/get/", ""), + ) if form_name: - form_key_list = self.server.db.get_form_keys_fingerprints(form_name.group(1)) + form_key_list = self.server.db.get_form_keys_fingerprints( + form_name.group(1) + ) self._api_answer_pgp_json(form_key_list) else: self._close_websocket(1002, "Invalid form name.") def _api_get_user(self): - user_fingerprint = re.match(self._RE_KEY_FINGERPRINT, self._pgpjson_pending_request.replace("user/", "")) + user_fingerprint = re.match( + self._RE_KEY_FINGERPRINT, + self._pgpjson_pending_request.replace("user/", ""), + ) if user_fingerprint: fpr = user_fingerprint.group(1) - if self._pgpjson_user["access_level"] == ADMIN_ACCESS or self._pgpjson_user["fingerprint"] == fpr: + if ( + self._pgpjson_user["access_level"] == ADMIN_ACCESS + or self._pgpjson_user["fingerprint"] == fpr + ): user = self.server.db.get_user(fpr) self._api_answer_pgp_json(user) else: @@ -89,15 +114,27 @@ class RequestHandler(BaseHTTPRequestHandler): def _api_get_user_key(self, fingerprints=None): keys = list() - if fingerprints == None: + if fingerprints is None: for user in self.server.db.get_user(): - keys.append({ "fingerprint": user["fingerprint"], - "armored_key": (self._gpg_context.key_export(user["fingerprint"])).decode()}) + keys.append( + { + "fingerprint": user["fingerprint"], + "armored_key": ( + self._gpg_context.key_export(user["fingerprint"]) + ).decode(), + } + ) else: for fpr in fingerprints: if self.server.db.get_user(fpr) is not None: - keys.append({ "fingerprint": fpr, - "armored_key": (self._gpg_context.key_export(fpr)).decode()}) + keys.append( + { + "fingerprint": fpr, + "armored_key": ( + self._gpg_context.key_export(fpr) + ).decode(), + } + ) else: self._close_websocket(1008, "Asking for an unknown key.") return @@ -106,14 +143,21 @@ class RequestHandler(BaseHTTPRequestHandler): def _api_register_user(self): try: req = self._read_json(["pubKey"]) - except JsonMissingFieldException: - self._api_register_user_answer(400, f"Missing {missingField.missing}.") + except JsonMissingFieldException as missing_field: + self._api_register_user_answer( + 400, f"Missing {missing_field.missing}." + ) return results = self.server.gpg_context.key_import(req["pubKey"].encode()) if results == "IMPORT_PROBLEM" or not results.considered: self._api_register_user_answer(400, "Invalid pubKey.") elif results.imported: - access_level = ADMIN_ACCESS if len(self.server.db.get_user()) == 0 else 0 # first user get admin access level + access_level = ( + ADMIN_ACCESS + if len(self.server.db.get_user()) + == 0 # first user get admin access level + else 0 + ) self.server.db.add_user(results.imports[0].fpr, access_level) print(f"Imported key {results.imports[0].fpr}.") self._api_register_user_answer(201, "Key registered.") @@ -121,20 +165,25 @@ class RequestHandler(BaseHTTPRequestHandler): self._api_register_user_answer(200, "Key already on server.") else: self._api_register_user_answer(418, "What happened?") - print(f"Tried to add following pubKey and got strange results:\n{req[pubKey]}\n\n{results}") + print( + f"Tried to add following pubKey and got strange results:\n{req['pubKey']}\n\n{results}" + ) def _api_register_user_answer(self, code: int, message: str): self.send_response(code) self._set_content_type("json") self.end_headers() - self.wfile.write(json.dumps({ "message": message }).encode()) + self.wfile.write(json.dumps({"message": message}).encode()) def _api_set_collected_data(self): if self._pgpjson_user["access_level"] != ADMIN_ACCESS: self._close_websocket(1008, "Access denied.") return - form_name = re.match(self._RE_FORM_NAME, self._pgpjson_pending_request.replace("data/set/", "")) + form_name = re.match( + self._RE_FORM_NAME, + self._pgpjson_pending_request.replace("data/set/", ""), + ) if form_name: if self._pgpjson_payload is None: self._close_websocket(1002, "Missing payload.") @@ -143,10 +192,14 @@ class RequestHandler(BaseHTTPRequestHandler): if "secret" not in data or "id" not in data: self._close_websocket(1002, "Invalid payload") return - elif not data["secret"].startswith("-----BEGIN PGP MESSAGE-----"): + elif not data["secret"].startswith( + "-----BEGIN PGP MESSAGE-----" + ): self._close_websocket(1002, "Invalid PGP message") return - self.server.db.set_collected_data(form_name.group(1), self._pgpjson_payload) + self.server.db.set_collected_data( + form_name.group(1), self._pgpjson_payload + ) self._close_websocket() else: self._close_websocket(1002, "Invalid form name.") @@ -156,10 +209,19 @@ class RequestHandler(BaseHTTPRequestHandler): self._close_websocket(1008, "Access denied.") return - form_name = re.match(self._RE_FORM_NAME, self._pgpjson_pending_request.replace("form/set/", "")) + form_name = re.match( + self._RE_FORM_NAME, + self._pgpjson_pending_request.replace("form/set/", ""), + ) if form_name: - if self._pgpjson_payload is not None and "fingerprints" in self._pgpjson_payload and isinstance(self._pgpjson_payload["fingerprints"], list): - self.server.db.set_form_keys_fingerprints(form_name.group(1), self._pgpjson_payload["fingerprints"]) + if ( + self._pgpjson_payload is not None + and "fingerprints" in self._pgpjson_payload + and isinstance(self._pgpjson_payload["fingerprints"], list) + ): + self.server.db.set_form_keys_fingerprints( + form_name.group(1), self._pgpjson_payload["fingerprints"] + ) self._close_websocket() else: self._close_websocket(1002, "Invalid request payload.") @@ -170,29 +232,43 @@ class RequestHandler(BaseHTTPRequestHandler): if self._pgpjson_user["access_level"] == ADMIN_ACCESS: self._close_websocket(1008, "Cannot delete admin account.") return - result = gpgme_op_delete_ext(self._gpg_context.wrapped, self._pgpjson_client_key, GPGME_DELETE_FORCE) + result = gpgme_op_delete_ext( + self._gpg_context.wrapped, + self._pgpjson_client_key, + GPGME_DELETE_FORCE, + ) if result == GPG_ERR_NO_ERROR: self.server.db.delete_user(self._pgpjson_client_key.fpr) print(f"Key {self._pgpjson_client_key.fpr} deleted.") self._close_websocket() else: - print(f"Failed to delete key {self._pgpjson_client_key.fpr}, status code: {result}.") + print( + f"Failed to delete key {self._pgpjson_client_key.fpr}, status code: {result}." + ) self._close_websocket(1011, "Failed to delete key.") def _close_websocket(self, code=1000, reason=None): websocket.close(self.wfile, code, reason) self._websocket_connected = False - print(f"Closed WebSocket connection with {self.client_address[0]}:{self.client_address[1]}") + print( + f"Closed WebSocket connection with {self.client_address[0]}:{self.client_address[1]}" + ) if code != 1000: print(f"Code: {code} reason: {reason}") - def _handle_api_request(self, message: websocket.WebSocketMessage): - signature_key = self._pgpjson_client_key if self._pgpjson_client_key is not None else None - json_, signature = websocket.decrypt_pgp_json(self.wfile, self._gpg_context, - message, signature_key) + signature_key = ( + self._pgpjson_client_key + if self._pgpjson_client_key is not None + else None + ) + json_, signature = websocket.decrypt_pgp_json( + self.wfile, self._gpg_context, message, signature_key + ) - if "request" not in json_ or not re.match(self._RE_PGP_JSON_REQUEST_FIELD, json_["request"]): + if "request" not in json_ or not re.match( + self._RE_PGP_JSON_REQUEST_FIELD, json_["request"] + ): self._close_websocket(1002, "Missing or invalid request field.") return @@ -200,14 +276,18 @@ class RequestHandler(BaseHTTPRequestHandler): self._pgpjson_pending_request = json_["request"] if "payload" in json_: self._pgpjson_payload = json_["payload"] - self._pgpjson_client_key = self._gpg_context.get_key(signature[0].fpr) + self._pgpjson_client_key = self._gpg_context.get_key( + signature[0].fpr + ) self._request_signature() return elif "nonce" not in json_ or json_["nonce"] != self._pgpjson_nonce: self._close_websocket(1002, "Authentication failed.") return - self._pgpjson_user = self.server.db.get_user(self._pgpjson_client_key.fpr) + self._pgpjson_user = self.server.db.get_user( + self._pgpjson_client_key.fpr + ) # API "routes" if self._pgpjson_pending_request.startswith("data/get/"): @@ -232,20 +312,26 @@ class RequestHandler(BaseHTTPRequestHandler): self._pgpjson_pending_request = None self._pgpjson_payload = None self._pgpjson_user = None - self._gpg_context = gpg.Context(armor=True, - home_dir=self.server.gpg_context.home_dir, - offline=True, - signers=[self.server.key]) - + self._gpg_context = gpg.Context( + armor=True, + home_dir=self.server.gpg_context.home_dir, + offline=True, + signers=[self.server.key], + ) + try: proto = websocket.handshake(self, ["pgp-json"]) if proto is None: self._close_websocket(1002, "I only speak pgp-json") return self._websocket_connected = True - print(f"WebSocket connection with {self.client_address[0]}:{self.client_address[1]}") + print( + f"WebSocket connection with {self.client_address[0]}:{self.client_address[1]}" + ) except websocket.HandshakeError as error: - print(f"WebSocket handshake failed. {error.get_reason(error.why)}: {error.what}={error.value}") + print( + f"WebSocket handshake failed. {error.get_reason(error.why)}: {error.what}={error.value}" + ) return try: @@ -253,7 +339,9 @@ class RequestHandler(BaseHTTPRequestHandler): message = websocket.read_next_message(self.rfile, self.wfile) self._handle_api_request(message) except websocket.WebSocketCloseException as close: - print(f"WebSocket closed with code {close.code}, reason {close.reason}") + print( + f"WebSocket closed with code {close.code}, reason {close.reason}" + ) def _read(self): if not self.headers["Content-Length"]: @@ -267,16 +355,20 @@ class RequestHandler(BaseHTTPRequestHandler): except json.decoder.JSONDecodeError: data = {} for f in required_fields: - if not f in data: + if f not in data: raise JsonMissingFieldException(missing=f) return data def _request_signature(self): self._pgpjson_nonce = token_urlsafe(NONCE_BYTES) - signature_request = websocket.encrypt_pgp_json({ - "nonce": self._pgpjson_nonce, - "request": self._pgpjson_pending_request - }, self._pgpjson_client_key, self._gpg_context) + signature_request = websocket.encrypt_pgp_json( + { + "nonce": self._pgpjson_nonce, + "request": self._pgpjson_pending_request, + }, + self._pgpjson_client_key, + self._gpg_context, + ) websocket.send_message(self.wfile, signature_request) def _serve_err404(self): @@ -302,7 +394,9 @@ class RequestHandler(BaseHTTPRequestHandler): self.wfile.write(output) def _serve_form_keys(self): - form_name = re.match(self._RE_FORM_NAME, self.path.replace("/formkeys/", "")) + form_name = re.match( + self._RE_FORM_NAME, self.path.replace("/formkeys/", "") + ) if form_name: form_name = form_name.group(1) keys = [] @@ -313,13 +407,17 @@ class RequestHandler(BaseHTTPRequestHandler): self._set_cors() self._set_content_type("json") self.end_headers() - self.wfile.write(json.dumps({ "form": form_name, "keys": keys }).encode()) + self.wfile.write( + json.dumps({"form": form_name, "keys": keys}).encode() + ) else: self._serve_err404() def _set_content_type(self, subtype): if subtype not in self._VALID_CONTENT_SUBTYPES: - raise ValueError(f"RequestHandler._set_content_type: _type must be one of {self._VALID_CONTENT_SUBTYPES}.") + raise ValueError( + f"RequestHandler._set_content_type: _type must be one of {self._VALID_CONTENT_SUBTYPES}." + ) _type = "text" if subtype == "json": _type = "application" @@ -336,14 +434,18 @@ class RequestHandler(BaseHTTPRequestHandler): self._serve_form_keys() elif self.path == "/key/srv": # May raise GPGMEerror - key = self.server.gpg_context.key_export(self.server.db.get_config("server_key")) + key = self.server.gpg_context.key_export( + self.server.db.get_config("server_key") + ) self.send_response(200) self._set_cors() self._set_content_type("plain") self.end_headers() self.wfile.write(key) elif self.path.startswith("/key/"): - key_fingerprint = re.match(self._RE_KEY_FINGERPRINT, self.path.replace("/key/", "")) + key_fingerprint = re.match( + self._RE_KEY_FINGERPRINT, self.path.replace("/key/", "") + ) key = self.server.gpg_context.key_export(key_fingerprint.group(1)) key = key if key is not None else b"Unknown" self.send_response(200) @@ -365,9 +467,12 @@ class RequestHandler(BaseHTTPRequestHandler): self.send_response(404) self.end_headers() + class Server(HTTPServer): def __init__(self, listen_to, gpg_home, db): - self.gpg_context = gpg.Context(armor=True, home_dir=gpg_home, offline=True) + self.gpg_context = gpg.Context( + armor=True, home_dir=gpg_home, offline=True + ) self.db = DataBase(db) self.initSrvKeys() super().__init__(listen_to, RequestHandler) @@ -376,11 +481,19 @@ class Server(HTTPServer): fingerprint = self.db.get_config("server_key") if not fingerprint: try: - result = self.gpg_context.create_key(userid="PoC server", algorithm="ed25519", - expires=False, sign=True) + result = self.gpg_context.create_key( + userid="PoC server", + algorithm="ed25519", + expires=False, + sign=True, + ) fingerprint = result.fpr - result = self.gpg_context.create_subkey(key=self.gpg_context.get_key(fingerprint), - algorithm="cv25519", expires=False, encrypt=True) + result = self.gpg_context.create_subkey( + key=self.gpg_context.get_key(fingerprint), + algorithm="cv25519", + expires=False, + encrypt=True, + ) self.db.set_config("server_key", fingerprint) self.key = self.gpg_context.get_key(fingerprint) except gpg.errors.GPGMEError as error: @@ -392,17 +505,20 @@ class Server(HTTPServer): else: try: self.key = self.gpg_context.get_key(fingerprint) - except gpg.errors.KeyError as error: + except KeyError as error: print(f"Cannot find server key {fingerprint}.") sysexit(error) print(f"Loaded server key {fingerprint}.") + if __name__ == "__main__": def run(listen_to, gpg_home, db): httpd = Server(listen_to, gpg_home, db) try: - print(f"Starting server… listening to {listen_to[0]}:{listen_to[1]} with keystore {gpg_home}, db {db}") + print( + f"Starting server… listening to {listen_to[0]}:{listen_to[1]} with keystore {gpg_home}, db {db}" + ) httpd.serve_forever() except KeyboardInterrupt: pass diff --git a/websocket.py b/websocket.py index 8cfa016..faa2f96 100644 --- a/websocket.py +++ b/websocket.py @@ -3,35 +3,44 @@ # https://developer.mozilla.org/en-US/docs/Web/API/WebSockets_API/Writing_WebSocket_servers # https://gist.github.com/SevenW/47be2f9ab74cac26bf21/ (SevenW/HTTPWebSocketsHandler.py) -from base64 import b64decode, b64encode import json -from gpg import Context as GPGContext -from gpg.errors import GPGMEError -from hashlib import sha1 +from base64 import b64decode, b64encode from binascii import Error as BinasciiError +from hashlib import sha1 from http.server import BaseHTTPRequestHandler from io import BufferedIOBase from time import time +from gpg import Context as GPGContext +from gpg.errors import GPGMEError + _WEBSOCKET_GUID = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11" -OPCODE = { "continueation": 0x0, - "text": 0x1, - "binary": 0x2, - "close": 0x8, - "ping": 0x9, - "pong": 0xa } +OPCODE = { + "continueation": 0x0, + "text": 0x1, + "binary": 0x2, + "close": 0x8, + "ping": 0x9, + "pong": 0xA, +} CONTROL_OPCODES = [OPCODE["close"], OPCODE["ping"], OPCODE["pong"]] -NONCONTROL_OPCODES = [OPCODE["continueation"], OPCODE["text"], OPCODE["binary"]] +NONCONTROL_OPCODES = [ + OPCODE["continueation"], + OPCODE["text"], + OPCODE["binary"], +] PROTOCOL_VERSION = 13 HTTP_VERSION = "HTTP/1.1" PGP_JSON_SIGNATURE_VALIDITY = 60 + class WebSocketException(Exception): pass + class FrameError(WebSocketException): - def __init__(message: str, frame: dict): + def __init__(self, message: str, frame: dict): assert "FIN" in frame and frame["FIN"] is not None assert "RSV1" in frame and frame["RSV1"] is not None assert "RSV2" in frame and frame["RSV2"] is not None @@ -40,10 +49,13 @@ class FrameError(WebSocketException): self.message = message self.frame = frame + class HandshakeError(WebSocketException): - reason = { "invalid_header": 0, - "missing_header": 1, - "incompatible_version": 2 } + reason = { + "invalid_header": 0, + "missing_header": 1, + "incompatible_version": 2, + } def __init__(self, what: str, value: str, why: int): assert why in self.reason.values() @@ -56,20 +68,23 @@ class HandshakeError(WebSocketException): if v == code: return k + class InvalidLengthError(FrameError): - def __init__(length: int): + def __init__(self, length: int): self.length = length + class WebSocketCloseException(WebSocketException): def __init__(self, code: int, reason: str): self.code = code self.reason = reason -class WebSocketMessage(): + +class WebSocketMessage: def __add__(self, other): assert isinstance(other, WebSocketMessage) assert other["opcode"] == OPCODE["continueation"] - if opcode == OPCODE["text"]: + if self.opcode == OPCODE["text"]: # may raise UnicodeError self.payload += other.payload.decode() else: @@ -86,21 +101,35 @@ class WebSocketMessage(): def __str__(self): return f"WebSocket Message opcode:0x{self.opcode:x} payload:{self.payload}" + def _check_required_headers(request_handler: BaseHTTPRequestHandler): if request_handler.request_version != HTTP_VERSION: request_handler.send_response(400) request_handler.end_headers() - raise HandshakeError("HTTP", request_handler.request_version, - HandshakeError.reason["incompatible_version"]) + raise HandshakeError( + "HTTP", + request_handler.request_version, + HandshakeError.reason["incompatible_version"], + ) - for header in [["Host", None], - ["Upgrade", "websocket"], - ["Connection", "upgrade"], - ["Sec-WebSocket-Key", None], - ["Sec-WebSocket-Version", None]]: + for header in [ + ["Host", None], + ["Upgrade", "websocket"], + ["Connection", "upgrade"], + ["Sec-WebSocket-Key", None], + ["Sec-WebSocket-Version", None], + ]: value = request_handler.headers.get(header[0]) - requisite = header[1] in value.lower().split(", ") if header[1] is not None else value is not None - reason = HandshakeError.reason["invalid_header"] if value is not None else HandshakeError.reason["missing_header"] + requisite = ( + header[1] in value.lower().split(", ") + if header[1] is not None + else value is not None + ) + reason = ( + HandshakeError.reason["invalid_header"] + if value is not None + else HandshakeError.reason["missing_header"] + ) if requisite is False: request_handler.send_response(400) request_handler.end_headers() @@ -110,9 +139,15 @@ def _check_required_headers(request_handler: BaseHTTPRequestHandler): if ws_key is None: request_handler.send_response(400) request_handler.end_headers() - raise HandshakeError("Sec-WebSocket-Key", None, HandshakeError.reason["missing_header"]) + raise HandshakeError( + "Sec-WebSocket-Key", None, HandshakeError.reason["missing_header"] + ) try: - invalid_key_error = HandshakeError("Sec-WebSocket-Key", ws_key, HandshakeError.reason["invalid_header"]) + invalid_key_error = HandshakeError( + "Sec-WebSocket-Key", + ws_key, + HandshakeError.reason["invalid_header"], + ) decoded = b64decode(s=ws_key, altchars=None, validate=True) if len(decoded) != 16: request_handler.send_response(400) @@ -123,25 +158,37 @@ def _check_required_headers(request_handler: BaseHTTPRequestHandler): request_handler.end_headers() raise invalid_key_error + def _decode_payload(masking_key: bytes, encoded_payload: bytes): decoded_payload = bytearray() for byte in encoded_payload: - decoded_payload += bytes([byte ^ masking_key[len(decoded_payload) % 4]]) + decoded_payload += bytes( + [byte ^ masking_key[len(decoded_payload) % 4]] + ) return decoded_payload -def _encode_data_frame(fin: int, opcode: int, rsv1: int, rsv2: int, rsv3: int, payload: bytes): + +def _encode_data_frame( + fin: int, opcode: int, rsv1: int, rsv2: int, rsv3: int, payload: bytes +): assert isinstance(payload, bytes) - if not opcode in OPCODE.values(): - raise ValueError(f"Unsupported opcode {opcode}. Valid values {list(OPCODE.values())}.") + if opcode not in OPCODE.values(): + raise ValueError( + f"Unsupported opcode {opcode}. Valid values {list(OPCODE.values())}." + ) if fin == 0 and opcode in CONTROL_OPCODES: - raise ValueError(f"Control frames cannot be fragmented") + raise ValueError("Control frames cannot be fragmented") if opcode in CONTROL_OPCODES and len(payload) > 125: - raise ValueError("Control frame cannot have a payload bigger than 125 bytes.") + raise ValueError( + "Control frame cannot have a payload bigger than 125 bytes." + ) payload_length = len(payload) length_bits = 7 - if payload_length > 0x7fffffffffffffff: - raise ValueError(f"Payload maximal size exceeded (provided {payload_length} bytes).") - elif payload_length > 0xffff: + if payload_length > 0x7FFFFFFFFFFFFFFF: + raise ValueError( + f"Payload maximal size exceeded (provided {payload_length} bytes)." + ) + elif payload_length > 0xFFFF: length_bits = 7 + 64 elif payload_length > 125: length_bits = 7 + 16 @@ -162,10 +209,13 @@ def _encode_data_frame(fin: int, opcode: int, rsv1: int, rsv2: int, rsv3: int, p return frame + def _handle_control_frame(wfile: BufferedIOBase, frame: dict): if frame["opcode"] == OPCODE["close"]: code = frame["status_code"] if "status_code" in frame else None - payload = code.to_bytes(2, byteorder="big") if code is not None else b"" + payload = ( + code.to_bytes(2, byteorder="big") if code is not None else b"" + ) reason = frame["close_reason"] if "close_reason" in frame else "-" send_message(wfile, OPCODE["close"], payload) raise WebSocketCloseException(code, reason) @@ -173,26 +223,29 @@ def _handle_control_frame(wfile: BufferedIOBase, frame: dict): payload = frame["payload"] if "payload" in frame else b"" send_message(wfile, OPCODE["pong"], payload) + def _read_data_frame(rfile: BufferedIOBase): frame = {} - #char = rfile.read(1) - #if len(char) == 0: - # return net_bytes = ord(rfile.read(1)) frame["FIN"] = net_bytes >> 7 frame["RSV1"] = (net_bytes & 0x40) >> 6 frame["RSV2"] = (net_bytes & 0x20) >> 5 frame["RSV3"] = (net_bytes & 0x10) >> 4 - frame["opcode"] = net_bytes & 0x0f + frame["opcode"] = net_bytes & 0x0F if frame["RSV1"] != 0 or frame["RSV2"] != 0 or frame["RSV3"] != 0: - raise FrameError("Unsupported feature. RSV1, RSV2 or RSV3 has a non-zero value.", frame) + raise FrameError( + "Unsupported feature. RSV1, RSV2 or RSV3 has a non-zero value.", + frame, + ) if not frame["opcode"] in OPCODE.values(): raise FrameError("Unsupported opcode value.", frame) if frame["FIN"] == 0 and frame["opcode"] != OPCODE["continueation"]: - raise FrameError("FIN bit not set for a non-continueation frame.", frame) + raise FrameError( + "FIN bit not set for a non-continueation frame.", frame + ) if frame["opcode"] in CONTROL_OPCODES and frame["FIN"] == 0: raise FrameError("FIN bit not set for a control frame.", frame) @@ -203,7 +256,7 @@ def _read_data_frame(rfile: BufferedIOBase): if mask_bit == 0: raise FrameError("Unmasked frame from client.", frame) - length1 = net_bytes & 0x7f + length1 = net_bytes & 0x7F if frame["opcode"] in CONTROL_OPCODES and length1 > 125: raise FrameError("Control frame with invalid payload length.", frame) @@ -211,20 +264,25 @@ def _read_data_frame(rfile: BufferedIOBase): try: length = _read_payload_length(length1, rfile) except InvalidLengthError as error: - raise FrameError(f"Invalid payload length of {error.length} bytes.", frame) + raise FrameError( + f"Invalid payload length of {error.length} bytes.", frame + ) masking_key = rfile.read(4) encoded_payload = rfile.read(length) frame["payload"] = _decode_payload(masking_key, encoded_payload) if frame["opcode"] == OPCODE["close"] and frame["payload"]: - frame["status_code"] = int.from_bytes(frame["payload"][0:2], byteorder="big") + frame["status_code"] = int.from_bytes( + frame["payload"][0:2], byteorder="big" + ) if length > 2: # /!\ may raise UnicodeError /!\ frame["close_reason"] = frame["payload"][2:].decode() return frame + def _read_payload_length(payload_length1: int, rfile: BufferedIOBase): final_length = payload_length1 if payload_length1 == 126: @@ -235,33 +293,51 @@ def _read_payload_length(payload_length1: int, rfile: BufferedIOBase): raise InvalidLengthError(final_length) return final_length + def close(wfile: BufferedIOBase, code=1000, reason=None): code_bytes = code.to_bytes(2, byteorder="big") payload = code_bytes if reason is None else code_bytes + reason.encode() frame = _encode_data_frame(1, OPCODE["close"], 0, 0, 0, payload) wfile.write(frame) -def decrypt_pgp_json(wfile: BufferedIOBase, gpg_context: GPGContext, message: WebSocketMessage, signature_key=None): + +def decrypt_pgp_json( + wfile: BufferedIOBase, + gpg_context: GPGContext, + message: WebSocketMessage, + signature_key=None, +): if message.opcode != OPCODE["text"]: close(wfile, 1003, "Only text datatype is allowed.") raise WebSocketCloseException(1003, "Received no-text datatype.") try: - plaintext, result, verify_result = gpg_context.decrypt(message.payload.encode()) + plaintext, result, verify_result = gpg_context.decrypt( + message.payload.encode() + ) except GPGMEError as error: - close(wfile, 1002, f"Failed to decrypt message.") - raise WebSocketCloseException(1002, f"Failed to decrypt websocket message: {error}.") + close(wfile, 1002, "Failed to decrypt message.") + raise WebSocketCloseException( + 1002, f"Failed to decrypt websocket message: {error}." + ) if len(verify_result.signatures) == 0: close(wfile, 1008, "No signature recognized.") - raise WebSocketCloseException(1008, "Message with missing or unknown signature.") - if signature_key is not None and signature_key.fpr != verify_result.signatures[0].fpr: + raise WebSocketCloseException( + 1008, "Message with missing or unknown signature." + ) + if ( + signature_key is not None + and signature_key.fpr != verify_result.signatures[0].fpr + ): close(wfile, 1008, "Message signature check failed.") raise WebSocketCloseException(1008, "Signing keys doesn't match.") signature_age = int(time()) - verify_result.signatures[0].timestamp if signature_age > PGP_JSON_SIGNATURE_VALIDITY: close(wfile, 1008, "Signature too old.") - raise WebSocketCloseException(1008, f"Message with a {signature_age}s old signature.") + raise WebSocketCloseException( + 1008, f"Message with a {signature_age}s old signature." + ) try: json_ = json.loads(plaintext) @@ -271,27 +347,36 @@ def decrypt_pgp_json(wfile: BufferedIOBase, gpg_context: GPGContext, message: We return json_, verify_result.signatures + def encrypt_pgp_json(obj: dict, recipient, gpg_context: GPGContext): json_ = json.dumps(obj).encode() - ciphertext, result, sign_result = gpg_context.encrypt(json_, - recipients=[recipient], - sign=True, - always_trust=True) + ciphertext, result, sign_result = gpg_context.encrypt( + json_, recipients=[recipient], sign=True, always_trust=True + ) return ciphertext + def handshake(request_handler: BaseHTTPRequestHandler, subprotocols=[]): request_handler.protocol_version = HTTP_VERSION _check_required_headers(request_handler) websocket_key = request_handler.headers.get("Sec-WebSocket-Key") - digest = b64encode(sha1((websocket_key + _WEBSOCKET_GUID).encode()).digest()) + digest = b64encode( + sha1((websocket_key + _WEBSOCKET_GUID).encode()).digest() + ) websocket_version = request_handler.headers.get("Sec-WebSocket-Version") if int(websocket_version) != PROTOCOL_VERSION: request_handler.send_response(400) request_handler.send_header("Sec-WebSocket-Version", PROTOCOL_VERSION) request_handler.end_headers() - raise HandshakeError("Sec-WebSocket-Version", websocket_version, HandshakeError.reason["incompatible_version"]) + raise HandshakeError( + "Sec-WebSocket-Version", + websocket_version, + HandshakeError.reason["incompatible_version"], + ) selected_subprotocol = None - requested_subprotocols = request_handler.headers.get("Sec-WebSocket-Protocol") + requested_subprotocols = request_handler.headers.get( + "Sec-WebSocket-Protocol" + ) if requested_subprotocols: requested_subprotocols = requested_subprotocols.split(",") for proto in requested_subprotocols: @@ -304,11 +389,14 @@ def handshake(request_handler: BaseHTTPRequestHandler, subprotocols=[]): request_handler.send_header("Connection", "Upgrade") request_handler.send_header("Sec-WebSocket-Accept", digest.decode()) if requested_subprotocols and selected_subprotocol is not None: - request_handler.send_header("Sec-WebSocket-Protocol", selected_subprotocol) + request_handler.send_header( + "Sec-WebSocket-Protocol", selected_subprotocol + ) request_handler.end_headers() return selected_subprotocol + def read_next_message(rfile: BufferedIOBase, wfile: BufferedIOBase): frame = _read_data_frame(rfile) message = WebSocketMessage(frame["opcode"], frame["payload"]) @@ -319,10 +407,20 @@ def read_next_message(rfile: BufferedIOBase, wfile: BufferedIOBase): _handle_control_frame(wfile, frame) return read_next_message(rfile, wfile) else: - return message + read_next_message(rfile) + return message + read_next_message(rfile, wfile) -def send_message(wfile: BufferedIOBase, payload: bytes, opcode=OPCODE["text"], rsv1=0, rsv2=0, rsv3=0): - if len(payload) > 0x7fffffffffffffff: - raise ValueError(f"Payload to big. Sending fragmented messages not implemented.") + +def send_message( + wfile: BufferedIOBase, + payload: bytes, + opcode=OPCODE["text"], + rsv1=0, + rsv2=0, + rsv3=0, +): + if len(payload) > 0x7FFFFFFFFFFFFFFF: + raise ValueError( + "Payload to big. Sending fragmented messages not implemented." + ) frame = _encode_data_frame(1, opcode, rsv1, rsv2, rsv3, payload) wfile.write(frame)