From 47168e8c9331a424b17ef54605e2510f302c82f1 Mon Sep 17 00:00:00 2001 From: cyi1341 Date: Fri, 12 Jan 2024 21:27:27 +0800 Subject: [PATCH] Added gemini and a little classes --- Dockerfile | 5 +- bot.py | 145 +++++++++++++++++---------------------------- config/config.yaml | 5 +- get_messages.py | 124 ++++++++++++++++++++++++++++++++++++++ run.py | 13 ++++ 5 files changed, 198 insertions(+), 94 deletions(-) create mode 100644 get_messages.py create mode 100644 run.py diff --git a/Dockerfile b/Dockerfile index 7b7715d..f6ceeb3 100644 --- a/Dockerfile +++ b/Dockerfile @@ -3,6 +3,7 @@ FROM python:3.11-alpine # Install dependencies COPY requirements.txt ./ + RUN pip install -r requirements.txt # Set working directory @@ -11,8 +12,8 @@ WORKDIR /app # Bundle app source COPY . . -# Expose port -EXPOSE 8080 +# Expose port +EXPOSE 8080 # Run bot CMD ["python", "bot.py"] diff --git a/bot.py b/bot.py index ab2f5ab..226b0ee 100644 --- a/bot.py +++ b/bot.py @@ -7,7 +7,6 @@ import nextcord from nextcord.ext import commands import re import threading -from llama_tokenizer_lite import LlamaTokenizerLite import logging import traceback from datetime import timedelta @@ -19,6 +18,10 @@ import unicodedata import string from html import escape from difflib import SequenceMatcher +# self +from llama_tokenizer_lite import LlamaTokenizerLite +from run import failed_discord_bot +from get_messages import MessageProcessor def similar(a, b): return SequenceMatcher(None, a, b).ratio() @@ -29,9 +32,9 @@ Mancer_Key = os.getenv("MANCER_KEY") Mancer_Model = "mytholite" Mancer_URL = "https://neuro.mancer.tech/webui/"+Mancer_Model+"/api" -intents = nextcord.Intents.default() -intents.message_content = True -bot = nextcord.Client() +run = failed_discord_bot() +intents = run.get_intents() +bot = run.get_bot() tokenizer = LlamaTokenizerLite() @@ -140,6 +143,7 @@ async def process_images_and_generate_responses(message): try: for _ in range(7): # Limit to 7 attempts + logging.info(f"Number of _: {_}") if not reply_punctuation: channel = message.channel messages = [] @@ -162,79 +166,27 @@ async def process_images_and_generate_responses(message): return while True: + # Process the messages + global discord_msgs, formatted_users, user_names discord_msgs = [] - user_names = [] - HUGGING_API_URL = "https://api-inference.huggingface.co/models/Salesforce/blip-image-captioning-large" - hugging_key = os.getenv("HUGGING_KEY") - hugging_headers = {"Authorization": f"Bearer {hugging_key}"} - logging.info("0 initialise") + if 'prompt' not in globals(): + message_processor = MessageProcessor(bot=bot) + else: + message_processor = MessageProcessor( + bot=bot, + formatted_users=formatted_users, + image_cache=image_cache, + discord_msgs=discord_msgs, + user_names=user_names + ) + logging.info("0 image stuff") + await message_processor.process_messages(messages) - for m in messages: - if m.attachments: - if m.id not in image_cache: - image_cache[m.id] = [] - for attachment in m.attachments: - # If the image has been processed before, skip the processing part - if any(img['url'] == attachment.url for img in image_cache[m.id]): - continue - async with aiohttp.ClientSession() as session: - async with session.get(attachment.url) as response: - response.raise_for_status() - file_data = await response.read() - kind = filetype.guess(file_data) - if kind is not None: - if filetype.is_image(file_data) and kind.mime != 'image/gif': - logging.info("is image") - extension = '.' + kind.extension - with tempfile.NamedTemporaryFile(suffix=extension, delete=False) as temp_file: - temp_file.write(file_data) - temp_file.seek(0) - with open(temp_file.name, "rb") as f: - data = f.read() - for _ in range(5): - try: - response = requests.post(HUGGING_API_URL, headers=hugging_headers, data=data) - if response.json()[0]["generated_text"]: - break - except: - logging.info("An error occured while processing image, retrying...") - await asyncio.sleep(6) - - output = response.json() - try: - image_cache[m.id].append({'url': attachment.url, 'output': output[0]["generated_text"]}) - except: - image_cache[m.id].append({'url': attachment.url, 'output': "something random"}) - - - logging.info([img['output'] for img in image_cache[m.id]]) - else: - logging.info("is not image") - - msg_content = m.content - if m.id in image_cache: - for image in image_cache[m.id]: - if msg_content: - msg_content += f' *sends an image of {image["output"]}*' - else: - msg_content += f'*sends an image of {image["output"]}*' - else: - msg_content = m.content - - discord_msgs.append({ - "content": msg_content, - "author": m.author.name, - "is_bot": m.author.bot - }) - if m.author.name != bot.user.name: - user_names.append(m.author.name) - - # Remove cache for messages that no longer exist - for msg_id in list(image_cache.keys()): - if not any(msg.id == msg_id for msg in messages): - del image_cache[msg_id] - logging.info("1 image stuff") - formatted_users = ', '.join(set(user_names)) + # Access the processed messages and image cache + discord_msgs = message_processor.get_discord_msgs() + image_cache = message_processor.get_image_cache() + formatted_users = message_processor.get_formatted_users() + user_names = message_processor.get_user_names() global system_prompt global input_segment @@ -250,21 +202,29 @@ async def process_images_and_generate_responses(message): system_prompt_directory = os.path.join(config_directory, 'system_prompt') input_segment_directory = os.path.join(config_directory, 'input_segment') output_segment_directory = os.path.join(config_directory, 'output_segment') + input_sequence_directory = os.path.join(config_directory, 'input_sequence') + output_sequence_directory = os.path.join(config_directory, 'output_sequence') end_loop_prompt_directory = os.path.join(config_directory, 'end_loop_prompt') ensure_directory_exists(system_prompt_directory) ensure_directory_exists(input_segment_directory) ensure_directory_exists(output_segment_directory) + ensure_directory_exists(input_sequence_directory) + ensure_directory_exists(output_sequence_directory) ensure_directory_exists(end_loop_prompt_directory) system_prompt = get_first_filename(data, 'system_prompt') input_segment = get_first_filename(data, 'input_segment') output_segment = get_first_filename(data, 'output_segment') + input_sequence = get_first_filename(data, 'input_sequence') + output_sequence = get_first_filename(data, 'output_sequence') end_loop_prompt = get_first_filename(data, 'end_loop_prompt') system_prompt = read_file_contents(system_prompt_directory, system_prompt) input_segment = read_file_contents(input_segment_directory, input_segment) output_segment = read_file_contents(output_segment_directory, output_segment) + input_sequence = read_file_contents(input_sequence_directory, input_sequence) + output_sequence = read_file_contents(output_sequence_directory, output_sequence) end_loop_prompt = read_file_contents(end_loop_prompt_directory, end_loop_prompt) end_loop_list = end_loop_prompt.split("\n") end_loop_list = [item for item in end_loop_list if item != ''] @@ -276,6 +236,7 @@ async def process_images_and_generate_responses(message): while i < len(discord_msgs): m = discord_msgs[i] + logging.info(f"discord_msgs: {discord_msgs}") try: last_char = m['content'][-1] except: @@ -293,7 +254,18 @@ async def process_images_and_generate_responses(message): prev_author = m['author'] i += 1 - message_prompt = "\n".join([f"{m['author']}: {m['content']}" for m in discord_msgs]) + '\n' + message_prompt = "" + append_input_seq = True + for m in discord_msgs: + if m['author'] == bot.user.name: + if not append_input_seq: + append_input_seq = True + message_prompt += output_sequence + else: + if append_input_seq: + message_prompt += input_sequence + append_input_seq = False + message_prompt += f"{m['author']}: {m['content']}\n" last_bot_index = None for i, line in enumerate(message_prompt.split('\n')): @@ -301,14 +273,6 @@ async def process_images_and_generate_responses(message): last_bot_index = i logging.info(f"last bot index: {last_bot_index}") - if last_bot_index is not None: - # Cover next user message - prompt_lines = message_prompt.split('\n') - next_user_line_index = last_bot_index + 1 - if next_user_line_index < len(prompt_lines): - prompt_lines[next_user_line_index] = '{{ ' + prompt_lines[next_user_line_index] + ' }}' - message_prompt = '\n'.join(prompt_lines) - prompt = (f"{system_prompt}{input_segment}{message_prompt}{output_segment}") prompt = prompt.replace('{{users}}', formatted_users) prompt = prompt.replace('{{bot}}', bot.user.name) @@ -319,7 +283,7 @@ async def process_images_and_generate_responses(message): stop_string = [f"\n{name}" for name in set(user_names)] stop_string += [f"{bot.user.name}:"] - stop_string += ["", "<|", "\n#", "\n\n\n", "", "\n<", "\n{", ",", "?", "!", ";", "."] + stop_string += ["", "<|", "\n#", "\n*", "\n\n\n", "", ",", "?", "!", ";", "."] headers = { @@ -334,16 +298,11 @@ async def process_images_and_generate_responses(message): if token_count <= 2410: break else: + logging.info(f"The too many messages: {messages}") if len(messages) != 0: messages.pop(0) oldest_message_timestamp = messages[0].created_at - timedelta(microseconds=1) - - # Remove entries from image_cache that are no longer in messages - message_ids = [m.id for m in messages] - for id in list(image_cache.keys()): - if id not in message_ids: - del image_cache[id] old_messages_object = [msg async for msg in channel.history(after=nextcord.Object(id=last_bot_message_id)) if last_bot_message_id] old_messages = "\n".join([f"{m.author}: \n{m.content}" for m in old_messages_object]) + "\n" if len(old_messages) == 0: @@ -424,6 +383,7 @@ async def process_images_and_generate_responses(message): for m in messages: if m.attachments: msg_content = m.content + logging.info(f"please tell me there is any image cache in here {image_cache}") if m.id in image_cache: for image in image_cache[m.id]: if msg_content: @@ -431,6 +391,7 @@ async def process_images_and_generate_responses(message): else: msg_content += f'*sends an image of {image["output"]}*' else: + logging.info(f"no image for {m.content}") msg_content = m.content discord_msgs.append({ @@ -463,6 +424,8 @@ async def process_images_and_generate_responses(message): for i, line in enumerate(message_prompt.split('\n')): if line.startswith(bot.user.name + ":"): last_bot_index = i + logging.info(f"message_prompt: {message_prompt}") + logging.info(f"bot.user.name: {bot.user.name}") logging.info(f"last bot index: {last_bot_index}") prompt = (f"{system_prompt}{input_segment}{message_prompt}{output_segment}") prompt = prompt.replace('{{users}}', formatted_users) diff --git a/config/config.yaml b/config/config.yaml index 75b8288..b6c3027 100644 --- a/config/config.yaml +++ b/config/config.yaml @@ -16,4 +16,7 @@ boot_prompt: - boot_prompt.txt initial_prompt: - initial_prompt.txt - +input_sequence: + - input_sequence.txt +output_sequence: + - output_sequence.txt diff --git a/get_messages.py b/get_messages.py new file mode 100644 index 0000000..716853c --- /dev/null +++ b/get_messages.py @@ -0,0 +1,124 @@ +import aiohttp +import asyncio +import filetype +import base64 +import logging +import json +import tempfile +import os +from typing import List, Dict + +class MessageProcessor: + def __init__(self, bot=None, formatted_users=None, image_cache=dict(), discord_msgs=None, user_names=None): + self.bot = bot + self.formatted_users = formatted_users + self.image_cache = image_cache + self.discord_msgs = discord_msgs if discord_msgs else [] + self.user_names = user_names if user_names else [] + self.gemini_key = os.environ['GEMINI_KEY'] + + async def get_image_caption(self, attachment) -> str: + image_url = attachment.url + async with aiohttp.ClientSession() as session: + async with session.get(image_url) as response: + response.raise_for_status() + file_data = await response.read() + kind = filetype.guess(file_data) + if kind is not None: + if filetype.is_image(file_data) and kind.mime != 'image/gif': + if kind.mime == 'image/jpeg': + mime_type = 'image/jpeg' + elif kind.mime == 'image/png': + mime_type = 'image/png' + elif kind.mime == 'image/webp': + mime_type = 'image/webp' + elif kind.mime == 'image/heic': + mime_type = 'image/heic' + elif kind.mime == 'image/heif': + mime_type = 'image/heif' + else: + # Handle unsupported image formats here + mime_type = None + logging.info("is image") + extension = '.' + kind.extension + with tempfile.NamedTemporaryFile(suffix=extension, delete=False) as temp_file: + with open(temp_file.name, "wb") as f: + f.write(file_data) + with open(temp_file.name, "rb") as f: + data = f.read() + for _ in range(5): + try: + text = f"Describe the image in full detail, from the type of image to the texts if any, as lengthy as possible, in English only. Start your response with 'an image of' in lowercase, in as many text as possible. Classify persons/characters based on image URL if any. Image URL with filename in the end: {image_url}" + base64data = base64.b64encode(data).decode('utf-8') + jsonRequest = json.dumps({ + "contents":[ + { + "parts":[ + {"text": text}, + { + "inline_data": { + "mime_type": mime_type, + "data": base64data + } + } + ] + } + ], + "generationConfig": { + "temperature": 0.0, + } + }) + response = await session.post('https://generativelanguage.googleapis.com/v1beta/models/gemini-pro-vision:generateContent?key=' + self.gemini_key, headers={'Content-Type': 'application/json'}, data=jsonRequest) + json_data = await response.json() + if json_data['candidates'][0]['content']['parts'][0]['text']: + break + except: + logging.info("An error occured while processing image, retrying...") + await asyncio.sleep(6) + + output = json_data['candidates'][0]['content']['parts'][0]['text'] + return output + + async def process_messages(self, messages: List['Message']) -> List[dict]: + for m in messages: + msg_content = m.content + if m.attachments: + for attachment in m.attachments: + if attachment.url in self.image_cache: + image_caption = self.image_cache[attachment.url] + self.image_cache[attachment.url] = image_caption + msg_content = msg_content.strip() + f' *sends {image_caption}*' + else: + image_caption = await self.get_image_caption(attachment) + self.image_cache[attachment.url] = image_caption + logging.info(f"url: {attachment.url}") + msg_content = msg_content.strip() + f' *sends {image_caption}*' + + self.discord_msgs.append({ + "content": msg_content, + "author": m.author.name, + "is_bot": m.author.bot + }) + if m.author.name != self.bot.user.name: + self.user_names.append(m.author.name) + + self.formatted_users = ', '.join(set(self.user_names)) + for m in messages: + if m.attachments: + for attachment in m.attachments: + if attachment.url not in self.image_cache: + del self.image_cache[attachment.url] + + return self.discord_msgs + + def get_discord_msgs(self): + return self.discord_msgs + + def get_image_cache(self): + return self.image_cache + + def get_formatted_users(self): + return self.formatted_users + + def get_user_names(self): + return self.user_names diff --git a/run.py b/run.py new file mode 100644 index 0000000..00a0510 --- /dev/null +++ b/run.py @@ -0,0 +1,13 @@ +import nextcord + +class failed_discord_bot: + def __init__(self): + self.intents = nextcord.Intents.default() + self.intents.message_content = True + self.bot = nextcord.Client(intents=self.intents) + + def get_bot(self): + return self.bot + + def get_intents(self): + return self.intents