Skip to content

Commit

Permalink
Rewrite the program structure (remove bot separately)And update the l…
Browse files Browse the repository at this point in the history
…og to independent records for each server
  • Loading branch information
starpig1129 committed Jul 1, 2024
1 parent d4507e3 commit 9e5ef89
Show file tree
Hide file tree
Showing 6 changed files with 280 additions and 247 deletions.
249 changes: 249 additions & 0 deletions bot.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,249 @@
import discord
import sys
import os
import re
import traceback
import aiohttp
import update
import function as func
import json
import logging
from zhconv import convert
from discord.ext import commands
from web import IPCServer
from motor.motor_asyncio import AsyncIOMotorClient
from datetime import datetime
from voicelink import VoicelinkException
from gpt.choose_act import choose_act
from gpt.sendmessage import gpt_message, load_and_index_dialogue_history, save_vector_store, vector_store
from logs import TimedRotatingFileHandler
class Translator(discord.app_commands.Translator):
async def load(self):
print("Loaded Translator")

async def unload(self):
print("Unload Translator")

async def translate(self, string: discord.app_commands.locale_str, locale: discord.Locale, context: discord.app_commands.TranslationContext):
if str(locale) in func.LOCAL_LANGS:
return func.LOCAL_LANGS[str(locale)].get(string.message, None)
return None
# 配置 logging
def setup_logger(server_name):
logger = logging.getLogger(server_name)
logger.setLevel(logging.INFO)
handler = TimedRotatingFileHandler(server_name)
formatter = logging.Formatter('%(asctime)s %(levelname)s:%(message)s')
handler.setFormatter(formatter)
logger.addHandler(handler)
return logger
class PigPig(commands.Bot):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.dialogue_history_file = './data/dialogue_history.json'
self.vector_store_path = './data/vector_store'
self.load_dialogue_history()
load_and_index_dialogue_history(self.dialogue_history_file)
self.ipc = IPCServer(
self,
host=func.settings.ipc_server["host"],
port=func.settings.ipc_server["port"],
sercet_key=func.tokens.sercet_key
)
self.loggers = {}

def setup_logger_for_guild(self, guild_name):
if guild_name not in self.loggers:
self.loggers[guild_name] = setup_logger(guild_name)

def get_logger_for_guild(self, guild_name):
if guild_name in self.loggers:
return self.loggers[guild_name]
else:
self.setup_logger_for_guild(guild_name)
return self.loggers[guild_name]

def setup_logger_for_guild(self, guild_name):
if guild_name not in self.loggers:
self.loggers[guild_name] = setup_logger(guild_name)

def load_dialogue_history(self):
"""從檔案中讀取對話歷史"""
try:
with open(self.dialogue_history_file, 'r', encoding='utf-8') as file:
self.dialogue_history = json.load(file)
except FileNotFoundError:
self.dialogue_history = {}

def save_dialogue_history(self):
"""將對話歷史保存到檔案中"""
with open(self.dialogue_history_file, 'w', encoding='utf-8') as file:
json.dump(self.dialogue_history, file, ensure_ascii=False, indent=4)
save_vector_store(vector_store, self.vector_store_path)

async def on_message(self, message: discord.Message, /) -> None:
if message.author.bot or not message.guild:
return

guild_name = message.guild.name
self.setup_logger_for_guild(guild_name)
logger = self.loggers[guild_name]

logger.info(f'收到訊息: {message.content} (來自:伺服器:{message.guild},頻道:{message.channel.name},{message.author.name})')
await self.process_commands(message)

channel_id = str(message.channel.id)
if channel_id not in self.dialogue_history:
self.dialogue_history[channel_id] = []

try:
match = re.search(r"<@\d+>\s*(.*)", message.content)
prompt = match.group(1)
except AttributeError: # 如果正則表達式沒有匹配到,會拋出 AttributeError
prompt = message.content

self.dialogue_history[channel_id].append({"role": "user", "content": prompt})
# 實現生成回應的邏輯
if self.user.id in message.raw_mentions and not message.mention_everyone:
# 發送初始訊息
message_to_edit = await message.reply("思考中...")
try:
execute_action = await choose_act(self,prompt, message, message_to_edit)
await execute_action(message_to_edit, self.dialogue_history, channel_id, prompt, message)
except Exception as e:
print(e)
self.save_dialogue_history()

async def on_message_edit(self, before: discord.Message, after: discord.Message):
if before.author.bot or not before.guild:
return

logger = self.get_logger_for_guild(before.guild.name)
logger.info(
f"訊息修改: 原訊息({before.content}) 新訊息({after.content}) 頻道:{before.channel.name}, 作者:{before.author}"
)
channel_id = str(after.channel.id)
if channel_id not in self.dialogue_history:
self.dialogue_history[channel_id] = []

try:
match = re.search(r"<@\d+>\s*(.*)", after.content)
prompt = match.group(1)
except AttributeError: # 如果正則表達式沒有匹配到,會拋出 AttributeError
prompt = after.content

self.dialogue_history[channel_id].append({"role": "user", "content": prompt})

# 實現生成回應的邏輯
if self.user.id in after.raw_mentions and not after.mention_everyone:
try:
# Fetch the bot's previous reply
async for msg in after.channel.history(limit=50):
if msg.reference and msg.reference.message_id == before.id and msg.author.id == self.user.id:
await msg.delete() # 删除之前的回复

message_to_edit = await after.reply("思考中...") # 创建新的回复
execute_action = await choose_act(self,prompt, after, message_to_edit)
await execute_action(message_to_edit, self.dialogue_history, channel_id, prompt, after)
except Exception as e:
print(e)
self.save_dialogue_history()


async def connect_db(self) -> None:
if not ((db_name := func.tokens.mongodb_name) and (db_url := func.tokens.mongodb_url)):
raise Exception("MONGODB_NAME and MONGODB_URL can't not be empty in settings.json")

try:
func.MONGO_DB = AsyncIOMotorClient(host=db_url, serverSelectionTimeoutMS=5000)
await func.MONGO_DB.server_info()
print("Successfully connected to MongoDB!")

except Exception as e:
raise Exception("Not able to connect MongoDB! Reason:", e)

func.SETTINGS_DB = func.MONGO_DB[db_name]["Settings"]
func.USERS_DB = func.MONGO_DB[db_name]["Users"]

async def setup_hook(self) -> None:
func.langs_setup()

# Connecting to MongoDB
await self.connect_db()
# Loading all the module in `cogs` folder
for module in os.listdir(func.ROOT_DIR + '/cogs'):
if module.endswith('.py'):
try:
await self.load_extension(f"cogs.{module[:-3]}")
print(f"Loaded {module[:-3]}")
except Exception as e:
print(traceback.format_exc())

if func.settings.ipc_server.get("enable", False):
await self.ipc.start()

if not func.settings.version or func.settings.version != update.__version__:
func.update_json("settings.json", new_data={"version": update.__version__})

await self.tree.set_translator(Translator())
await self.tree.sync()

async def on_ready(self):
print("------------------")
print(f"Logging As {self.user}")
print(f"Bot ID: {self.user.id}")
print("------------------")
print(f"Discord Version: {discord.__version__}")
print(f"Python Version: {sys.version}")
print("------------------")
data = {}
data['guilds'] = []
for guild in self.guilds:
guild_info = {
'guild_name': guild.name,'guild_id': guild.id,
'channels': []
}
for channel in guild.channels:
channel_info =f"channel_name: {channel.name},channel_id: {channel.id},channel_type: {str(channel.type)}"
guild_info['channels'].append(channel_info)
data['guilds'].append(guild_info)
self.setup_logger_for_guild(guild.name) # 設置每個伺服器的 logger

# 將資料寫入 JSON 文件
with open('logs/guilds_and_channels.json', 'w', encoding='utf-8') as f:
json.dump(data, f, ensure_ascii=False, indent=4)
print('update succesfully guilds_and_channels.json')
func.tokens.client_id = self.user.id
func.LOCAL_LANGS.clear()

async def on_command_error(self, ctx: commands.Context, exception, /) -> None:
error = getattr(exception, 'original', exception)
if ctx.interaction:
error = getattr(error, 'original', error)
if isinstance(error, (commands.CommandNotFound, aiohttp.client_exceptions.ClientOSError)):
return

elif isinstance(error, (commands.CommandOnCooldown, commands.MissingPermissions, commands.RangeError, commands.BadArgument)):
pass

elif isinstance(error, (commands.MissingRequiredArgument, commands.MissingRequiredAttachment)):
command = f"{ctx.prefix}" + (f"{ctx.command.parent.qualified_name} " if ctx.command.parent else "") + f"{ctx.command.name} {ctx.command.signature}"
position = command.find(f"<{ctx.current_parameter.name}>") + 1
description = f"**Correct Usage:**\n```{command}\n" + " " * position + "^" * len(ctx.current_parameter.name) + "```\n"
if ctx.command.aliases:
description += f"**Aliases:**\n`{', '.join([f'{ctx.prefix}{alias}' for alias in ctx.command.aliases])}`\n\n"
description += f"**Description:**\n{ctx.command.help}\n\u200b"

embed = discord.Embed(description=description, color=func.settings.embed_color)
embed.set_footer(icon_url=ctx.me.display_avatar.url, text=f"More Help: {func.settings.invite_link}")
return await ctx.reply(embed=embed)

elif not issubclass(error.__class__, VoicelinkException):
error = func.get_lang(ctx.guild.id, "unknownException") + func.settings.invite_link
if (guildId := ctx.guild.id) not in func.ERROR_LOGS:
func.ERROR_LOGS[guildId][round(datetime.timestamp(datetime.now()))] = "".join(traceback.format_exception(type(exception), exception, exception.__traceback__))

try:
return await ctx.reply(error, ephemeral=True)
except:
pass
17 changes: 9 additions & 8 deletions gpt/choose_act.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import json
import aiohttp
import logging
from gpt.gpt_response_gen import generate_response
from gpt.sendmessage import gpt_message
from gpt.vqa import vqa_answer
Expand All @@ -21,7 +20,7 @@ def internet_search(query: str, search_type: str):
If the conversation contains a URL, select url
Args:
query (str): Query to search the web with
search_type (str): Type of search to perform (one of [eat,url,general, image, youtube])
search_type (str): Type of search to perform (one of [general,eat,url, image, youtube])
"""
pass
```
Expand Down Expand Up @@ -115,7 +114,8 @@ def manage_user_data(user_id: str, user_data: str = None, action: str = 'read'):
'''
async def generate_image(message_to_edit, message,prompt: str, n_steps: int = 40, high_noise_frac: float = 0.8):
await message_to_edit.edit(content="畫畫修練中")
async def choose_act(prompt, message,message_to_edit):
async def choose_act(bot,prompt, message,message_to_edit):
logger = bot.get_logger_for_guild(message.guild.name)
prompt = f"msgtime:[{str(datetime.now())[:-7]}]{prompt}"
global system_prompt
default_action_list = [
Expand Down Expand Up @@ -148,7 +148,7 @@ async def choose_act(prompt, message,message_to_edit):
responses += response
# 解析 JSON 字符串
thread.join()
#logging.info(responses)
logger.info(responses)
try:
# 提取 JSON 部分
json_start = responses.find("[")
Expand All @@ -162,9 +162,10 @@ async def choose_act(prompt, message,message_to_edit):
action_list = default_action_list

async def execute_action(message_to_edit, dialogue_history, channel_id, original_prompt, message):
logger = bot.get_logger_for_guild(message_to_edit.guild.name)
nonlocal action_list, tool_func_dict
final_results = []
logging.info(action_list)
logger.info(action_list)
try:
for action in action_list:
tool_name = action["tool_name"]
Expand All @@ -181,13 +182,13 @@ async def execute_action(message_to_edit, dialogue_history, channel_id, original
if result is not None and tool_name != "directly_answer":
final_results.append(result)
except Exception as e:
logging.info(e)
logger.info(e)
else:
logging.info(f"未知的工具函数: {tool_name}")
logger.info(f"未知的工具函数: {tool_name}")
finally:
integrated_results = "\n".join(final_results)
final_prompt = f'<<information:\n{integrated_results}\n{original_prompt}>>'
gptresponses = await gpt_message(message_to_edit, message, final_prompt)
dialogue_history[channel_id].append({"role": "assistant", "content": gptresponses})
logging.info(f'PigPig:{gptresponses}')
logger.info(f'PigPig:{gptresponses}')
return execute_action
2 changes: 1 addition & 1 deletion gpt/internet/google_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ async def google_search(message_to_edit,message,query):
soup = BeautifulSoup(html, 'html.parser')
search_results = soup.select('.g')
search = ""
for result in search_results[:5]:
for result in search_results[:8]:
title_element = result.select_one('h3')
title = title_element.text if title_element else 'No Title'
snippet_element = result.select_one('.VwiC3b')
Expand Down
2 changes: 1 addition & 1 deletion gpt/sendmessage.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ async def gpt_message(message_to_edit,message,prompt):
related_data = search_vector_database(prompt) # 使用 LangChain 搜尋相關資料
# 讀取該訊息頻道最近的歷史紀錄
history = []
async for msg in channel.history(limit=10):
async for msg in channel.history(limit=5):
history.append(msg)
history.reverse()
history = history[:-2]
Expand Down
28 changes: 19 additions & 9 deletions logs.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,18 @@
import logging
import os
from datetime import datetime

class TimedRotatingFileHandler(logging.Handler):
def __init__(self):
def __init__(self, server_name):
super().__init__()
self.server_name = server_name
self.current_date = datetime.now().strftime('%Y%m%d')
self.current_hour = datetime.now().strftime('%H')
self._create_new_folder()
self._open_new_file()

def _create_new_folder(self):
log_directory = f'logs/{self.current_date}'
log_directory = f'logs/{self.server_name}/{self.current_date}'
if not os.path.exists(log_directory):
os.makedirs(log_directory)
self.log_directory = log_directory
Expand All @@ -22,20 +24,28 @@ def _open_new_file(self):
def emit(self, record):
current_date = datetime.now().strftime('%Y%m%d')
current_hour = datetime.now().strftime('%H')
if current_date != self.current_date:
if current_date != self.current_date or current_hour != self.current_hour:
self.stream.close()
self.current_date = current_date
self._create_new_folder()
self.current_hour = current_hour
self._open_new_file()
elif current_hour != self.current_hour:
self.stream.close()
self.current_hour = current_hour
self._create_new_folder()
self._open_new_file()
msg = self.format(record)
self.stream.write(msg + '\n')
self.stream.flush()

def close(self):
self.stream.close()
super().close()
super().close()

def setup_logger(server_name):
logger = logging.getLogger(server_name)
logger.setLevel(logging.INFO)
handler = TimedRotatingFileHandler(server_name)
formatter = logging.Formatter('%(asctime)s %(levelname)s:%(message)s')
handler.setFormatter(formatter)
# 移除所有默認的處理程序
if logger.hasHandlers():
logger.handlers.clear()
logger.addHandler(handler)
return logger
Loading

0 comments on commit 9e5ef89

Please sign in to comment.