581 lines
27 KiB
Python
581 lines
27 KiB
Python
import os
|
|
import requests
|
|
import json
|
|
import yaml
|
|
import asyncio
|
|
import nextcord
|
|
from nextcord.ext import commands
|
|
import re
|
|
import threading
|
|
import logging
|
|
import traceback
|
|
from datetime import timedelta
|
|
import tempfile
|
|
import filetype
|
|
import aiohttp
|
|
import random
|
|
import unicodedata
|
|
import string
|
|
from html import escape
|
|
from difflib import SequenceMatcher
|
|
# self
|
|
from llama_tokenizer_lite import LlamaTokenizerLite
|
|
from run import failed_discord_bot
|
|
from get_messages import MessageProcessor
|
|
|
|
def similar(a, b):
|
|
return SequenceMatcher(None, a, b).ratio()
|
|
|
|
logging.basicConfig(level=logging.INFO, format='%(message)s')
|
|
|
|
Mancer_Key = os.getenv("MANCER_KEY")
|
|
Mancer_Model = "mytholite"
|
|
Mancer_URL = "https://neuro.mancer.tech/webui/"+Mancer_Model+"/api"
|
|
|
|
run = failed_discord_bot()
|
|
intents = run.get_intents()
|
|
bot = run.get_bot()
|
|
|
|
tokenizer = LlamaTokenizerLite()
|
|
|
|
oldest_message_timestamp = None
|
|
last_bot_message_id = None
|
|
image_cache = {}
|
|
current_task = None
|
|
|
|
# Create a lock
|
|
lock = asyncio.Lock()
|
|
|
|
def get_mancer_body(prompt, stop_string):
|
|
|
|
data = read_yaml_file(os.path.join(config_directory, 'config.yaml'))
|
|
|
|
mancer_settings_directory = os.path.join(config_directory, 'mancer_settings')
|
|
|
|
ensure_directory_exists(mancer_settings_directory)
|
|
|
|
mancer_settings = get_first_filename(data, 'mancer_settings')
|
|
|
|
mancer_settings = read_file_contents(mancer_settings_directory, mancer_settings)
|
|
data = json.loads(mancer_settings)
|
|
data['prompt'] = prompt
|
|
data['stopping_strings'] = stop_string
|
|
return data
|
|
|
|
def read_yaml_file(filename):
|
|
with open(filename, 'r') as file:
|
|
data = yaml.safe_load(file)
|
|
return data
|
|
|
|
def get_first_filename(data, field):
|
|
return data[field][0]
|
|
|
|
def read_file_contents(directory, filename):
|
|
with open(os.path.join(directory, filename), 'r') as file:
|
|
contents = file.read()
|
|
return contents
|
|
|
|
def ensure_directory_exists(directory):
|
|
if not os.path.exists(directory):
|
|
os.makedirs(directory)
|
|
|
|
config_directory = 'config'
|
|
|
|
class MessageQueue():
|
|
def __init__(self):
|
|
self.queue = asyncio.Queue()
|
|
|
|
async def add_message(self, message):
|
|
await self.queue.put(message)
|
|
|
|
async def process_messages(self):
|
|
while True:
|
|
message = await self.queue.get()
|
|
try:
|
|
await process_images_and_generate_responses(message)
|
|
except Exception as e:
|
|
logging.exception("Failed to process message", exc_info=e)
|
|
finally:
|
|
self.queue.task_done()
|
|
|
|
|
|
message_queue = MessageQueue()
|
|
|
|
@bot.event
|
|
async def on_ready():
|
|
for guild in bot.guilds:
|
|
for channel in guild.text_channels:
|
|
if channel.permissions_for(guild.me).send_messages:
|
|
data = read_yaml_file(os.path.join(config_directory, 'config.yaml'))
|
|
|
|
boot_prompt_directory = os.path.join(config_directory, 'boot_prompt')
|
|
initial_prompt_directory = os.path.join(config_directory, 'initial_prompt')
|
|
|
|
ensure_directory_exists(boot_prompt_directory)
|
|
ensure_directory_exists(initial_prompt_directory)
|
|
|
|
boot_prompt = get_first_filename(data, 'boot_prompt')
|
|
initial_prompt = get_first_filename(data, 'initial_prompt')
|
|
|
|
boot_prompt = read_file_contents(boot_prompt_directory, boot_prompt)
|
|
initial_prompt = read_file_contents(initial_prompt_directory, initial_prompt)
|
|
await channel.send(boot_prompt)
|
|
sent_message = await channel.send(initial_prompt)
|
|
global oldest_message_timestamp, last_bot_message_id
|
|
oldest_message_timestamp = sent_message.created_at - timedelta(microseconds=1)
|
|
last_bot_message_id = sent_message.id
|
|
await message_queue.add_message(sent_message)
|
|
break
|
|
logging.info('Bot is ready')
|
|
|
|
@bot.event
|
|
async def on_message(message):
|
|
await message_queue.add_message(message)
|
|
logging.info('New message received')
|
|
|
|
@bot.event
|
|
async def process_images_and_generate_responses(message):
|
|
logging.info("process function initialised")
|
|
global oldest_message_timestamp
|
|
global last_bot_message_id
|
|
global image_cache
|
|
reply_punctuation = False
|
|
|
|
try:
|
|
for _ in range(7): # Limit to 7 attempts
|
|
logging.info(f"Number of _: {_}")
|
|
if not reply_punctuation:
|
|
channel = message.channel
|
|
messages = []
|
|
if oldest_message_timestamp is None:
|
|
oldest_message_timestamp = message.created_at
|
|
messages.append(message)
|
|
else:
|
|
async for m in channel.history(limit=200, after=oldest_message_timestamp):
|
|
messages.append(m)
|
|
logging.info(f"Number of messages it's reading: {len(messages)}")
|
|
try:
|
|
latest_message = messages[-1]
|
|
logging.info(f"latest message: {latest_message}")
|
|
latest_author_id = latest_message.author.id
|
|
except:
|
|
return
|
|
|
|
if latest_author_id == bot.user.id:
|
|
logging.info("latest author is bot")
|
|
return
|
|
|
|
while True:
|
|
# Process the messages
|
|
global discord_msgs, formatted_users, user_names
|
|
discord_msgs = []
|
|
if 'prompt' not in globals():
|
|
message_processor = MessageProcessor(bot=bot)
|
|
else:
|
|
message_processor = MessageProcessor(
|
|
bot=bot,
|
|
formatted_users=formatted_users,
|
|
image_cache=image_cache,
|
|
discord_msgs=discord_msgs,
|
|
user_names=user_names
|
|
)
|
|
logging.info("0 image stuff")
|
|
await message_processor.process_messages(messages)
|
|
|
|
# Access the processed messages and image cache
|
|
discord_msgs = message_processor.get_discord_msgs()
|
|
image_cache = message_processor.get_image_cache()
|
|
formatted_users = message_processor.get_formatted_users()
|
|
user_names = message_processor.get_user_names()
|
|
|
|
global system_prompt
|
|
global input_segment
|
|
global output_segment
|
|
global end_loop_prompt
|
|
global prompt
|
|
|
|
config_directory = 'config'
|
|
ensure_directory_exists(config_directory)
|
|
|
|
data = read_yaml_file(os.path.join(config_directory, 'config.yaml'))
|
|
|
|
system_prompt_directory = os.path.join(config_directory, 'system_prompt')
|
|
input_segment_directory = os.path.join(config_directory, 'input_segment')
|
|
output_segment_directory = os.path.join(config_directory, 'output_segment')
|
|
input_sequence_directory = os.path.join(config_directory, 'input_sequence')
|
|
output_sequence_directory = os.path.join(config_directory, 'output_sequence')
|
|
end_loop_prompt_directory = os.path.join(config_directory, 'end_loop_prompt')
|
|
|
|
ensure_directory_exists(system_prompt_directory)
|
|
ensure_directory_exists(input_segment_directory)
|
|
ensure_directory_exists(output_segment_directory)
|
|
ensure_directory_exists(input_sequence_directory)
|
|
ensure_directory_exists(output_sequence_directory)
|
|
ensure_directory_exists(end_loop_prompt_directory)
|
|
|
|
system_prompt = get_first_filename(data, 'system_prompt')
|
|
input_segment = get_first_filename(data, 'input_segment')
|
|
output_segment = get_first_filename(data, 'output_segment')
|
|
input_sequence = get_first_filename(data, 'input_sequence')
|
|
output_sequence = get_first_filename(data, 'output_sequence')
|
|
end_loop_prompt = get_first_filename(data, 'end_loop_prompt')
|
|
|
|
system_prompt = read_file_contents(system_prompt_directory, system_prompt)
|
|
input_segment = read_file_contents(input_segment_directory, input_segment)
|
|
output_segment = read_file_contents(output_segment_directory, output_segment)
|
|
input_sequence = read_file_contents(input_sequence_directory, input_sequence)
|
|
output_sequence = read_file_contents(output_sequence_directory, output_sequence)
|
|
end_loop_prompt = read_file_contents(end_loop_prompt_directory, end_loop_prompt)
|
|
end_loop_list = end_loop_prompt.split("\n")
|
|
end_loop_list = [item for item in end_loop_list if item != '']
|
|
end_loop_prompt = random.choice(end_loop_list)
|
|
logging.info("2 prompt stuff")
|
|
|
|
prev_author = None
|
|
i = 0
|
|
|
|
while i < len(discord_msgs):
|
|
m = discord_msgs[i]
|
|
logging.info(f"discord_msgs: {discord_msgs}")
|
|
try:
|
|
last_char = m['content'][-1]
|
|
except:
|
|
last_char = '.'
|
|
if m['author'] == prev_author:
|
|
if unicodedata.category(m['content'][-1])[0] not in ('L', 'N'):
|
|
discord_msgs[i-1]['content'] += f" {m['content']}"
|
|
else:
|
|
# Lowercase first letter before concatenating
|
|
lowercase_content = m['content'][0].lower() + m['content'][1:]
|
|
discord_msgs[i-1]['content'] += f", {lowercase_content}"
|
|
|
|
del discord_msgs[i]
|
|
else:
|
|
prev_author = m['author']
|
|
i += 1
|
|
|
|
message_prompt = ""
|
|
append_input_seq = True
|
|
for m in discord_msgs:
|
|
if m['author'] == bot.user.name:
|
|
if not append_input_seq:
|
|
append_input_seq = True
|
|
message_prompt += output_sequence
|
|
else:
|
|
if append_input_seq:
|
|
message_prompt += input_sequence
|
|
append_input_seq = False
|
|
message_prompt += f"{m['author']}: {m['content']}\n"
|
|
|
|
last_bot_index = None
|
|
for i, line in enumerate(message_prompt.split('\n')):
|
|
if line.startswith(bot.user.name + ":"):
|
|
last_bot_index = i
|
|
logging.info(f"last bot index: {last_bot_index}")
|
|
|
|
prompt = (f"{system_prompt}{input_segment}{message_prompt}{output_segment}")
|
|
prompt = prompt.replace('{{users}}', formatted_users)
|
|
prompt = prompt.replace('{{bot}}', bot.user.name)
|
|
prompt = prompt.replace('{{endloop}}', end_loop_prompt)
|
|
# Removing the newline character
|
|
prompt = prompt.rstrip("\n")
|
|
logging.info("3 more prompt stuff")
|
|
|
|
stop_string = [f"\n{name}" for name in set(user_names)]
|
|
stop_string += [f"{bot.user.name}:"]
|
|
stop_string += ["</s>", "<|", "\n#", "\n*", "\n\n\n", "<START>", ",", "?", "!", ";", "."]
|
|
|
|
|
|
headers = {
|
|
"X-API-KEY": Mancer_Key,
|
|
"Content-Type": "application/json",
|
|
}
|
|
|
|
new_body = get_mancer_body(prompt, stop_string)
|
|
|
|
token_count = await asyncio.to_thread(len, tokenizer.encode(prompt))
|
|
logging.info(f"token count (local): {token_count}")
|
|
if token_count <= 2410:
|
|
break
|
|
else:
|
|
logging.info(f"The too many messages: {messages}")
|
|
if len(messages) != 0:
|
|
messages.pop(0)
|
|
|
|
oldest_message_timestamp = messages[0].created_at - timedelta(microseconds=1)
|
|
old_messages_object = [msg async for msg in channel.history(after=nextcord.Object(id=last_bot_message_id)) if last_bot_message_id]
|
|
old_messages = "\n".join([f"{m.author}: \n{m.content}" for m in old_messages_object]) + "\n"
|
|
if len(old_messages) == 0:
|
|
return
|
|
|
|
logging.info("4 before generating stuff")
|
|
if _ != 0:
|
|
await asyncio.sleep(0.5)
|
|
async with channel.typing():
|
|
async with aiohttp.ClientSession() as session:
|
|
async with session.post(Mancer_URL + "/v1/generate", json=new_body, headers=headers) as response:
|
|
logging.info(f"Request Body: {json.dumps(new_body, indent=2)}")
|
|
api_response = await response.json()
|
|
x_input_tokens = api_response.get('x-input-tokens')
|
|
x_output_tokens = api_response.get('x-output-tokens')
|
|
logging.info(f"Response: {json.dumps(api_response, indent=2)}")
|
|
reply_content = api_response.get('results')[0].get('text')
|
|
reply_content = str(reply_content)
|
|
# Remove stop_string if it's at the end of reply_content
|
|
stop_string_python_sucks = [s if s.endswith('\n') else s+'\n' for s in stop_string]
|
|
stop_string_python_sucks.extend(stop_string)
|
|
|
|
if reply_content:
|
|
if reply_content.endswith((",", "!", "?", ";")):
|
|
reply_punctuation = True
|
|
else:
|
|
reply_punctuation = False
|
|
for stop_string in stop_string_python_sucks:
|
|
if reply_content.endswith(stop_string):
|
|
reply_content = reply_content[:-len(stop_string)]
|
|
break
|
|
if not reply_punctuation:
|
|
# Check for new messages after the bot's last message
|
|
new_messages_object = [msg async for msg in channel.history(after=nextcord.Object(id=last_bot_message_id)) if last_bot_message_id]
|
|
new_messages = "\n".join([f"{m.author}: \n{m.content}" for m in new_messages_object]) + "\n"
|
|
try:
|
|
latest_message = new_messages_object[-1]
|
|
latest_author_id = latest_message.author.id
|
|
except:
|
|
return
|
|
else:
|
|
new_messages = old_messages
|
|
|
|
# If new messages are found, regenerate the response
|
|
if old_messages == new_messages or not reply_punctuation and not re.fullmatch(r'[!?]+', reply_content):
|
|
logging.info('old_messages == new_messages')
|
|
# Check if the response consists of max amount of max_new_tokens
|
|
if x_output_tokens == 150:
|
|
logging.info("Generated message contained too many tokens, regenerating...")
|
|
else:
|
|
bot_messages = [msg['content'] for msg in discord_msgs if msg['is_bot']]
|
|
try:
|
|
last_punctuation = reply_content[-1]
|
|
if reply_punctuation and last_punctuation == ',':
|
|
reply_content = reply_content[:-1]
|
|
sent_message = await channel.send(reply_content) # Store the sent message
|
|
if old_messages != new_messages and not reply_punctuation:
|
|
break
|
|
if not reply_punctuation:
|
|
last_bot_message_id = sent_message.id # Update the ID of the last message sent by the bot
|
|
else:
|
|
reply_content = reply_content + last_punctuation
|
|
prompt = prompt + reply_content
|
|
while True:
|
|
new_body = get_mancer_body(prompt, stop_string)
|
|
token_count = await asyncio.to_thread(len, tokenizer.encode(prompt))
|
|
logging.info(f"token count (local): {token_count}")
|
|
if token_count <= 2410:
|
|
break
|
|
else:
|
|
discord_msgs = []
|
|
if len(messages) != 0:
|
|
messages.pop(0)
|
|
logging.info(f"messages variable: {messages}")
|
|
prev_author = None
|
|
i = 0
|
|
# messages initialise
|
|
for m in messages:
|
|
if m.attachments:
|
|
msg_content = m.content
|
|
logging.info(f"please tell me there is any image cache in here {image_cache}")
|
|
if m.id in image_cache:
|
|
for image in image_cache[m.id]:
|
|
if msg_content:
|
|
msg_content += f' *sends an image of {image["output"]}*'
|
|
else:
|
|
msg_content += f'*sends an image of {image["output"]}*'
|
|
else:
|
|
logging.info(f"no image for {m.content}")
|
|
msg_content = m.content
|
|
|
|
discord_msgs.append({
|
|
"content": msg_content,
|
|
"author": m.author.name,
|
|
"is_bot": m.author.bot
|
|
})
|
|
# discord_msgs initialise
|
|
while i < len(discord_msgs):
|
|
m = discord_msgs[i]
|
|
try:
|
|
last_char = m['content'][-1]
|
|
except:
|
|
last_char = '.'
|
|
if m['author'] == prev_author:
|
|
if unicodedata.category(m['content'][-1])[0] not in ('L', 'N'):
|
|
discord_msgs[i-1]['content'] += f" {m['content']}"
|
|
else:
|
|
# Lowercase first letter before concatenating
|
|
lowercase_content = m['content'][0].lower() + m['content'][1:]
|
|
discord_msgs[i-1]['content'] += f", {lowercase_content}"
|
|
|
|
del discord_msgs[i]
|
|
else:
|
|
prev_author = m['author']
|
|
i += 1
|
|
|
|
message_prompt = "\n".join([f"{m['author']}: {m['content']}" for m in discord_msgs]) + '\n'
|
|
last_bot_index = None
|
|
for i, line in enumerate(message_prompt.split('\n')):
|
|
if line.startswith(bot.user.name + ":"):
|
|
last_bot_index = i
|
|
logging.info(f"message_prompt: {message_prompt}")
|
|
logging.info(f"bot.user.name: {bot.user.name}")
|
|
logging.info(f"last bot index: {last_bot_index}")
|
|
prompt = (f"{system_prompt}{input_segment}{message_prompt}{output_segment}")
|
|
prompt = prompt.replace('{{users}}', formatted_users)
|
|
prompt = prompt.replace('{{bot}}', bot.user.name)
|
|
prompt = prompt.replace('{{endloop}}', end_loop_prompt)
|
|
# Removing the newline character
|
|
prompt = prompt.rstrip("\n")
|
|
prompt = prompt + reply_content
|
|
new_body["prompt"] = prompt
|
|
if reply_content.endswith(("!", "?")) and _ >= 2:
|
|
return
|
|
except nextcord.errors.HTTPException:
|
|
pass
|
|
else:
|
|
logging.info('old_messages != new_messages')
|
|
try:
|
|
last_bot_message_id = sent_message.id
|
|
except:
|
|
pass
|
|
if not reply_punctuation:
|
|
old_messages_object = [msg async for msg in channel.history(after=nextcord.Object(id=last_bot_message_id)) if last_bot_message_id]
|
|
old_messages = "\n".join([f"{m.author}: \n{m.content}" for m in old_messages_object]) + "\n"
|
|
except Exception as error:
|
|
logging.info(str(error))
|
|
logging.info(str(traceback.format_exc()))
|
|
|
|
@bot.slash_command(name="ping", description="Check if I'm alive!")
|
|
async def ping(interaction:nextcord.Interaction):
|
|
|
|
try:
|
|
latency = round(bot.latency * 1000)
|
|
except Exception as e:
|
|
latency = "Error getting bot latency: " + str(e)
|
|
|
|
try:
|
|
r = requests.get(Mancer_URL)
|
|
mancer_ping = round(r.elapsed.total_seconds() * 1000)
|
|
except Exception as e:
|
|
mancer_ping = "Error pinging LLM API: " + str(e)
|
|
|
|
response = f"Bot Latency: {latency}ms\nLLM Latency: {mancer_ping}ms"
|
|
|
|
await interaction.response.send_message(response, ephemeral=True)
|
|
|
|
|
|
@bot.slash_command(name="showprompt", description="Check my current prompt in full!")
|
|
async def showprompt(interaction:nextcord.Interaction):
|
|
global prompt
|
|
if 'prompt' not in globals():
|
|
await interaction.response.send_message("Send a message to me first to see the prompt!", ephemeral=True)
|
|
return
|
|
|
|
html = f"<!DOCTYPE html><html><head><style>"
|
|
|
|
# CSS for dark mode
|
|
html += """
|
|
body {
|
|
background: white;
|
|
color: black;
|
|
}
|
|
@media (prefers-color-scheme: dark) {
|
|
body {
|
|
background: black;
|
|
color: white;
|
|
}
|
|
}
|
|
"""
|
|
|
|
# CSS for formatting
|
|
html += """
|
|
body {
|
|
font-family: "Helvetica Neue", Helvetica, Arial, sans-serif;
|
|
}
|
|
|
|
p {
|
|
font-size: 18px;
|
|
line-height: 1.3;
|
|
}
|
|
"""
|
|
|
|
# CSS for line numbers
|
|
html += """
|
|
p {
|
|
counter-increment: line;
|
|
padding-left: 3em;
|
|
border-left: 2px solid transparent;
|
|
position: relative;
|
|
}
|
|
p:before {
|
|
content: counter(line);
|
|
position: absolute;
|
|
left: 0;
|
|
}
|
|
"""
|
|
|
|
html += "</style></head><body>"
|
|
|
|
# Add prompt with line number spans
|
|
for i, line in enumerate(prompt.split("\n")):
|
|
if not line.strip(): # If line is empty
|
|
line = " "
|
|
else:
|
|
line = escape(line)
|
|
html += f"<p>{line}</p>"
|
|
|
|
html += "</body></html>"
|
|
|
|
# Generate a random filename
|
|
filename = ''.join(random.choice(string.ascii_letters + string.digits) for _ in range(10)) + '.html'
|
|
|
|
with open(filename, 'w+') as f:
|
|
f.write(html)
|
|
|
|
async with aiohttp.ClientSession() as session:
|
|
async with session.post("https://0x0.st",
|
|
data={"file": open(filename, 'rb'),
|
|
"expires": "1",
|
|
"secret": ""
|
|
}) as resp:
|
|
url = await resp.text()
|
|
|
|
os.remove(filename) # Delete the file after uploading
|
|
|
|
# Wrap url in markdown with text "View Prompt Here!"
|
|
markdown_url = f"[View Prompt Here!]({url.strip()})"
|
|
|
|
# Add line about link expiration
|
|
message = f"{markdown_url}\n\n`Link expires after an hour`"
|
|
|
|
await interaction.response.send_message(message, ephemeral=True)
|
|
|
|
@bot.event
|
|
async def on_resume():
|
|
|
|
# Re-register intents
|
|
intents = nextcord.Intents.default()
|
|
intents.message_content = True
|
|
bot.intents = intents
|
|
|
|
logging.info('Bot is resumed')
|
|
|
|
|
|
def main():
|
|
|
|
bot.loop.create_task(message_queue.process_messages())
|
|
|
|
# Run bot
|
|
token = os.environ['DISCORD_BOT_TOKEN']
|
|
bot.run(token)
|
|
|
|
if __name__ == "__main__":
|
|
main()
|