Permet de maintenir la conversation au lieu de la stopper après un message sans outils
215 lines
9.4 KiB
Python
215 lines
9.4 KiB
Python
from langchain_mistralai import ChatMistralAI
|
||
from langgraph.graph import MessagesState
|
||
from langgraph.prebuilt import ToolNode
|
||
from langchain.chat_models import init_chat_model
|
||
from langgraph.graph import START, END
|
||
from langchain.messages import HumanMessage, AIMessage, SystemMessage, ToolMessage
|
||
from langgraph.types import interrupt
|
||
from shutil import rmtree
|
||
import os
|
||
import sys
|
||
from pathlib import Path
|
||
import json
|
||
|
||
from .tools import getTools, getWeeklyReportTools
|
||
from .state import CustomState
|
||
from .InterruptPayload import InterruptPayload
|
||
from .StateElements.TodoElement import TodoElement
|
||
|
||
# Variables principales
|
||
TAILLE_CONTEXTE_MAX = 20000 #charactères
|
||
PROMPT_SUMMARY = """Tu dois résumer le message qui te sera envoyé, de façon à préserver le plus d'informations, et en deux ou trois phrases.
|
||
En écrivant ta réponse, n'inclus QUE le message qui a été résumé, seulement ton résumé et rien d'autre.
|
||
Voici le message sur lequel tu dois travailler, fais le résumé :\n"""
|
||
|
||
# LLM principal
|
||
llm = ChatMistralAI( # LLM sans outils
|
||
model="mistral-large-latest",
|
||
temperature=0,
|
||
max_retries=2,
|
||
)
|
||
|
||
# NODES
|
||
def inject_preparation_prompt(state: CustomState):
|
||
""" Noeud qui vise juste à insérer le message indiquant au LLM comment travailler sur les résumés de comptes-rendus """
|
||
return {'messages': HumanMessage(
|
||
"Ton but est de lire les fichiers présents dans la base de données en utilisant l'outil 'search_in_files',\
|
||
afin de générer des rapports sur chaque semaine du stage qui y est décrit. Il y a 25 semaines, tu dois toutes les résumer,\
|
||
avec des détails et des informations complémentaires.\
|
||
Pour enregistrer chaque semaine du stage, utilise l'outil 'write_week_report'. Tu DOIS les enregistrer avec cet outil.\
|
||
Une fois terminé, fais une liste de tous les outils, logiciels, méthodes, entreprises, techniques, ect.. utilisés,\
|
||
et fais en une liste avec quelques descriptions que tu devras enregistrer avec l'outil 'write_library_tools_details_on_internship'."
|
||
)}
|
||
|
||
def preparation_docs(state: CustomState):
|
||
"""Noeud en charge de préparer les résumés pour chaque semaine des rapports, et la liste des outils et méthodes utilisées"""
|
||
model = llm.bind_tools(getWeeklyReportTools()) # LLM en charge de générer des rapports hebdomadaires sur le stage
|
||
|
||
return {'messages': model.invoke(state['messages'])}
|
||
|
||
def user_prompt(state: CustomState):
|
||
""" Dans ce nœud, l'utilisateur peut écrire un HumanMessage pour l'IA """
|
||
|
||
messages = [msg for msg in state['messages']] # Je récupère la liste des messages
|
||
|
||
sys_message = SystemMessage("Salut") # TODO: Anti-injections
|
||
user_message = HumanMessage(
|
||
InterruptPayload.fromJSON(
|
||
interrupt(
|
||
InterruptPayload({'prompt':''}, payload_type=InterruptPayload.USER_PROMPT).toJSON()
|
||
)
|
||
).get("prompt")
|
||
) # Récupérer un prompt
|
||
|
||
end = False # Permet de mettre fin à l'exécution du modèle
|
||
if user_message.content.lower().strip() == "exit":
|
||
end = True
|
||
|
||
messages.append(sys_message) # Rajout des nouveaux messages dans le système
|
||
messages.append(user_message)
|
||
|
||
return {'stop': end, 'messages': messages}# Je passe unen liste, devrait écraser tous les messages précédent au lieu d'ajouter à la liste du State
|
||
|
||
|
||
def LLM_central(state: CustomState):
|
||
"""Noeud qui s'occupe de gérer les appels au LLM"""
|
||
# Initialisation du LLM
|
||
model = llm.bind_tools(getTools())
|
||
#print(state)
|
||
|
||
if "todo" in state.keys(): # S'il y a des TODO, je l'ajoute avant le prompt au LLM
|
||
if len(state['todo'])>0:
|
||
sysmsg = SystemMessage(f"Voici la liste des tâches en cours : {str([f"{i}: {str(TodoElement.fromJSON(todo))}\n" for i,todo in enumerate(state['todo'])])}")
|
||
print(sysmsg.content)
|
||
return {"messages": [model.invoke(state["messages"] + [AIMessage('.'), sysmsg])]} # AIMessage pour que Msitrail ne refuse pas la requête avec un 400
|
||
|
||
# Appel du LLM
|
||
return {"messages": [model.invoke(state["messages"])]}
|
||
|
||
def context_shortener(state: CustomState):
|
||
""" Noeud visant à réduire la taille du contexte pour éviter une explosion de la taille de la mémoire court-terme/contexte """
|
||
# Récupérer le chemin vers le point d'entrée
|
||
base_dir: Path = Path(sys.argv[0]).resolve().parent
|
||
reports_dir = base_dir / "outils_resumes" # Chemin du dossier des rapports
|
||
|
||
lastSummarizedMessage = 0 # 0, je ne résume pas le premier message système
|
||
if 'lastSummarizedMessage' in state.keys():
|
||
lastSummarizedMessage = state['lastSummarizedMessage'] # Récupérer l'index du dernier message qui a été résumé
|
||
else:
|
||
# Premier passage, je supprime les anciens outils si besoin
|
||
rmtree(reports_dir.as_posix()) # Supprimer le dossier
|
||
reports_dir.mkdir(parents=True, exist_ok=False) # Créer le dossier
|
||
|
||
messages = [msg for msg in state['messages'][lastSummarizedMessage+1:]] # Récupérer tous les messages après lastSummarizedMessage sans l'inclure
|
||
newMessages = [msg for msg in state['messages'][:lastSummarizedMessage]] # Pré-remplir les anciens messages déjà revus
|
||
|
||
for msg in messages: # Boucle principale
|
||
if isinstance(msg, HumanMessage) or isinstance(msg, AIMessage):
|
||
# Message pouvant être directement résumé
|
||
if len(msg.content) > 0: # s'il y a un contenu dans ce message
|
||
msg.content = llm.invoke(PROMPT_SUMMARY + msg.content).content # Je le résume
|
||
newMessages.append(msg)
|
||
elif isinstance(msg, ToolMessage):
|
||
# Outil, sera placé dans un fichier
|
||
|
||
file_name = f"resume_{msg.tool_call_id}.txt" # Nom unique
|
||
full_path = reports_dir / file_name
|
||
|
||
with open(full_path, "w", encoding="utf-8") as f:
|
||
# Ecriture
|
||
f.write(f"""
|
||
Tu as utilisé un outil, qui a retourné ceci:
|
||
{msg.content}
|
||
""") # TODO: Trouver un moyen d'ajouter le nom de l'outil depuis les ToolCall vers ce write
|
||
|
||
msg.content = f"Pour voir le compte-rendu complet de cet outil, utilise ton outil 'read_file(file_path=\"outils_resumes/{file_name}\")'.\n Résumé:\n" + \
|
||
llm.invoke(PROMPT_SUMMARY + msg.content).content
|
||
newMessages.append(msg)
|
||
|
||
else:
|
||
# SystemMessage. Je ne les modifie pas
|
||
newMessages.append(msg)
|
||
|
||
lastSummarizedMessage+=1
|
||
|
||
return {'messages': newMessages, 'lastSummarizedMessage': lastSummarizedMessage} # Je retourne une liste entière, ce qui devrait remplacer toute la liste au lieu d'ajouter un simple message
|
||
|
||
# fonction de routage
|
||
def should_shorten(state: CustomState)->str:
|
||
"""
|
||
Fonction de routage, permet de savoir s'il est temps de résumer la contexte de la conversation
|
||
|
||
Args:
|
||
state (CustomState): Le State actuel
|
||
|
||
Returns:
|
||
str: Faut-il réduire le contexte ?
|
||
"""
|
||
count = 0
|
||
for msg in state['messages']: count += len(msg.content) # Compter le nombre total de caractères dans le contexte
|
||
|
||
if count < TAILLE_CONTEXTE_MAX:
|
||
# OK
|
||
return 'sous la limite'
|
||
return 'réduire contexte'
|
||
|
||
# fonction de routage : Après reponse_question, si le LLM veut appeler un outil, on va au tool_node
|
||
def should_continue(state: CustomState):
|
||
"""
|
||
Vérifier s'il y a un appel aux outils dans le dernier message
|
||
"""
|
||
if isinstance(state, list):
|
||
ai_message = state[-1]
|
||
elif messages := state.get("messages", []):
|
||
ai_message = messages[-1]
|
||
else:
|
||
raise ValueError(f"No messages found in input state to tool_edge: {state}")
|
||
|
||
if hasattr(ai_message, "tool_calls") and len(ai_message.tool_calls) > 0:
|
||
return "tools"
|
||
return "no_tools"
|
||
|
||
weekly_report_tools = ToolNode(tools=getWeeklyReportTools())
|
||
tool_node = ToolNode(tools=getTools())
|
||
|
||
|
||
class BasicToolNode: # De mon ancien projet, https://github.com/LJ5O/Assistant/blob/main/modules/Brain/src/LLM/graph/nodes/BasicToolNode.py
|
||
"""A node that runs the tools requested in the last AIMessage."""
|
||
|
||
def __init__(self, tools: list) -> None:
|
||
self.tools_by_name = {tool.name: tool for tool in tools}
|
||
|
||
def __call__(self, inputs: dict):
|
||
if messages := inputs.get("messages", []):
|
||
message = messages[-1]
|
||
else:
|
||
raise ValueError("No message found in input")
|
||
outputs = []
|
||
for tool_call in message.tool_calls:
|
||
#print(tool_call["args"])
|
||
tool_result = self.tools_by_name[tool_call["name"]].invoke(
|
||
tool_call["args"]
|
||
)
|
||
outputs.append(
|
||
ToolMessage(
|
||
content=json.dumps(tool_result),
|
||
name=tool_call["name"],
|
||
tool_call_id=tool_call["id"],
|
||
)
|
||
)
|
||
return {"messages": outputs}
|
||
|
||
# fonction de routage
|
||
def is_resumes_reports_already_initialised(state: CustomState)->str:
|
||
"""Permet de savoirr si les résumés de comptes-rendu ont déjà été générés.
|
||
S'ils le sont, inutile de recréer ce dossier.
|
||
|
||
Returns:
|
||
str: Faut-il générer les résumés ?
|
||
"""
|
||
# Récupérer le chemin vers le point d'entrée
|
||
base_dir: Path = Path(sys.argv[0]).resolve().parent
|
||
reports_dir = base_dir / "rapports_resumes" # Chemin du dossier des rapports
|
||
if os.path.isdir(reports_dir):
|
||
return "résumés déjà générés"
|
||
else: return "résumés non disponibles" |