Format python with black

Various small fixes
This commit is contained in:
Andrea Blankenstijn 2021-02-07 17:45:48 +01:00
parent 83f5c3e2ad
commit 9270547776
2 changed files with 348 additions and 134 deletions

250
server.py
View File

@ -1,43 +1,51 @@
#!/usr/bin/env python3
import gpg
import json
import os
import re
from gpg.gpgme import GPG_ERR_NO_ERROR, GPGME_DELETE_FORCE, gpgme_op_delete_ext
from http.server import BaseHTTPRequestHandler, HTTPServer
from secrets import token_urlsafe
from sys import argv, exit as sysexit
from sys import argv
from sys import exit as sysexit
from tempfile import TemporaryDirectory
import gpg
from gpg.gpgme import GPG_ERR_NO_ERROR, GPGME_DELETE_FORCE, gpgme_op_delete_ext
import websocket
from database import DataBase
NONCE_BYTES = 128
ADMIN_ACCESS = 100
class JsonMissingFieldException(Exception):
def __init__(self, missing):
self.missing = missing
class RequestHandler(BaseHTTPRequestHandler):
_RE_FILES = re.compile(r"^(/|/admin\.html|/data\.html|/forms\.html|/config\.js|/public\.js|/registered\.js|/openpgp\.min\.js)$")
_RE_FILES = re.compile(
r"^(/|/admin\.html|/data\.html|/forms\.html|/config\.js|/public\.js|/registered\.js|/openpgp\.min\.js)$"
)
_RE_FORM_NAME = re.compile(r"^([-_0-9a-zA-Z]{1,64})$")
_RE_KEY_FINGERPRINT = re.compile(r"^([0-9a-fA-F]{40})$")
_RE_PGP_JSON_REQUEST_FIELD = re.compile(r"^[0-9a-zA-Z_/]{3,100}$")
_VALID_CONTENT_SUBTYPES = {"html", "javascript", "json", "plain"}
def _api_answer_pgp_json(self, payload):
request_answer = websocket.encrypt_pgp_json({
"request": self._pgpjson_pending_request,
"payload": payload
}, self._pgpjson_client_key, self._gpg_context)
request_answer = websocket.encrypt_pgp_json(
{"request": self._pgpjson_pending_request, "payload": payload},
self._pgpjson_client_key,
self._gpg_context,
)
websocket.send_message(self.wfile, request_answer)
self._close_websocket()
def _api_collect_data(self):
form_name = re.match(self._RE_FORM_NAME, self.path.replace("/post/", ""))
form_name = re.match(
self._RE_FORM_NAME, self.path.replace("/post/", "")
)
if form_name:
secret = self._read().decode()
self.server.db.store_form_data(form_name.group(1), secret)
@ -45,11 +53,17 @@ class RequestHandler(BaseHTTPRequestHandler):
self.end_headers()
def _api_get_collected_data(self):
form_name = re.match(self._RE_FORM_NAME, self._pgpjson_pending_request.replace("data/get/", ""))
form_name = re.match(
self._RE_FORM_NAME,
self._pgpjson_pending_request.replace("data/get/", ""),
)
if form_name:
form_name = form_name.group(1)
form_users = self.server.db.get_form_keys_fingerprints(form_name)
if self._pgpjson_user["fingerprint"] not in form_users["fingerprints"]:
if (
self._pgpjson_user["fingerprint"]
not in form_users["fingerprints"]
):
self._close_websocket(1008, "Access denied.")
else:
collected_data = self.server.db.get_collected_data(form_name)
@ -62,18 +76,29 @@ class RequestHandler(BaseHTTPRequestHandler):
self._close_websocket(1008, "Access denied.")
return
form_name = re.match(self._RE_FORM_NAME, self._pgpjson_pending_request.replace("form/get/", ""))
form_name = re.match(
self._RE_FORM_NAME,
self._pgpjson_pending_request.replace("form/get/", ""),
)
if form_name:
form_key_list = self.server.db.get_form_keys_fingerprints(form_name.group(1))
form_key_list = self.server.db.get_form_keys_fingerprints(
form_name.group(1)
)
self._api_answer_pgp_json(form_key_list)
else:
self._close_websocket(1002, "Invalid form name.")
def _api_get_user(self):
user_fingerprint = re.match(self._RE_KEY_FINGERPRINT, self._pgpjson_pending_request.replace("user/", ""))
user_fingerprint = re.match(
self._RE_KEY_FINGERPRINT,
self._pgpjson_pending_request.replace("user/", ""),
)
if user_fingerprint:
fpr = user_fingerprint.group(1)
if self._pgpjson_user["access_level"] == ADMIN_ACCESS or self._pgpjson_user["fingerprint"] == fpr:
if (
self._pgpjson_user["access_level"] == ADMIN_ACCESS
or self._pgpjson_user["fingerprint"] == fpr
):
user = self.server.db.get_user(fpr)
self._api_answer_pgp_json(user)
else:
@ -89,15 +114,27 @@ class RequestHandler(BaseHTTPRequestHandler):
def _api_get_user_key(self, fingerprints=None):
keys = list()
if fingerprints == None:
if fingerprints is None:
for user in self.server.db.get_user():
keys.append({ "fingerprint": user["fingerprint"],
"armored_key": (self._gpg_context.key_export(user["fingerprint"])).decode()})
keys.append(
{
"fingerprint": user["fingerprint"],
"armored_key": (
self._gpg_context.key_export(user["fingerprint"])
).decode(),
}
)
else:
for fpr in fingerprints:
if self.server.db.get_user(fpr) is not None:
keys.append({ "fingerprint": fpr,
"armored_key": (self._gpg_context.key_export(fpr)).decode()})
keys.append(
{
"fingerprint": fpr,
"armored_key": (
self._gpg_context.key_export(fpr)
).decode(),
}
)
else:
self._close_websocket(1008, "Asking for an unknown key.")
return
@ -106,14 +143,21 @@ class RequestHandler(BaseHTTPRequestHandler):
def _api_register_user(self):
try:
req = self._read_json(["pubKey"])
except JsonMissingFieldException:
self._api_register_user_answer(400, f"Missing {missingField.missing}.")
except JsonMissingFieldException as missing_field:
self._api_register_user_answer(
400, f"Missing {missing_field.missing}."
)
return
results = self.server.gpg_context.key_import(req["pubKey"].encode())
if results == "IMPORT_PROBLEM" or not results.considered:
self._api_register_user_answer(400, "Invalid pubKey.")
elif results.imported:
access_level = ADMIN_ACCESS if len(self.server.db.get_user()) == 0 else 0 # first user get admin access level
access_level = (
ADMIN_ACCESS
if len(self.server.db.get_user())
== 0 # first user get admin access level
else 0
)
self.server.db.add_user(results.imports[0].fpr, access_level)
print(f"Imported key {results.imports[0].fpr}.")
self._api_register_user_answer(201, "Key registered.")
@ -121,20 +165,25 @@ class RequestHandler(BaseHTTPRequestHandler):
self._api_register_user_answer(200, "Key already on server.")
else:
self._api_register_user_answer(418, "What happened?")
print(f"Tried to add following pubKey and got strange results:\n{req[pubKey]}\n\n{results}")
print(
f"Tried to add following pubKey and got strange results:\n{req['pubKey']}\n\n{results}"
)
def _api_register_user_answer(self, code: int, message: str):
self.send_response(code)
self._set_content_type("json")
self.end_headers()
self.wfile.write(json.dumps({ "message": message }).encode())
self.wfile.write(json.dumps({"message": message}).encode())
def _api_set_collected_data(self):
if self._pgpjson_user["access_level"] != ADMIN_ACCESS:
self._close_websocket(1008, "Access denied.")
return
form_name = re.match(self._RE_FORM_NAME, self._pgpjson_pending_request.replace("data/set/", ""))
form_name = re.match(
self._RE_FORM_NAME,
self._pgpjson_pending_request.replace("data/set/", ""),
)
if form_name:
if self._pgpjson_payload is None:
self._close_websocket(1002, "Missing payload.")
@ -143,10 +192,14 @@ class RequestHandler(BaseHTTPRequestHandler):
if "secret" not in data or "id" not in data:
self._close_websocket(1002, "Invalid payload")
return
elif not data["secret"].startswith("-----BEGIN PGP MESSAGE-----"):
elif not data["secret"].startswith(
"-----BEGIN PGP MESSAGE-----"
):
self._close_websocket(1002, "Invalid PGP message")
return
self.server.db.set_collected_data(form_name.group(1), self._pgpjson_payload)
self.server.db.set_collected_data(
form_name.group(1), self._pgpjson_payload
)
self._close_websocket()
else:
self._close_websocket(1002, "Invalid form name.")
@ -156,10 +209,19 @@ class RequestHandler(BaseHTTPRequestHandler):
self._close_websocket(1008, "Access denied.")
return
form_name = re.match(self._RE_FORM_NAME, self._pgpjson_pending_request.replace("form/set/", ""))
form_name = re.match(
self._RE_FORM_NAME,
self._pgpjson_pending_request.replace("form/set/", ""),
)
if form_name:
if self._pgpjson_payload is not None and "fingerprints" in self._pgpjson_payload and isinstance(self._pgpjson_payload["fingerprints"], list):
self.server.db.set_form_keys_fingerprints(form_name.group(1), self._pgpjson_payload["fingerprints"])
if (
self._pgpjson_payload is not None
and "fingerprints" in self._pgpjson_payload
and isinstance(self._pgpjson_payload["fingerprints"], list)
):
self.server.db.set_form_keys_fingerprints(
form_name.group(1), self._pgpjson_payload["fingerprints"]
)
self._close_websocket()
else:
self._close_websocket(1002, "Invalid request payload.")
@ -170,29 +232,43 @@ class RequestHandler(BaseHTTPRequestHandler):
if self._pgpjson_user["access_level"] == ADMIN_ACCESS:
self._close_websocket(1008, "Cannot delete admin account.")
return
result = gpgme_op_delete_ext(self._gpg_context.wrapped, self._pgpjson_client_key, GPGME_DELETE_FORCE)
result = gpgme_op_delete_ext(
self._gpg_context.wrapped,
self._pgpjson_client_key,
GPGME_DELETE_FORCE,
)
if result == GPG_ERR_NO_ERROR:
self.server.db.delete_user(self._pgpjson_client_key.fpr)
print(f"Key {self._pgpjson_client_key.fpr} deleted.")
self._close_websocket()
else:
print(f"Failed to delete key {self._pgpjson_client_key.fpr}, status code: {result}.")
print(
f"Failed to delete key {self._pgpjson_client_key.fpr}, status code: {result}."
)
self._close_websocket(1011, "Failed to delete key.")
def _close_websocket(self, code=1000, reason=None):
websocket.close(self.wfile, code, reason)
self._websocket_connected = False
print(f"Closed WebSocket connection with {self.client_address[0]}:{self.client_address[1]}")
print(
f"Closed WebSocket connection with {self.client_address[0]}:{self.client_address[1]}"
)
if code != 1000:
print(f"Code: {code} reason: {reason}")
def _handle_api_request(self, message: websocket.WebSocketMessage):
signature_key = self._pgpjson_client_key if self._pgpjson_client_key is not None else None
json_, signature = websocket.decrypt_pgp_json(self.wfile, self._gpg_context,
message, signature_key)
signature_key = (
self._pgpjson_client_key
if self._pgpjson_client_key is not None
else None
)
json_, signature = websocket.decrypt_pgp_json(
self.wfile, self._gpg_context, message, signature_key
)
if "request" not in json_ or not re.match(self._RE_PGP_JSON_REQUEST_FIELD, json_["request"]):
if "request" not in json_ or not re.match(
self._RE_PGP_JSON_REQUEST_FIELD, json_["request"]
):
self._close_websocket(1002, "Missing or invalid request field.")
return
@ -200,14 +276,18 @@ class RequestHandler(BaseHTTPRequestHandler):
self._pgpjson_pending_request = json_["request"]
if "payload" in json_:
self._pgpjson_payload = json_["payload"]
self._pgpjson_client_key = self._gpg_context.get_key(signature[0].fpr)
self._pgpjson_client_key = self._gpg_context.get_key(
signature[0].fpr
)
self._request_signature()
return
elif "nonce" not in json_ or json_["nonce"] != self._pgpjson_nonce:
self._close_websocket(1002, "Authentication failed.")
return
self._pgpjson_user = self.server.db.get_user(self._pgpjson_client_key.fpr)
self._pgpjson_user = self.server.db.get_user(
self._pgpjson_client_key.fpr
)
# API "routes"
if self._pgpjson_pending_request.startswith("data/get/"):
@ -232,20 +312,26 @@ class RequestHandler(BaseHTTPRequestHandler):
self._pgpjson_pending_request = None
self._pgpjson_payload = None
self._pgpjson_user = None
self._gpg_context = gpg.Context(armor=True,
home_dir=self.server.gpg_context.home_dir,
offline=True,
signers=[self.server.key])
self._gpg_context = gpg.Context(
armor=True,
home_dir=self.server.gpg_context.home_dir,
offline=True,
signers=[self.server.key],
)
try:
proto = websocket.handshake(self, ["pgp-json"])
if proto is None:
self._close_websocket(1002, "I only speak pgp-json")
return
self._websocket_connected = True
print(f"WebSocket connection with {self.client_address[0]}:{self.client_address[1]}")
print(
f"WebSocket connection with {self.client_address[0]}:{self.client_address[1]}"
)
except websocket.HandshakeError as error:
print(f"WebSocket handshake failed. {error.get_reason(error.why)}: {error.what}={error.value}")
print(
f"WebSocket handshake failed. {error.get_reason(error.why)}: {error.what}={error.value}"
)
return
try:
@ -253,7 +339,9 @@ class RequestHandler(BaseHTTPRequestHandler):
message = websocket.read_next_message(self.rfile, self.wfile)
self._handle_api_request(message)
except websocket.WebSocketCloseException as close:
print(f"WebSocket closed with code {close.code}, reason {close.reason}")
print(
f"WebSocket closed with code {close.code}, reason {close.reason}"
)
def _read(self):
if not self.headers["Content-Length"]:
@ -267,16 +355,20 @@ class RequestHandler(BaseHTTPRequestHandler):
except json.decoder.JSONDecodeError:
data = {}
for f in required_fields:
if not f in data:
if f not in data:
raise JsonMissingFieldException(missing=f)
return data
def _request_signature(self):
self._pgpjson_nonce = token_urlsafe(NONCE_BYTES)
signature_request = websocket.encrypt_pgp_json({
"nonce": self._pgpjson_nonce,
"request": self._pgpjson_pending_request
}, self._pgpjson_client_key, self._gpg_context)
signature_request = websocket.encrypt_pgp_json(
{
"nonce": self._pgpjson_nonce,
"request": self._pgpjson_pending_request,
},
self._pgpjson_client_key,
self._gpg_context,
)
websocket.send_message(self.wfile, signature_request)
def _serve_err404(self):
@ -302,7 +394,9 @@ class RequestHandler(BaseHTTPRequestHandler):
self.wfile.write(output)
def _serve_form_keys(self):
form_name = re.match(self._RE_FORM_NAME, self.path.replace("/formkeys/", ""))
form_name = re.match(
self._RE_FORM_NAME, self.path.replace("/formkeys/", "")
)
if form_name:
form_name = form_name.group(1)
keys = []
@ -313,13 +407,17 @@ class RequestHandler(BaseHTTPRequestHandler):
self._set_cors()
self._set_content_type("json")
self.end_headers()
self.wfile.write(json.dumps({ "form": form_name, "keys": keys }).encode())
self.wfile.write(
json.dumps({"form": form_name, "keys": keys}).encode()
)
else:
self._serve_err404()
def _set_content_type(self, subtype):
if subtype not in self._VALID_CONTENT_SUBTYPES:
raise ValueError(f"RequestHandler._set_content_type: _type must be one of {self._VALID_CONTENT_SUBTYPES}.")
raise ValueError(
f"RequestHandler._set_content_type: _type must be one of {self._VALID_CONTENT_SUBTYPES}."
)
_type = "text"
if subtype == "json":
_type = "application"
@ -336,14 +434,18 @@ class RequestHandler(BaseHTTPRequestHandler):
self._serve_form_keys()
elif self.path == "/key/srv":
# May raise GPGMEerror
key = self.server.gpg_context.key_export(self.server.db.get_config("server_key"))
key = self.server.gpg_context.key_export(
self.server.db.get_config("server_key")
)
self.send_response(200)
self._set_cors()
self._set_content_type("plain")
self.end_headers()
self.wfile.write(key)
elif self.path.startswith("/key/"):
key_fingerprint = re.match(self._RE_KEY_FINGERPRINT, self.path.replace("/key/", ""))
key_fingerprint = re.match(
self._RE_KEY_FINGERPRINT, self.path.replace("/key/", "")
)
key = self.server.gpg_context.key_export(key_fingerprint.group(1))
key = key if key is not None else b"Unknown"
self.send_response(200)
@ -365,9 +467,12 @@ class RequestHandler(BaseHTTPRequestHandler):
self.send_response(404)
self.end_headers()
class Server(HTTPServer):
def __init__(self, listen_to, gpg_home, db):
self.gpg_context = gpg.Context(armor=True, home_dir=gpg_home, offline=True)
self.gpg_context = gpg.Context(
armor=True, home_dir=gpg_home, offline=True
)
self.db = DataBase(db)
self.initSrvKeys()
super().__init__(listen_to, RequestHandler)
@ -376,11 +481,19 @@ class Server(HTTPServer):
fingerprint = self.db.get_config("server_key")
if not fingerprint:
try:
result = self.gpg_context.create_key(userid="PoC server", algorithm="ed25519",
expires=False, sign=True)
result = self.gpg_context.create_key(
userid="PoC server",
algorithm="ed25519",
expires=False,
sign=True,
)
fingerprint = result.fpr
result = self.gpg_context.create_subkey(key=self.gpg_context.get_key(fingerprint),
algorithm="cv25519", expires=False, encrypt=True)
result = self.gpg_context.create_subkey(
key=self.gpg_context.get_key(fingerprint),
algorithm="cv25519",
expires=False,
encrypt=True,
)
self.db.set_config("server_key", fingerprint)
self.key = self.gpg_context.get_key(fingerprint)
except gpg.errors.GPGMEError as error:
@ -392,17 +505,20 @@ class Server(HTTPServer):
else:
try:
self.key = self.gpg_context.get_key(fingerprint)
except gpg.errors.KeyError as error:
except KeyError as error:
print(f"Cannot find server key {fingerprint}.")
sysexit(error)
print(f"Loaded server key {fingerprint}.")
if __name__ == "__main__":
def run(listen_to, gpg_home, db):
httpd = Server(listen_to, gpg_home, db)
try:
print(f"Starting server… listening to {listen_to[0]}:{listen_to[1]} with keystore {gpg_home}, db {db}")
print(
f"Starting server… listening to {listen_to[0]}:{listen_to[1]} with keystore {gpg_home}, db {db}"
)
httpd.serve_forever()
except KeyboardInterrupt:
pass

View File

@ -3,35 +3,44 @@
# https://developer.mozilla.org/en-US/docs/Web/API/WebSockets_API/Writing_WebSocket_servers
# https://gist.github.com/SevenW/47be2f9ab74cac26bf21/ (SevenW/HTTPWebSocketsHandler.py)
from base64 import b64decode, b64encode
import json
from gpg import Context as GPGContext
from gpg.errors import GPGMEError
from hashlib import sha1
from base64 import b64decode, b64encode
from binascii import Error as BinasciiError
from hashlib import sha1
from http.server import BaseHTTPRequestHandler
from io import BufferedIOBase
from time import time
from gpg import Context as GPGContext
from gpg.errors import GPGMEError
_WEBSOCKET_GUID = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"
OPCODE = { "continueation": 0x0,
"text": 0x1,
"binary": 0x2,
"close": 0x8,
"ping": 0x9,
"pong": 0xa }
OPCODE = {
"continueation": 0x0,
"text": 0x1,
"binary": 0x2,
"close": 0x8,
"ping": 0x9,
"pong": 0xA,
}
CONTROL_OPCODES = [OPCODE["close"], OPCODE["ping"], OPCODE["pong"]]
NONCONTROL_OPCODES = [OPCODE["continueation"], OPCODE["text"], OPCODE["binary"]]
NONCONTROL_OPCODES = [
OPCODE["continueation"],
OPCODE["text"],
OPCODE["binary"],
]
PROTOCOL_VERSION = 13
HTTP_VERSION = "HTTP/1.1"
PGP_JSON_SIGNATURE_VALIDITY = 60
class WebSocketException(Exception):
pass
class FrameError(WebSocketException):
def __init__(message: str, frame: dict):
def __init__(self, message: str, frame: dict):
assert "FIN" in frame and frame["FIN"] is not None
assert "RSV1" in frame and frame["RSV1"] is not None
assert "RSV2" in frame and frame["RSV2"] is not None
@ -40,10 +49,13 @@ class FrameError(WebSocketException):
self.message = message
self.frame = frame
class HandshakeError(WebSocketException):
reason = { "invalid_header": 0,
"missing_header": 1,
"incompatible_version": 2 }
reason = {
"invalid_header": 0,
"missing_header": 1,
"incompatible_version": 2,
}
def __init__(self, what: str, value: str, why: int):
assert why in self.reason.values()
@ -56,20 +68,23 @@ class HandshakeError(WebSocketException):
if v == code:
return k
class InvalidLengthError(FrameError):
def __init__(length: int):
def __init__(self, length: int):
self.length = length
class WebSocketCloseException(WebSocketException):
def __init__(self, code: int, reason: str):
self.code = code
self.reason = reason
class WebSocketMessage():
class WebSocketMessage:
def __add__(self, other):
assert isinstance(other, WebSocketMessage)
assert other["opcode"] == OPCODE["continueation"]
if opcode == OPCODE["text"]:
if self.opcode == OPCODE["text"]:
# may raise UnicodeError
self.payload += other.payload.decode()
else:
@ -86,21 +101,35 @@ class WebSocketMessage():
def __str__(self):
return f"WebSocket Message opcode:0x{self.opcode:x} payload:{self.payload}"
def _check_required_headers(request_handler: BaseHTTPRequestHandler):
if request_handler.request_version != HTTP_VERSION:
request_handler.send_response(400)
request_handler.end_headers()
raise HandshakeError("HTTP", request_handler.request_version,
HandshakeError.reason["incompatible_version"])
raise HandshakeError(
"HTTP",
request_handler.request_version,
HandshakeError.reason["incompatible_version"],
)
for header in [["Host", None],
["Upgrade", "websocket"],
["Connection", "upgrade"],
["Sec-WebSocket-Key", None],
["Sec-WebSocket-Version", None]]:
for header in [
["Host", None],
["Upgrade", "websocket"],
["Connection", "upgrade"],
["Sec-WebSocket-Key", None],
["Sec-WebSocket-Version", None],
]:
value = request_handler.headers.get(header[0])
requisite = header[1] in value.lower().split(", ") if header[1] is not None else value is not None
reason = HandshakeError.reason["invalid_header"] if value is not None else HandshakeError.reason["missing_header"]
requisite = (
header[1] in value.lower().split(", ")
if header[1] is not None
else value is not None
)
reason = (
HandshakeError.reason["invalid_header"]
if value is not None
else HandshakeError.reason["missing_header"]
)
if requisite is False:
request_handler.send_response(400)
request_handler.end_headers()
@ -110,9 +139,15 @@ def _check_required_headers(request_handler: BaseHTTPRequestHandler):
if ws_key is None:
request_handler.send_response(400)
request_handler.end_headers()
raise HandshakeError("Sec-WebSocket-Key", None, HandshakeError.reason["missing_header"])
raise HandshakeError(
"Sec-WebSocket-Key", None, HandshakeError.reason["missing_header"]
)
try:
invalid_key_error = HandshakeError("Sec-WebSocket-Key", ws_key, HandshakeError.reason["invalid_header"])
invalid_key_error = HandshakeError(
"Sec-WebSocket-Key",
ws_key,
HandshakeError.reason["invalid_header"],
)
decoded = b64decode(s=ws_key, altchars=None, validate=True)
if len(decoded) != 16:
request_handler.send_response(400)
@ -123,25 +158,37 @@ def _check_required_headers(request_handler: BaseHTTPRequestHandler):
request_handler.end_headers()
raise invalid_key_error
def _decode_payload(masking_key: bytes, encoded_payload: bytes):
decoded_payload = bytearray()
for byte in encoded_payload:
decoded_payload += bytes([byte ^ masking_key[len(decoded_payload) % 4]])
decoded_payload += bytes(
[byte ^ masking_key[len(decoded_payload) % 4]]
)
return decoded_payload
def _encode_data_frame(fin: int, opcode: int, rsv1: int, rsv2: int, rsv3: int, payload: bytes):
def _encode_data_frame(
fin: int, opcode: int, rsv1: int, rsv2: int, rsv3: int, payload: bytes
):
assert isinstance(payload, bytes)
if not opcode in OPCODE.values():
raise ValueError(f"Unsupported opcode {opcode}. Valid values {list(OPCODE.values())}.")
if opcode not in OPCODE.values():
raise ValueError(
f"Unsupported opcode {opcode}. Valid values {list(OPCODE.values())}."
)
if fin == 0 and opcode in CONTROL_OPCODES:
raise ValueError(f"Control frames cannot be fragmented")
raise ValueError("Control frames cannot be fragmented")
if opcode in CONTROL_OPCODES and len(payload) > 125:
raise ValueError("Control frame cannot have a payload bigger than 125 bytes.")
raise ValueError(
"Control frame cannot have a payload bigger than 125 bytes."
)
payload_length = len(payload)
length_bits = 7
if payload_length > 0x7fffffffffffffff:
raise ValueError(f"Payload maximal size exceeded (provided {payload_length} bytes).")
elif payload_length > 0xffff:
if payload_length > 0x7FFFFFFFFFFFFFFF:
raise ValueError(
f"Payload maximal size exceeded (provided {payload_length} bytes)."
)
elif payload_length > 0xFFFF:
length_bits = 7 + 64
elif payload_length > 125:
length_bits = 7 + 16
@ -162,10 +209,13 @@ def _encode_data_frame(fin: int, opcode: int, rsv1: int, rsv2: int, rsv3: int, p
return frame
def _handle_control_frame(wfile: BufferedIOBase, frame: dict):
if frame["opcode"] == OPCODE["close"]:
code = frame["status_code"] if "status_code" in frame else None
payload = code.to_bytes(2, byteorder="big") if code is not None else b""
payload = (
code.to_bytes(2, byteorder="big") if code is not None else b""
)
reason = frame["close_reason"] if "close_reason" in frame else "-"
send_message(wfile, OPCODE["close"], payload)
raise WebSocketCloseException(code, reason)
@ -173,26 +223,29 @@ def _handle_control_frame(wfile: BufferedIOBase, frame: dict):
payload = frame["payload"] if "payload" in frame else b""
send_message(wfile, OPCODE["pong"], payload)
def _read_data_frame(rfile: BufferedIOBase):
frame = {}
#char = rfile.read(1)
#if len(char) == 0:
# return
net_bytes = ord(rfile.read(1))
frame["FIN"] = net_bytes >> 7
frame["RSV1"] = (net_bytes & 0x40) >> 6
frame["RSV2"] = (net_bytes & 0x20) >> 5
frame["RSV3"] = (net_bytes & 0x10) >> 4
frame["opcode"] = net_bytes & 0x0f
frame["opcode"] = net_bytes & 0x0F
if frame["RSV1"] != 0 or frame["RSV2"] != 0 or frame["RSV3"] != 0:
raise FrameError("Unsupported feature. RSV1, RSV2 or RSV3 has a non-zero value.", frame)
raise FrameError(
"Unsupported feature. RSV1, RSV2 or RSV3 has a non-zero value.",
frame,
)
if not frame["opcode"] in OPCODE.values():
raise FrameError("Unsupported opcode value.", frame)
if frame["FIN"] == 0 and frame["opcode"] != OPCODE["continueation"]:
raise FrameError("FIN bit not set for a non-continueation frame.", frame)
raise FrameError(
"FIN bit not set for a non-continueation frame.", frame
)
if frame["opcode"] in CONTROL_OPCODES and frame["FIN"] == 0:
raise FrameError("FIN bit not set for a control frame.", frame)
@ -203,7 +256,7 @@ def _read_data_frame(rfile: BufferedIOBase):
if mask_bit == 0:
raise FrameError("Unmasked frame from client.", frame)
length1 = net_bytes & 0x7f
length1 = net_bytes & 0x7F
if frame["opcode"] in CONTROL_OPCODES and length1 > 125:
raise FrameError("Control frame with invalid payload length.", frame)
@ -211,20 +264,25 @@ def _read_data_frame(rfile: BufferedIOBase):
try:
length = _read_payload_length(length1, rfile)
except InvalidLengthError as error:
raise FrameError(f"Invalid payload length of {error.length} bytes.", frame)
raise FrameError(
f"Invalid payload length of {error.length} bytes.", frame
)
masking_key = rfile.read(4)
encoded_payload = rfile.read(length)
frame["payload"] = _decode_payload(masking_key, encoded_payload)
if frame["opcode"] == OPCODE["close"] and frame["payload"]:
frame["status_code"] = int.from_bytes(frame["payload"][0:2], byteorder="big")
frame["status_code"] = int.from_bytes(
frame["payload"][0:2], byteorder="big"
)
if length > 2:
# /!\ may raise UnicodeError /!\
frame["close_reason"] = frame["payload"][2:].decode()
return frame
def _read_payload_length(payload_length1: int, rfile: BufferedIOBase):
final_length = payload_length1
if payload_length1 == 126:
@ -235,33 +293,51 @@ def _read_payload_length(payload_length1: int, rfile: BufferedIOBase):
raise InvalidLengthError(final_length)
return final_length
def close(wfile: BufferedIOBase, code=1000, reason=None):
code_bytes = code.to_bytes(2, byteorder="big")
payload = code_bytes if reason is None else code_bytes + reason.encode()
frame = _encode_data_frame(1, OPCODE["close"], 0, 0, 0, payload)
wfile.write(frame)
def decrypt_pgp_json(wfile: BufferedIOBase, gpg_context: GPGContext, message: WebSocketMessage, signature_key=None):
def decrypt_pgp_json(
wfile: BufferedIOBase,
gpg_context: GPGContext,
message: WebSocketMessage,
signature_key=None,
):
if message.opcode != OPCODE["text"]:
close(wfile, 1003, "Only text datatype is allowed.")
raise WebSocketCloseException(1003, "Received no-text datatype.")
try:
plaintext, result, verify_result = gpg_context.decrypt(message.payload.encode())
plaintext, result, verify_result = gpg_context.decrypt(
message.payload.encode()
)
except GPGMEError as error:
close(wfile, 1002, f"Failed to decrypt message.")
raise WebSocketCloseException(1002, f"Failed to decrypt websocket message: {error}.")
close(wfile, 1002, "Failed to decrypt message.")
raise WebSocketCloseException(
1002, f"Failed to decrypt websocket message: {error}."
)
if len(verify_result.signatures) == 0:
close(wfile, 1008, "No signature recognized.")
raise WebSocketCloseException(1008, "Message with missing or unknown signature.")
if signature_key is not None and signature_key.fpr != verify_result.signatures[0].fpr:
raise WebSocketCloseException(
1008, "Message with missing or unknown signature."
)
if (
signature_key is not None
and signature_key.fpr != verify_result.signatures[0].fpr
):
close(wfile, 1008, "Message signature check failed.")
raise WebSocketCloseException(1008, "Signing keys doesn't match.")
signature_age = int(time()) - verify_result.signatures[0].timestamp
if signature_age > PGP_JSON_SIGNATURE_VALIDITY:
close(wfile, 1008, "Signature too old.")
raise WebSocketCloseException(1008, f"Message with a {signature_age}s old signature.")
raise WebSocketCloseException(
1008, f"Message with a {signature_age}s old signature."
)
try:
json_ = json.loads(plaintext)
@ -271,27 +347,36 @@ def decrypt_pgp_json(wfile: BufferedIOBase, gpg_context: GPGContext, message: We
return json_, verify_result.signatures
def encrypt_pgp_json(obj: dict, recipient, gpg_context: GPGContext):
json_ = json.dumps(obj).encode()
ciphertext, result, sign_result = gpg_context.encrypt(json_,
recipients=[recipient],
sign=True,
always_trust=True)
ciphertext, result, sign_result = gpg_context.encrypt(
json_, recipients=[recipient], sign=True, always_trust=True
)
return ciphertext
def handshake(request_handler: BaseHTTPRequestHandler, subprotocols=[]):
request_handler.protocol_version = HTTP_VERSION
_check_required_headers(request_handler)
websocket_key = request_handler.headers.get("Sec-WebSocket-Key")
digest = b64encode(sha1((websocket_key + _WEBSOCKET_GUID).encode()).digest())
digest = b64encode(
sha1((websocket_key + _WEBSOCKET_GUID).encode()).digest()
)
websocket_version = request_handler.headers.get("Sec-WebSocket-Version")
if int(websocket_version) != PROTOCOL_VERSION:
request_handler.send_response(400)
request_handler.send_header("Sec-WebSocket-Version", PROTOCOL_VERSION)
request_handler.end_headers()
raise HandshakeError("Sec-WebSocket-Version", websocket_version, HandshakeError.reason["incompatible_version"])
raise HandshakeError(
"Sec-WebSocket-Version",
websocket_version,
HandshakeError.reason["incompatible_version"],
)
selected_subprotocol = None
requested_subprotocols = request_handler.headers.get("Sec-WebSocket-Protocol")
requested_subprotocols = request_handler.headers.get(
"Sec-WebSocket-Protocol"
)
if requested_subprotocols:
requested_subprotocols = requested_subprotocols.split(",")
for proto in requested_subprotocols:
@ -304,11 +389,14 @@ def handshake(request_handler: BaseHTTPRequestHandler, subprotocols=[]):
request_handler.send_header("Connection", "Upgrade")
request_handler.send_header("Sec-WebSocket-Accept", digest.decode())
if requested_subprotocols and selected_subprotocol is not None:
request_handler.send_header("Sec-WebSocket-Protocol", selected_subprotocol)
request_handler.send_header(
"Sec-WebSocket-Protocol", selected_subprotocol
)
request_handler.end_headers()
return selected_subprotocol
def read_next_message(rfile: BufferedIOBase, wfile: BufferedIOBase):
frame = _read_data_frame(rfile)
message = WebSocketMessage(frame["opcode"], frame["payload"])
@ -319,10 +407,20 @@ def read_next_message(rfile: BufferedIOBase, wfile: BufferedIOBase):
_handle_control_frame(wfile, frame)
return read_next_message(rfile, wfile)
else:
return message + read_next_message(rfile)
return message + read_next_message(rfile, wfile)
def send_message(wfile: BufferedIOBase, payload: bytes, opcode=OPCODE["text"], rsv1=0, rsv2=0, rsv3=0):
if len(payload) > 0x7fffffffffffffff:
raise ValueError(f"Payload to big. Sending fragmented messages not implemented.")
def send_message(
wfile: BufferedIOBase,
payload: bytes,
opcode=OPCODE["text"],
rsv1=0,
rsv2=0,
rsv3=0,
):
if len(payload) > 0x7FFFFFFFFFFFFFFF:
raise ValueError(
"Payload to big. Sending fragmented messages not implemented."
)
frame = _encode_data_frame(1, opcode, rsv1, rsv2, rsv3, payload)
wfile.write(frame)