chat/plugin.py

204 lines
8.2 KiB
Python

###
# Copyright (c) 2023, John Burwell
# All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# * Redistributions of source code must retain the above copyright notice,
# this list of conditions, and the following disclaimer.
# * Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions, and the following disclaimer in the
# documentation and/or other materials provided with the distribution.
# * Neither the name of the author of this software nor the name of
# contributors to this software may be used to endorse or promote products
# derived from this software without specific prior written consent.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE
# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
# POSSIBILITY OF SUCH DAMAGE.
###
import json
import re
import requests
import supybot
from supybot import callbacks, conf, ircutils
from supybot.commands import *
import logging
try:
from supybot.i18n import PluginInternationalization
_ = PluginInternationalization('Chat')
except ImportError:
# Placeholder that allows to run the plugin on a bot
# without the i18n module
_ = lambda x: x
def truncate_messages(messages, max_tokens):
"""
Truncates the messages list to ensure the total token count does not exceed max_tokens.
Args:
messages (list): The list of message dictionaries to truncate.
max_tokens (int): The maximum number of tokens allowed.
Returns:
list: The truncated list of messages.
"""
total_tokens = 0
truncated = []
for message in reversed(messages):
# Approximate token count by splitting content into words
message_tokens = len(message["content"].split())
if total_tokens + message_tokens > max_tokens:
break
truncated.insert(0, message)
total_tokens += message_tokens
return truncated
class Chat(callbacks.Plugin):
"""Sends message to ChatGPT and replies with the response
"""
def __init__(self, irc):
self.__parent = super(Chat, self)
self.__parent.__init__(irc)
log_level = self.registryValue('log_level').upper()
self.log.setLevel(getattr(logging, log_level, logging.INFO))
self.log.info("Chat plugin initialized with log level: %s", log_level)
def filter_prefix(self, msg, prefix):
if msg.startswith(prefix):
return msg[len(prefix):]
else:
return msg
def chat(self, irc, msg, args, string):
"""
<message>
Sends a message to ChatGPT and returns the response. The bot will include recent
conversation history from the channel to provide context.
Example:
@bot chat What is the capital of France?
"""
# Construct the invocation string to identify bot commands
invocation_string = f"{conf.supybot.reply.whenAddressedBy.chars()}{self.name().lower()} "
self.log.debug(f"Invocation string: {invocation_string} | User: {msg.nick} | Channel: {msg.args[0]}")
# Retrieve model and token settings from the plugin's configuration
model = self.registryValue("model")
max_tokens = self.registryValue("max_tokens")
# Use a default system prompt if none is configured
default_prompt = "You are a helpful assistant."
system_prompt = self.registryValue("system_prompt") or default_prompt
# Replace dynamic placeholders in the system prompt with actual values
system_prompt = system_prompt.replace("$bot_name", irc.nick).replace("$channel_name", msg.args[0])
# Retrieve the last few lines of the chat scrollback to provide context
history = irc.state.history[-self.registryValue("scrollback_lines"):]
self.log.debug(f"Raw history: {history}")
# Filter the scrollback to include only PRIVMSGs in the current channel
filtered_messages = [
(message.nick, self.filter_prefix(message.args[1], f"{invocation_string}"))
for message in history
if message.command == 'PRIVMSG' and message.args[0] == msg.args[0]
][:-1]
if not filtered_messages:
# Log a warning if no relevant messages are found in the scrollback
self.log.warning(f"No messages found in scrollback for channel {msg.args[0]}")
# Format the conversation history for the API request
conversation_history = [
{
"role": "assistant" if nick == "" else "user",
"content": re.sub(r'^.+?:\\s', '', msg) if nick == "" else f"{nick}: {msg}"
}
for nick, msg in filtered_messages
]
# Combine the system prompt and the conversation history
messages = [{"role": "system", "content": system_prompt}] + conversation_history + [{"role": "user", "content": msg.args[1]}]
# Truncate messages to ensure the total token count does not exceed the model's limit
messages = truncate_messages(messages, 8192)
self.log.debug(f"API Request: {json.dumps(messages)}")
try:
# Send the request to the OpenAI API
res = requests.post(
"https://api.openai.com/v1/chat/completions",
headers={
"Content-Type": "application/json",
"Authorization": f"Bearer {self.registryValue('api_key')}"
},
json={
"model": model,
"messages": messages,
"max_tokens": max_tokens,
},
timeout=10 # Set a timeout for the request
)
res.raise_for_status() # Raise an HTTPError for bad responses (4xx or 5xx)
res = res.json()
self.log.debug(f"API Response: {json.dumps(res)}")
if "error" in res:
# Log and reply with the error message if the API returns an error
error_message = res["error"].get("message", "Unknown error")
self.log.error(f"API error: {error_message} | User input: {msg.args[1]} | Channel: {msg.args[0]}")
irc.reply(f"API error: {error_message}")
return
# Extract and format the response from the API
response = res['choices'][0]['message']['content'].strip()
# Handle multi-line responses intelligently
lines = response.splitlines()
if len(lines) > 1:
# Join lines with the configured join_string, skipping empty lines
response = self.registryValue("join_string").join(line.strip() for line in lines if line.strip())
irc.reply(response)
# Log the successful processing of the request
self.log.info(f"Successfully processed request for user {msg.nick} in channel {msg.args[0]}")
except requests.exceptions.Timeout:
# Handle and log timeout errors
self.log.error("Request timed out.")
irc.reply("The request to the API timed out. Please try again later.")
except requests.exceptions.HTTPError as e:
# Handle and log HTTP errors
self.log.error(f"HTTP error: {e}")
irc.reply("An HTTP error occurred while contacting the API.")
except requests.exceptions.RequestException as e:
# Handle and log other request exceptions
self.log.error(f"Request exception: {e}")
irc.reply("An error occurred while contacting the API.")
chat = wrap(chat, ['text'])
Class = Chat