failed_discord_bot/get_messages.py

125 lines
6.1 KiB
Python

import aiohttp
import asyncio
import filetype
import base64
import logging
import json
import tempfile
import os
from typing import List, Dict
class MessageProcessor:
def __init__(self, bot=None, formatted_users=None, image_cache=dict(), discord_msgs=None, user_names=None):
self.bot = bot
self.formatted_users = formatted_users
self.image_cache = image_cache
self.discord_msgs = discord_msgs if discord_msgs else []
self.user_names = user_names if user_names else []
self.gemini_key = os.environ['GEMINI_KEY']
async def get_image_caption(self, attachment) -> str:
image_url = attachment.url
async with aiohttp.ClientSession() as session:
async with session.get(image_url) as response:
response.raise_for_status()
file_data = await response.read()
kind = filetype.guess(file_data)
if kind is not None:
if filetype.is_image(file_data) and kind.mime != 'image/gif':
if kind.mime == 'image/jpeg':
mime_type = 'image/jpeg'
elif kind.mime == 'image/png':
mime_type = 'image/png'
elif kind.mime == 'image/webp':
mime_type = 'image/webp'
elif kind.mime == 'image/heic':
mime_type = 'image/heic'
elif kind.mime == 'image/heif':
mime_type = 'image/heif'
else:
# Handle unsupported image formats here
mime_type = None
logging.info("is image")
extension = '.' + kind.extension
with tempfile.NamedTemporaryFile(suffix=extension, delete=False) as temp_file:
with open(temp_file.name, "wb") as f:
f.write(file_data)
with open(temp_file.name, "rb") as f:
data = f.read()
for _ in range(5):
try:
text = f"Describe the image in full detail, from the type of image to the texts if any, as lengthy as possible, in English only. Start your response with 'an image of' in lowercase, in as many text as possible. Classify persons/characters based on image URL if any. Image URL with filename in the end: {image_url}"
base64data = base64.b64encode(data).decode('utf-8')
jsonRequest = json.dumps({
"contents":[
{
"parts":[
{"text": text},
{
"inline_data": {
"mime_type": mime_type,
"data": base64data
}
}
]
}
],
"generationConfig": {
"temperature": 0.0,
}
})
response = await session.post('https://generativelanguage.googleapis.com/v1beta/models/gemini-pro-vision:generateContent?key=' + self.gemini_key, headers={'Content-Type': 'application/json'}, data=jsonRequest)
json_data = await response.json()
if json_data['candidates'][0]['content']['parts'][0]['text']:
break
except:
logging.info("An error occured while processing image, retrying...")
await asyncio.sleep(6)
output = json_data['candidates'][0]['content']['parts'][0]['text']
return output
async def process_messages(self, messages: List['Message']) -> List[dict]:
for m in messages:
msg_content = m.content
if m.attachments:
for attachment in m.attachments:
if attachment.url in self.image_cache:
image_caption = self.image_cache[attachment.url]
self.image_cache[attachment.url] = image_caption
msg_content = msg_content.strip() + f' *sends {image_caption}*'
else:
image_caption = await self.get_image_caption(attachment)
self.image_cache[attachment.url] = image_caption
logging.info(f"url: {attachment.url}")
msg_content = msg_content.strip() + f' *sends {image_caption}*'
self.discord_msgs.append({
"content": msg_content,
"author": m.author.name,
"is_bot": m.author.bot
})
if m.author.name != self.bot.user.name:
self.user_names.append(m.author.name)
self.formatted_users = ', '.join(set(self.user_names))
for m in messages:
if m.attachments:
for attachment in m.attachments:
if attachment.url not in self.image_cache:
del self.image_cache[attachment.url]
return self.discord_msgs
def get_discord_msgs(self):
return self.discord_msgs
def get_image_cache(self):
return self.image_cache
def get_formatted_users(self):
return self.formatted_users
def get_user_names(self):
return self.user_names