Use pagination

This commit is contained in:
夜坂雅 2023-06-10 14:54:18 +08:00
parent 10134ab920
commit 2184af2055
1 changed files with 44 additions and 32 deletions

View File

@ -5,9 +5,8 @@ import re
import time
from asyncio import create_subprocess_exec
from asyncio.subprocess import PIPE
from collections import namedtuple
from datetime import datetime, timedelta, timezone
from io import BytesIO, StringIO
from io import BytesIO
from typing import Optional
from nio import AsyncClient, MatrixRoom, RoomMessageText, UploadResponse
@ -77,7 +76,7 @@ async def send_wordcloud(
end_date = None
if days is not None:
end_date = start_date - timedelta(days=days)
texts = MessageIter(room, sender, end_date)
texts = MessageIter(room, event.server_timestamp, sender, end_date)
freqs = await get_word_freqs(texts)
st3 = time.time()
@ -169,6 +168,7 @@ class MessageIter:
def __init__(
self,
room: MatrixRoom,
base_ts: int,
sender: Optional[str],
end_date: Optional[datetime],
):
@ -179,37 +179,49 @@ class MessageIter:
self.room = room
self.users = set()
self.count = 0
self.last_ts = base_ts
def batches(self, limit: int):
"""Return a iterator for query pagination."""
while not self.done:
if self.sender is None:
msg_items = (
MatrixMessage.select()
.where(
(MatrixMessage.room_id == self.room.room_id)
& (MatrixMessage.origin_server_ts < self.last_ts)
)
.order_by(MatrixMessage.origin_server_ts.desc())
.limit(limit)
)
else:
msg_items = (
MatrixMessage.select()
.where(
(MatrixMessage.room_id == self.room.room_id)
& (MatrixMessage.sender == self.sender)
& (MatrixMessage.origin_server_ts < self.last_ts)
)
.order_by(MatrixMessage.origin_server_ts.desc())
.limit(limit)
)
if msg_items.count() < limit:
self.done = True
yield msg_items
def __iter__(self):
return self.into_iter()
def into_iter(self):
if self.sender is None:
msg_items = (
MatrixMessage.select()
.where(MatrixMessage.room_id == self.room.room_id)
.order_by(MatrixMessage.origin_server_ts.desc())
)
else:
msg_items = (
MatrixMessage.select()
.where(
(MatrixMessage.room_id == self.room.room_id)
& (MatrixMessage.sender == self.sender)
)
.order_by(MatrixMessage.origin_server_ts.desc())
)
for msg_item in msg_items.namedtuples().iterator():
if msg_item.sender in DROP_USERS: # XXX: Special case for Arch Linux CN
continue
if self.end_date is not None:
if msg_item.datetime < self.end_date:
return
self.count += 1
string = process_message(msg_item)
self.users.add(msg_item.sender)
self.last_ts = msg_item.origin_server_ts
yield string
for msg_items in self.batches(self.LIMIT):
for msg_item in msg_items.namedtuples().iterator():
if msg_item.sender in DROP_USERS: # XXX: Special case for Arch Linux CN
continue
if self.end_date is not None:
if msg_item.datetime < self.end_date:
return
self.count += 1
string = process_message(msg_item)
self.users.add(msg_item.sender)
self.last_ts = msg_item.origin_server_ts
yield string
def process_message(msg_item):