125 lines
6.1 KiB
Python
125 lines
6.1 KiB
Python
import aiohttp
|
|
import asyncio
|
|
import filetype
|
|
import base64
|
|
import logging
|
|
import json
|
|
import tempfile
|
|
import os
|
|
from typing import List, Dict
|
|
|
|
class MessageProcessor:
|
|
def __init__(self, bot=None, formatted_users=None, image_cache=dict(), discord_msgs=None, user_names=None):
|
|
self.bot = bot
|
|
self.formatted_users = formatted_users
|
|
self.image_cache = image_cache
|
|
self.discord_msgs = discord_msgs if discord_msgs else []
|
|
self.user_names = user_names if user_names else []
|
|
self.gemini_key = os.environ['GEMINI_KEY']
|
|
|
|
async def get_image_caption(self, attachment) -> str:
|
|
image_url = attachment.url
|
|
async with aiohttp.ClientSession() as session:
|
|
async with session.get(image_url) as response:
|
|
response.raise_for_status()
|
|
file_data = await response.read()
|
|
kind = filetype.guess(file_data)
|
|
if kind is not None:
|
|
if filetype.is_image(file_data) and kind.mime != 'image/gif':
|
|
if kind.mime == 'image/jpeg':
|
|
mime_type = 'image/jpeg'
|
|
elif kind.mime == 'image/png':
|
|
mime_type = 'image/png'
|
|
elif kind.mime == 'image/webp':
|
|
mime_type = 'image/webp'
|
|
elif kind.mime == 'image/heic':
|
|
mime_type = 'image/heic'
|
|
elif kind.mime == 'image/heif':
|
|
mime_type = 'image/heif'
|
|
else:
|
|
# Handle unsupported image formats here
|
|
mime_type = None
|
|
logging.info("is image")
|
|
extension = '.' + kind.extension
|
|
with tempfile.NamedTemporaryFile(suffix=extension, delete=False) as temp_file:
|
|
with open(temp_file.name, "wb") as f:
|
|
f.write(file_data)
|
|
with open(temp_file.name, "rb") as f:
|
|
data = f.read()
|
|
for _ in range(5):
|
|
try:
|
|
text = f"Describe the image in full detail, from the type of image to the texts if any, as lengthy as possible, in English only. Start your response with 'an image of' in lowercase, in as many text as possible. Classify persons/characters based on image URL if any. Image URL with filename in the end: {image_url}"
|
|
base64data = base64.b64encode(data).decode('utf-8')
|
|
jsonRequest = json.dumps({
|
|
"contents":[
|
|
{
|
|
"parts":[
|
|
{"text": text},
|
|
{
|
|
"inline_data": {
|
|
"mime_type": mime_type,
|
|
"data": base64data
|
|
}
|
|
}
|
|
]
|
|
}
|
|
],
|
|
"generationConfig": {
|
|
"temperature": 0.0,
|
|
}
|
|
})
|
|
response = await session.post('https://generativelanguage.googleapis.com/v1beta/models/gemini-pro-vision:generateContent?key=' + self.gemini_key, headers={'Content-Type': 'application/json'}, data=jsonRequest)
|
|
json_data = await response.json()
|
|
if json_data['candidates'][0]['content']['parts'][0]['text']:
|
|
break
|
|
except:
|
|
logging.info("An error occured while processing image, retrying...")
|
|
await asyncio.sleep(6)
|
|
|
|
output = json_data['candidates'][0]['content']['parts'][0]['text']
|
|
return output
|
|
|
|
async def process_messages(self, messages: List['Message']) -> List[dict]:
|
|
for m in messages:
|
|
msg_content = m.content
|
|
if m.attachments:
|
|
for attachment in m.attachments:
|
|
if attachment.url in self.image_cache:
|
|
image_caption = self.image_cache[attachment.url]
|
|
self.image_cache[attachment.url] = image_caption
|
|
msg_content = msg_content.strip() + f' *sends {image_caption}*'
|
|
else:
|
|
image_caption = await self.get_image_caption(attachment)
|
|
self.image_cache[attachment.url] = image_caption
|
|
logging.info(f"url: {attachment.url}")
|
|
msg_content = msg_content.strip() + f' *sends {image_caption}*'
|
|
|
|
self.discord_msgs.append({
|
|
"content": msg_content,
|
|
"author": m.author.name,
|
|
"is_bot": m.author.bot
|
|
})
|
|
if m.author.name != self.bot.user.name:
|
|
self.user_names.append(m.author.name)
|
|
|
|
self.formatted_users = ', '.join(set(self.user_names))
|
|
for m in messages:
|
|
if m.attachments:
|
|
for attachment in m.attachments:
|
|
if attachment.url not in self.image_cache:
|
|
del self.image_cache[attachment.url]
|
|
|
|
return self.discord_msgs
|
|
|
|
def get_discord_msgs(self):
|
|
return self.discord_msgs
|
|
|
|
def get_image_cache(self):
|
|
return self.image_cache
|
|
|
|
def get_formatted_users(self):
|
|
return self.formatted_users
|
|
|
|
def get_user_names(self):
|
|
return self.user_names
|