Docstrings, types, config reader

This commit is contained in:
Egor Guslyancev 2023-12-18 08:01:00 -03:00
parent 7b4c677eae
commit c525c42487
GPG Key ID: D7E709AA465A55F9
2 changed files with 76 additions and 54 deletions

89
bot.py
View File

@ -11,14 +11,14 @@ import typing
from sys import stderr, stdout, stdin
from threading import Thread
import telebot
from config_reader import config
import config_reader as cr
import db_classes
# TODO more backends (redis at least)
db = db_classes.PickleDB(".db")
db.load()
CURRENT_VERSION = "v1.0rc7"
CURRENT_VERSION = "v1.0rc8"
VERSION = db.read("about.version", CURRENT_VERSION)
db.write("about.updatedfrom", VERSION)
db.write("about.version", CURRENT_VERSION)
@ -34,18 +34,19 @@ if (db.read("about.host") is None) and __debug__:
bot = telebot.TeleBot(
config["tokens"]["devel" if __debug__ else "prod"], parse_mode="MarkdownV2"
cr.read(f"tokens.{("devel" if __debug__ else "prod")}"), parse_mode="MarkdownV2"
)
def get_time(forum: int):
def get_time(forum: int) -> dt.datetime:
"Get datetime.now in forum's timezone. Default timezone is UTC+3."
return dt.datetime.now(dt.UTC) + dt.timedelta(
hours=db.read(str(forum) + ".settings.timezone", 3)
)
def change_phase(forum: int, date: dt.datetime = None):
def change_phase(forum: int, date: dt.datetime = None) -> None:
"Changes forum's current phase."
if date is None:
date = get_time(forum).date()
phase = db.read(str(forum) + ".schedule.phase")
@ -104,7 +105,8 @@ def get_chat(
return None
def check_if_admin(message: telebot.types.Message) -> bool:
def check_if_admin(message: telebot.types.Message) -> bool | None:
"Checks if the message is sent by the forum's admin."
forum = message.chat.id
admin = db.read(str(forum) + ".settings.admin")
if admin is None:
@ -114,26 +116,26 @@ def check_if_admin(message: telebot.types.Message) -> bool:
return admin["id"] == message.from_user.id
def mention(forum: int, user_id: int) -> str:
user_id = str(user_id)
if db.read(str(forum) + ".people." + user_id) is None:
stderr.write("Пользователя с ID " + user_id + " нет в базе.\n")
return
def mention(forum: int, uid: int) -> str | None:
"Returns markdown formatted string with user's mention."
uid = str(uid)
if db.read(str(forum) + ".people." + uid) is None:
stderr.write("Пользователя с ID " + uid + " нет в базе.\n")
return None
return (
"["
+ db.read(str(forum) + ".people." + user_id + ".name")
+ db.read(str(forum) + ".people." + uid + ".name")
+ " "
+ db.read(str(forum) + ".people." + user_id + ".surname")
+ db.read(str(forum) + ".people." + uid + ".surname")
+ "](tg://user?id="
+ str(user_id)
+ str(uid)
+ ")"
)
def find_uids(forum: int, s: str):
def find_uids(forum: int, s: str) -> list | None:
"Find user's id by nickname, name or surname."
people = db.read(str(forum) + ".people")
if people is None:
return
if people is None:
return None
if s[0] == "@":
@ -144,7 +146,8 @@ def find_uids(forum: int, s: str):
return f
def format_user_info(forum: int, uid):
def format_user_info(forum: int, uid: int) -> str:
"Returns markdown formatted string with all user's info by their id."
uid = str(uid)
person = db.read(str(forum) + ".people." + uid)
if person is None:
@ -156,21 +159,24 @@ def format_user_info(forum: int, uid):
return r
def prepend_user(forum: int, ulist_s: str, uid):
def prepend_user(forum: int, ulist_s: str, uid: int) -> None:
"Inserts user id at the start of provided db list in forum's context."
uid = str(uid)
ulist = db.read(str(forum) + "." + ulist_s, [])
ulist = list(set([uid] + ulist))
db.write(str(forum) + "." + ulist_s, ulist)
def append_user(forum: int, ulist_s: str, uid):
def append_user(forum: int, ulist_s: str, uid: int) -> None:
"Inserts user id at the end of provided db list in forum's context."
uid = str(uid)
ulist = db.read(str(forum) + "." + ulist_s, [])
ulist = list(set(ulist + [uid]))
db.write(str(forum) + "." + ulist_s, ulist)
def pop_user(forum: int, ulist_s: str):
def pop_user(forum: int, ulist_s: str) -> dict | None:
"Removes user id from the start of provided db list in forum's context. Returns user id."
ulist = db.read(str(forum) + "." + ulist_s, [])
r = None
if len(ulist) > 0:
@ -179,7 +185,7 @@ def pop_user(forum: int, ulist_s: str):
return r
def insert_user_in_current_order(forum: int, uid) -> bool:
def insert_user_in_current_order(forum: int, uid: int) -> bool:
uid = str(uid)
order = db.read(str(forum) + ".rookies.order", [])
people = db.read(str(forum) + ".people", {})
@ -203,12 +209,11 @@ def insert_user_in_current_order(forum: int, uid) -> bool:
pos = list(order.keys()).index(current)
if pos == 0:
return False
else:
db.write(str(forum) + ".rookies.order", list(order.keys())[1:])
return True
db.write(str(forum) + ".rookies.order", list(order.keys())[1:])
return True
def parse_dates(forum: int, args):
def parse_dates(forum: int, args: typing.Iterable) -> list | str:
dates = []
cur_date = get_time(forum).date() - dt.timedelta(days=1)
cur_year = cur_date.year
@ -258,7 +263,7 @@ def parse_dates(forum: int, args):
return dates
def mod_days(message: telebot.types.Message, target, neighbour):
def mod_days(message: telebot.types.Message, target: str, neighbour: str) -> None:
forum = message.chat.id
chat = get_chat(message)
if chat is not None:
@ -270,7 +275,7 @@ def mod_days(message: telebot.types.Message, target, neighbour):
dates = [get_time(forum).date()]
else:
dates = parse_dates(forum, args)
if type(dates) is str:
if isinstance(dates, str):
bot.reply_to(
chat,
telebot.formatting.escape_markdown(dates)
@ -355,12 +360,14 @@ def start_bot(message: telebot.types.Message):
if message.chat.is_forum:
bot.reply_to(
chat,
"Привет\\! Я бот для управления дежурствами и напоминания о них\\. Напиши /link, чтобы привязать комнату\\.",
"Привет\\! Я бот для управления дежурствами и напоминания о них\\. "
+ "Напиши /link, чтобы привязать комнату\\.",
)
else:
bot.reply_to(
chat,
"Я работаю только на форумах \\(супергруппах с комнатами\\)\\. Пригласи меня в один из них и напиши /start",
"Я работаю только на форумах \\(супергруппах с комнатами\\)\\. "
+ "Пригласи меня в один из них и напиши /start",
)
@ -442,14 +449,6 @@ if __debug__:
bot.delete_message(forum, message.id)
@bot.message_handler(commands=["fuck", "fuck_you", "fuck-you", "fuckyou"])
def rude(message: telebot.types.Message):
forum = message.chat.id
chat = get_chat(message, True)
if chat is not None:
bot.delete_message(forum, message.id)
@bot.message_handler(commands=["backup"])
def backup_db(message: telebot.types.Message):
forum = message.chat.id
@ -1333,7 +1332,7 @@ def get_hours() -> tuple:
return (8, 20)
def stack_update(forum: int, force_reset=False):
def stack_update(forum: int, force_reset: bool = False) -> None:
now = get_time(forum)
now_date = now.date()
order = db.read(str(forum) + ".rookies.order", [])
@ -1395,14 +1394,14 @@ def stack_update(forum: int, force_reset=False):
)
def clean_old_dates(date: dt.datetime, array: str):
def clean_old_dates(date: dt.datetime, array: str) -> None:
"Removes dates from db's array which are older than `date`."
a = db.read(array)
a = a.filter(lambda x: x >= date, a)
db.write(array, a)
def update(forum: int):
def update(forum: int) -> None:
now = get_time(forum)
now_date = now.date()
now_time = now.time()
@ -1424,7 +1423,7 @@ def update(forum: int):
remind_users(forum)
def update_notify(forum: int):
def update_notify(forum: int) -> None:
"Notifies the forum about bot's new version."
bot.reply_to(
get_chat(forum),
@ -1439,13 +1438,13 @@ def process1():
def process2():
"The process updates duty order for every forum once a `period` seconds."
period = int(config["settings"]["notify_period"])
period = int(cr.read("settings.notify_period"))
prev_time = time.time()
while True:
cur_time = time.time()
if cur_time - prev_time >= period:
prev_time = cur_time
stdout.write("Update\n")
stdout.write("Process 2 update\n")
for i in db.read("").keys():
try:
update(int(i))
@ -1463,7 +1462,7 @@ if VERSION != CURRENT_VERSION:
for FORUM in db.read("").keys():
try:
update_notify(int(FORUM))
print("Notified", FORUM)
print("New version notification", FORUM)
except ValueError:
pass

View File

@ -6,17 +6,40 @@
import configparser
FIELDS = {
"tokens.prod": (str, "" if __debug__ else None, None),
"tokens.devel": (str, None if __debug__ else "", None),
"settings.notify_period": (int, 120, None),
}
CONFIG_FILENAME = "config.ini"
config = configparser.ConfigParser()
config.read(CONFIG_FILENAME)
def check_field(field: str, t: type):
"Checks if the config field does exist and has the correct type."
a = config
for i in field.split('.'):
a = a[i]
a = t(a)
check_field("tokens.prod", str)
check_field("tokens.devel", str)
check_field("settings.notify_period", int)
def check_field(field: str) -> None:
"Checks if the config field does exist and has the correct type."
read(field)
def read(field: str) -> any:
"Reads field and converts the value to the correct type."
a = config
if field not in FIELDS:
raise ValueError(f"Field {field} does not exist.")
t, d, allowed = FIELDS[field]
for i in field.split("."):
a = a[i]
if a is None:
if d is None:
raise ValueError(f"Field {field} does not have value.")
return d
a = t(a)
if allowed is not None:
if a not in allowed:
raise ValueError(f"Field {field} does not have allowed value.")
return a
for f in FIELDS:
check_field(f)