sorryops/server.py
2024-05-03 15:35:57 -03:00

404 lines
16 KiB
Python

# SPDX-FileCopyrightText: 2024 Egor Guslyancev <electromagneticcyclone@disroot.org>
# SPDX-FileCopyrightText: 2023 Miel Donkers
#
# SPDX-License-Identifier: AGPL-3.0-or-later
from http.server import BaseHTTPRequestHandler, HTTPServer
import json
import sys
from threading import Thread
from os import listdir
from time import sleep, time
from itertools import combinations
from markdown import markdown
import db_classes
import timeout as tmo
# Simple config
GET_ONLY_FOR_VIP = False
POST_ONLY_FOR_VIP = True
VERSION = "20240503.2"
CHARSET = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789"
db = db_classes.PickleDB(".db")
db.load()
def comb(s):
result = []
for r in range(1, len(s) + 1):
result.extend([''.join(p) for p in combinations(s, r)])
return result
def resolve_logs(tries = 3, backup = True, backup_name = None):
print("Resolving…")
if tries < 1:
try:
db.pop("pending")
except KeyError:
pass
try:
db.pop("contributors")
except KeyError:
pass
print("Resolved")
return
pending = db.read("pending", [])
if len(pending) == 0:
print("Resolved")
return
if backup:
# backup_name = f"backup{str(int(time()))}.db"
backup_name = "backup.db"
db.save(backup_name)
try:
for test in pending:
logs = db.pop(f"tests.{test}.logs")
if logs is None:
continue
correct = db.pop(f"tests.{test}.correct")
if correct is None:
correct = {}
incorrect = db.pop(f"tests.{test}.incorrect")
if incorrect is None:
incorrect = {}
for l in logs:
if 'shadow' not in l:
l['shadow'] = l['answers']
for a in l['answers']:
if a == "dummy":
continue
if a[0] in correct:
if a[2] == correct[a[0]]:
l['correct'] -= 1
a[0] = '-'
if a[0] in incorrect:
if a[2] in incorrect[a[0]]:
a[0] = '-'
l['answers'] = list(filter(lambda a: a[0] != '-', l['answers']))
if len(l['answers']) == l['correct']:
for a in l['answers']:
if a == "dummy":
continue
correct[a[0]] = a[2]
if a[0] in incorrect:
incorrect.pop(a[0])
l['answers'] = []
elif l['correct'] == 0:
for a in l['answers']:
if a == "dummy":
continue
if a[0] not in incorrect:
incorrect[a[0]] = []
incorrect[a[0]] = list(set(incorrect[a[0]] + [a[2]]))
match a[2][0]:
case "[":
pass
case "{":
if len(incorrect[a[0]]) == (2**a[3] - 2): # [- 2]: Exclude empty and unknown
for c in comb(CHARSET[:a[3]]):
if c not in incorrect[a[0]]:
correct[a[0]] = c
break
incorrect.pop(a[0])
if len(incorrect[a[0]]) > (2**a[3] - 2):
incorrect[a[0]] = []
case _:
if len(incorrect[a[0]]) == (a[3] - 1):
for c in CHARSET[:a[3]]:
if c not in incorrect[a[0]]:
correct[a[0]] = c
break
incorrect.pop(a[0])
l['answers'] = []
elif l['correct'] < 0:
for a in l['shadow']:
if a[0] in incorrect:
incorrect.pop(a[0])
if a[0] in correct:
correct.pop(a[0])
l['answers'] = l['shadow']
logs = list(filter(lambda l: (len(l['answers']) != l['answers'].count("dummy")), logs))
new_logs = db.read(f"tests.{test}.logs")
if new_logs is None:
new_logs = []
db.write(f"tests.{test}.logs", logs + new_logs)
db.write(f"tests.{test}.correct", correct)
db.write(f"tests.{test}.incorrect", incorrect)
except Exception as e:
print("Something really bad happend. Recovering from backup")
print(e)
if backup_name is not None:
db.load(backup_name)
resolve_logs(tries - 1, False, backup_name)
def parse_request(r):
try:
results = json.loads(r)
# for field in ['type', 'id', 'uid', 'answers', 'correct', 'all']:
# if field not in results:
# return 400
rtype = results['type']
test_id = results['id']
user_id = results['uid']
stud_id = results['sid']
data = db.read(f"tests.{test_id}", None)
blacklist = db.read('users.blacklist', set())
vip = db.read('users.vip', set())
if user_id not in vip and POST_ONLY_FOR_VIP:
return 403
if user_id in blacklist:
return 403
match rtype:
case "test_results":
if 'access' in data:
if len(data['access']) > 0:
if f"{user_id}{stud_id}" in data['access']:
pass
elif user_id in data['access']:
data['access'].remove(user_id)
data['access'].append(f"{user_id}{stud_id}")
db.write(f"tests.{test_id}", data)
else:
return 403
answers = results['answers']
all_answ = int(results['all'])
while len(answers) != all_answ:
answers.append("dummy")
log = {
'answers': answers,
'shadow': answers,
'correct': int(results['correct']),
}
logs = db.read(f'tests.{test_id}.logs', [])
logs = logs + [log]
db.write(f'tests.{test_id}.logs', logs)
pending = db.read('pending', set())
pending = set(list(pending) + [test_id])
db.write('pending', pending)
contributors = db.read('contributors', set())
contributors = set(list(contributors) + [user_id])
db.write('contributors', contributors)
return 202
# case "add_vip":
# if user_id not in vip:
# return 403
# vip = set(list(vip) + [results['user']])
# db.write('users.vip', vip)
# case "del_vip":
# if user_id not in vip:
# return 403
# vip.remove(results['user'])
# db.write('users.vip', vip)
# case "add_blacklist":
# if user_id not in vip:
# return 403
# blacklist = set(list(blacklist) + [results['user']])
# db.write('users.blacklist', blacklist)
# case "del_blacklist":
# if user_id not in vip:
# return 403
# blacklist.remove(results['user'])
# db.write('users.blacklist', blacklist)
case _:
raise KeyError()
except KeyError:
print("Invalid request")
return 400
except json.decoder.JSONDecodeError:
print("Bad request")
return 400
except Exception as e:
print("Something bad happend")
print(e)
return 400
class S(BaseHTTPRequestHandler):
def _set_response(self, status):
self.send_response(status)
def do_GET(self):
sp = self.path.split('?')
parameters = {}
if len(sp) < 2:
self_path = sp[0]
else:
self_path, params = sp[:2]
for p in params.split('&'):
p = p.split('=')
if len(p) > 1:
parameters[p[0]] = p[1]
match self_path:
case "/":
self._set_response(200)
self.send_header('Content-type', 'text/html; charset=utf-8')
self.end_headers()
# TODO Autodetect browser language
with open("README.ru.md", "r", encoding='utf-8') as fi:
self.wfile.write(
markdown(fi.read(), extensions=['fenced_code', 'codehilite'])
.encode('utf-8'))
case "/license":
self._set_response(200)
self.send_header('Content-type', 'text/html; charset=utf-8')
self.end_headers()
for i in listdir("LICENSES"):
with open(f"LICENSES/{i}", "r", encoding='utf-8') as fi:
self.wfile.write(fi.read().encode('utf-8'))
self.wfile.write("-----------------------------------".encode('utf-8'))
case "/source":
self.send_response(301)
self.send_header(
'Location',
'https://git.disroot.org/electromagneticcyclone/sorryops'
)
self.end_headers()
case _ if self_path.startswith("/TOS."):
self._set_response(200)
self.send_header('Content-type', 'text/html; charset=utf-8')
self.end_headers()
path = self_path[1:] + ("" if self_path.endswith(".md") else ".md")
if path in listdir("./"):
with open(path, "r", encoding='utf-8') as fi:
self.wfile.write(
markdown(fi.read(), extensions=['fenced_code', 'codehilite'])
.encode('utf-8')
)
else:
self._set_response(404)
self.end_headers()
self.wfile.write("404 Not found".encode('utf-8'))
case "/favicon.ico":
self._set_response(200)
with open("favicon.ico", "rb") as fi:
data = fi.read()
self.send_header('Accept-Ranges', 'bytes')
self.send_header('Content-Disposition', 'attachment')
self.send_header('Content-Length', len(data))
self.end_headers()
self.wfile.write(data)
case _ if self_path.startswith("/yandex_"):
self._set_response(200)
with open(self_path[1:], "rb") as fi:
data = fi.read()
self.send_header('Content-type', 'text/html; charset=utf-8')
self.end_headers()
self.wfile.write(data)
case _ if self_path.startswith("/assets/"):
self._set_response(200)
with open(self_path[1:], "rb") as fi:
data = fi.read()
self.send_header('Accept-Ranges', 'bytes')
self.send_header('Content-Disposition', 'attachment')
self.send_header('Content-Length', len(data))
self.end_headers()
self.wfile.write(data)
case "/add_access":
user = parameters.get('uid', "")
target_user = parameters.get('tuid', "")
test_id = parameters.get('test', "")
if user not in db.read('users.vip'):
self._set_response(403)
self.end_headers()
self.wfile.write("403 Forbidden".encode('utf-8'))
return
if len(target_user) != 36:
self._set_response(400)
self.end_headers()
self.wfile.write("400 Bad request".encode('utf-8'))
return
data = db.read(f'tests.{test_id}', None)
if data is None:
self._set_response(400)
self.end_headers()
self.wfile.write("400 Bad request".encode('utf-8'))
return
if 'access' not in data:
data['access'] = []
data['access'].append(target_user)
data['access'] = list(set(data['access']))
db.write(f'tests.{test_id}', data)
self._set_response(200)
self.end_headers()
self.wfile.write("200 OK".encode('utf-8'))
case _:
db_path = 'tests.' + '.'.join(self_path[1:].split('/'))
data = db.read(db_path, None)
user = parameters.get('uid', "")
stud = parameters.get('sid', "")
if (user not in db.read('users.vip') and GET_ONLY_FOR_VIP) or (len(user) != 36):
self._set_response(403)
self.end_headers()
self.wfile.write("403 Forbidden".encode('utf-8'))
return
if data is None:
self._set_response(404)
self.end_headers()
self.wfile.write("404 Not found".encode('utf-8'))
return
if 'access' in data:
if len(data['access']) > 0:
if f"{user}{stud}" in data['access']:
pass
elif user in data['access']:
data['access'].remove(user)
data['access'].append(f"{user}{stud}")
db.write(db_path, data)
else:
self._set_response(403)
self.end_headers()
self.wfile.write("403 Forbidden".encode('utf-8'))
return
else:
data['access'] = []
send_data = {}
if 'correct' in data:
send_data['correct'] = data['correct']
if 'incorrect' in data:
send_data['incorrect'] = data['incorrect']
send_data['version'] = VERSION
self._set_response(200)
self.send_header('Content-type', 'text/json; charset=utf-8')
self.end_headers()
self.wfile.write(json.dumps(send_data).encode('utf-8'))
def do_POST(self):
content_length = int(self.headers['Content-Length'])
post_data = self.rfile.read(content_length).decode('utf-8')
self._set_response(parse_request(post_data))
self.wfile.write(f"POST request for {self.path}".encode('utf-8'))
def run_resolver():
p2_tmo = tmo.Timeout(2 * 60)
while True:
try:
p2_tmo.check(resolve_logs, lambda: None)
except:
pass
def run_http_server(httpd):
while True:
try:
httpd.serve_forever()
except:
httpd.server_close()
def run(server_class=HTTPServer, handler_class=S, port=8000):
server_address = ('127.0.0.1', port)
httpd = server_class(server_address, handler_class)
funcs = [lambda: run_http_server(httpd), run_resolver]
threads = map(lambda x: Thread(target=x), funcs)
for thread in threads:
thread.daemon = True
thread.start()
try:
while True:
sleep(1)
except KeyboardInterrupt:
pass
for thread in threads:
thread.stop()
httpd.server_close()
if __name__ == '__main__' and sys.flags.interactive == 0:
run()