Docstrings, types, config reader
This commit is contained in:
parent
7b4c677eae
commit
c525c42487
89
bot.py
89
bot.py
|
@ -11,14 +11,14 @@ import typing
|
||||||
from sys import stderr, stdout, stdin
|
from sys import stderr, stdout, stdin
|
||||||
from threading import Thread
|
from threading import Thread
|
||||||
import telebot
|
import telebot
|
||||||
from config_reader import config
|
import config_reader as cr
|
||||||
import db_classes
|
import db_classes
|
||||||
|
|
||||||
|
|
||||||
# TODO more backends (redis at least)
|
# TODO more backends (redis at least)
|
||||||
db = db_classes.PickleDB(".db")
|
db = db_classes.PickleDB(".db")
|
||||||
db.load()
|
db.load()
|
||||||
CURRENT_VERSION = "v1.0rc7"
|
CURRENT_VERSION = "v1.0rc8"
|
||||||
VERSION = db.read("about.version", CURRENT_VERSION)
|
VERSION = db.read("about.version", CURRENT_VERSION)
|
||||||
db.write("about.updatedfrom", VERSION)
|
db.write("about.updatedfrom", VERSION)
|
||||||
db.write("about.version", CURRENT_VERSION)
|
db.write("about.version", CURRENT_VERSION)
|
||||||
|
@ -34,18 +34,19 @@ if (db.read("about.host") is None) and __debug__:
|
||||||
|
|
||||||
|
|
||||||
bot = telebot.TeleBot(
|
bot = telebot.TeleBot(
|
||||||
config["tokens"]["devel" if __debug__ else "prod"], parse_mode="MarkdownV2"
|
cr.read(f"tokens.{("devel" if __debug__ else "prod")}"), parse_mode="MarkdownV2"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def get_time(forum: int):
|
def get_time(forum: int) -> dt.datetime:
|
||||||
"Get datetime.now in forum's timezone. Default timezone is UTC+3."
|
"Get datetime.now in forum's timezone. Default timezone is UTC+3."
|
||||||
return dt.datetime.now(dt.UTC) + dt.timedelta(
|
return dt.datetime.now(dt.UTC) + dt.timedelta(
|
||||||
hours=db.read(str(forum) + ".settings.timezone", 3)
|
hours=db.read(str(forum) + ".settings.timezone", 3)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def change_phase(forum: int, date: dt.datetime = None):
|
def change_phase(forum: int, date: dt.datetime = None) -> None:
|
||||||
|
"Changes forum's current phase."
|
||||||
if date is None:
|
if date is None:
|
||||||
date = get_time(forum).date()
|
date = get_time(forum).date()
|
||||||
phase = db.read(str(forum) + ".schedule.phase")
|
phase = db.read(str(forum) + ".schedule.phase")
|
||||||
|
@ -104,7 +105,8 @@ def get_chat(
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
def check_if_admin(message: telebot.types.Message) -> bool:
|
def check_if_admin(message: telebot.types.Message) -> bool | None:
|
||||||
|
"Checks if the message is sent by the forum's admin."
|
||||||
forum = message.chat.id
|
forum = message.chat.id
|
||||||
admin = db.read(str(forum) + ".settings.admin")
|
admin = db.read(str(forum) + ".settings.admin")
|
||||||
if admin is None:
|
if admin is None:
|
||||||
|
@ -114,26 +116,26 @@ def check_if_admin(message: telebot.types.Message) -> bool:
|
||||||
return admin["id"] == message.from_user.id
|
return admin["id"] == message.from_user.id
|
||||||
|
|
||||||
|
|
||||||
def mention(forum: int, user_id: int) -> str:
|
def mention(forum: int, uid: int) -> str | None:
|
||||||
user_id = str(user_id)
|
"Returns markdown formatted string with user's mention."
|
||||||
if db.read(str(forum) + ".people." + user_id) is None:
|
uid = str(uid)
|
||||||
stderr.write("Пользователя с ID " + user_id + " нет в базе.\n")
|
if db.read(str(forum) + ".people." + uid) is None:
|
||||||
return
|
stderr.write("Пользователя с ID " + uid + " нет в базе.\n")
|
||||||
|
return None
|
||||||
return (
|
return (
|
||||||
"["
|
"["
|
||||||
+ db.read(str(forum) + ".people." + user_id + ".name")
|
+ db.read(str(forum) + ".people." + uid + ".name")
|
||||||
+ " "
|
+ " "
|
||||||
+ db.read(str(forum) + ".people." + user_id + ".surname")
|
+ db.read(str(forum) + ".people." + uid + ".surname")
|
||||||
+ "](tg://user?id="
|
+ "](tg://user?id="
|
||||||
+ str(user_id)
|
+ str(uid)
|
||||||
+ ")"
|
+ ")"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def find_uids(forum: int, s: str):
|
def find_uids(forum: int, s: str) -> list | None:
|
||||||
|
"Find user's id by nickname, name or surname."
|
||||||
people = db.read(str(forum) + ".people")
|
people = db.read(str(forum) + ".people")
|
||||||
if people is None:
|
|
||||||
return
|
|
||||||
if people is None:
|
if people is None:
|
||||||
return None
|
return None
|
||||||
if s[0] == "@":
|
if s[0] == "@":
|
||||||
|
@ -144,7 +146,8 @@ def find_uids(forum: int, s: str):
|
||||||
return f
|
return f
|
||||||
|
|
||||||
|
|
||||||
def format_user_info(forum: int, uid):
|
def format_user_info(forum: int, uid: int) -> str:
|
||||||
|
"Returns markdown formatted string with all user's info by their id."
|
||||||
uid = str(uid)
|
uid = str(uid)
|
||||||
person = db.read(str(forum) + ".people." + uid)
|
person = db.read(str(forum) + ".people." + uid)
|
||||||
if person is None:
|
if person is None:
|
||||||
|
@ -156,21 +159,24 @@ def format_user_info(forum: int, uid):
|
||||||
return r
|
return r
|
||||||
|
|
||||||
|
|
||||||
def prepend_user(forum: int, ulist_s: str, uid):
|
def prepend_user(forum: int, ulist_s: str, uid: int) -> None:
|
||||||
|
"Inserts user id at the start of provided db list in forum's context."
|
||||||
uid = str(uid)
|
uid = str(uid)
|
||||||
ulist = db.read(str(forum) + "." + ulist_s, [])
|
ulist = db.read(str(forum) + "." + ulist_s, [])
|
||||||
ulist = list(set([uid] + ulist))
|
ulist = list(set([uid] + ulist))
|
||||||
db.write(str(forum) + "." + ulist_s, ulist)
|
db.write(str(forum) + "." + ulist_s, ulist)
|
||||||
|
|
||||||
|
|
||||||
def append_user(forum: int, ulist_s: str, uid):
|
def append_user(forum: int, ulist_s: str, uid: int) -> None:
|
||||||
|
"Inserts user id at the end of provided db list in forum's context."
|
||||||
uid = str(uid)
|
uid = str(uid)
|
||||||
ulist = db.read(str(forum) + "." + ulist_s, [])
|
ulist = db.read(str(forum) + "." + ulist_s, [])
|
||||||
ulist = list(set(ulist + [uid]))
|
ulist = list(set(ulist + [uid]))
|
||||||
db.write(str(forum) + "." + ulist_s, ulist)
|
db.write(str(forum) + "." + ulist_s, ulist)
|
||||||
|
|
||||||
|
|
||||||
def pop_user(forum: int, ulist_s: str):
|
def pop_user(forum: int, ulist_s: str) -> dict | None:
|
||||||
|
"Removes user id from the start of provided db list in forum's context. Returns user id."
|
||||||
ulist = db.read(str(forum) + "." + ulist_s, [])
|
ulist = db.read(str(forum) + "." + ulist_s, [])
|
||||||
r = None
|
r = None
|
||||||
if len(ulist) > 0:
|
if len(ulist) > 0:
|
||||||
|
@ -179,7 +185,7 @@ def pop_user(forum: int, ulist_s: str):
|
||||||
return r
|
return r
|
||||||
|
|
||||||
|
|
||||||
def insert_user_in_current_order(forum: int, uid) -> bool:
|
def insert_user_in_current_order(forum: int, uid: int) -> bool:
|
||||||
uid = str(uid)
|
uid = str(uid)
|
||||||
order = db.read(str(forum) + ".rookies.order", [])
|
order = db.read(str(forum) + ".rookies.order", [])
|
||||||
people = db.read(str(forum) + ".people", {})
|
people = db.read(str(forum) + ".people", {})
|
||||||
|
@ -203,12 +209,11 @@ def insert_user_in_current_order(forum: int, uid) -> bool:
|
||||||
pos = list(order.keys()).index(current)
|
pos = list(order.keys()).index(current)
|
||||||
if pos == 0:
|
if pos == 0:
|
||||||
return False
|
return False
|
||||||
else:
|
db.write(str(forum) + ".rookies.order", list(order.keys())[1:])
|
||||||
db.write(str(forum) + ".rookies.order", list(order.keys())[1:])
|
return True
|
||||||
return True
|
|
||||||
|
|
||||||
|
|
||||||
def parse_dates(forum: int, args):
|
def parse_dates(forum: int, args: typing.Iterable) -> list | str:
|
||||||
dates = []
|
dates = []
|
||||||
cur_date = get_time(forum).date() - dt.timedelta(days=1)
|
cur_date = get_time(forum).date() - dt.timedelta(days=1)
|
||||||
cur_year = cur_date.year
|
cur_year = cur_date.year
|
||||||
|
@ -258,7 +263,7 @@ def parse_dates(forum: int, args):
|
||||||
return dates
|
return dates
|
||||||
|
|
||||||
|
|
||||||
def mod_days(message: telebot.types.Message, target, neighbour):
|
def mod_days(message: telebot.types.Message, target: str, neighbour: str) -> None:
|
||||||
forum = message.chat.id
|
forum = message.chat.id
|
||||||
chat = get_chat(message)
|
chat = get_chat(message)
|
||||||
if chat is not None:
|
if chat is not None:
|
||||||
|
@ -270,7 +275,7 @@ def mod_days(message: telebot.types.Message, target, neighbour):
|
||||||
dates = [get_time(forum).date()]
|
dates = [get_time(forum).date()]
|
||||||
else:
|
else:
|
||||||
dates = parse_dates(forum, args)
|
dates = parse_dates(forum, args)
|
||||||
if type(dates) is str:
|
if isinstance(dates, str):
|
||||||
bot.reply_to(
|
bot.reply_to(
|
||||||
chat,
|
chat,
|
||||||
telebot.formatting.escape_markdown(dates)
|
telebot.formatting.escape_markdown(dates)
|
||||||
|
@ -355,12 +360,14 @@ def start_bot(message: telebot.types.Message):
|
||||||
if message.chat.is_forum:
|
if message.chat.is_forum:
|
||||||
bot.reply_to(
|
bot.reply_to(
|
||||||
chat,
|
chat,
|
||||||
"Привет\\! Я бот для управления дежурствами и напоминания о них\\. Напиши /link, чтобы привязать комнату\\.",
|
"Привет\\! Я бот для управления дежурствами и напоминания о них\\. "
|
||||||
|
+ "Напиши /link, чтобы привязать комнату\\.",
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
bot.reply_to(
|
bot.reply_to(
|
||||||
chat,
|
chat,
|
||||||
"Я работаю только на форумах \\(супергруппах с комнатами\\)\\. Пригласи меня в один из них и напиши /start",
|
"Я работаю только на форумах \\(супергруппах с комнатами\\)\\. "
|
||||||
|
+ "Пригласи меня в один из них и напиши /start",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -442,14 +449,6 @@ if __debug__:
|
||||||
bot.delete_message(forum, message.id)
|
bot.delete_message(forum, message.id)
|
||||||
|
|
||||||
|
|
||||||
@bot.message_handler(commands=["fuck", "fuck_you", "fuck-you", "fuckyou"])
|
|
||||||
def rude(message: telebot.types.Message):
|
|
||||||
forum = message.chat.id
|
|
||||||
chat = get_chat(message, True)
|
|
||||||
if chat is not None:
|
|
||||||
bot.delete_message(forum, message.id)
|
|
||||||
|
|
||||||
|
|
||||||
@bot.message_handler(commands=["backup"])
|
@bot.message_handler(commands=["backup"])
|
||||||
def backup_db(message: telebot.types.Message):
|
def backup_db(message: telebot.types.Message):
|
||||||
forum = message.chat.id
|
forum = message.chat.id
|
||||||
|
@ -1333,7 +1332,7 @@ def get_hours() -> tuple:
|
||||||
return (8, 20)
|
return (8, 20)
|
||||||
|
|
||||||
|
|
||||||
def stack_update(forum: int, force_reset=False):
|
def stack_update(forum: int, force_reset: bool = False) -> None:
|
||||||
now = get_time(forum)
|
now = get_time(forum)
|
||||||
now_date = now.date()
|
now_date = now.date()
|
||||||
order = db.read(str(forum) + ".rookies.order", [])
|
order = db.read(str(forum) + ".rookies.order", [])
|
||||||
|
@ -1395,14 +1394,14 @@ def stack_update(forum: int, force_reset=False):
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def clean_old_dates(date: dt.datetime, array: str):
|
def clean_old_dates(date: dt.datetime, array: str) -> None:
|
||||||
"Removes dates from db's array which are older than `date`."
|
"Removes dates from db's array which are older than `date`."
|
||||||
a = db.read(array)
|
a = db.read(array)
|
||||||
a = a.filter(lambda x: x >= date, a)
|
a = a.filter(lambda x: x >= date, a)
|
||||||
db.write(array, a)
|
db.write(array, a)
|
||||||
|
|
||||||
|
|
||||||
def update(forum: int):
|
def update(forum: int) -> None:
|
||||||
now = get_time(forum)
|
now = get_time(forum)
|
||||||
now_date = now.date()
|
now_date = now.date()
|
||||||
now_time = now.time()
|
now_time = now.time()
|
||||||
|
@ -1424,7 +1423,7 @@ def update(forum: int):
|
||||||
remind_users(forum)
|
remind_users(forum)
|
||||||
|
|
||||||
|
|
||||||
def update_notify(forum: int):
|
def update_notify(forum: int) -> None:
|
||||||
"Notifies the forum about bot's new version."
|
"Notifies the forum about bot's new version."
|
||||||
bot.reply_to(
|
bot.reply_to(
|
||||||
get_chat(forum),
|
get_chat(forum),
|
||||||
|
@ -1439,13 +1438,13 @@ def process1():
|
||||||
|
|
||||||
def process2():
|
def process2():
|
||||||
"The process updates duty order for every forum once a `period` seconds."
|
"The process updates duty order for every forum once a `period` seconds."
|
||||||
period = int(config["settings"]["notify_period"])
|
period = int(cr.read("settings.notify_period"))
|
||||||
prev_time = time.time()
|
prev_time = time.time()
|
||||||
while True:
|
while True:
|
||||||
cur_time = time.time()
|
cur_time = time.time()
|
||||||
if cur_time - prev_time >= period:
|
if cur_time - prev_time >= period:
|
||||||
prev_time = cur_time
|
prev_time = cur_time
|
||||||
stdout.write("Update\n")
|
stdout.write("Process 2 update\n")
|
||||||
for i in db.read("").keys():
|
for i in db.read("").keys():
|
||||||
try:
|
try:
|
||||||
update(int(i))
|
update(int(i))
|
||||||
|
@ -1463,7 +1462,7 @@ if VERSION != CURRENT_VERSION:
|
||||||
for FORUM in db.read("").keys():
|
for FORUM in db.read("").keys():
|
||||||
try:
|
try:
|
||||||
update_notify(int(FORUM))
|
update_notify(int(FORUM))
|
||||||
print("Notified", FORUM)
|
print("New version notification", FORUM)
|
||||||
except ValueError:
|
except ValueError:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
|
@ -6,17 +6,40 @@
|
||||||
|
|
||||||
import configparser
|
import configparser
|
||||||
|
|
||||||
|
FIELDS = {
|
||||||
|
"tokens.prod": (str, "" if __debug__ else None, None),
|
||||||
|
"tokens.devel": (str, None if __debug__ else "", None),
|
||||||
|
"settings.notify_period": (int, 120, None),
|
||||||
|
}
|
||||||
|
|
||||||
CONFIG_FILENAME = "config.ini"
|
CONFIG_FILENAME = "config.ini"
|
||||||
config = configparser.ConfigParser()
|
config = configparser.ConfigParser()
|
||||||
config.read(CONFIG_FILENAME)
|
config.read(CONFIG_FILENAME)
|
||||||
|
|
||||||
def check_field(field: str, t: type):
|
|
||||||
"Checks if the config field does exist and has the correct type."
|
|
||||||
a = config
|
|
||||||
for i in field.split('.'):
|
|
||||||
a = a[i]
|
|
||||||
a = t(a)
|
|
||||||
|
|
||||||
check_field("tokens.prod", str)
|
def check_field(field: str) -> None:
|
||||||
check_field("tokens.devel", str)
|
"Checks if the config field does exist and has the correct type."
|
||||||
check_field("settings.notify_period", int)
|
read(field)
|
||||||
|
|
||||||
|
|
||||||
|
def read(field: str) -> any:
|
||||||
|
"Reads field and converts the value to the correct type."
|
||||||
|
a = config
|
||||||
|
if field not in FIELDS:
|
||||||
|
raise ValueError(f"Field {field} does not exist.")
|
||||||
|
t, d, allowed = FIELDS[field]
|
||||||
|
for i in field.split("."):
|
||||||
|
a = a[i]
|
||||||
|
if a is None:
|
||||||
|
if d is None:
|
||||||
|
raise ValueError(f"Field {field} does not have value.")
|
||||||
|
return d
|
||||||
|
a = t(a)
|
||||||
|
if allowed is not None:
|
||||||
|
if a not in allowed:
|
||||||
|
raise ValueError(f"Field {field} does not have allowed value.")
|
||||||
|
return a
|
||||||
|
|
||||||
|
|
||||||
|
for f in FIELDS:
|
||||||
|
check_field(f)
|
||||||
|
|
Loading…
Reference in New Issue