failed_discord_bot/bot.py

581 lines
27 KiB
Python

import os
import requests
import json
import yaml
import asyncio
import nextcord
from nextcord.ext import commands
import re
import threading
import logging
import traceback
from datetime import timedelta
import tempfile
import filetype
import aiohttp
import random
import unicodedata
import string
from html import escape
from difflib import SequenceMatcher
# self
from llama_tokenizer_lite import LlamaTokenizerLite
from run import failed_discord_bot
from get_messages import MessageProcessor
def similar(a, b):
return SequenceMatcher(None, a, b).ratio()
logging.basicConfig(level=logging.INFO, format='%(message)s')
Mancer_Key = os.getenv("MANCER_KEY")
Mancer_Model = "mytholite"
Mancer_URL = "https://neuro.mancer.tech/webui/"+Mancer_Model+"/api"
run = failed_discord_bot()
intents = run.get_intents()
bot = run.get_bot()
tokenizer = LlamaTokenizerLite()
oldest_message_timestamp = None
last_bot_message_id = None
image_cache = {}
current_task = None
# Create a lock
lock = asyncio.Lock()
def get_mancer_body(prompt, stop_string):
data = read_yaml_file(os.path.join(config_directory, 'config.yaml'))
mancer_settings_directory = os.path.join(config_directory, 'mancer_settings')
ensure_directory_exists(mancer_settings_directory)
mancer_settings = get_first_filename(data, 'mancer_settings')
mancer_settings = read_file_contents(mancer_settings_directory, mancer_settings)
data = json.loads(mancer_settings)
data['prompt'] = prompt
data['stopping_strings'] = stop_string
return data
def read_yaml_file(filename):
with open(filename, 'r') as file:
data = yaml.safe_load(file)
return data
def get_first_filename(data, field):
return data[field][0]
def read_file_contents(directory, filename):
with open(os.path.join(directory, filename), 'r') as file:
contents = file.read()
return contents
def ensure_directory_exists(directory):
if not os.path.exists(directory):
os.makedirs(directory)
config_directory = 'config'
class MessageQueue():
def __init__(self):
self.queue = asyncio.Queue()
async def add_message(self, message):
await self.queue.put(message)
async def process_messages(self):
while True:
message = await self.queue.get()
try:
await process_images_and_generate_responses(message)
except Exception as e:
logging.exception("Failed to process message", exc_info=e)
finally:
self.queue.task_done()
message_queue = MessageQueue()
@bot.event
async def on_ready():
for guild in bot.guilds:
for channel in guild.text_channels:
if channel.permissions_for(guild.me).send_messages:
data = read_yaml_file(os.path.join(config_directory, 'config.yaml'))
boot_prompt_directory = os.path.join(config_directory, 'boot_prompt')
initial_prompt_directory = os.path.join(config_directory, 'initial_prompt')
ensure_directory_exists(boot_prompt_directory)
ensure_directory_exists(initial_prompt_directory)
boot_prompt = get_first_filename(data, 'boot_prompt')
initial_prompt = get_first_filename(data, 'initial_prompt')
boot_prompt = read_file_contents(boot_prompt_directory, boot_prompt)
initial_prompt = read_file_contents(initial_prompt_directory, initial_prompt)
await channel.send(boot_prompt)
sent_message = await channel.send(initial_prompt)
global oldest_message_timestamp, last_bot_message_id
oldest_message_timestamp = sent_message.created_at - timedelta(microseconds=1)
last_bot_message_id = sent_message.id
await message_queue.add_message(sent_message)
break
logging.info('Bot is ready')
@bot.event
async def on_message(message):
await message_queue.add_message(message)
logging.info('New message received')
@bot.event
async def process_images_and_generate_responses(message):
logging.info("process function initialised")
global oldest_message_timestamp
global last_bot_message_id
global image_cache
reply_punctuation = False
try:
for _ in range(7): # Limit to 7 attempts
logging.info(f"Number of _: {_}")
if not reply_punctuation:
channel = message.channel
messages = []
if oldest_message_timestamp is None:
oldest_message_timestamp = message.created_at
messages.append(message)
else:
async for m in channel.history(limit=200, after=oldest_message_timestamp):
messages.append(m)
logging.info(f"Number of messages it's reading: {len(messages)}")
try:
latest_message = messages[-1]
logging.info(f"latest message: {latest_message}")
latest_author_id = latest_message.author.id
except:
return
if latest_author_id == bot.user.id:
logging.info("latest author is bot")
return
while True:
# Process the messages
global discord_msgs, formatted_users, user_names
discord_msgs = []
if 'prompt' not in globals():
message_processor = MessageProcessor(bot=bot)
else:
message_processor = MessageProcessor(
bot=bot,
formatted_users=formatted_users,
image_cache=image_cache,
discord_msgs=discord_msgs,
user_names=user_names
)
logging.info("0 image stuff")
await message_processor.process_messages(messages)
# Access the processed messages and image cache
discord_msgs = message_processor.get_discord_msgs()
image_cache = message_processor.get_image_cache()
formatted_users = message_processor.get_formatted_users()
user_names = message_processor.get_user_names()
global system_prompt
global input_segment
global output_segment
global end_loop_prompt
global prompt
config_directory = 'config'
ensure_directory_exists(config_directory)
data = read_yaml_file(os.path.join(config_directory, 'config.yaml'))
system_prompt_directory = os.path.join(config_directory, 'system_prompt')
input_segment_directory = os.path.join(config_directory, 'input_segment')
output_segment_directory = os.path.join(config_directory, 'output_segment')
input_sequence_directory = os.path.join(config_directory, 'input_sequence')
output_sequence_directory = os.path.join(config_directory, 'output_sequence')
end_loop_prompt_directory = os.path.join(config_directory, 'end_loop_prompt')
ensure_directory_exists(system_prompt_directory)
ensure_directory_exists(input_segment_directory)
ensure_directory_exists(output_segment_directory)
ensure_directory_exists(input_sequence_directory)
ensure_directory_exists(output_sequence_directory)
ensure_directory_exists(end_loop_prompt_directory)
system_prompt = get_first_filename(data, 'system_prompt')
input_segment = get_first_filename(data, 'input_segment')
output_segment = get_first_filename(data, 'output_segment')
input_sequence = get_first_filename(data, 'input_sequence')
output_sequence = get_first_filename(data, 'output_sequence')
end_loop_prompt = get_first_filename(data, 'end_loop_prompt')
system_prompt = read_file_contents(system_prompt_directory, system_prompt)
input_segment = read_file_contents(input_segment_directory, input_segment)
output_segment = read_file_contents(output_segment_directory, output_segment)
input_sequence = read_file_contents(input_sequence_directory, input_sequence)
output_sequence = read_file_contents(output_sequence_directory, output_sequence)
end_loop_prompt = read_file_contents(end_loop_prompt_directory, end_loop_prompt)
end_loop_list = end_loop_prompt.split("\n")
end_loop_list = [item for item in end_loop_list if item != '']
end_loop_prompt = random.choice(end_loop_list)
logging.info("2 prompt stuff")
prev_author = None
i = 0
while i < len(discord_msgs):
m = discord_msgs[i]
logging.info(f"discord_msgs: {discord_msgs}")
try:
last_char = m['content'][-1]
except:
last_char = '.'
if m['author'] == prev_author:
if unicodedata.category(m['content'][-1])[0] not in ('L', 'N'):
discord_msgs[i-1]['content'] += f" {m['content']}"
else:
# Lowercase first letter before concatenating
lowercase_content = m['content'][0].lower() + m['content'][1:]
discord_msgs[i-1]['content'] += f", {lowercase_content}"
del discord_msgs[i]
else:
prev_author = m['author']
i += 1
message_prompt = ""
append_input_seq = True
for m in discord_msgs:
if m['author'] == bot.user.name:
if not append_input_seq:
append_input_seq = True
message_prompt += output_sequence
else:
if append_input_seq:
message_prompt += input_sequence
append_input_seq = False
message_prompt += f"{m['author']}: {m['content']}\n"
last_bot_index = None
for i, line in enumerate(message_prompt.split('\n')):
if line.startswith(bot.user.name + ":"):
last_bot_index = i
logging.info(f"last bot index: {last_bot_index}")
prompt = (f"{system_prompt}{input_segment}{message_prompt}{output_segment}")
prompt = prompt.replace('{{users}}', formatted_users)
prompt = prompt.replace('{{bot}}', bot.user.name)
prompt = prompt.replace('{{endloop}}', end_loop_prompt)
# Removing the newline character
prompt = prompt.rstrip("\n")
logging.info("3 more prompt stuff")
stop_string = [f"\n{name}" for name in set(user_names)]
stop_string += [f"{bot.user.name}:"]
stop_string += ["</s>", "<|", "\n#", "\n*", "\n\n\n", "<START>", ",", "?", "!", ";", "."]
headers = {
"X-API-KEY": Mancer_Key,
"Content-Type": "application/json",
}
new_body = get_mancer_body(prompt, stop_string)
token_count = await asyncio.to_thread(len, tokenizer.encode(prompt))
logging.info(f"token count (local): {token_count}")
if token_count <= 2410:
break
else:
logging.info(f"The too many messages: {messages}")
if len(messages) != 0:
messages.pop(0)
oldest_message_timestamp = messages[0].created_at - timedelta(microseconds=1)
old_messages_object = [msg async for msg in channel.history(after=nextcord.Object(id=last_bot_message_id)) if last_bot_message_id]
old_messages = "\n".join([f"{m.author}: \n{m.content}" for m in old_messages_object]) + "\n"
if len(old_messages) == 0:
return
logging.info("4 before generating stuff")
if _ != 0:
await asyncio.sleep(0.5)
async with channel.typing():
async with aiohttp.ClientSession() as session:
async with session.post(Mancer_URL + "/v1/generate", json=new_body, headers=headers) as response:
logging.info(f"Request Body: {json.dumps(new_body, indent=2)}")
api_response = await response.json()
x_input_tokens = api_response.get('x-input-tokens')
x_output_tokens = api_response.get('x-output-tokens')
logging.info(f"Response: {json.dumps(api_response, indent=2)}")
reply_content = api_response.get('results')[0].get('text')
reply_content = str(reply_content)
# Remove stop_string if it's at the end of reply_content
stop_string_python_sucks = [s if s.endswith('\n') else s+'\n' for s in stop_string]
stop_string_python_sucks.extend(stop_string)
if reply_content:
if reply_content.endswith((",", "!", "?", ";")):
reply_punctuation = True
else:
reply_punctuation = False
for stop_string in stop_string_python_sucks:
if reply_content.endswith(stop_string):
reply_content = reply_content[:-len(stop_string)]
break
if not reply_punctuation:
# Check for new messages after the bot's last message
new_messages_object = [msg async for msg in channel.history(after=nextcord.Object(id=last_bot_message_id)) if last_bot_message_id]
new_messages = "\n".join([f"{m.author}: \n{m.content}" for m in new_messages_object]) + "\n"
try:
latest_message = new_messages_object[-1]
latest_author_id = latest_message.author.id
except:
return
else:
new_messages = old_messages
# If new messages are found, regenerate the response
if old_messages == new_messages or not reply_punctuation and not re.fullmatch(r'[!?]+', reply_content):
logging.info('old_messages == new_messages')
# Check if the response consists of max amount of max_new_tokens
if x_output_tokens == 150:
logging.info("Generated message contained too many tokens, regenerating...")
else:
bot_messages = [msg['content'] for msg in discord_msgs if msg['is_bot']]
try:
last_punctuation = reply_content[-1]
if reply_punctuation and last_punctuation == ',':
reply_content = reply_content[:-1]
sent_message = await channel.send(reply_content) # Store the sent message
if old_messages != new_messages and not reply_punctuation:
break
if not reply_punctuation:
last_bot_message_id = sent_message.id # Update the ID of the last message sent by the bot
else:
reply_content = reply_content + last_punctuation
prompt = prompt + reply_content
while True:
new_body = get_mancer_body(prompt, stop_string)
token_count = await asyncio.to_thread(len, tokenizer.encode(prompt))
logging.info(f"token count (local): {token_count}")
if token_count <= 2410:
break
else:
discord_msgs = []
if len(messages) != 0:
messages.pop(0)
logging.info(f"messages variable: {messages}")
prev_author = None
i = 0
# messages initialise
for m in messages:
if m.attachments:
msg_content = m.content
logging.info(f"please tell me there is any image cache in here {image_cache}")
if m.id in image_cache:
for image in image_cache[m.id]:
if msg_content:
msg_content += f' *sends an image of {image["output"]}*'
else:
msg_content += f'*sends an image of {image["output"]}*'
else:
logging.info(f"no image for {m.content}")
msg_content = m.content
discord_msgs.append({
"content": msg_content,
"author": m.author.name,
"is_bot": m.author.bot
})
# discord_msgs initialise
while i < len(discord_msgs):
m = discord_msgs[i]
try:
last_char = m['content'][-1]
except:
last_char = '.'
if m['author'] == prev_author:
if unicodedata.category(m['content'][-1])[0] not in ('L', 'N'):
discord_msgs[i-1]['content'] += f" {m['content']}"
else:
# Lowercase first letter before concatenating
lowercase_content = m['content'][0].lower() + m['content'][1:]
discord_msgs[i-1]['content'] += f", {lowercase_content}"
del discord_msgs[i]
else:
prev_author = m['author']
i += 1
message_prompt = "\n".join([f"{m['author']}: {m['content']}" for m in discord_msgs]) + '\n'
last_bot_index = None
for i, line in enumerate(message_prompt.split('\n')):
if line.startswith(bot.user.name + ":"):
last_bot_index = i
logging.info(f"message_prompt: {message_prompt}")
logging.info(f"bot.user.name: {bot.user.name}")
logging.info(f"last bot index: {last_bot_index}")
prompt = (f"{system_prompt}{input_segment}{message_prompt}{output_segment}")
prompt = prompt.replace('{{users}}', formatted_users)
prompt = prompt.replace('{{bot}}', bot.user.name)
prompt = prompt.replace('{{endloop}}', end_loop_prompt)
# Removing the newline character
prompt = prompt.rstrip("\n")
prompt = prompt + reply_content
new_body["prompt"] = prompt
if reply_content.endswith(("!", "?")) and _ >= 2:
return
except nextcord.errors.HTTPException:
pass
else:
logging.info('old_messages != new_messages')
try:
last_bot_message_id = sent_message.id
except:
pass
if not reply_punctuation:
old_messages_object = [msg async for msg in channel.history(after=nextcord.Object(id=last_bot_message_id)) if last_bot_message_id]
old_messages = "\n".join([f"{m.author}: \n{m.content}" for m in old_messages_object]) + "\n"
except Exception as error:
logging.info(str(error))
logging.info(str(traceback.format_exc()))
@bot.slash_command(name="ping", description="Check if I'm alive!")
async def ping(interaction:nextcord.Interaction):
try:
latency = round(bot.latency * 1000)
except Exception as e:
latency = "Error getting bot latency: " + str(e)
try:
r = requests.get(Mancer_URL)
mancer_ping = round(r.elapsed.total_seconds() * 1000)
except Exception as e:
mancer_ping = "Error pinging LLM API: " + str(e)
response = f"Bot Latency: {latency}ms\nLLM Latency: {mancer_ping}ms"
await interaction.response.send_message(response, ephemeral=True)
@bot.slash_command(name="showprompt", description="Check my current prompt in full!")
async def showprompt(interaction:nextcord.Interaction):
global prompt
if 'prompt' not in globals():
await interaction.response.send_message("Send a message to me first to see the prompt!", ephemeral=True)
return
html = f"<!DOCTYPE html><html><head><style>"
# CSS for dark mode
html += """
body {
background: white;
color: black;
}
@media (prefers-color-scheme: dark) {
body {
background: black;
color: white;
}
}
"""
# CSS for formatting
html += """
body {
font-family: "Helvetica Neue", Helvetica, Arial, sans-serif;
}
p {
font-size: 18px;
line-height: 1.3;
}
"""
# CSS for line numbers
html += """
p {
counter-increment: line;
padding-left: 3em;
border-left: 2px solid transparent;
position: relative;
}
p:before {
content: counter(line);
position: absolute;
left: 0;
}
"""
html += "</style></head><body>"
# Add prompt with line number spans
for i, line in enumerate(prompt.split("\n")):
if not line.strip(): # If line is empty
line = "&nbsp;"
else:
line = escape(line)
html += f"<p>{line}</p>"
html += "</body></html>"
# Generate a random filename
filename = ''.join(random.choice(string.ascii_letters + string.digits) for _ in range(10)) + '.html'
with open(filename, 'w+') as f:
f.write(html)
async with aiohttp.ClientSession() as session:
async with session.post("https://0x0.st",
data={"file": open(filename, 'rb'),
"expires": "1",
"secret": ""
}) as resp:
url = await resp.text()
os.remove(filename) # Delete the file after uploading
# Wrap url in markdown with text "View Prompt Here!"
markdown_url = f"[View Prompt Here!]({url.strip()})"
# Add line about link expiration
message = f"{markdown_url}\n\n`Link expires after an hour`"
await interaction.response.send_message(message, ephemeral=True)
@bot.event
async def on_resume():
# Re-register intents
intents = nextcord.Intents.default()
intents.message_content = True
bot.intents = intents
logging.info('Bot is resumed')
def main():
bot.loop.create_task(message_queue.process_messages())
# Run bot
token = os.environ['DISCORD_BOT_TOKEN']
bot.run(token)
if __name__ == "__main__":
main()