Use a efficient query method
This commit is contained in:
parent
a6366657bc
commit
d657b58733
|
@ -6,8 +6,8 @@ import time
|
|||
from asyncio import create_subprocess_exec
|
||||
from asyncio.subprocess import PIPE
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from io import BytesIO, StringIO
|
||||
from typing import Optional, Tuple
|
||||
from io import BytesIO
|
||||
from typing import Optional
|
||||
|
||||
from nio import AsyncClient, MatrixRoom, RoomMessageText, UploadResponse
|
||||
from wand.image import Image
|
||||
|
@ -24,16 +24,19 @@ TIMEZONE = timezone(timedelta(hours=8)) # UTC+8
|
|||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def get_word_freqs(text):
|
||||
async def get_word_freqs(texts):
|
||||
proc = await create_subprocess_exec(
|
||||
CUTWORDS_EXE,
|
||||
stdin=PIPE,
|
||||
stdout=PIPE,
|
||||
)
|
||||
|
||||
stdout, _ = await proc.communicate(input=text.encode("utf-8"))
|
||||
for i in texts:
|
||||
proc.stdin.write(i.encode("utf-8"))
|
||||
proc.stdin.write(b"\n")
|
||||
|
||||
freqs = {}
|
||||
stdout = await proc.stdout.read()
|
||||
lines = stdout.decode().splitlines()
|
||||
for line in lines:
|
||||
word, freq = line.split(None, 1)
|
||||
|
@ -71,9 +74,13 @@ async def send_wordcloud(
|
|||
end_date = None
|
||||
if days is not None:
|
||||
end_date = start_date - timedelta(days=days)
|
||||
(texts, count, users) = gather_messages(room, sender, end_date)
|
||||
st2 = time.time()
|
||||
logger.info("Gathered message using %.3f seconds", st2 - st.timestamp())
|
||||
texts = MessageIter(room, event.server_timestamp, sender, end_date)
|
||||
|
||||
freqs = await get_word_freqs(texts)
|
||||
st3 = time.time()
|
||||
logger.info("Analyzed message using %.3f seconds", st3 - st)
|
||||
|
||||
count = texts.count
|
||||
if count == 0:
|
||||
await send_text_to_room(
|
||||
client,
|
||||
|
@ -85,9 +92,8 @@ async def send_wordcloud(
|
|||
literal_text=True,
|
||||
)
|
||||
return
|
||||
freqs = await get_word_freqs(texts)
|
||||
st3 = time.time()
|
||||
logger.info("Analyzed message using %.3f seconds", st3 - st2)
|
||||
|
||||
users = texts.users
|
||||
|
||||
loop = asyncio.get_running_loop()
|
||||
await loop.run_in_executor(None, make_image, freqs, bytesio)
|
||||
|
@ -103,7 +109,7 @@ async def send_wordcloud(
|
|||
|
||||
# Seek again
|
||||
bytesio.seek(0)
|
||||
resp, maybe_keys = await client.upload(
|
||||
resp, _ = await client.upload(
|
||||
bytesio,
|
||||
content_type="image/png",
|
||||
filename="image.png",
|
||||
|
@ -154,57 +160,93 @@ async def send_wordcloud(
|
|||
DROP_USERS = {"@telegram_1454289754:nichi.co", "@variation:matrix.org", "@bot:bgme.me"}
|
||||
|
||||
|
||||
def gather_messages(
|
||||
room: MatrixRoom,
|
||||
sender: Optional[str],
|
||||
end_date: Optional[datetime],
|
||||
) -> Tuple[str, int, int]:
|
||||
stringio = StringIO()
|
||||
count = 0
|
||||
if sender is None:
|
||||
msg_items = (
|
||||
MatrixMessage.select()
|
||||
.where(MatrixMessage.room_id == room.room_id)
|
||||
.order_by(MatrixMessage.origin_server_ts.desc())
|
||||
)
|
||||
else:
|
||||
msg_items = (
|
||||
MatrixMessage.select()
|
||||
.where(
|
||||
(MatrixMessage.room_id == room.room_id)
|
||||
& (MatrixMessage.sender == sender)
|
||||
)
|
||||
.order_by(MatrixMessage.origin_server_ts.desc())
|
||||
)
|
||||
users = set()
|
||||
for msg_item in msg_items:
|
||||
if end_date is not None:
|
||||
if msg_item.datetime < end_date:
|
||||
break
|
||||
if msg_item.sender in DROP_USERS: # XXX: Special case for Arch Linux CN
|
||||
continue
|
||||
if msg_item.formatted_body is not None:
|
||||
string = re.sub(r"<mx-reply>.*</mx-reply>", "", msg_item.formatted_body)
|
||||
fwd_match = re.match(
|
||||
r"Forwarded message from .*<tg-forward>(.*)</tg-forward>",
|
||||
string,
|
||||
)
|
||||
if fwd_match is not None:
|
||||
string = fwd_match.group(1)
|
||||
print(strip_tags(string), file=stringio)
|
||||
count += 1
|
||||
users.add(msg_item.sender)
|
||||
elif msg_item.body is not None:
|
||||
# XXX: Special case for Arch Linux CN
|
||||
if msg_item.sender == "@matterbridge:nichi.co":
|
||||
data = re.sub(r"^\[.*\] ", "", msg_item.body)
|
||||
print(data.strip(), file=stringio)
|
||||
else:
|
||||
print(msg_item.body, file=stringio)
|
||||
count += 1
|
||||
users.add(msg_item.sender)
|
||||
else:
|
||||
continue
|
||||
class MessageIter:
|
||||
|
||||
ret = stringio.getvalue()
|
||||
return (ret, count, len(users))
|
||||
LIMIT = 1000
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
room: MatrixRoom,
|
||||
base_ts: int,
|
||||
sender: Optional[str],
|
||||
end_date: Optional[datetime],
|
||||
):
|
||||
self.msg_items = None
|
||||
self.final_batch = False
|
||||
self.sender = sender
|
||||
self.base_ts = base_ts
|
||||
self.last_ts = self.base_ts
|
||||
self.end_date = end_date
|
||||
self.users = set()
|
||||
self.count = 0
|
||||
|
||||
def order_next_batch(self):
|
||||
if self.sender is None:
|
||||
msg_items = (
|
||||
MatrixMessage.select()
|
||||
.where(
|
||||
(MatrixMessage.room_id == self.room.room_id)
|
||||
& (MatrixMessage.origin_server_ts < self.last_ts)
|
||||
)
|
||||
.order_by(MatrixMessage.origin_server_ts.desc())
|
||||
.limit(self.LIMIT)
|
||||
)
|
||||
else:
|
||||
msg_items = (
|
||||
MatrixMessage.select()
|
||||
.where(
|
||||
(MatrixMessage.room_id == self.room.room_id)
|
||||
& (MatrixMessage.sender == self.sender)
|
||||
& (MatrixMessage.origin_server_ts < self.last_ts)
|
||||
)
|
||||
.order_by(MatrixMessage.origin_server_ts.desc())
|
||||
.limit(self.LIMIT)
|
||||
)
|
||||
|
||||
self.msg_items = msg_items
|
||||
if msg_items.count() < self.LIMIT:
|
||||
self.final_batch = True
|
||||
|
||||
def __iter__(self):
|
||||
return self
|
||||
|
||||
def __next__(self):
|
||||
if self.msg_items is None:
|
||||
self.order_next_batch()
|
||||
try:
|
||||
msg_item = next(self.msg_items.iterator())
|
||||
while msg_item.sender in DROP_USERS: # XXX: Special case for Arch Linux CN
|
||||
msg_item = next(self.msg_items.iterator())
|
||||
if self.end_date is not None:
|
||||
if msg_item.datetime < self.end_date:
|
||||
raise StopIteration
|
||||
self.count += 1
|
||||
string = process_message(msg_item)
|
||||
self.users.add(msg_item.sender)
|
||||
self.last_ts = msg_item.origin_server_ts
|
||||
return string
|
||||
except StopIteration:
|
||||
if self.final_batch:
|
||||
raise StopIteration
|
||||
else:
|
||||
self.order_next_batch()
|
||||
return next(self)
|
||||
|
||||
|
||||
def process_message(msg_item):
|
||||
if msg_item.formatted_body is not None:
|
||||
string = re.sub(r"<mx-reply>.*</mx-reply>", "", msg_item.formatted_body)
|
||||
fwd_match = re.match(
|
||||
r"Forwarded message from .*<tg-forward>(.*)</tg-forward>",
|
||||
string,
|
||||
)
|
||||
if fwd_match is not None:
|
||||
string = fwd_match.group(1)
|
||||
return strip_tags(string)
|
||||
elif msg_item.body is not None:
|
||||
# XXX: Special case for Arch Linux CN
|
||||
if msg_item.sender == "@matterbridge:nichi.co":
|
||||
data = re.sub(r"^\[.*\] ", "", msg_item.body)
|
||||
return data.strip()
|
||||
else:
|
||||
return msg_item.body
|
||||
|
|
Loading…
Reference in New Issue