failed_discord_bot/bot.py

618 lines
29 KiB
Python

import os
import requests
import json
import yaml
import asyncio
import nextcord
from nextcord.ext import commands
import re
import threading
from llama_tokenizer_lite import LlamaTokenizerLite
import logging
import traceback
from datetime import timedelta
import tempfile
import filetype
import aiohttp
import random
import unicodedata
import string
from html import escape
from difflib import SequenceMatcher
def similar(a, b):
return SequenceMatcher(None, a, b).ratio()
logging.basicConfig(level=logging.INFO, format='%(message)s')
Mancer_Key = os.getenv("MANCER_KEY")
Mancer_Model = "mytholite"
Mancer_URL = "https://neuro.mancer.tech/webui/"+Mancer_Model+"/api"
intents = nextcord.Intents.default()
intents.message_content = True
bot = nextcord.Client()
tokenizer = LlamaTokenizerLite()
oldest_message_timestamp = None
last_bot_message_id = None
image_cache = {}
current_task = None
# Create a lock
lock = asyncio.Lock()
def get_mancer_body(prompt, stop_string):
data = read_yaml_file(os.path.join(config_directory, 'config.yaml'))
mancer_settings_directory = os.path.join(config_directory, 'mancer_settings')
ensure_directory_exists(mancer_settings_directory)
mancer_settings = get_first_filename(data, 'mancer_settings')
mancer_settings = read_file_contents(mancer_settings_directory, mancer_settings)
data = json.loads(mancer_settings)
data['prompt'] = prompt
data['stopping_strings'] = stop_string
return data
def read_yaml_file(filename):
with open(filename, 'r') as file:
data = yaml.safe_load(file)
return data
def get_first_filename(data, field):
return data[field][0]
def read_file_contents(directory, filename):
with open(os.path.join(directory, filename), 'r') as file:
contents = file.read()
return contents
def ensure_directory_exists(directory):
if not os.path.exists(directory):
os.makedirs(directory)
config_directory = 'config'
class MessageQueue():
def __init__(self):
self.queue = asyncio.Queue()
async def add_message(self, message):
await self.queue.put(message)
async def process_messages(self):
while True:
message = await self.queue.get()
try:
await process_images_and_generate_responses(message)
except Exception as e:
logging.exception("Failed to process message", exc_info=e)
finally:
self.queue.task_done()
message_queue = MessageQueue()
@bot.event
async def on_ready():
for guild in bot.guilds:
for channel in guild.text_channels:
if channel.permissions_for(guild.me).send_messages:
data = read_yaml_file(os.path.join(config_directory, 'config.yaml'))
boot_prompt_directory = os.path.join(config_directory, 'boot_prompt')
initial_prompt_directory = os.path.join(config_directory, 'initial_prompt')
ensure_directory_exists(boot_prompt_directory)
ensure_directory_exists(initial_prompt_directory)
boot_prompt = get_first_filename(data, 'boot_prompt')
initial_prompt = get_first_filename(data, 'initial_prompt')
boot_prompt = read_file_contents(boot_prompt_directory, boot_prompt)
initial_prompt = read_file_contents(initial_prompt_directory, initial_prompt)
await channel.send(boot_prompt)
sent_message = await channel.send(initial_prompt)
global oldest_message_timestamp, last_bot_message_id
oldest_message_timestamp = sent_message.created_at - timedelta(microseconds=1)
last_bot_message_id = sent_message.id
await message_queue.add_message(sent_message)
break
logging.info('Bot is ready')
@bot.event
async def on_message(message):
await message_queue.add_message(message)
logging.info('New message received')
@bot.event
async def process_images_and_generate_responses(message):
logging.info("process function initialised")
global oldest_message_timestamp
global last_bot_message_id
global image_cache
reply_punctuation = False
try:
for _ in range(7): # Limit to 7 attempts
if not reply_punctuation:
channel = message.channel
messages = []
if oldest_message_timestamp is None:
oldest_message_timestamp = message.created_at
messages.append(message)
else:
async for m in channel.history(limit=200, after=oldest_message_timestamp):
messages.append(m)
logging.info(f"Number of messages it's reading: {len(messages)}")
try:
latest_message = messages[-1]
logging.info(f"latest message: {latest_message}")
latest_author_id = latest_message.author.id
except:
return
if latest_author_id == bot.user.id:
logging.info("latest author is bot")
return
while True:
discord_msgs = []
user_names = []
HUGGING_API_URL = "https://api-inference.huggingface.co/models/Salesforce/blip-image-captioning-large"
hugging_key = os.getenv("HUGGING_KEY")
hugging_headers = {"Authorization": f"Bearer {hugging_key}"}
logging.info("0 initialise")
for m in messages:
if m.attachments:
if m.id not in image_cache:
image_cache[m.id] = []
for attachment in m.attachments:
# If the image has been processed before, skip the processing part
if any(img['url'] == attachment.url for img in image_cache[m.id]):
continue
async with aiohttp.ClientSession() as session:
async with session.get(attachment.url) as response:
response.raise_for_status()
file_data = await response.read()
kind = filetype.guess(file_data)
if kind is not None:
if filetype.is_image(file_data) and kind.mime != 'image/gif':
logging.info("is image")
extension = '.' + kind.extension
with tempfile.NamedTemporaryFile(suffix=extension, delete=False) as temp_file:
temp_file.write(file_data)
temp_file.seek(0)
with open(temp_file.name, "rb") as f:
data = f.read()
for _ in range(5):
try:
response = requests.post(HUGGING_API_URL, headers=hugging_headers, data=data)
if response.json()[0]["generated_text"]:
break
except:
logging.info("An error occured while processing image, retrying...")
await asyncio.sleep(6)
output = response.json()
try:
image_cache[m.id].append({'url': attachment.url, 'output': output[0]["generated_text"]})
except:
image_cache[m.id].append({'url': attachment.url, 'output': "something random"})
logging.info([img['output'] for img in image_cache[m.id]])
else:
logging.info("is not image")
msg_content = m.content
if m.id in image_cache:
for image in image_cache[m.id]:
if msg_content:
msg_content += f' *sends an image of {image["output"]}*'
else:
msg_content += f'*sends an image of {image["output"]}*'
else:
msg_content = m.content
discord_msgs.append({
"content": msg_content,
"author": m.author.name,
"is_bot": m.author.bot
})
if m.author.name != bot.user.name:
user_names.append(m.author.name)
# Remove cache for messages that no longer exist
for msg_id in list(image_cache.keys()):
if not any(msg.id == msg_id for msg in messages):
del image_cache[msg_id]
logging.info("1 image stuff")
formatted_users = ', '.join(set(user_names))
global system_prompt
global input_segment
global output_segment
global end_loop_prompt
global prompt
config_directory = 'config'
ensure_directory_exists(config_directory)
data = read_yaml_file(os.path.join(config_directory, 'config.yaml'))
system_prompt_directory = os.path.join(config_directory, 'system_prompt')
input_segment_directory = os.path.join(config_directory, 'input_segment')
output_segment_directory = os.path.join(config_directory, 'output_segment')
end_loop_prompt_directory = os.path.join(config_directory, 'end_loop_prompt')
ensure_directory_exists(system_prompt_directory)
ensure_directory_exists(input_segment_directory)
ensure_directory_exists(output_segment_directory)
ensure_directory_exists(end_loop_prompt_directory)
system_prompt = get_first_filename(data, 'system_prompt')
input_segment = get_first_filename(data, 'input_segment')
output_segment = get_first_filename(data, 'output_segment')
end_loop_prompt = get_first_filename(data, 'end_loop_prompt')
system_prompt = read_file_contents(system_prompt_directory, system_prompt)
input_segment = read_file_contents(input_segment_directory, input_segment)
output_segment = read_file_contents(output_segment_directory, output_segment)
end_loop_prompt = read_file_contents(end_loop_prompt_directory, end_loop_prompt)
end_loop_list = end_loop_prompt.split("\n")
end_loop_list = [item for item in end_loop_list if item != '']
end_loop_prompt = random.choice(end_loop_list)
logging.info("2 prompt stuff")
prev_author = None
i = 0
while i < len(discord_msgs):
m = discord_msgs[i]
try:
last_char = m['content'][-1]
except:
last_char = '.'
if m['author'] == prev_author:
if unicodedata.category(m['content'][-1])[0] not in ('L', 'N'):
discord_msgs[i-1]['content'] += f" {m['content']}"
else:
# Lowercase first letter before concatenating
lowercase_content = m['content'][0].lower() + m['content'][1:]
discord_msgs[i-1]['content'] += f", {lowercase_content}"
del discord_msgs[i]
else:
prev_author = m['author']
i += 1
message_prompt = "\n".join([f"{m['author']}: {m['content']}" for m in discord_msgs]) + '\n'
last_bot_index = None
for i, line in enumerate(message_prompt.split('\n')):
if line.startswith(bot.user.name + ":"):
last_bot_index = i
logging.info(f"last bot index: {last_bot_index}")
if last_bot_index is not None:
# Cover next user message
prompt_lines = message_prompt.split('\n')
next_user_line_index = last_bot_index + 1
if next_user_line_index < len(prompt_lines):
prompt_lines[next_user_line_index] = '{{ ' + prompt_lines[next_user_line_index] + ' }}'
message_prompt = '\n'.join(prompt_lines)
prompt = (f"{system_prompt}{input_segment}{message_prompt}{output_segment}")
prompt = prompt.replace('{{users}}', formatted_users)
prompt = prompt.replace('{{bot}}', bot.user.name)
prompt = prompt.replace('{{endloop}}', end_loop_prompt)
# Removing the newline character
prompt = prompt.rstrip("\n")
logging.info("3 more prompt stuff")
stop_string = [f"\n{name}" for name in set(user_names)]
stop_string += [f"{bot.user.name}:"]
stop_string += ["</s>", "<|", "\n#", "\n\n\n", "<START>", "\n<", "\n{", ",", "?", "!", ";", "."]
headers = {
"X-API-KEY": Mancer_Key,
"Content-Type": "application/json",
}
new_body = get_mancer_body(prompt, stop_string)
token_count = await asyncio.to_thread(len, tokenizer.encode(prompt))
logging.info(f"token count (local): {token_count}")
if token_count <= 2410:
break
else:
if len(messages) != 0:
messages.pop(0)
oldest_message_timestamp = messages[0].created_at - timedelta(microseconds=1)
# Remove entries from image_cache that are no longer in messages
message_ids = [m.id for m in messages]
for id in list(image_cache.keys()):
if id not in message_ids:
del image_cache[id]
old_messages_object = [msg async for msg in channel.history(after=nextcord.Object(id=last_bot_message_id)) if last_bot_message_id]
old_messages = "\n".join([f"{m.author}: \n{m.content}" for m in old_messages_object]) + "\n"
if len(old_messages) == 0:
return
logging.info("4 before generating stuff")
if _ != 0:
await asyncio.sleep(0.5)
async with channel.typing():
async with aiohttp.ClientSession() as session:
async with session.post(Mancer_URL + "/v1/generate", json=new_body, headers=headers) as response:
logging.info(f"Request Body: {json.dumps(new_body, indent=2)}")
api_response = await response.json()
x_input_tokens = api_response.get('x-input-tokens')
x_output_tokens = api_response.get('x-output-tokens')
logging.info(f"Response: {json.dumps(api_response, indent=2)}")
reply_content = api_response.get('results')[0].get('text')
reply_content = str(reply_content)
# Remove stop_string if it's at the end of reply_content
stop_string_python_sucks = [s if s.endswith('\n') else s+'\n' for s in stop_string]
stop_string_python_sucks.extend(stop_string)
if reply_content:
if reply_content.endswith((",", "!", "?", ";")):
reply_punctuation = True
else:
reply_punctuation = False
for stop_string in stop_string_python_sucks:
if reply_content.endswith(stop_string):
reply_content = reply_content[:-len(stop_string)]
break
if not reply_punctuation:
# Check for new messages after the bot's last message
new_messages_object = [msg async for msg in channel.history(after=nextcord.Object(id=last_bot_message_id)) if last_bot_message_id]
new_messages = "\n".join([f"{m.author}: \n{m.content}" for m in new_messages_object]) + "\n"
try:
latest_message = new_messages_object[-1]
latest_author_id = latest_message.author.id
except:
return
else:
new_messages = old_messages
# If new messages are found, regenerate the response
if old_messages == new_messages or not reply_punctuation and not re.fullmatch(r'[!?]+', reply_content):
logging.info('old_messages == new_messages')
# Check if the response consists of max amount of max_new_tokens
if x_output_tokens == 150:
logging.info("Generated message contained too many tokens, regenerating...")
else:
bot_messages = [msg['content'] for msg in discord_msgs if msg['is_bot']]
try:
last_punctuation = reply_content[-1]
if reply_punctuation and last_punctuation == ',':
reply_content = reply_content[:-1]
sent_message = await channel.send(reply_content) # Store the sent message
if old_messages != new_messages and not reply_punctuation:
break
if not reply_punctuation:
last_bot_message_id = sent_message.id # Update the ID of the last message sent by the bot
else:
reply_content = reply_content + last_punctuation
prompt = prompt + reply_content
while True:
new_body = get_mancer_body(prompt, stop_string)
token_count = await asyncio.to_thread(len, tokenizer.encode(prompt))
logging.info(f"token count (local): {token_count}")
if token_count <= 2410:
break
else:
discord_msgs = []
if len(messages) != 0:
messages.pop(0)
logging.info(f"messages variable: {messages}")
prev_author = None
i = 0
# messages initialise
for m in messages:
if m.attachments:
msg_content = m.content
if m.id in image_cache:
for image in image_cache[m.id]:
if msg_content:
msg_content += f' *sends an image of {image["output"]}*'
else:
msg_content += f'*sends an image of {image["output"]}*'
else:
msg_content = m.content
discord_msgs.append({
"content": msg_content,
"author": m.author.name,
"is_bot": m.author.bot
})
# discord_msgs initialise
while i < len(discord_msgs):
m = discord_msgs[i]
try:
last_char = m['content'][-1]
except:
last_char = '.'
if m['author'] == prev_author:
if unicodedata.category(m['content'][-1])[0] not in ('L', 'N'):
discord_msgs[i-1]['content'] += f" {m['content']}"
else:
# Lowercase first letter before concatenating
lowercase_content = m['content'][0].lower() + m['content'][1:]
discord_msgs[i-1]['content'] += f", {lowercase_content}"
del discord_msgs[i]
else:
prev_author = m['author']
i += 1
message_prompt = "\n".join([f"{m['author']}: {m['content']}" for m in discord_msgs]) + '\n'
last_bot_index = None
for i, line in enumerate(message_prompt.split('\n')):
if line.startswith(bot.user.name + ":"):
last_bot_index = i
logging.info(f"last bot index: {last_bot_index}")
prompt = (f"{system_prompt}{input_segment}{message_prompt}{output_segment}")
prompt = prompt.replace('{{users}}', formatted_users)
prompt = prompt.replace('{{bot}}', bot.user.name)
prompt = prompt.replace('{{endloop}}', end_loop_prompt)
# Removing the newline character
prompt = prompt.rstrip("\n")
prompt = prompt + reply_content
new_body["prompt"] = prompt
if reply_content.endswith(("!", "?")) and _ >= 2:
return
except nextcord.errors.HTTPException:
pass
else:
logging.info('old_messages != new_messages')
try:
last_bot_message_id = sent_message.id
except:
pass
if not reply_punctuation:
old_messages_object = [msg async for msg in channel.history(after=nextcord.Object(id=last_bot_message_id)) if last_bot_message_id]
old_messages = "\n".join([f"{m.author}: \n{m.content}" for m in old_messages_object]) + "\n"
except Exception as error:
logging.info(str(error))
logging.info(str(traceback.format_exc()))
@bot.slash_command(name="ping", description="Check if I'm alive!")
async def ping(interaction:nextcord.Interaction):
try:
latency = round(bot.latency * 1000)
except Exception as e:
latency = "Error getting bot latency: " + str(e)
try:
r = requests.get(Mancer_URL)
mancer_ping = round(r.elapsed.total_seconds() * 1000)
except Exception as e:
mancer_ping = "Error pinging LLM API: " + str(e)
response = f"Bot Latency: {latency}ms\nLLM Latency: {mancer_ping}ms"
await interaction.response.send_message(response, ephemeral=True)
@bot.slash_command(name="showprompt", description="Check my current prompt in full!")
async def showprompt(interaction:nextcord.Interaction):
global prompt
if 'prompt' not in globals():
await interaction.response.send_message("Send a message to me first to see the prompt!", ephemeral=True)
return
html = f"<!DOCTYPE html><html><head><style>"
# CSS for dark mode
html += """
body {
background: white;
color: black;
}
@media (prefers-color-scheme: dark) {
body {
background: black;
color: white;
}
}
"""
# CSS for formatting
html += """
body {
font-family: "Helvetica Neue", Helvetica, Arial, sans-serif;
}
p {
font-size: 18px;
line-height: 1.3;
}
"""
# CSS for line numbers
html += """
p {
counter-increment: line;
padding-left: 3em;
border-left: 2px solid transparent;
position: relative;
}
p:before {
content: counter(line);
position: absolute;
left: 0;
}
"""
html += "</style></head><body>"
# Add prompt with line number spans
for i, line in enumerate(prompt.split("\n")):
if not line.strip(): # If line is empty
line = "&nbsp;"
else:
line = escape(line)
html += f"<p>{line}</p>"
html += "</body></html>"
# Generate a random filename
filename = ''.join(random.choice(string.ascii_letters + string.digits) for _ in range(10)) + '.html'
with open(filename, 'w+') as f:
f.write(html)
async with aiohttp.ClientSession() as session:
async with session.post("https://0x0.st",
data={"file": open(filename, 'rb'),
"expires": "1",
"secret": ""
}) as resp:
url = await resp.text()
os.remove(filename) # Delete the file after uploading
# Wrap url in markdown with text "View Prompt Here!"
markdown_url = f"[View Prompt Here!]({url.strip()})"
# Add line about link expiration
message = f"{markdown_url}\n\n`Link expires after an hour`"
await interaction.response.send_message(message, ephemeral=True)
@bot.event
async def on_resume():
# Re-register intents
intents = nextcord.Intents.default()
intents.message_content = True
bot.intents = intents
logging.info('Bot is resumed')
def main():
bot.loop.create_task(message_queue.process_messages())
# Run bot
token = os.environ['DISCORD_BOT_TOKEN']
bot.run(token)
if __name__ == "__main__":
main()