Docstrings, types, config reader

This commit is contained in:
Egor Guslyancev 2023-12-18 08:01:00 -03:00
parent 7b4c677eae
commit c525c42487
GPG Key ID: D7E709AA465A55F9
2 changed files with 76 additions and 54 deletions

89
bot.py
View File

@ -11,14 +11,14 @@ import typing
from sys import stderr, stdout, stdin from sys import stderr, stdout, stdin
from threading import Thread from threading import Thread
import telebot import telebot
from config_reader import config import config_reader as cr
import db_classes import db_classes
# TODO more backends (redis at least) # TODO more backends (redis at least)
db = db_classes.PickleDB(".db") db = db_classes.PickleDB(".db")
db.load() db.load()
CURRENT_VERSION = "v1.0rc7" CURRENT_VERSION = "v1.0rc8"
VERSION = db.read("about.version", CURRENT_VERSION) VERSION = db.read("about.version", CURRENT_VERSION)
db.write("about.updatedfrom", VERSION) db.write("about.updatedfrom", VERSION)
db.write("about.version", CURRENT_VERSION) db.write("about.version", CURRENT_VERSION)
@ -34,18 +34,19 @@ if (db.read("about.host") is None) and __debug__:
bot = telebot.TeleBot( bot = telebot.TeleBot(
config["tokens"]["devel" if __debug__ else "prod"], parse_mode="MarkdownV2" cr.read(f"tokens.{("devel" if __debug__ else "prod")}"), parse_mode="MarkdownV2"
) )
def get_time(forum: int): def get_time(forum: int) -> dt.datetime:
"Get datetime.now in forum's timezone. Default timezone is UTC+3." "Get datetime.now in forum's timezone. Default timezone is UTC+3."
return dt.datetime.now(dt.UTC) + dt.timedelta( return dt.datetime.now(dt.UTC) + dt.timedelta(
hours=db.read(str(forum) + ".settings.timezone", 3) hours=db.read(str(forum) + ".settings.timezone", 3)
) )
def change_phase(forum: int, date: dt.datetime = None): def change_phase(forum: int, date: dt.datetime = None) -> None:
"Changes forum's current phase."
if date is None: if date is None:
date = get_time(forum).date() date = get_time(forum).date()
phase = db.read(str(forum) + ".schedule.phase") phase = db.read(str(forum) + ".schedule.phase")
@ -104,7 +105,8 @@ def get_chat(
return None return None
def check_if_admin(message: telebot.types.Message) -> bool: def check_if_admin(message: telebot.types.Message) -> bool | None:
"Checks if the message is sent by the forum's admin."
forum = message.chat.id forum = message.chat.id
admin = db.read(str(forum) + ".settings.admin") admin = db.read(str(forum) + ".settings.admin")
if admin is None: if admin is None:
@ -114,26 +116,26 @@ def check_if_admin(message: telebot.types.Message) -> bool:
return admin["id"] == message.from_user.id return admin["id"] == message.from_user.id
def mention(forum: int, user_id: int) -> str: def mention(forum: int, uid: int) -> str | None:
user_id = str(user_id) "Returns markdown formatted string with user's mention."
if db.read(str(forum) + ".people." + user_id) is None: uid = str(uid)
stderr.write("Пользователя с ID " + user_id + " нет в базе.\n") if db.read(str(forum) + ".people." + uid) is None:
return stderr.write("Пользователя с ID " + uid + " нет в базе.\n")
return None
return ( return (
"[" "["
+ db.read(str(forum) + ".people." + user_id + ".name") + db.read(str(forum) + ".people." + uid + ".name")
+ " " + " "
+ db.read(str(forum) + ".people." + user_id + ".surname") + db.read(str(forum) + ".people." + uid + ".surname")
+ "](tg://user?id=" + "](tg://user?id="
+ str(user_id) + str(uid)
+ ")" + ")"
) )
def find_uids(forum: int, s: str): def find_uids(forum: int, s: str) -> list | None:
"Find user's id by nickname, name or surname."
people = db.read(str(forum) + ".people") people = db.read(str(forum) + ".people")
if people is None:
return
if people is None: if people is None:
return None return None
if s[0] == "@": if s[0] == "@":
@ -144,7 +146,8 @@ def find_uids(forum: int, s: str):
return f return f
def format_user_info(forum: int, uid): def format_user_info(forum: int, uid: int) -> str:
"Returns markdown formatted string with all user's info by their id."
uid = str(uid) uid = str(uid)
person = db.read(str(forum) + ".people." + uid) person = db.read(str(forum) + ".people." + uid)
if person is None: if person is None:
@ -156,21 +159,24 @@ def format_user_info(forum: int, uid):
return r return r
def prepend_user(forum: int, ulist_s: str, uid): def prepend_user(forum: int, ulist_s: str, uid: int) -> None:
"Inserts user id at the start of provided db list in forum's context."
uid = str(uid) uid = str(uid)
ulist = db.read(str(forum) + "." + ulist_s, []) ulist = db.read(str(forum) + "." + ulist_s, [])
ulist = list(set([uid] + ulist)) ulist = list(set([uid] + ulist))
db.write(str(forum) + "." + ulist_s, ulist) db.write(str(forum) + "." + ulist_s, ulist)
def append_user(forum: int, ulist_s: str, uid): def append_user(forum: int, ulist_s: str, uid: int) -> None:
"Inserts user id at the end of provided db list in forum's context."
uid = str(uid) uid = str(uid)
ulist = db.read(str(forum) + "." + ulist_s, []) ulist = db.read(str(forum) + "." + ulist_s, [])
ulist = list(set(ulist + [uid])) ulist = list(set(ulist + [uid]))
db.write(str(forum) + "." + ulist_s, ulist) db.write(str(forum) + "." + ulist_s, ulist)
def pop_user(forum: int, ulist_s: str): def pop_user(forum: int, ulist_s: str) -> dict | None:
"Removes user id from the start of provided db list in forum's context. Returns user id."
ulist = db.read(str(forum) + "." + ulist_s, []) ulist = db.read(str(forum) + "." + ulist_s, [])
r = None r = None
if len(ulist) > 0: if len(ulist) > 0:
@ -179,7 +185,7 @@ def pop_user(forum: int, ulist_s: str):
return r return r
def insert_user_in_current_order(forum: int, uid) -> bool: def insert_user_in_current_order(forum: int, uid: int) -> bool:
uid = str(uid) uid = str(uid)
order = db.read(str(forum) + ".rookies.order", []) order = db.read(str(forum) + ".rookies.order", [])
people = db.read(str(forum) + ".people", {}) people = db.read(str(forum) + ".people", {})
@ -203,12 +209,11 @@ def insert_user_in_current_order(forum: int, uid) -> bool:
pos = list(order.keys()).index(current) pos = list(order.keys()).index(current)
if pos == 0: if pos == 0:
return False return False
else: db.write(str(forum) + ".rookies.order", list(order.keys())[1:])
db.write(str(forum) + ".rookies.order", list(order.keys())[1:]) return True
return True
def parse_dates(forum: int, args): def parse_dates(forum: int, args: typing.Iterable) -> list | str:
dates = [] dates = []
cur_date = get_time(forum).date() - dt.timedelta(days=1) cur_date = get_time(forum).date() - dt.timedelta(days=1)
cur_year = cur_date.year cur_year = cur_date.year
@ -258,7 +263,7 @@ def parse_dates(forum: int, args):
return dates return dates
def mod_days(message: telebot.types.Message, target, neighbour): def mod_days(message: telebot.types.Message, target: str, neighbour: str) -> None:
forum = message.chat.id forum = message.chat.id
chat = get_chat(message) chat = get_chat(message)
if chat is not None: if chat is not None:
@ -270,7 +275,7 @@ def mod_days(message: telebot.types.Message, target, neighbour):
dates = [get_time(forum).date()] dates = [get_time(forum).date()]
else: else:
dates = parse_dates(forum, args) dates = parse_dates(forum, args)
if type(dates) is str: if isinstance(dates, str):
bot.reply_to( bot.reply_to(
chat, chat,
telebot.formatting.escape_markdown(dates) telebot.formatting.escape_markdown(dates)
@ -355,12 +360,14 @@ def start_bot(message: telebot.types.Message):
if message.chat.is_forum: if message.chat.is_forum:
bot.reply_to( bot.reply_to(
chat, chat,
"Привет\\! Я бот для управления дежурствами и напоминания о них\\. Напиши /link, чтобы привязать комнату\\.", "Привет\\! Я бот для управления дежурствами и напоминания о них\\. "
+ "Напиши /link, чтобы привязать комнату\\.",
) )
else: else:
bot.reply_to( bot.reply_to(
chat, chat,
"Я работаю только на форумах \\(супергруппах с комнатами\\)\\. Пригласи меня в один из них и напиши /start", "Я работаю только на форумах \\(супергруппах с комнатами\\)\\. "
+ "Пригласи меня в один из них и напиши /start",
) )
@ -442,14 +449,6 @@ if __debug__:
bot.delete_message(forum, message.id) bot.delete_message(forum, message.id)
@bot.message_handler(commands=["fuck", "fuck_you", "fuck-you", "fuckyou"])
def rude(message: telebot.types.Message):
forum = message.chat.id
chat = get_chat(message, True)
if chat is not None:
bot.delete_message(forum, message.id)
@bot.message_handler(commands=["backup"]) @bot.message_handler(commands=["backup"])
def backup_db(message: telebot.types.Message): def backup_db(message: telebot.types.Message):
forum = message.chat.id forum = message.chat.id
@ -1333,7 +1332,7 @@ def get_hours() -> tuple:
return (8, 20) return (8, 20)
def stack_update(forum: int, force_reset=False): def stack_update(forum: int, force_reset: bool = False) -> None:
now = get_time(forum) now = get_time(forum)
now_date = now.date() now_date = now.date()
order = db.read(str(forum) + ".rookies.order", []) order = db.read(str(forum) + ".rookies.order", [])
@ -1395,14 +1394,14 @@ def stack_update(forum: int, force_reset=False):
) )
def clean_old_dates(date: dt.datetime, array: str): def clean_old_dates(date: dt.datetime, array: str) -> None:
"Removes dates from db's array which are older than `date`." "Removes dates from db's array which are older than `date`."
a = db.read(array) a = db.read(array)
a = a.filter(lambda x: x >= date, a) a = a.filter(lambda x: x >= date, a)
db.write(array, a) db.write(array, a)
def update(forum: int): def update(forum: int) -> None:
now = get_time(forum) now = get_time(forum)
now_date = now.date() now_date = now.date()
now_time = now.time() now_time = now.time()
@ -1424,7 +1423,7 @@ def update(forum: int):
remind_users(forum) remind_users(forum)
def update_notify(forum: int): def update_notify(forum: int) -> None:
"Notifies the forum about bot's new version." "Notifies the forum about bot's new version."
bot.reply_to( bot.reply_to(
get_chat(forum), get_chat(forum),
@ -1439,13 +1438,13 @@ def process1():
def process2(): def process2():
"The process updates duty order for every forum once a `period` seconds." "The process updates duty order for every forum once a `period` seconds."
period = int(config["settings"]["notify_period"]) period = int(cr.read("settings.notify_period"))
prev_time = time.time() prev_time = time.time()
while True: while True:
cur_time = time.time() cur_time = time.time()
if cur_time - prev_time >= period: if cur_time - prev_time >= period:
prev_time = cur_time prev_time = cur_time
stdout.write("Update\n") stdout.write("Process 2 update\n")
for i in db.read("").keys(): for i in db.read("").keys():
try: try:
update(int(i)) update(int(i))
@ -1463,7 +1462,7 @@ if VERSION != CURRENT_VERSION:
for FORUM in db.read("").keys(): for FORUM in db.read("").keys():
try: try:
update_notify(int(FORUM)) update_notify(int(FORUM))
print("Notified", FORUM) print("New version notification", FORUM)
except ValueError: except ValueError:
pass pass

View File

@ -6,17 +6,40 @@
import configparser import configparser
FIELDS = {
"tokens.prod": (str, "" if __debug__ else None, None),
"tokens.devel": (str, None if __debug__ else "", None),
"settings.notify_period": (int, 120, None),
}
CONFIG_FILENAME = "config.ini" CONFIG_FILENAME = "config.ini"
config = configparser.ConfigParser() config = configparser.ConfigParser()
config.read(CONFIG_FILENAME) config.read(CONFIG_FILENAME)
def check_field(field: str, t: type):
"Checks if the config field does exist and has the correct type."
a = config
for i in field.split('.'):
a = a[i]
a = t(a)
check_field("tokens.prod", str) def check_field(field: str) -> None:
check_field("tokens.devel", str) "Checks if the config field does exist and has the correct type."
check_field("settings.notify_period", int) read(field)
def read(field: str) -> any:
"Reads field and converts the value to the correct type."
a = config
if field not in FIELDS:
raise ValueError(f"Field {field} does not exist.")
t, d, allowed = FIELDS[field]
for i in field.split("."):
a = a[i]
if a is None:
if d is None:
raise ValueError(f"Field {field} does not have value.")
return d
a = t(a)
if allowed is not None:
if a not in allowed:
raise ValueError(f"Field {field} does not have allowed value.")
return a
for f in FIELDS:
check_field(f)