Gestion du contexte/mémoire courte
Gestion de la mémoire courte
This commit is contained in:
@@ -5,6 +5,7 @@ from langchain.chat_models import init_chat_model
|
||||
from langgraph.graph import START, END
|
||||
from langchain.messages import HumanMessage, AIMessage, SystemMessage, ToolMessage
|
||||
from langgraph.types import interrupt
|
||||
from shutil import rmtree
|
||||
import os
|
||||
import sys
|
||||
from pathlib import Path
|
||||
@@ -14,6 +15,12 @@ from .tools import getTools, getWeeklyReportTools
|
||||
from .state import CustomState
|
||||
from .InterruptPayload import InterruptPayload
|
||||
|
||||
# Variables principales
|
||||
TAILLE_CONTEXTE_MAX = 20000 #charactères
|
||||
PROMPT_SUMMARY = """Tu dois résumer le message qui te sera envoyé, de façon à préserver le plus d'informations, et en deux ou trois phrases.
|
||||
En écrivant ta réponse, n'inclus QUE le message qui a été résumé, seulement ton résumé et rien d'autre.
|
||||
Voici le message sur lequel tu dois travailler, fais le résumé :\n"""
|
||||
|
||||
# LLM principal
|
||||
llm = ChatMistralAI( # LLM sans outils
|
||||
model="mistral-large-latest",
|
||||
@@ -69,7 +76,51 @@ def LLM_central(state: MessagesState):
|
||||
|
||||
def context_shortener(state: CustomState):
|
||||
""" Noeud visant à réduire la taille du contexte pour éviter une explosion de la taille de la mémoire court-terme/contexte """
|
||||
raise NotImplementedError('TODO, faut que je le fasse')
|
||||
# Récupérer le chemin vers le point d'entrée
|
||||
base_dir: Path = Path(sys.argv[0]).resolve().parent
|
||||
reports_dir = base_dir / "outils_resumes" # Chemin du dossier des rapports
|
||||
|
||||
lastSummarizedMessage = 0 # 0, je ne résume pas le premier message système
|
||||
if 'lastSummarizedMessage' in state.keys():
|
||||
lastSummarizedMessage = state['lastSummarizedMessage'] # Récupérer l'index du dernier message qui a été résumé
|
||||
else:
|
||||
# Premier passage, je supprime les anciens outils si besoin
|
||||
rmtree(reports_dir.as_posix()) # Supprimer le dossier
|
||||
reports_dir.mkdir(parents=True, exist_ok=False) # Créer le dossier
|
||||
|
||||
messages = [msg for msg in state['messages'][lastSummarizedMessage+1:]] # Récupérer tous les messages après lastSummarizedMessage sans l'inclure
|
||||
newMessages = [msg for msg in state['messages'][:lastSummarizedMessage]] # Pré-remplir les anciens messages déjà revus
|
||||
|
||||
for msg in messages: # Boucle principale
|
||||
if isinstance(msg, HumanMessage) or isinstance(msg, AIMessage):
|
||||
# Message pouvant être directement résumé
|
||||
if len(msg.content) > 0: # s'il y a un contenu dans ce message
|
||||
msg.content = llm.invoke(PROMPT_SUMMARY + msg.content).content # Je le résume
|
||||
newMessages.append(msg)
|
||||
elif isinstance(msg, ToolMessage):
|
||||
# Outil, sera placé dans un fichier
|
||||
|
||||
file_name = f"resume_{msg.tool_call_id}.txt" # Nom unique
|
||||
full_path = reports_dir / file_name
|
||||
|
||||
with open(full_path, "w", encoding="utf-8") as f:
|
||||
# Ecriture
|
||||
f.write(f"""
|
||||
Tu as utilisé un outil, qui a retourné ceci:
|
||||
{msg.content}
|
||||
""") # TODO: Trouver un moyen d'ajouter le nom de l'outil depuis les ToolCall vers ce write
|
||||
|
||||
msg.content = f"Pour voir le compte-rendu complet de cet outil, utilise ton outil 'read_file(file_path=\"outils_resumes/{file_name}\")'.\n Résumé:\n" + \
|
||||
llm.invoke(PROMPT_SUMMARY + msg.content).content
|
||||
newMessages.append(msg)
|
||||
|
||||
else:
|
||||
# SystemMessage. Je ne les modifie pas
|
||||
newMessages.append(msg)
|
||||
|
||||
lastSummarizedMessage+=1
|
||||
|
||||
return {'messages': newMessages, 'lastSummarizedMessage': lastSummarizedMessage} # Je retourne une liste entière, ce qui devrait remplacer toute la liste au lieu d'ajouter un simple message
|
||||
|
||||
# fonction de routage
|
||||
def should_shorten(state: CustomState)->str:
|
||||
@@ -82,7 +133,6 @@ def should_shorten(state: CustomState)->str:
|
||||
Returns:
|
||||
str: Faut-il réduire le contexte ?
|
||||
"""
|
||||
TAILLE_CONTEXTE_MAX = 20000 #charactères
|
||||
count = 0
|
||||
for msg in state['messages']: count += len(msg.content) # Compter le nombre total de caractères dans le contexte
|
||||
|
||||
|
||||
@@ -10,6 +10,8 @@ class CustomState(MessagesState):
|
||||
ragQuery: str # Requête envoyée au RAG, pour le cross-encodeur
|
||||
ragDocuments: List[str] # Documents retrouvés par le RAG, pour le cross-encodeur
|
||||
|
||||
lastSummarizedMessage: int # Index du message où l'on s'était arrêté de résumer
|
||||
|
||||
# TODO: Ajouter la source des documents sélectionnés pour la fin du rapport ?
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user