parent
83f5c3e2ad
commit
9270547776
250
server.py
250
server.py
|
@ -1,43 +1,51 @@
|
|||
#!/usr/bin/env python3
|
||||
|
||||
import gpg
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
from gpg.gpgme import GPG_ERR_NO_ERROR, GPGME_DELETE_FORCE, gpgme_op_delete_ext
|
||||
from http.server import BaseHTTPRequestHandler, HTTPServer
|
||||
from secrets import token_urlsafe
|
||||
from sys import argv, exit as sysexit
|
||||
from sys import argv
|
||||
from sys import exit as sysexit
|
||||
from tempfile import TemporaryDirectory
|
||||
|
||||
import gpg
|
||||
from gpg.gpgme import GPG_ERR_NO_ERROR, GPGME_DELETE_FORCE, gpgme_op_delete_ext
|
||||
|
||||
import websocket
|
||||
from database import DataBase
|
||||
|
||||
NONCE_BYTES = 128
|
||||
ADMIN_ACCESS = 100
|
||||
|
||||
|
||||
class JsonMissingFieldException(Exception):
|
||||
def __init__(self, missing):
|
||||
self.missing = missing
|
||||
|
||||
|
||||
|
||||
class RequestHandler(BaseHTTPRequestHandler):
|
||||
_RE_FILES = re.compile(r"^(/|/admin\.html|/data\.html|/forms\.html|/config\.js|/public\.js|/registered\.js|/openpgp\.min\.js)$")
|
||||
_RE_FILES = re.compile(
|
||||
r"^(/|/admin\.html|/data\.html|/forms\.html|/config\.js|/public\.js|/registered\.js|/openpgp\.min\.js)$"
|
||||
)
|
||||
_RE_FORM_NAME = re.compile(r"^([-_0-9a-zA-Z]{1,64})$")
|
||||
_RE_KEY_FINGERPRINT = re.compile(r"^([0-9a-fA-F]{40})$")
|
||||
_RE_PGP_JSON_REQUEST_FIELD = re.compile(r"^[0-9a-zA-Z_/]{3,100}$")
|
||||
_VALID_CONTENT_SUBTYPES = {"html", "javascript", "json", "plain"}
|
||||
|
||||
def _api_answer_pgp_json(self, payload):
|
||||
request_answer = websocket.encrypt_pgp_json({
|
||||
"request": self._pgpjson_pending_request,
|
||||
"payload": payload
|
||||
}, self._pgpjson_client_key, self._gpg_context)
|
||||
request_answer = websocket.encrypt_pgp_json(
|
||||
{"request": self._pgpjson_pending_request, "payload": payload},
|
||||
self._pgpjson_client_key,
|
||||
self._gpg_context,
|
||||
)
|
||||
websocket.send_message(self.wfile, request_answer)
|
||||
self._close_websocket()
|
||||
|
||||
def _api_collect_data(self):
|
||||
form_name = re.match(self._RE_FORM_NAME, self.path.replace("/post/", ""))
|
||||
form_name = re.match(
|
||||
self._RE_FORM_NAME, self.path.replace("/post/", "")
|
||||
)
|
||||
if form_name:
|
||||
secret = self._read().decode()
|
||||
self.server.db.store_form_data(form_name.group(1), secret)
|
||||
|
@ -45,11 +53,17 @@ class RequestHandler(BaseHTTPRequestHandler):
|
|||
self.end_headers()
|
||||
|
||||
def _api_get_collected_data(self):
|
||||
form_name = re.match(self._RE_FORM_NAME, self._pgpjson_pending_request.replace("data/get/", ""))
|
||||
form_name = re.match(
|
||||
self._RE_FORM_NAME,
|
||||
self._pgpjson_pending_request.replace("data/get/", ""),
|
||||
)
|
||||
if form_name:
|
||||
form_name = form_name.group(1)
|
||||
form_users = self.server.db.get_form_keys_fingerprints(form_name)
|
||||
if self._pgpjson_user["fingerprint"] not in form_users["fingerprints"]:
|
||||
if (
|
||||
self._pgpjson_user["fingerprint"]
|
||||
not in form_users["fingerprints"]
|
||||
):
|
||||
self._close_websocket(1008, "Access denied.")
|
||||
else:
|
||||
collected_data = self.server.db.get_collected_data(form_name)
|
||||
|
@ -62,18 +76,29 @@ class RequestHandler(BaseHTTPRequestHandler):
|
|||
self._close_websocket(1008, "Access denied.")
|
||||
return
|
||||
|
||||
form_name = re.match(self._RE_FORM_NAME, self._pgpjson_pending_request.replace("form/get/", ""))
|
||||
form_name = re.match(
|
||||
self._RE_FORM_NAME,
|
||||
self._pgpjson_pending_request.replace("form/get/", ""),
|
||||
)
|
||||
if form_name:
|
||||
form_key_list = self.server.db.get_form_keys_fingerprints(form_name.group(1))
|
||||
form_key_list = self.server.db.get_form_keys_fingerprints(
|
||||
form_name.group(1)
|
||||
)
|
||||
self._api_answer_pgp_json(form_key_list)
|
||||
else:
|
||||
self._close_websocket(1002, "Invalid form name.")
|
||||
|
||||
def _api_get_user(self):
|
||||
user_fingerprint = re.match(self._RE_KEY_FINGERPRINT, self._pgpjson_pending_request.replace("user/", ""))
|
||||
user_fingerprint = re.match(
|
||||
self._RE_KEY_FINGERPRINT,
|
||||
self._pgpjson_pending_request.replace("user/", ""),
|
||||
)
|
||||
if user_fingerprint:
|
||||
fpr = user_fingerprint.group(1)
|
||||
if self._pgpjson_user["access_level"] == ADMIN_ACCESS or self._pgpjson_user["fingerprint"] == fpr:
|
||||
if (
|
||||
self._pgpjson_user["access_level"] == ADMIN_ACCESS
|
||||
or self._pgpjson_user["fingerprint"] == fpr
|
||||
):
|
||||
user = self.server.db.get_user(fpr)
|
||||
self._api_answer_pgp_json(user)
|
||||
else:
|
||||
|
@ -89,15 +114,27 @@ class RequestHandler(BaseHTTPRequestHandler):
|
|||
|
||||
def _api_get_user_key(self, fingerprints=None):
|
||||
keys = list()
|
||||
if fingerprints == None:
|
||||
if fingerprints is None:
|
||||
for user in self.server.db.get_user():
|
||||
keys.append({ "fingerprint": user["fingerprint"],
|
||||
"armored_key": (self._gpg_context.key_export(user["fingerprint"])).decode()})
|
||||
keys.append(
|
||||
{
|
||||
"fingerprint": user["fingerprint"],
|
||||
"armored_key": (
|
||||
self._gpg_context.key_export(user["fingerprint"])
|
||||
).decode(),
|
||||
}
|
||||
)
|
||||
else:
|
||||
for fpr in fingerprints:
|
||||
if self.server.db.get_user(fpr) is not None:
|
||||
keys.append({ "fingerprint": fpr,
|
||||
"armored_key": (self._gpg_context.key_export(fpr)).decode()})
|
||||
keys.append(
|
||||
{
|
||||
"fingerprint": fpr,
|
||||
"armored_key": (
|
||||
self._gpg_context.key_export(fpr)
|
||||
).decode(),
|
||||
}
|
||||
)
|
||||
else:
|
||||
self._close_websocket(1008, "Asking for an unknown key.")
|
||||
return
|
||||
|
@ -106,14 +143,21 @@ class RequestHandler(BaseHTTPRequestHandler):
|
|||
def _api_register_user(self):
|
||||
try:
|
||||
req = self._read_json(["pubKey"])
|
||||
except JsonMissingFieldException:
|
||||
self._api_register_user_answer(400, f"Missing {missingField.missing}.")
|
||||
except JsonMissingFieldException as missing_field:
|
||||
self._api_register_user_answer(
|
||||
400, f"Missing {missing_field.missing}."
|
||||
)
|
||||
return
|
||||
results = self.server.gpg_context.key_import(req["pubKey"].encode())
|
||||
if results == "IMPORT_PROBLEM" or not results.considered:
|
||||
self._api_register_user_answer(400, "Invalid pubKey.")
|
||||
elif results.imported:
|
||||
access_level = ADMIN_ACCESS if len(self.server.db.get_user()) == 0 else 0 # first user get admin access level
|
||||
access_level = (
|
||||
ADMIN_ACCESS
|
||||
if len(self.server.db.get_user())
|
||||
== 0 # first user get admin access level
|
||||
else 0
|
||||
)
|
||||
self.server.db.add_user(results.imports[0].fpr, access_level)
|
||||
print(f"Imported key {results.imports[0].fpr}.")
|
||||
self._api_register_user_answer(201, "Key registered.")
|
||||
|
@ -121,20 +165,25 @@ class RequestHandler(BaseHTTPRequestHandler):
|
|||
self._api_register_user_answer(200, "Key already on server.")
|
||||
else:
|
||||
self._api_register_user_answer(418, "What happened?")
|
||||
print(f"Tried to add following pubKey and got strange results:\n{req[pubKey]}\n\n{results}")
|
||||
print(
|
||||
f"Tried to add following pubKey and got strange results:\n{req['pubKey']}\n\n{results}"
|
||||
)
|
||||
|
||||
def _api_register_user_answer(self, code: int, message: str):
|
||||
self.send_response(code)
|
||||
self._set_content_type("json")
|
||||
self.end_headers()
|
||||
self.wfile.write(json.dumps({ "message": message }).encode())
|
||||
self.wfile.write(json.dumps({"message": message}).encode())
|
||||
|
||||
def _api_set_collected_data(self):
|
||||
if self._pgpjson_user["access_level"] != ADMIN_ACCESS:
|
||||
self._close_websocket(1008, "Access denied.")
|
||||
return
|
||||
|
||||
form_name = re.match(self._RE_FORM_NAME, self._pgpjson_pending_request.replace("data/set/", ""))
|
||||
form_name = re.match(
|
||||
self._RE_FORM_NAME,
|
||||
self._pgpjson_pending_request.replace("data/set/", ""),
|
||||
)
|
||||
if form_name:
|
||||
if self._pgpjson_payload is None:
|
||||
self._close_websocket(1002, "Missing payload.")
|
||||
|
@ -143,10 +192,14 @@ class RequestHandler(BaseHTTPRequestHandler):
|
|||
if "secret" not in data or "id" not in data:
|
||||
self._close_websocket(1002, "Invalid payload")
|
||||
return
|
||||
elif not data["secret"].startswith("-----BEGIN PGP MESSAGE-----"):
|
||||
elif not data["secret"].startswith(
|
||||
"-----BEGIN PGP MESSAGE-----"
|
||||
):
|
||||
self._close_websocket(1002, "Invalid PGP message")
|
||||
return
|
||||
self.server.db.set_collected_data(form_name.group(1), self._pgpjson_payload)
|
||||
self.server.db.set_collected_data(
|
||||
form_name.group(1), self._pgpjson_payload
|
||||
)
|
||||
self._close_websocket()
|
||||
else:
|
||||
self._close_websocket(1002, "Invalid form name.")
|
||||
|
@ -156,10 +209,19 @@ class RequestHandler(BaseHTTPRequestHandler):
|
|||
self._close_websocket(1008, "Access denied.")
|
||||
return
|
||||
|
||||
form_name = re.match(self._RE_FORM_NAME, self._pgpjson_pending_request.replace("form/set/", ""))
|
||||
form_name = re.match(
|
||||
self._RE_FORM_NAME,
|
||||
self._pgpjson_pending_request.replace("form/set/", ""),
|
||||
)
|
||||
if form_name:
|
||||
if self._pgpjson_payload is not None and "fingerprints" in self._pgpjson_payload and isinstance(self._pgpjson_payload["fingerprints"], list):
|
||||
self.server.db.set_form_keys_fingerprints(form_name.group(1), self._pgpjson_payload["fingerprints"])
|
||||
if (
|
||||
self._pgpjson_payload is not None
|
||||
and "fingerprints" in self._pgpjson_payload
|
||||
and isinstance(self._pgpjson_payload["fingerprints"], list)
|
||||
):
|
||||
self.server.db.set_form_keys_fingerprints(
|
||||
form_name.group(1), self._pgpjson_payload["fingerprints"]
|
||||
)
|
||||
self._close_websocket()
|
||||
else:
|
||||
self._close_websocket(1002, "Invalid request payload.")
|
||||
|
@ -170,29 +232,43 @@ class RequestHandler(BaseHTTPRequestHandler):
|
|||
if self._pgpjson_user["access_level"] == ADMIN_ACCESS:
|
||||
self._close_websocket(1008, "Cannot delete admin account.")
|
||||
return
|
||||
result = gpgme_op_delete_ext(self._gpg_context.wrapped, self._pgpjson_client_key, GPGME_DELETE_FORCE)
|
||||
result = gpgme_op_delete_ext(
|
||||
self._gpg_context.wrapped,
|
||||
self._pgpjson_client_key,
|
||||
GPGME_DELETE_FORCE,
|
||||
)
|
||||
if result == GPG_ERR_NO_ERROR:
|
||||
self.server.db.delete_user(self._pgpjson_client_key.fpr)
|
||||
print(f"Key {self._pgpjson_client_key.fpr} deleted.")
|
||||
self._close_websocket()
|
||||
else:
|
||||
print(f"Failed to delete key {self._pgpjson_client_key.fpr}, status code: {result}.")
|
||||
print(
|
||||
f"Failed to delete key {self._pgpjson_client_key.fpr}, status code: {result}."
|
||||
)
|
||||
self._close_websocket(1011, "Failed to delete key.")
|
||||
|
||||
def _close_websocket(self, code=1000, reason=None):
|
||||
websocket.close(self.wfile, code, reason)
|
||||
self._websocket_connected = False
|
||||
print(f"Closed WebSocket connection with {self.client_address[0]}:{self.client_address[1]}")
|
||||
print(
|
||||
f"Closed WebSocket connection with {self.client_address[0]}:{self.client_address[1]}"
|
||||
)
|
||||
if code != 1000:
|
||||
print(f"Code: {code} reason: {reason}")
|
||||
|
||||
|
||||
def _handle_api_request(self, message: websocket.WebSocketMessage):
|
||||
signature_key = self._pgpjson_client_key if self._pgpjson_client_key is not None else None
|
||||
json_, signature = websocket.decrypt_pgp_json(self.wfile, self._gpg_context,
|
||||
message, signature_key)
|
||||
signature_key = (
|
||||
self._pgpjson_client_key
|
||||
if self._pgpjson_client_key is not None
|
||||
else None
|
||||
)
|
||||
json_, signature = websocket.decrypt_pgp_json(
|
||||
self.wfile, self._gpg_context, message, signature_key
|
||||
)
|
||||
|
||||
if "request" not in json_ or not re.match(self._RE_PGP_JSON_REQUEST_FIELD, json_["request"]):
|
||||
if "request" not in json_ or not re.match(
|
||||
self._RE_PGP_JSON_REQUEST_FIELD, json_["request"]
|
||||
):
|
||||
self._close_websocket(1002, "Missing or invalid request field.")
|
||||
return
|
||||
|
||||
|
@ -200,14 +276,18 @@ class RequestHandler(BaseHTTPRequestHandler):
|
|||
self._pgpjson_pending_request = json_["request"]
|
||||
if "payload" in json_:
|
||||
self._pgpjson_payload = json_["payload"]
|
||||
self._pgpjson_client_key = self._gpg_context.get_key(signature[0].fpr)
|
||||
self._pgpjson_client_key = self._gpg_context.get_key(
|
||||
signature[0].fpr
|
||||
)
|
||||
self._request_signature()
|
||||
return
|
||||
elif "nonce" not in json_ or json_["nonce"] != self._pgpjson_nonce:
|
||||
self._close_websocket(1002, "Authentication failed.")
|
||||
return
|
||||
|
||||
self._pgpjson_user = self.server.db.get_user(self._pgpjson_client_key.fpr)
|
||||
self._pgpjson_user = self.server.db.get_user(
|
||||
self._pgpjson_client_key.fpr
|
||||
)
|
||||
|
||||
# API "routes"
|
||||
if self._pgpjson_pending_request.startswith("data/get/"):
|
||||
|
@ -232,20 +312,26 @@ class RequestHandler(BaseHTTPRequestHandler):
|
|||
self._pgpjson_pending_request = None
|
||||
self._pgpjson_payload = None
|
||||
self._pgpjson_user = None
|
||||
self._gpg_context = gpg.Context(armor=True,
|
||||
home_dir=self.server.gpg_context.home_dir,
|
||||
offline=True,
|
||||
signers=[self.server.key])
|
||||
|
||||
self._gpg_context = gpg.Context(
|
||||
armor=True,
|
||||
home_dir=self.server.gpg_context.home_dir,
|
||||
offline=True,
|
||||
signers=[self.server.key],
|
||||
)
|
||||
|
||||
try:
|
||||
proto = websocket.handshake(self, ["pgp-json"])
|
||||
if proto is None:
|
||||
self._close_websocket(1002, "I only speak pgp-json")
|
||||
return
|
||||
self._websocket_connected = True
|
||||
print(f"WebSocket connection with {self.client_address[0]}:{self.client_address[1]}")
|
||||
print(
|
||||
f"WebSocket connection with {self.client_address[0]}:{self.client_address[1]}"
|
||||
)
|
||||
except websocket.HandshakeError as error:
|
||||
print(f"WebSocket handshake failed. {error.get_reason(error.why)}: {error.what}={error.value}")
|
||||
print(
|
||||
f"WebSocket handshake failed. {error.get_reason(error.why)}: {error.what}={error.value}"
|
||||
)
|
||||
return
|
||||
|
||||
try:
|
||||
|
@ -253,7 +339,9 @@ class RequestHandler(BaseHTTPRequestHandler):
|
|||
message = websocket.read_next_message(self.rfile, self.wfile)
|
||||
self._handle_api_request(message)
|
||||
except websocket.WebSocketCloseException as close:
|
||||
print(f"WebSocket closed with code {close.code}, reason {close.reason}")
|
||||
print(
|
||||
f"WebSocket closed with code {close.code}, reason {close.reason}"
|
||||
)
|
||||
|
||||
def _read(self):
|
||||
if not self.headers["Content-Length"]:
|
||||
|
@ -267,16 +355,20 @@ class RequestHandler(BaseHTTPRequestHandler):
|
|||
except json.decoder.JSONDecodeError:
|
||||
data = {}
|
||||
for f in required_fields:
|
||||
if not f in data:
|
||||
if f not in data:
|
||||
raise JsonMissingFieldException(missing=f)
|
||||
return data
|
||||
|
||||
def _request_signature(self):
|
||||
self._pgpjson_nonce = token_urlsafe(NONCE_BYTES)
|
||||
signature_request = websocket.encrypt_pgp_json({
|
||||
"nonce": self._pgpjson_nonce,
|
||||
"request": self._pgpjson_pending_request
|
||||
}, self._pgpjson_client_key, self._gpg_context)
|
||||
signature_request = websocket.encrypt_pgp_json(
|
||||
{
|
||||
"nonce": self._pgpjson_nonce,
|
||||
"request": self._pgpjson_pending_request,
|
||||
},
|
||||
self._pgpjson_client_key,
|
||||
self._gpg_context,
|
||||
)
|
||||
websocket.send_message(self.wfile, signature_request)
|
||||
|
||||
def _serve_err404(self):
|
||||
|
@ -302,7 +394,9 @@ class RequestHandler(BaseHTTPRequestHandler):
|
|||
self.wfile.write(output)
|
||||
|
||||
def _serve_form_keys(self):
|
||||
form_name = re.match(self._RE_FORM_NAME, self.path.replace("/formkeys/", ""))
|
||||
form_name = re.match(
|
||||
self._RE_FORM_NAME, self.path.replace("/formkeys/", "")
|
||||
)
|
||||
if form_name:
|
||||
form_name = form_name.group(1)
|
||||
keys = []
|
||||
|
@ -313,13 +407,17 @@ class RequestHandler(BaseHTTPRequestHandler):
|
|||
self._set_cors()
|
||||
self._set_content_type("json")
|
||||
self.end_headers()
|
||||
self.wfile.write(json.dumps({ "form": form_name, "keys": keys }).encode())
|
||||
self.wfile.write(
|
||||
json.dumps({"form": form_name, "keys": keys}).encode()
|
||||
)
|
||||
else:
|
||||
self._serve_err404()
|
||||
|
||||
def _set_content_type(self, subtype):
|
||||
if subtype not in self._VALID_CONTENT_SUBTYPES:
|
||||
raise ValueError(f"RequestHandler._set_content_type: _type must be one of {self._VALID_CONTENT_SUBTYPES}.")
|
||||
raise ValueError(
|
||||
f"RequestHandler._set_content_type: _type must be one of {self._VALID_CONTENT_SUBTYPES}."
|
||||
)
|
||||
_type = "text"
|
||||
if subtype == "json":
|
||||
_type = "application"
|
||||
|
@ -336,14 +434,18 @@ class RequestHandler(BaseHTTPRequestHandler):
|
|||
self._serve_form_keys()
|
||||
elif self.path == "/key/srv":
|
||||
# May raise GPGMEerror
|
||||
key = self.server.gpg_context.key_export(self.server.db.get_config("server_key"))
|
||||
key = self.server.gpg_context.key_export(
|
||||
self.server.db.get_config("server_key")
|
||||
)
|
||||
self.send_response(200)
|
||||
self._set_cors()
|
||||
self._set_content_type("plain")
|
||||
self.end_headers()
|
||||
self.wfile.write(key)
|
||||
elif self.path.startswith("/key/"):
|
||||
key_fingerprint = re.match(self._RE_KEY_FINGERPRINT, self.path.replace("/key/", ""))
|
||||
key_fingerprint = re.match(
|
||||
self._RE_KEY_FINGERPRINT, self.path.replace("/key/", "")
|
||||
)
|
||||
key = self.server.gpg_context.key_export(key_fingerprint.group(1))
|
||||
key = key if key is not None else b"Unknown"
|
||||
self.send_response(200)
|
||||
|
@ -365,9 +467,12 @@ class RequestHandler(BaseHTTPRequestHandler):
|
|||
self.send_response(404)
|
||||
self.end_headers()
|
||||
|
||||
|
||||
class Server(HTTPServer):
|
||||
def __init__(self, listen_to, gpg_home, db):
|
||||
self.gpg_context = gpg.Context(armor=True, home_dir=gpg_home, offline=True)
|
||||
self.gpg_context = gpg.Context(
|
||||
armor=True, home_dir=gpg_home, offline=True
|
||||
)
|
||||
self.db = DataBase(db)
|
||||
self.initSrvKeys()
|
||||
super().__init__(listen_to, RequestHandler)
|
||||
|
@ -376,11 +481,19 @@ class Server(HTTPServer):
|
|||
fingerprint = self.db.get_config("server_key")
|
||||
if not fingerprint:
|
||||
try:
|
||||
result = self.gpg_context.create_key(userid="PoC server", algorithm="ed25519",
|
||||
expires=False, sign=True)
|
||||
result = self.gpg_context.create_key(
|
||||
userid="PoC server",
|
||||
algorithm="ed25519",
|
||||
expires=False,
|
||||
sign=True,
|
||||
)
|
||||
fingerprint = result.fpr
|
||||
result = self.gpg_context.create_subkey(key=self.gpg_context.get_key(fingerprint),
|
||||
algorithm="cv25519", expires=False, encrypt=True)
|
||||
result = self.gpg_context.create_subkey(
|
||||
key=self.gpg_context.get_key(fingerprint),
|
||||
algorithm="cv25519",
|
||||
expires=False,
|
||||
encrypt=True,
|
||||
)
|
||||
self.db.set_config("server_key", fingerprint)
|
||||
self.key = self.gpg_context.get_key(fingerprint)
|
||||
except gpg.errors.GPGMEError as error:
|
||||
|
@ -392,17 +505,20 @@ class Server(HTTPServer):
|
|||
else:
|
||||
try:
|
||||
self.key = self.gpg_context.get_key(fingerprint)
|
||||
except gpg.errors.KeyError as error:
|
||||
except KeyError as error:
|
||||
print(f"Cannot find server key {fingerprint}.")
|
||||
sysexit(error)
|
||||
print(f"Loaded server key {fingerprint}.")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
def run(listen_to, gpg_home, db):
|
||||
httpd = Server(listen_to, gpg_home, db)
|
||||
try:
|
||||
print(f"Starting server… listening to {listen_to[0]}:{listen_to[1]} with keystore {gpg_home}, db {db}")
|
||||
print(
|
||||
f"Starting server… listening to {listen_to[0]}:{listen_to[1]} with keystore {gpg_home}, db {db}"
|
||||
)
|
||||
httpd.serve_forever()
|
||||
except KeyboardInterrupt:
|
||||
pass
|
||||
|
|
232
websocket.py
232
websocket.py
|
@ -3,35 +3,44 @@
|
|||
# https://developer.mozilla.org/en-US/docs/Web/API/WebSockets_API/Writing_WebSocket_servers
|
||||
# https://gist.github.com/SevenW/47be2f9ab74cac26bf21/ (SevenW/HTTPWebSocketsHandler.py)
|
||||
|
||||
from base64 import b64decode, b64encode
|
||||
import json
|
||||
from gpg import Context as GPGContext
|
||||
from gpg.errors import GPGMEError
|
||||
from hashlib import sha1
|
||||
from base64 import b64decode, b64encode
|
||||
from binascii import Error as BinasciiError
|
||||
from hashlib import sha1
|
||||
from http.server import BaseHTTPRequestHandler
|
||||
from io import BufferedIOBase
|
||||
from time import time
|
||||
|
||||
from gpg import Context as GPGContext
|
||||
from gpg.errors import GPGMEError
|
||||
|
||||
_WEBSOCKET_GUID = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"
|
||||
OPCODE = { "continueation": 0x0,
|
||||
"text": 0x1,
|
||||
"binary": 0x2,
|
||||
"close": 0x8,
|
||||
"ping": 0x9,
|
||||
"pong": 0xa }
|
||||
OPCODE = {
|
||||
"continueation": 0x0,
|
||||
"text": 0x1,
|
||||
"binary": 0x2,
|
||||
"close": 0x8,
|
||||
"ping": 0x9,
|
||||
"pong": 0xA,
|
||||
}
|
||||
CONTROL_OPCODES = [OPCODE["close"], OPCODE["ping"], OPCODE["pong"]]
|
||||
NONCONTROL_OPCODES = [OPCODE["continueation"], OPCODE["text"], OPCODE["binary"]]
|
||||
NONCONTROL_OPCODES = [
|
||||
OPCODE["continueation"],
|
||||
OPCODE["text"],
|
||||
OPCODE["binary"],
|
||||
]
|
||||
PROTOCOL_VERSION = 13
|
||||
HTTP_VERSION = "HTTP/1.1"
|
||||
|
||||
PGP_JSON_SIGNATURE_VALIDITY = 60
|
||||
|
||||
|
||||
class WebSocketException(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class FrameError(WebSocketException):
|
||||
def __init__(message: str, frame: dict):
|
||||
def __init__(self, message: str, frame: dict):
|
||||
assert "FIN" in frame and frame["FIN"] is not None
|
||||
assert "RSV1" in frame and frame["RSV1"] is not None
|
||||
assert "RSV2" in frame and frame["RSV2"] is not None
|
||||
|
@ -40,10 +49,13 @@ class FrameError(WebSocketException):
|
|||
self.message = message
|
||||
self.frame = frame
|
||||
|
||||
|
||||
class HandshakeError(WebSocketException):
|
||||
reason = { "invalid_header": 0,
|
||||
"missing_header": 1,
|
||||
"incompatible_version": 2 }
|
||||
reason = {
|
||||
"invalid_header": 0,
|
||||
"missing_header": 1,
|
||||
"incompatible_version": 2,
|
||||
}
|
||||
|
||||
def __init__(self, what: str, value: str, why: int):
|
||||
assert why in self.reason.values()
|
||||
|
@ -56,20 +68,23 @@ class HandshakeError(WebSocketException):
|
|||
if v == code:
|
||||
return k
|
||||
|
||||
|
||||
class InvalidLengthError(FrameError):
|
||||
def __init__(length: int):
|
||||
def __init__(self, length: int):
|
||||
self.length = length
|
||||
|
||||
|
||||
class WebSocketCloseException(WebSocketException):
|
||||
def __init__(self, code: int, reason: str):
|
||||
self.code = code
|
||||
self.reason = reason
|
||||
|
||||
class WebSocketMessage():
|
||||
|
||||
class WebSocketMessage:
|
||||
def __add__(self, other):
|
||||
assert isinstance(other, WebSocketMessage)
|
||||
assert other["opcode"] == OPCODE["continueation"]
|
||||
if opcode == OPCODE["text"]:
|
||||
if self.opcode == OPCODE["text"]:
|
||||
# may raise UnicodeError
|
||||
self.payload += other.payload.decode()
|
||||
else:
|
||||
|
@ -86,21 +101,35 @@ class WebSocketMessage():
|
|||
def __str__(self):
|
||||
return f"WebSocket Message opcode:0x{self.opcode:x} payload:{self.payload}"
|
||||
|
||||
|
||||
def _check_required_headers(request_handler: BaseHTTPRequestHandler):
|
||||
if request_handler.request_version != HTTP_VERSION:
|
||||
request_handler.send_response(400)
|
||||
request_handler.end_headers()
|
||||
raise HandshakeError("HTTP", request_handler.request_version,
|
||||
HandshakeError.reason["incompatible_version"])
|
||||
raise HandshakeError(
|
||||
"HTTP",
|
||||
request_handler.request_version,
|
||||
HandshakeError.reason["incompatible_version"],
|
||||
)
|
||||
|
||||
for header in [["Host", None],
|
||||
["Upgrade", "websocket"],
|
||||
["Connection", "upgrade"],
|
||||
["Sec-WebSocket-Key", None],
|
||||
["Sec-WebSocket-Version", None]]:
|
||||
for header in [
|
||||
["Host", None],
|
||||
["Upgrade", "websocket"],
|
||||
["Connection", "upgrade"],
|
||||
["Sec-WebSocket-Key", None],
|
||||
["Sec-WebSocket-Version", None],
|
||||
]:
|
||||
value = request_handler.headers.get(header[0])
|
||||
requisite = header[1] in value.lower().split(", ") if header[1] is not None else value is not None
|
||||
reason = HandshakeError.reason["invalid_header"] if value is not None else HandshakeError.reason["missing_header"]
|
||||
requisite = (
|
||||
header[1] in value.lower().split(", ")
|
||||
if header[1] is not None
|
||||
else value is not None
|
||||
)
|
||||
reason = (
|
||||
HandshakeError.reason["invalid_header"]
|
||||
if value is not None
|
||||
else HandshakeError.reason["missing_header"]
|
||||
)
|
||||
if requisite is False:
|
||||
request_handler.send_response(400)
|
||||
request_handler.end_headers()
|
||||
|
@ -110,9 +139,15 @@ def _check_required_headers(request_handler: BaseHTTPRequestHandler):
|
|||
if ws_key is None:
|
||||
request_handler.send_response(400)
|
||||
request_handler.end_headers()
|
||||
raise HandshakeError("Sec-WebSocket-Key", None, HandshakeError.reason["missing_header"])
|
||||
raise HandshakeError(
|
||||
"Sec-WebSocket-Key", None, HandshakeError.reason["missing_header"]
|
||||
)
|
||||
try:
|
||||
invalid_key_error = HandshakeError("Sec-WebSocket-Key", ws_key, HandshakeError.reason["invalid_header"])
|
||||
invalid_key_error = HandshakeError(
|
||||
"Sec-WebSocket-Key",
|
||||
ws_key,
|
||||
HandshakeError.reason["invalid_header"],
|
||||
)
|
||||
decoded = b64decode(s=ws_key, altchars=None, validate=True)
|
||||
if len(decoded) != 16:
|
||||
request_handler.send_response(400)
|
||||
|
@ -123,25 +158,37 @@ def _check_required_headers(request_handler: BaseHTTPRequestHandler):
|
|||
request_handler.end_headers()
|
||||
raise invalid_key_error
|
||||
|
||||
|
||||
def _decode_payload(masking_key: bytes, encoded_payload: bytes):
|
||||
decoded_payload = bytearray()
|
||||
for byte in encoded_payload:
|
||||
decoded_payload += bytes([byte ^ masking_key[len(decoded_payload) % 4]])
|
||||
decoded_payload += bytes(
|
||||
[byte ^ masking_key[len(decoded_payload) % 4]]
|
||||
)
|
||||
return decoded_payload
|
||||
|
||||
def _encode_data_frame(fin: int, opcode: int, rsv1: int, rsv2: int, rsv3: int, payload: bytes):
|
||||
|
||||
def _encode_data_frame(
|
||||
fin: int, opcode: int, rsv1: int, rsv2: int, rsv3: int, payload: bytes
|
||||
):
|
||||
assert isinstance(payload, bytes)
|
||||
if not opcode in OPCODE.values():
|
||||
raise ValueError(f"Unsupported opcode {opcode}. Valid values {list(OPCODE.values())}.")
|
||||
if opcode not in OPCODE.values():
|
||||
raise ValueError(
|
||||
f"Unsupported opcode {opcode}. Valid values {list(OPCODE.values())}."
|
||||
)
|
||||
if fin == 0 and opcode in CONTROL_OPCODES:
|
||||
raise ValueError(f"Control frames cannot be fragmented")
|
||||
raise ValueError("Control frames cannot be fragmented")
|
||||
if opcode in CONTROL_OPCODES and len(payload) > 125:
|
||||
raise ValueError("Control frame cannot have a payload bigger than 125 bytes.")
|
||||
raise ValueError(
|
||||
"Control frame cannot have a payload bigger than 125 bytes."
|
||||
)
|
||||
payload_length = len(payload)
|
||||
length_bits = 7
|
||||
if payload_length > 0x7fffffffffffffff:
|
||||
raise ValueError(f"Payload maximal size exceeded (provided {payload_length} bytes).")
|
||||
elif payload_length > 0xffff:
|
||||
if payload_length > 0x7FFFFFFFFFFFFFFF:
|
||||
raise ValueError(
|
||||
f"Payload maximal size exceeded (provided {payload_length} bytes)."
|
||||
)
|
||||
elif payload_length > 0xFFFF:
|
||||
length_bits = 7 + 64
|
||||
elif payload_length > 125:
|
||||
length_bits = 7 + 16
|
||||
|
@ -162,10 +209,13 @@ def _encode_data_frame(fin: int, opcode: int, rsv1: int, rsv2: int, rsv3: int, p
|
|||
|
||||
return frame
|
||||
|
||||
|
||||
def _handle_control_frame(wfile: BufferedIOBase, frame: dict):
|
||||
if frame["opcode"] == OPCODE["close"]:
|
||||
code = frame["status_code"] if "status_code" in frame else None
|
||||
payload = code.to_bytes(2, byteorder="big") if code is not None else b""
|
||||
payload = (
|
||||
code.to_bytes(2, byteorder="big") if code is not None else b""
|
||||
)
|
||||
reason = frame["close_reason"] if "close_reason" in frame else "-"
|
||||
send_message(wfile, OPCODE["close"], payload)
|
||||
raise WebSocketCloseException(code, reason)
|
||||
|
@ -173,26 +223,29 @@ def _handle_control_frame(wfile: BufferedIOBase, frame: dict):
|
|||
payload = frame["payload"] if "payload" in frame else b""
|
||||
send_message(wfile, OPCODE["pong"], payload)
|
||||
|
||||
|
||||
def _read_data_frame(rfile: BufferedIOBase):
|
||||
frame = {}
|
||||
#char = rfile.read(1)
|
||||
#if len(char) == 0:
|
||||
# return
|
||||
net_bytes = ord(rfile.read(1))
|
||||
frame["FIN"] = net_bytes >> 7
|
||||
frame["RSV1"] = (net_bytes & 0x40) >> 6
|
||||
frame["RSV2"] = (net_bytes & 0x20) >> 5
|
||||
frame["RSV3"] = (net_bytes & 0x10) >> 4
|
||||
frame["opcode"] = net_bytes & 0x0f
|
||||
frame["opcode"] = net_bytes & 0x0F
|
||||
|
||||
if frame["RSV1"] != 0 or frame["RSV2"] != 0 or frame["RSV3"] != 0:
|
||||
raise FrameError("Unsupported feature. RSV1, RSV2 or RSV3 has a non-zero value.", frame)
|
||||
raise FrameError(
|
||||
"Unsupported feature. RSV1, RSV2 or RSV3 has a non-zero value.",
|
||||
frame,
|
||||
)
|
||||
|
||||
if not frame["opcode"] in OPCODE.values():
|
||||
raise FrameError("Unsupported opcode value.", frame)
|
||||
|
||||
if frame["FIN"] == 0 and frame["opcode"] != OPCODE["continueation"]:
|
||||
raise FrameError("FIN bit not set for a non-continueation frame.", frame)
|
||||
raise FrameError(
|
||||
"FIN bit not set for a non-continueation frame.", frame
|
||||
)
|
||||
|
||||
if frame["opcode"] in CONTROL_OPCODES and frame["FIN"] == 0:
|
||||
raise FrameError("FIN bit not set for a control frame.", frame)
|
||||
|
@ -203,7 +256,7 @@ def _read_data_frame(rfile: BufferedIOBase):
|
|||
if mask_bit == 0:
|
||||
raise FrameError("Unmasked frame from client.", frame)
|
||||
|
||||
length1 = net_bytes & 0x7f
|
||||
length1 = net_bytes & 0x7F
|
||||
|
||||
if frame["opcode"] in CONTROL_OPCODES and length1 > 125:
|
||||
raise FrameError("Control frame with invalid payload length.", frame)
|
||||
|
@ -211,20 +264,25 @@ def _read_data_frame(rfile: BufferedIOBase):
|
|||
try:
|
||||
length = _read_payload_length(length1, rfile)
|
||||
except InvalidLengthError as error:
|
||||
raise FrameError(f"Invalid payload length of {error.length} bytes.", frame)
|
||||
raise FrameError(
|
||||
f"Invalid payload length of {error.length} bytes.", frame
|
||||
)
|
||||
|
||||
masking_key = rfile.read(4)
|
||||
encoded_payload = rfile.read(length)
|
||||
frame["payload"] = _decode_payload(masking_key, encoded_payload)
|
||||
|
||||
if frame["opcode"] == OPCODE["close"] and frame["payload"]:
|
||||
frame["status_code"] = int.from_bytes(frame["payload"][0:2], byteorder="big")
|
||||
frame["status_code"] = int.from_bytes(
|
||||
frame["payload"][0:2], byteorder="big"
|
||||
)
|
||||
if length > 2:
|
||||
# /!\ may raise UnicodeError /!\
|
||||
frame["close_reason"] = frame["payload"][2:].decode()
|
||||
|
||||
return frame
|
||||
|
||||
|
||||
def _read_payload_length(payload_length1: int, rfile: BufferedIOBase):
|
||||
final_length = payload_length1
|
||||
if payload_length1 == 126:
|
||||
|
@ -235,33 +293,51 @@ def _read_payload_length(payload_length1: int, rfile: BufferedIOBase):
|
|||
raise InvalidLengthError(final_length)
|
||||
return final_length
|
||||
|
||||
|
||||
def close(wfile: BufferedIOBase, code=1000, reason=None):
|
||||
code_bytes = code.to_bytes(2, byteorder="big")
|
||||
payload = code_bytes if reason is None else code_bytes + reason.encode()
|
||||
frame = _encode_data_frame(1, OPCODE["close"], 0, 0, 0, payload)
|
||||
wfile.write(frame)
|
||||
|
||||
def decrypt_pgp_json(wfile: BufferedIOBase, gpg_context: GPGContext, message: WebSocketMessage, signature_key=None):
|
||||
|
||||
def decrypt_pgp_json(
|
||||
wfile: BufferedIOBase,
|
||||
gpg_context: GPGContext,
|
||||
message: WebSocketMessage,
|
||||
signature_key=None,
|
||||
):
|
||||
if message.opcode != OPCODE["text"]:
|
||||
close(wfile, 1003, "Only text datatype is allowed.")
|
||||
raise WebSocketCloseException(1003, "Received no-text datatype.")
|
||||
try:
|
||||
plaintext, result, verify_result = gpg_context.decrypt(message.payload.encode())
|
||||
plaintext, result, verify_result = gpg_context.decrypt(
|
||||
message.payload.encode()
|
||||
)
|
||||
except GPGMEError as error:
|
||||
close(wfile, 1002, f"Failed to decrypt message.")
|
||||
raise WebSocketCloseException(1002, f"Failed to decrypt websocket message: {error}.")
|
||||
close(wfile, 1002, "Failed to decrypt message.")
|
||||
raise WebSocketCloseException(
|
||||
1002, f"Failed to decrypt websocket message: {error}."
|
||||
)
|
||||
|
||||
if len(verify_result.signatures) == 0:
|
||||
close(wfile, 1008, "No signature recognized.")
|
||||
raise WebSocketCloseException(1008, "Message with missing or unknown signature.")
|
||||
if signature_key is not None and signature_key.fpr != verify_result.signatures[0].fpr:
|
||||
raise WebSocketCloseException(
|
||||
1008, "Message with missing or unknown signature."
|
||||
)
|
||||
if (
|
||||
signature_key is not None
|
||||
and signature_key.fpr != verify_result.signatures[0].fpr
|
||||
):
|
||||
close(wfile, 1008, "Message signature check failed.")
|
||||
raise WebSocketCloseException(1008, "Signing keys doesn't match.")
|
||||
|
||||
signature_age = int(time()) - verify_result.signatures[0].timestamp
|
||||
if signature_age > PGP_JSON_SIGNATURE_VALIDITY:
|
||||
close(wfile, 1008, "Signature too old.")
|
||||
raise WebSocketCloseException(1008, f"Message with a {signature_age}s old signature.")
|
||||
raise WebSocketCloseException(
|
||||
1008, f"Message with a {signature_age}s old signature."
|
||||
)
|
||||
|
||||
try:
|
||||
json_ = json.loads(plaintext)
|
||||
|
@ -271,27 +347,36 @@ def decrypt_pgp_json(wfile: BufferedIOBase, gpg_context: GPGContext, message: We
|
|||
|
||||
return json_, verify_result.signatures
|
||||
|
||||
|
||||
def encrypt_pgp_json(obj: dict, recipient, gpg_context: GPGContext):
|
||||
json_ = json.dumps(obj).encode()
|
||||
ciphertext, result, sign_result = gpg_context.encrypt(json_,
|
||||
recipients=[recipient],
|
||||
sign=True,
|
||||
always_trust=True)
|
||||
ciphertext, result, sign_result = gpg_context.encrypt(
|
||||
json_, recipients=[recipient], sign=True, always_trust=True
|
||||
)
|
||||
return ciphertext
|
||||
|
||||
|
||||
def handshake(request_handler: BaseHTTPRequestHandler, subprotocols=[]):
|
||||
request_handler.protocol_version = HTTP_VERSION
|
||||
_check_required_headers(request_handler)
|
||||
websocket_key = request_handler.headers.get("Sec-WebSocket-Key")
|
||||
digest = b64encode(sha1((websocket_key + _WEBSOCKET_GUID).encode()).digest())
|
||||
digest = b64encode(
|
||||
sha1((websocket_key + _WEBSOCKET_GUID).encode()).digest()
|
||||
)
|
||||
websocket_version = request_handler.headers.get("Sec-WebSocket-Version")
|
||||
if int(websocket_version) != PROTOCOL_VERSION:
|
||||
request_handler.send_response(400)
|
||||
request_handler.send_header("Sec-WebSocket-Version", PROTOCOL_VERSION)
|
||||
request_handler.end_headers()
|
||||
raise HandshakeError("Sec-WebSocket-Version", websocket_version, HandshakeError.reason["incompatible_version"])
|
||||
raise HandshakeError(
|
||||
"Sec-WebSocket-Version",
|
||||
websocket_version,
|
||||
HandshakeError.reason["incompatible_version"],
|
||||
)
|
||||
selected_subprotocol = None
|
||||
requested_subprotocols = request_handler.headers.get("Sec-WebSocket-Protocol")
|
||||
requested_subprotocols = request_handler.headers.get(
|
||||
"Sec-WebSocket-Protocol"
|
||||
)
|
||||
if requested_subprotocols:
|
||||
requested_subprotocols = requested_subprotocols.split(",")
|
||||
for proto in requested_subprotocols:
|
||||
|
@ -304,11 +389,14 @@ def handshake(request_handler: BaseHTTPRequestHandler, subprotocols=[]):
|
|||
request_handler.send_header("Connection", "Upgrade")
|
||||
request_handler.send_header("Sec-WebSocket-Accept", digest.decode())
|
||||
if requested_subprotocols and selected_subprotocol is not None:
|
||||
request_handler.send_header("Sec-WebSocket-Protocol", selected_subprotocol)
|
||||
request_handler.send_header(
|
||||
"Sec-WebSocket-Protocol", selected_subprotocol
|
||||
)
|
||||
request_handler.end_headers()
|
||||
|
||||
return selected_subprotocol
|
||||
|
||||
|
||||
def read_next_message(rfile: BufferedIOBase, wfile: BufferedIOBase):
|
||||
frame = _read_data_frame(rfile)
|
||||
message = WebSocketMessage(frame["opcode"], frame["payload"])
|
||||
|
@ -319,10 +407,20 @@ def read_next_message(rfile: BufferedIOBase, wfile: BufferedIOBase):
|
|||
_handle_control_frame(wfile, frame)
|
||||
return read_next_message(rfile, wfile)
|
||||
else:
|
||||
return message + read_next_message(rfile)
|
||||
return message + read_next_message(rfile, wfile)
|
||||
|
||||
def send_message(wfile: BufferedIOBase, payload: bytes, opcode=OPCODE["text"], rsv1=0, rsv2=0, rsv3=0):
|
||||
if len(payload) > 0x7fffffffffffffff:
|
||||
raise ValueError(f"Payload to big. Sending fragmented messages not implemented.")
|
||||
|
||||
def send_message(
|
||||
wfile: BufferedIOBase,
|
||||
payload: bytes,
|
||||
opcode=OPCODE["text"],
|
||||
rsv1=0,
|
||||
rsv2=0,
|
||||
rsv3=0,
|
||||
):
|
||||
if len(payload) > 0x7FFFFFFFFFFFFFFF:
|
||||
raise ValueError(
|
||||
"Payload to big. Sending fragmented messages not implemented."
|
||||
)
|
||||
frame = _encode_data_frame(1, opcode, rsv1, rsv2, rsv3, payload)
|
||||
wfile.write(frame)
|
||||
|
|
Loading…
Reference in New Issue