Added gemini and a little classes

This commit is contained in:
cyi1341 2024-01-12 21:27:27 +08:00
parent a82cbe2b01
commit 47168e8c93
5 changed files with 198 additions and 94 deletions

View File

@ -3,6 +3,7 @@ FROM python:3.11-alpine
# Install dependencies
COPY requirements.txt ./
RUN pip install -r requirements.txt
# Set working directory
@ -11,8 +12,8 @@ WORKDIR /app
# Bundle app source
COPY . .
# Expose port
EXPOSE 8080
# Expose port
EXPOSE 8080
# Run bot
CMD ["python", "bot.py"]

145
bot.py
View File

@ -7,7 +7,6 @@ import nextcord
from nextcord.ext import commands
import re
import threading
from llama_tokenizer_lite import LlamaTokenizerLite
import logging
import traceback
from datetime import timedelta
@ -19,6 +18,10 @@ import unicodedata
import string
from html import escape
from difflib import SequenceMatcher
# self
from llama_tokenizer_lite import LlamaTokenizerLite
from run import failed_discord_bot
from get_messages import MessageProcessor
def similar(a, b):
return SequenceMatcher(None, a, b).ratio()
@ -29,9 +32,9 @@ Mancer_Key = os.getenv("MANCER_KEY")
Mancer_Model = "mytholite"
Mancer_URL = "https://neuro.mancer.tech/webui/"+Mancer_Model+"/api"
intents = nextcord.Intents.default()
intents.message_content = True
bot = nextcord.Client()
run = failed_discord_bot()
intents = run.get_intents()
bot = run.get_bot()
tokenizer = LlamaTokenizerLite()
@ -140,6 +143,7 @@ async def process_images_and_generate_responses(message):
try:
for _ in range(7): # Limit to 7 attempts
logging.info(f"Number of _: {_}")
if not reply_punctuation:
channel = message.channel
messages = []
@ -162,79 +166,27 @@ async def process_images_and_generate_responses(message):
return
while True:
# Process the messages
global discord_msgs, formatted_users, user_names
discord_msgs = []
user_names = []
HUGGING_API_URL = "https://api-inference.huggingface.co/models/Salesforce/blip-image-captioning-large"
hugging_key = os.getenv("HUGGING_KEY")
hugging_headers = {"Authorization": f"Bearer {hugging_key}"}
logging.info("0 initialise")
if 'prompt' not in globals():
message_processor = MessageProcessor(bot=bot)
else:
message_processor = MessageProcessor(
bot=bot,
formatted_users=formatted_users,
image_cache=image_cache,
discord_msgs=discord_msgs,
user_names=user_names
)
logging.info("0 image stuff")
await message_processor.process_messages(messages)
for m in messages:
if m.attachments:
if m.id not in image_cache:
image_cache[m.id] = []
for attachment in m.attachments:
# If the image has been processed before, skip the processing part
if any(img['url'] == attachment.url for img in image_cache[m.id]):
continue
async with aiohttp.ClientSession() as session:
async with session.get(attachment.url) as response:
response.raise_for_status()
file_data = await response.read()
kind = filetype.guess(file_data)
if kind is not None:
if filetype.is_image(file_data) and kind.mime != 'image/gif':
logging.info("is image")
extension = '.' + kind.extension
with tempfile.NamedTemporaryFile(suffix=extension, delete=False) as temp_file:
temp_file.write(file_data)
temp_file.seek(0)
with open(temp_file.name, "rb") as f:
data = f.read()
for _ in range(5):
try:
response = requests.post(HUGGING_API_URL, headers=hugging_headers, data=data)
if response.json()[0]["generated_text"]:
break
except:
logging.info("An error occured while processing image, retrying...")
await asyncio.sleep(6)
output = response.json()
try:
image_cache[m.id].append({'url': attachment.url, 'output': output[0]["generated_text"]})
except:
image_cache[m.id].append({'url': attachment.url, 'output': "something random"})
logging.info([img['output'] for img in image_cache[m.id]])
else:
logging.info("is not image")
msg_content = m.content
if m.id in image_cache:
for image in image_cache[m.id]:
if msg_content:
msg_content += f' *sends an image of {image["output"]}*'
else:
msg_content += f'*sends an image of {image["output"]}*'
else:
msg_content = m.content
discord_msgs.append({
"content": msg_content,
"author": m.author.name,
"is_bot": m.author.bot
})
if m.author.name != bot.user.name:
user_names.append(m.author.name)
# Remove cache for messages that no longer exist
for msg_id in list(image_cache.keys()):
if not any(msg.id == msg_id for msg in messages):
del image_cache[msg_id]
logging.info("1 image stuff")
formatted_users = ', '.join(set(user_names))
# Access the processed messages and image cache
discord_msgs = message_processor.get_discord_msgs()
image_cache = message_processor.get_image_cache()
formatted_users = message_processor.get_formatted_users()
user_names = message_processor.get_user_names()
global system_prompt
global input_segment
@ -250,21 +202,29 @@ async def process_images_and_generate_responses(message):
system_prompt_directory = os.path.join(config_directory, 'system_prompt')
input_segment_directory = os.path.join(config_directory, 'input_segment')
output_segment_directory = os.path.join(config_directory, 'output_segment')
input_sequence_directory = os.path.join(config_directory, 'input_sequence')
output_sequence_directory = os.path.join(config_directory, 'output_sequence')
end_loop_prompt_directory = os.path.join(config_directory, 'end_loop_prompt')
ensure_directory_exists(system_prompt_directory)
ensure_directory_exists(input_segment_directory)
ensure_directory_exists(output_segment_directory)
ensure_directory_exists(input_sequence_directory)
ensure_directory_exists(output_sequence_directory)
ensure_directory_exists(end_loop_prompt_directory)
system_prompt = get_first_filename(data, 'system_prompt')
input_segment = get_first_filename(data, 'input_segment')
output_segment = get_first_filename(data, 'output_segment')
input_sequence = get_first_filename(data, 'input_sequence')
output_sequence = get_first_filename(data, 'output_sequence')
end_loop_prompt = get_first_filename(data, 'end_loop_prompt')
system_prompt = read_file_contents(system_prompt_directory, system_prompt)
input_segment = read_file_contents(input_segment_directory, input_segment)
output_segment = read_file_contents(output_segment_directory, output_segment)
input_sequence = read_file_contents(input_sequence_directory, input_sequence)
output_sequence = read_file_contents(output_sequence_directory, output_sequence)
end_loop_prompt = read_file_contents(end_loop_prompt_directory, end_loop_prompt)
end_loop_list = end_loop_prompt.split("\n")
end_loop_list = [item for item in end_loop_list if item != '']
@ -276,6 +236,7 @@ async def process_images_and_generate_responses(message):
while i < len(discord_msgs):
m = discord_msgs[i]
logging.info(f"discord_msgs: {discord_msgs}")
try:
last_char = m['content'][-1]
except:
@ -293,7 +254,18 @@ async def process_images_and_generate_responses(message):
prev_author = m['author']
i += 1
message_prompt = "\n".join([f"{m['author']}: {m['content']}" for m in discord_msgs]) + '\n'
message_prompt = ""
append_input_seq = True
for m in discord_msgs:
if m['author'] == bot.user.name:
if not append_input_seq:
append_input_seq = True
message_prompt += output_sequence
else:
if append_input_seq:
message_prompt += input_sequence
append_input_seq = False
message_prompt += f"{m['author']}: {m['content']}\n"
last_bot_index = None
for i, line in enumerate(message_prompt.split('\n')):
@ -301,14 +273,6 @@ async def process_images_and_generate_responses(message):
last_bot_index = i
logging.info(f"last bot index: {last_bot_index}")
if last_bot_index is not None:
# Cover next user message
prompt_lines = message_prompt.split('\n')
next_user_line_index = last_bot_index + 1
if next_user_line_index < len(prompt_lines):
prompt_lines[next_user_line_index] = '{{ ' + prompt_lines[next_user_line_index] + ' }}'
message_prompt = '\n'.join(prompt_lines)
prompt = (f"{system_prompt}{input_segment}{message_prompt}{output_segment}")
prompt = prompt.replace('{{users}}', formatted_users)
prompt = prompt.replace('{{bot}}', bot.user.name)
@ -319,7 +283,7 @@ async def process_images_and_generate_responses(message):
stop_string = [f"\n{name}" for name in set(user_names)]
stop_string += [f"{bot.user.name}:"]
stop_string += ["</s>", "<|", "\n#", "\n\n\n", "<START>", "\n<", "\n{", ",", "?", "!", ";", "."]
stop_string += ["</s>", "<|", "\n#", "\n*", "\n\n\n", "<START>", ",", "?", "!", ";", "."]
headers = {
@ -334,16 +298,11 @@ async def process_images_and_generate_responses(message):
if token_count <= 2410:
break
else:
logging.info(f"The too many messages: {messages}")
if len(messages) != 0:
messages.pop(0)
oldest_message_timestamp = messages[0].created_at - timedelta(microseconds=1)
# Remove entries from image_cache that are no longer in messages
message_ids = [m.id for m in messages]
for id in list(image_cache.keys()):
if id not in message_ids:
del image_cache[id]
old_messages_object = [msg async for msg in channel.history(after=nextcord.Object(id=last_bot_message_id)) if last_bot_message_id]
old_messages = "\n".join([f"{m.author}: \n{m.content}" for m in old_messages_object]) + "\n"
if len(old_messages) == 0:
@ -424,6 +383,7 @@ async def process_images_and_generate_responses(message):
for m in messages:
if m.attachments:
msg_content = m.content
logging.info(f"please tell me there is any image cache in here {image_cache}")
if m.id in image_cache:
for image in image_cache[m.id]:
if msg_content:
@ -431,6 +391,7 @@ async def process_images_and_generate_responses(message):
else:
msg_content += f'*sends an image of {image["output"]}*'
else:
logging.info(f"no image for {m.content}")
msg_content = m.content
discord_msgs.append({
@ -463,6 +424,8 @@ async def process_images_and_generate_responses(message):
for i, line in enumerate(message_prompt.split('\n')):
if line.startswith(bot.user.name + ":"):
last_bot_index = i
logging.info(f"message_prompt: {message_prompt}")
logging.info(f"bot.user.name: {bot.user.name}")
logging.info(f"last bot index: {last_bot_index}")
prompt = (f"{system_prompt}{input_segment}{message_prompt}{output_segment}")
prompt = prompt.replace('{{users}}', formatted_users)

View File

@ -16,4 +16,7 @@ boot_prompt:
- boot_prompt.txt
initial_prompt:
- initial_prompt.txt
input_sequence:
- input_sequence.txt
output_sequence:
- output_sequence.txt

124
get_messages.py Normal file
View File

@ -0,0 +1,124 @@
import aiohttp
import asyncio
import filetype
import base64
import logging
import json
import tempfile
import os
from typing import List, Dict
class MessageProcessor:
def __init__(self, bot=None, formatted_users=None, image_cache=dict(), discord_msgs=None, user_names=None):
self.bot = bot
self.formatted_users = formatted_users
self.image_cache = image_cache
self.discord_msgs = discord_msgs if discord_msgs else []
self.user_names = user_names if user_names else []
self.gemini_key = os.environ['GEMINI_KEY']
async def get_image_caption(self, attachment) -> str:
image_url = attachment.url
async with aiohttp.ClientSession() as session:
async with session.get(image_url) as response:
response.raise_for_status()
file_data = await response.read()
kind = filetype.guess(file_data)
if kind is not None:
if filetype.is_image(file_data) and kind.mime != 'image/gif':
if kind.mime == 'image/jpeg':
mime_type = 'image/jpeg'
elif kind.mime == 'image/png':
mime_type = 'image/png'
elif kind.mime == 'image/webp':
mime_type = 'image/webp'
elif kind.mime == 'image/heic':
mime_type = 'image/heic'
elif kind.mime == 'image/heif':
mime_type = 'image/heif'
else:
# Handle unsupported image formats here
mime_type = None
logging.info("is image")
extension = '.' + kind.extension
with tempfile.NamedTemporaryFile(suffix=extension, delete=False) as temp_file:
with open(temp_file.name, "wb") as f:
f.write(file_data)
with open(temp_file.name, "rb") as f:
data = f.read()
for _ in range(5):
try:
text = f"Describe the image in full detail, from the type of image to the texts if any, as lengthy as possible, in English only. Start your response with 'an image of' in lowercase, in as many text as possible. Classify persons/characters based on image URL if any. Image URL with filename in the end: {image_url}"
base64data = base64.b64encode(data).decode('utf-8')
jsonRequest = json.dumps({
"contents":[
{
"parts":[
{"text": text},
{
"inline_data": {
"mime_type": mime_type,
"data": base64data
}
}
]
}
],
"generationConfig": {
"temperature": 0.0,
}
})
response = await session.post('https://generativelanguage.googleapis.com/v1beta/models/gemini-pro-vision:generateContent?key=' + self.gemini_key, headers={'Content-Type': 'application/json'}, data=jsonRequest)
json_data = await response.json()
if json_data['candidates'][0]['content']['parts'][0]['text']:
break
except:
logging.info("An error occured while processing image, retrying...")
await asyncio.sleep(6)
output = json_data['candidates'][0]['content']['parts'][0]['text']
return output
async def process_messages(self, messages: List['Message']) -> List[dict]:
for m in messages:
msg_content = m.content
if m.attachments:
for attachment in m.attachments:
if attachment.url in self.image_cache:
image_caption = self.image_cache[attachment.url]
self.image_cache[attachment.url] = image_caption
msg_content = msg_content.strip() + f' *sends {image_caption}*'
else:
image_caption = await self.get_image_caption(attachment)
self.image_cache[attachment.url] = image_caption
logging.info(f"url: {attachment.url}")
msg_content = msg_content.strip() + f' *sends {image_caption}*'
self.discord_msgs.append({
"content": msg_content,
"author": m.author.name,
"is_bot": m.author.bot
})
if m.author.name != self.bot.user.name:
self.user_names.append(m.author.name)
self.formatted_users = ', '.join(set(self.user_names))
for m in messages:
if m.attachments:
for attachment in m.attachments:
if attachment.url not in self.image_cache:
del self.image_cache[attachment.url]
return self.discord_msgs
def get_discord_msgs(self):
return self.discord_msgs
def get_image_cache(self):
return self.image_cache
def get_formatted_users(self):
return self.formatted_users
def get_user_names(self):
return self.user_names

13
run.py Normal file
View File

@ -0,0 +1,13 @@
import nextcord
class failed_discord_bot:
def __init__(self):
self.intents = nextcord.Intents.default()
self.intents.message_content = True
self.bot = nextcord.Client(intents=self.intents)
def get_bot(self):
return self.bot
def get_intents(self):
return self.intents