Docstrings, types, config reader
This commit is contained in:
parent
7b4c677eae
commit
c525c42487
89
bot.py
89
bot.py
|
@ -11,14 +11,14 @@ import typing
|
|||
from sys import stderr, stdout, stdin
|
||||
from threading import Thread
|
||||
import telebot
|
||||
from config_reader import config
|
||||
import config_reader as cr
|
||||
import db_classes
|
||||
|
||||
|
||||
# TODO more backends (redis at least)
|
||||
db = db_classes.PickleDB(".db")
|
||||
db.load()
|
||||
CURRENT_VERSION = "v1.0rc7"
|
||||
CURRENT_VERSION = "v1.0rc8"
|
||||
VERSION = db.read("about.version", CURRENT_VERSION)
|
||||
db.write("about.updatedfrom", VERSION)
|
||||
db.write("about.version", CURRENT_VERSION)
|
||||
|
@ -34,18 +34,19 @@ if (db.read("about.host") is None) and __debug__:
|
|||
|
||||
|
||||
bot = telebot.TeleBot(
|
||||
config["tokens"]["devel" if __debug__ else "prod"], parse_mode="MarkdownV2"
|
||||
cr.read(f"tokens.{("devel" if __debug__ else "prod")}"), parse_mode="MarkdownV2"
|
||||
)
|
||||
|
||||
|
||||
def get_time(forum: int):
|
||||
def get_time(forum: int) -> dt.datetime:
|
||||
"Get datetime.now in forum's timezone. Default timezone is UTC+3."
|
||||
return dt.datetime.now(dt.UTC) + dt.timedelta(
|
||||
hours=db.read(str(forum) + ".settings.timezone", 3)
|
||||
)
|
||||
|
||||
|
||||
def change_phase(forum: int, date: dt.datetime = None):
|
||||
def change_phase(forum: int, date: dt.datetime = None) -> None:
|
||||
"Changes forum's current phase."
|
||||
if date is None:
|
||||
date = get_time(forum).date()
|
||||
phase = db.read(str(forum) + ".schedule.phase")
|
||||
|
@ -104,7 +105,8 @@ def get_chat(
|
|||
return None
|
||||
|
||||
|
||||
def check_if_admin(message: telebot.types.Message) -> bool:
|
||||
def check_if_admin(message: telebot.types.Message) -> bool | None:
|
||||
"Checks if the message is sent by the forum's admin."
|
||||
forum = message.chat.id
|
||||
admin = db.read(str(forum) + ".settings.admin")
|
||||
if admin is None:
|
||||
|
@ -114,26 +116,26 @@ def check_if_admin(message: telebot.types.Message) -> bool:
|
|||
return admin["id"] == message.from_user.id
|
||||
|
||||
|
||||
def mention(forum: int, user_id: int) -> str:
|
||||
user_id = str(user_id)
|
||||
if db.read(str(forum) + ".people." + user_id) is None:
|
||||
stderr.write("Пользователя с ID " + user_id + " нет в базе.\n")
|
||||
return
|
||||
def mention(forum: int, uid: int) -> str | None:
|
||||
"Returns markdown formatted string with user's mention."
|
||||
uid = str(uid)
|
||||
if db.read(str(forum) + ".people." + uid) is None:
|
||||
stderr.write("Пользователя с ID " + uid + " нет в базе.\n")
|
||||
return None
|
||||
return (
|
||||
"["
|
||||
+ db.read(str(forum) + ".people." + user_id + ".name")
|
||||
+ db.read(str(forum) + ".people." + uid + ".name")
|
||||
+ " "
|
||||
+ db.read(str(forum) + ".people." + user_id + ".surname")
|
||||
+ db.read(str(forum) + ".people." + uid + ".surname")
|
||||
+ "](tg://user?id="
|
||||
+ str(user_id)
|
||||
+ str(uid)
|
||||
+ ")"
|
||||
)
|
||||
|
||||
|
||||
def find_uids(forum: int, s: str):
|
||||
def find_uids(forum: int, s: str) -> list | None:
|
||||
"Find user's id by nickname, name or surname."
|
||||
people = db.read(str(forum) + ".people")
|
||||
if people is None:
|
||||
return
|
||||
if people is None:
|
||||
return None
|
||||
if s[0] == "@":
|
||||
|
@ -144,7 +146,8 @@ def find_uids(forum: int, s: str):
|
|||
return f
|
||||
|
||||
|
||||
def format_user_info(forum: int, uid):
|
||||
def format_user_info(forum: int, uid: int) -> str:
|
||||
"Returns markdown formatted string with all user's info by their id."
|
||||
uid = str(uid)
|
||||
person = db.read(str(forum) + ".people." + uid)
|
||||
if person is None:
|
||||
|
@ -156,21 +159,24 @@ def format_user_info(forum: int, uid):
|
|||
return r
|
||||
|
||||
|
||||
def prepend_user(forum: int, ulist_s: str, uid):
|
||||
def prepend_user(forum: int, ulist_s: str, uid: int) -> None:
|
||||
"Inserts user id at the start of provided db list in forum's context."
|
||||
uid = str(uid)
|
||||
ulist = db.read(str(forum) + "." + ulist_s, [])
|
||||
ulist = list(set([uid] + ulist))
|
||||
db.write(str(forum) + "." + ulist_s, ulist)
|
||||
|
||||
|
||||
def append_user(forum: int, ulist_s: str, uid):
|
||||
def append_user(forum: int, ulist_s: str, uid: int) -> None:
|
||||
"Inserts user id at the end of provided db list in forum's context."
|
||||
uid = str(uid)
|
||||
ulist = db.read(str(forum) + "." + ulist_s, [])
|
||||
ulist = list(set(ulist + [uid]))
|
||||
db.write(str(forum) + "." + ulist_s, ulist)
|
||||
|
||||
|
||||
def pop_user(forum: int, ulist_s: str):
|
||||
def pop_user(forum: int, ulist_s: str) -> dict | None:
|
||||
"Removes user id from the start of provided db list in forum's context. Returns user id."
|
||||
ulist = db.read(str(forum) + "." + ulist_s, [])
|
||||
r = None
|
||||
if len(ulist) > 0:
|
||||
|
@ -179,7 +185,7 @@ def pop_user(forum: int, ulist_s: str):
|
|||
return r
|
||||
|
||||
|
||||
def insert_user_in_current_order(forum: int, uid) -> bool:
|
||||
def insert_user_in_current_order(forum: int, uid: int) -> bool:
|
||||
uid = str(uid)
|
||||
order = db.read(str(forum) + ".rookies.order", [])
|
||||
people = db.read(str(forum) + ".people", {})
|
||||
|
@ -203,12 +209,11 @@ def insert_user_in_current_order(forum: int, uid) -> bool:
|
|||
pos = list(order.keys()).index(current)
|
||||
if pos == 0:
|
||||
return False
|
||||
else:
|
||||
db.write(str(forum) + ".rookies.order", list(order.keys())[1:])
|
||||
return True
|
||||
db.write(str(forum) + ".rookies.order", list(order.keys())[1:])
|
||||
return True
|
||||
|
||||
|
||||
def parse_dates(forum: int, args):
|
||||
def parse_dates(forum: int, args: typing.Iterable) -> list | str:
|
||||
dates = []
|
||||
cur_date = get_time(forum).date() - dt.timedelta(days=1)
|
||||
cur_year = cur_date.year
|
||||
|
@ -258,7 +263,7 @@ def parse_dates(forum: int, args):
|
|||
return dates
|
||||
|
||||
|
||||
def mod_days(message: telebot.types.Message, target, neighbour):
|
||||
def mod_days(message: telebot.types.Message, target: str, neighbour: str) -> None:
|
||||
forum = message.chat.id
|
||||
chat = get_chat(message)
|
||||
if chat is not None:
|
||||
|
@ -270,7 +275,7 @@ def mod_days(message: telebot.types.Message, target, neighbour):
|
|||
dates = [get_time(forum).date()]
|
||||
else:
|
||||
dates = parse_dates(forum, args)
|
||||
if type(dates) is str:
|
||||
if isinstance(dates, str):
|
||||
bot.reply_to(
|
||||
chat,
|
||||
telebot.formatting.escape_markdown(dates)
|
||||
|
@ -355,12 +360,14 @@ def start_bot(message: telebot.types.Message):
|
|||
if message.chat.is_forum:
|
||||
bot.reply_to(
|
||||
chat,
|
||||
"Привет\\! Я бот для управления дежурствами и напоминания о них\\. Напиши /link, чтобы привязать комнату\\.",
|
||||
"Привет\\! Я бот для управления дежурствами и напоминания о них\\. "
|
||||
+ "Напиши /link, чтобы привязать комнату\\.",
|
||||
)
|
||||
else:
|
||||
bot.reply_to(
|
||||
chat,
|
||||
"Я работаю только на форумах \\(супергруппах с комнатами\\)\\. Пригласи меня в один из них и напиши /start",
|
||||
"Я работаю только на форумах \\(супергруппах с комнатами\\)\\. "
|
||||
+ "Пригласи меня в один из них и напиши /start",
|
||||
)
|
||||
|
||||
|
||||
|
@ -442,14 +449,6 @@ if __debug__:
|
|||
bot.delete_message(forum, message.id)
|
||||
|
||||
|
||||
@bot.message_handler(commands=["fuck", "fuck_you", "fuck-you", "fuckyou"])
|
||||
def rude(message: telebot.types.Message):
|
||||
forum = message.chat.id
|
||||
chat = get_chat(message, True)
|
||||
if chat is not None:
|
||||
bot.delete_message(forum, message.id)
|
||||
|
||||
|
||||
@bot.message_handler(commands=["backup"])
|
||||
def backup_db(message: telebot.types.Message):
|
||||
forum = message.chat.id
|
||||
|
@ -1333,7 +1332,7 @@ def get_hours() -> tuple:
|
|||
return (8, 20)
|
||||
|
||||
|
||||
def stack_update(forum: int, force_reset=False):
|
||||
def stack_update(forum: int, force_reset: bool = False) -> None:
|
||||
now = get_time(forum)
|
||||
now_date = now.date()
|
||||
order = db.read(str(forum) + ".rookies.order", [])
|
||||
|
@ -1395,14 +1394,14 @@ def stack_update(forum: int, force_reset=False):
|
|||
)
|
||||
|
||||
|
||||
def clean_old_dates(date: dt.datetime, array: str):
|
||||
def clean_old_dates(date: dt.datetime, array: str) -> None:
|
||||
"Removes dates from db's array which are older than `date`."
|
||||
a = db.read(array)
|
||||
a = a.filter(lambda x: x >= date, a)
|
||||
db.write(array, a)
|
||||
|
||||
|
||||
def update(forum: int):
|
||||
def update(forum: int) -> None:
|
||||
now = get_time(forum)
|
||||
now_date = now.date()
|
||||
now_time = now.time()
|
||||
|
@ -1424,7 +1423,7 @@ def update(forum: int):
|
|||
remind_users(forum)
|
||||
|
||||
|
||||
def update_notify(forum: int):
|
||||
def update_notify(forum: int) -> None:
|
||||
"Notifies the forum about bot's new version."
|
||||
bot.reply_to(
|
||||
get_chat(forum),
|
||||
|
@ -1439,13 +1438,13 @@ def process1():
|
|||
|
||||
def process2():
|
||||
"The process updates duty order for every forum once a `period` seconds."
|
||||
period = int(config["settings"]["notify_period"])
|
||||
period = int(cr.read("settings.notify_period"))
|
||||
prev_time = time.time()
|
||||
while True:
|
||||
cur_time = time.time()
|
||||
if cur_time - prev_time >= period:
|
||||
prev_time = cur_time
|
||||
stdout.write("Update\n")
|
||||
stdout.write("Process 2 update\n")
|
||||
for i in db.read("").keys():
|
||||
try:
|
||||
update(int(i))
|
||||
|
@ -1463,7 +1462,7 @@ if VERSION != CURRENT_VERSION:
|
|||
for FORUM in db.read("").keys():
|
||||
try:
|
||||
update_notify(int(FORUM))
|
||||
print("Notified", FORUM)
|
||||
print("New version notification", FORUM)
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
|
|
|
@ -6,17 +6,40 @@
|
|||
|
||||
import configparser
|
||||
|
||||
FIELDS = {
|
||||
"tokens.prod": (str, "" if __debug__ else None, None),
|
||||
"tokens.devel": (str, None if __debug__ else "", None),
|
||||
"settings.notify_period": (int, 120, None),
|
||||
}
|
||||
|
||||
CONFIG_FILENAME = "config.ini"
|
||||
config = configparser.ConfigParser()
|
||||
config.read(CONFIG_FILENAME)
|
||||
|
||||
def check_field(field: str, t: type):
|
||||
"Checks if the config field does exist and has the correct type."
|
||||
a = config
|
||||
for i in field.split('.'):
|
||||
a = a[i]
|
||||
a = t(a)
|
||||
|
||||
check_field("tokens.prod", str)
|
||||
check_field("tokens.devel", str)
|
||||
check_field("settings.notify_period", int)
|
||||
def check_field(field: str) -> None:
|
||||
"Checks if the config field does exist and has the correct type."
|
||||
read(field)
|
||||
|
||||
|
||||
def read(field: str) -> any:
|
||||
"Reads field and converts the value to the correct type."
|
||||
a = config
|
||||
if field not in FIELDS:
|
||||
raise ValueError(f"Field {field} does not exist.")
|
||||
t, d, allowed = FIELDS[field]
|
||||
for i in field.split("."):
|
||||
a = a[i]
|
||||
if a is None:
|
||||
if d is None:
|
||||
raise ValueError(f"Field {field} does not have value.")
|
||||
return d
|
||||
a = t(a)
|
||||
if allowed is not None:
|
||||
if a not in allowed:
|
||||
raise ValueError(f"Field {field} does not have allowed value.")
|
||||
return a
|
||||
|
||||
|
||||
for f in FIELDS:
|
||||
check_field(f)
|
||||
|
|
Loading…
Reference in New Issue