Compare commits
4 Commits
82a5491188
...
8655359add
| Author | SHA1 | Date | |
|---|---|---|---|
|
8655359add
|
|||
|
8b32c0ac64
|
|||
|
bdf5b7dd98
|
|||
|
e0bd50a15b
|
@@ -16,24 +16,35 @@ def getGraph()->CompiledStateGraph:
|
|||||||
workflow = getState() # State prêt à utiliser
|
workflow = getState() # State prêt à utiliser
|
||||||
|
|
||||||
# Définition des sommets du graphe
|
# Définition des sommets du graphe
|
||||||
workflow.add_node(call_to_LLM)
|
workflow.add_node(user_prompt)
|
||||||
|
workflow.add_node(LLM_central)
|
||||||
workflow.add_node(preparation_docs)
|
workflow.add_node(preparation_docs)
|
||||||
workflow.add_node(inject_preparation_prompt)
|
workflow.add_node(inject_preparation_prompt)
|
||||||
workflow.add_node("tool_node", tool_node)# BasicToolNode(tools=getTools())) # N'est pas une fonction, mais une classe instanciée, je dois précisier le nom du node
|
workflow.add_node("tool_node", tool_node)# BasicToolNode(tools=getTools())) # N'est pas une fonction, mais une classe instanciée, je dois précisier le nom du node
|
||||||
workflow.add_node("weekly_report_tools", weekly_report_tools)
|
workflow.add_node("weekly_report_tools", weekly_report_tools)
|
||||||
|
workflow.add_node(context_shortener) # Réduit la taille du contexte
|
||||||
|
workflow.add_node("context_shortener_2", context_shortener) # Le même, sous un autre nom pour le différencier dans le graphe
|
||||||
|
|
||||||
# Arrêtes
|
# Arrêtes
|
||||||
workflow.set_entry_point("inject_preparation_prompt")
|
workflow.set_conditional_entry_point(is_resumes_reports_already_initialised, {
|
||||||
|
"résumés non disponibles": "inject_preparation_prompt", # Résumés non générés
|
||||||
|
"résumés déjà générés": "user_prompt" # Résumés déjà prêts, je peux aller direct à la partie principale
|
||||||
|
})
|
||||||
workflow.add_edge("inject_preparation_prompt", "preparation_docs")
|
workflow.add_edge("inject_preparation_prompt", "preparation_docs")
|
||||||
workflow.add_conditional_edges("preparation_docs", should_continue, {
|
workflow.add_conditional_edges("preparation_docs", should_continue, {
|
||||||
"tools":"weekly_report_tools",
|
"tools":"weekly_report_tools",
|
||||||
"no_tools":"call_to_LLM"
|
"no_tools":"context_shortener" # FIN de la préparation, on réduit le contexte avant de passer à la suite
|
||||||
})
|
})
|
||||||
|
workflow.add_edge("context_shortener", "user_prompt") # Et ici, je rejoins la partie principale qui rédigera le rapport
|
||||||
|
workflow.add_edge("user_prompt", "LLM_central")
|
||||||
|
|
||||||
#workflow.set_entry_point("call_to_LLM")
|
|
||||||
workflow.add_edge("weekly_report_tools", "preparation_docs")
|
workflow.add_edge("weekly_report_tools", "preparation_docs")
|
||||||
workflow.add_edge("tool_node", "call_to_LLM")
|
workflow.add_conditional_edges("tool_node", should_shorten, {
|
||||||
workflow.add_conditional_edges("call_to_LLM", should_continue, {
|
'sous la limite': "LLM_central",
|
||||||
|
'réduire contexte': "context_shortener_2"
|
||||||
|
})
|
||||||
|
workflow.add_edge("context_shortener_2", "LLM_central")
|
||||||
|
workflow.add_conditional_edges("LLM_central", should_continue, {
|
||||||
"tools":"tool_node",
|
"tools":"tool_node",
|
||||||
"no_tools":END
|
"no_tools":END
|
||||||
})
|
})
|
||||||
|
|||||||
@@ -10,9 +10,21 @@ class InterruptPayload():
|
|||||||
#EDITED = 2
|
#EDITED = 2
|
||||||
DENIED = 3
|
DENIED = 3
|
||||||
|
|
||||||
def __init__(self, fields:Dict, state:int=0):
|
TOOL_CALL = 999
|
||||||
|
USER_PROMPT = 998
|
||||||
|
|
||||||
|
def __init__(self, fields:Dict, state:int=0, payload_type:int=TOOL_CALL):
|
||||||
|
"""
|
||||||
|
Créer unne nouvelle instance de payload pour interrupt()
|
||||||
|
|
||||||
|
Args:
|
||||||
|
fields (Dict): Un dictionnaire d'arguments pour un call d'outil, ou {'prompt':str} pour une requête de prompt
|
||||||
|
state (int, optional): État de la requête. Defaults to 0. Définit en variables statiques de l'objet.
|
||||||
|
payload_type (int, optional): Type d'interuption, appel d'outil ou requête humaine. Defaults to TOOL_CALL. Définit en variables statiques de l'objet.
|
||||||
|
"""
|
||||||
self.__fields = fields
|
self.__fields = fields
|
||||||
self.__state = state
|
self.__state = state
|
||||||
|
self.__type = payload_type
|
||||||
|
|
||||||
def get(self, key:str)->str:
|
def get(self, key:str)->str:
|
||||||
"""
|
"""
|
||||||
@@ -39,7 +51,22 @@ class InterruptPayload():
|
|||||||
"""
|
"""
|
||||||
Afficher la requête proprement, permettant à l'utilisateur d'accepter, refuser ou modifier une requête
|
Afficher la requête proprement, permettant à l'utilisateur d'accepter, refuser ou modifier une requête
|
||||||
"""
|
"""
|
||||||
|
if self.__type == InterruptPayload.USER_PROMPT: # C'est une demande de prompt humain
|
||||||
|
self.__human_prompt_display()
|
||||||
|
else: # C'est un appel d'outil
|
||||||
|
self.__tool_query_display()
|
||||||
|
|
||||||
|
def __human_prompt_display(self):
|
||||||
|
print("=== L'AGENT DEMANDE DES CONSIGNES! ===\n")
|
||||||
|
|
||||||
|
print("Veuillez saisir un prompt pour l'agent...\n")
|
||||||
|
prompt = input("Prompt...")
|
||||||
|
|
||||||
|
self.__fields = {'prompt': prompt}
|
||||||
|
print("\nMerci, l'exécution va reprendre.\n")
|
||||||
|
print("======")
|
||||||
|
|
||||||
|
def __tool_query_display(self):
|
||||||
print("=== L'AGENT DEMANDE À UTILISER UN OUTIL RESTREINT! ===\n")
|
print("=== L'AGENT DEMANDE À UTILISER UN OUTIL RESTREINT! ===\n")
|
||||||
|
|
||||||
keys = list(self.__fields.keys())
|
keys = list(self.__fields.keys())
|
||||||
@@ -86,7 +113,7 @@ class InterruptPayload():
|
|||||||
Returns:
|
Returns:
|
||||||
str: String sérialisable via la méthode statique InterruptPayload.strImport(string)
|
str: String sérialisable via la méthode statique InterruptPayload.strImport(string)
|
||||||
"""
|
"""
|
||||||
return '{"state":'+ str(self.__state) +', "fields": ' + json.dumps(self.__fields, ensure_ascii=False, indent=indent) +'}'
|
return '{"state":'+ str(self.__state) +', "type": '+str(self.__type)+', "fields": ' + json.dumps(self.__fields, ensure_ascii=False, indent=indent) +'}'
|
||||||
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@@ -104,8 +131,9 @@ class InterruptPayload():
|
|||||||
|
|
||||||
state_ = data.get("state", 0)
|
state_ = data.get("state", 0)
|
||||||
fields_ = data.get("fields", {})
|
fields_ = data.get("fields", {})
|
||||||
|
type_ = data.get("type", InterruptPayload.TOOL_CALL)
|
||||||
|
|
||||||
return InterruptPayload(fields=fields_, state=state_)
|
return InterruptPayload(fields=fields_, state=state_, payload_type=type_)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -3,8 +3,12 @@ from langgraph.graph import MessagesState
|
|||||||
from langgraph.prebuilt import ToolNode
|
from langgraph.prebuilt import ToolNode
|
||||||
from langchain.chat_models import init_chat_model
|
from langchain.chat_models import init_chat_model
|
||||||
from langgraph.graph import START, END
|
from langgraph.graph import START, END
|
||||||
from langchain.messages import HumanMessage, AIMessage, SystemMessage
|
from langchain.messages import HumanMessage, AIMessage, SystemMessage, ToolMessage
|
||||||
from langgraph.types import interrupt
|
from langgraph.types import interrupt
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
from pathlib import Path
|
||||||
|
import json
|
||||||
|
|
||||||
from .tools import getTools, getWeeklyReportTools
|
from .tools import getTools, getWeeklyReportTools
|
||||||
from .state import CustomState
|
from .state import CustomState
|
||||||
@@ -22,7 +26,9 @@ def inject_preparation_prompt(state: CustomState):
|
|||||||
""" Noeud qui vise juste à insérer le message indiquant au LLM comment travailler sur les résumés de comptes-rendus """
|
""" Noeud qui vise juste à insérer le message indiquant au LLM comment travailler sur les résumés de comptes-rendus """
|
||||||
return {'messages': HumanMessage(
|
return {'messages': HumanMessage(
|
||||||
"Ton but est de lire les fichiers présents dans la base de données en utilisant l'outil 'search_in_files',\
|
"Ton but est de lire les fichiers présents dans la base de données en utilisant l'outil 'search_in_files',\
|
||||||
afin de générer des rapports sur chaque semaine du stage qui y est décrit. Pour enregistrer chaque semaine du stage, utilise l'outil 'write_week_report'.\
|
afin de générer des rapports sur chaque semaine du stage qui y est décrit. Il y a 25 semaines, tu dois toutes les résumer,\
|
||||||
|
avec des détails et des informations complémentaires.\
|
||||||
|
Pour enregistrer chaque semaine du stage, utilise l'outil 'write_week_report'. Tu DOIS les enregistrer avec cet outil.\
|
||||||
Une fois terminé, fais une liste de tous les outils, logiciels, méthodes, entreprises, techniques, ect.. utilisés,\
|
Une fois terminé, fais une liste de tous les outils, logiciels, méthodes, entreprises, techniques, ect.. utilisés,\
|
||||||
et fais en une liste avec quelques descriptions que tu devras enregistrer avec l'outil 'write_library_tools_details_on_internship'."
|
et fais en une liste avec quelques descriptions que tu devras enregistrer avec l'outil 'write_library_tools_details_on_internship'."
|
||||||
)}
|
)}
|
||||||
@@ -33,7 +39,27 @@ def preparation_docs(state: CustomState):
|
|||||||
|
|
||||||
return {'messages': model.invoke(state['messages'])}
|
return {'messages': model.invoke(state['messages'])}
|
||||||
|
|
||||||
def call_to_LLM(state: MessagesState):
|
def user_prompt(state: CustomState):
|
||||||
|
""" Dans ce nœud, l'utilisateur peut écrire un HumanMessage pour l'IA """
|
||||||
|
|
||||||
|
messages = [msg for msg in state['messages']] # Je récupère la liste des messages
|
||||||
|
|
||||||
|
sys_message = SystemMessage("Salut") # TODO: Anti-injections
|
||||||
|
user_message = HumanMessage(
|
||||||
|
InterruptPayload.fromJSON(
|
||||||
|
interrupt(
|
||||||
|
InterruptPayload({'prompt':''}, payload_type=InterruptPayload.USER_PROMPT).toJSON()
|
||||||
|
)
|
||||||
|
).get("prompt")
|
||||||
|
) # Récupérer un prompt
|
||||||
|
|
||||||
|
messages.append(sys_message) # Rajout des nouveaux messages dans le système
|
||||||
|
messages.append(user_message)
|
||||||
|
|
||||||
|
return {'messages': messages}# Je passe unen liste, devrait écraser tous les messages précédent au lieu d'ajouter à la liste du State
|
||||||
|
|
||||||
|
|
||||||
|
def LLM_central(state: MessagesState):
|
||||||
"""Noeud qui s'occupe de gérer les appels au LLM"""
|
"""Noeud qui s'occupe de gérer les appels au LLM"""
|
||||||
# Initialisation du LLM
|
# Initialisation du LLM
|
||||||
model = llm.bind_tools(getTools())
|
model = llm.bind_tools(getTools())
|
||||||
@@ -41,6 +67,30 @@ def call_to_LLM(state: MessagesState):
|
|||||||
# Appel du LLM
|
# Appel du LLM
|
||||||
return {"messages": [model.invoke(state["messages"])]}
|
return {"messages": [model.invoke(state["messages"])]}
|
||||||
|
|
||||||
|
def context_shortener(state: CustomState):
|
||||||
|
""" Noeud visant à réduire la taille du contexte pour éviter une explosion de la taille de la mémoire court-terme/contexte """
|
||||||
|
raise NotImplementedError('TODO, faut que je le fasse')
|
||||||
|
|
||||||
|
# fonction de routage
|
||||||
|
def should_shorten(state: CustomState)->str:
|
||||||
|
"""
|
||||||
|
Fonction de routage, permet de savoir s'il est temps de résumer la contexte de la conversation
|
||||||
|
|
||||||
|
Args:
|
||||||
|
state (CustomState): Le State actuel
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: Faut-il réduire le contexte ?
|
||||||
|
"""
|
||||||
|
TAILLE_CONTEXTE_MAX = 20000 #charactères
|
||||||
|
count = 0
|
||||||
|
for msg in state['messages']: count += len(msg.content) # Compter le nombre total de caractères dans le contexte
|
||||||
|
|
||||||
|
if count < TAILLE_CONTEXTE_MAX:
|
||||||
|
# OK
|
||||||
|
return 'sous la limite'
|
||||||
|
return 'réduire contexte'
|
||||||
|
|
||||||
# fonction de routage : Après reponse_question, si le LLM veut appeler un outil, on va au tool_node
|
# fonction de routage : Après reponse_question, si le LLM veut appeler un outil, on va au tool_node
|
||||||
def should_continue(state: MessagesState):
|
def should_continue(state: MessagesState):
|
||||||
"""
|
"""
|
||||||
@@ -57,21 +107,6 @@ def should_continue(state: MessagesState):
|
|||||||
return "tools"
|
return "tools"
|
||||||
return "no_tools"
|
return "no_tools"
|
||||||
|
|
||||||
def task_ended(state: MessagesState):
|
|
||||||
"""
|
|
||||||
Vérifier si l'agent a terminé son cycle, ou s'il faut le relancer
|
|
||||||
"""
|
|
||||||
if isinstance(state, list):
|
|
||||||
ai_message = state[-1]
|
|
||||||
elif messages := state.get("messages", []):
|
|
||||||
ai_message = messages[-1]
|
|
||||||
else:
|
|
||||||
raise ValueError(f"No messages found in input state to tool_edge: {state}")
|
|
||||||
|
|
||||||
if "terminé" in ai_message.content.lower():
|
|
||||||
return END
|
|
||||||
return "continue"
|
|
||||||
|
|
||||||
weekly_report_tools = ToolNode(tools=getWeeklyReportTools())
|
weekly_report_tools = ToolNode(tools=getWeeklyReportTools())
|
||||||
tool_node = ToolNode(tools=getTools())
|
tool_node = ToolNode(tools=getTools())
|
||||||
|
|
||||||
@@ -101,3 +136,18 @@ class BasicToolNode: # De mon ancien projet, https://github.com/LJ5O/Assistant/b
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
return {"messages": outputs}
|
return {"messages": outputs}
|
||||||
|
|
||||||
|
# fonction de routage
|
||||||
|
def is_resumes_reports_already_initialised(state: CustomState)->str:
|
||||||
|
"""Permet de savoirr si les résumés de comptes-rendu ont déjà été générés.
|
||||||
|
S'ils le sont, inutile de recréer ce dossier.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: Faut-il générer les résumés ?
|
||||||
|
"""
|
||||||
|
# Récupérer le chemin vers le point d'entrée
|
||||||
|
base_dir: Path = Path(sys.argv[0]).resolve().parent
|
||||||
|
reports_dir = base_dir / "rapports_resumes" # Chemin du dossier des rapports
|
||||||
|
if os.path.isdir(reports_dir):
|
||||||
|
return "résumés déjà générés"
|
||||||
|
else: return "résumés non disponibles"
|
||||||
BIN
imgs/agent.png
BIN
imgs/agent.png
Binary file not shown.
|
Before Width: | Height: | Size: 26 KiB After Width: | Height: | Size: 50 KiB |
@@ -16,13 +16,14 @@
|
|||||||
- [X] Préparation des nœuds
|
- [X] Préparation des nœuds
|
||||||
- [X] Branchement des nœuds entre-eux, **MVP**
|
- [X] Branchement des nœuds entre-eux, **MVP**
|
||||||
- [X] Human in the loop
|
- [X] Human in the loop
|
||||||
- [ ] Amélioration du workflow
|
- [X] Amélioration du workflow
|
||||||
|
- [ ] Gestion de la taille du contexte - Résumé de l'historique des messages
|
||||||
|
|
||||||
## Amélioration de l'agent
|
## Amélioration de l'agent
|
||||||
- [ ] Cross-encoding sur la sortie du **RAG**
|
- [ ] Cross-encoding sur la sortie du **RAG**
|
||||||
- [ ] Sauvegarde de l'état de l'agent
|
- [ ] Sauvegarde de l'état de l'agent
|
||||||
|
- [ ] Lecture d'un `skills.md`
|
||||||
- [ ] Système de redémarrage après un arrêt
|
- [ ] Système de redémarrage après un arrêt
|
||||||
- [ ] Gestion de la taille du contexte - Résumé de l'historique des messages
|
|
||||||
- [ ] Détection de *prompt injection*
|
- [ ] Détection de *prompt injection*
|
||||||
- [ ] Génération d'un PDF en sortie du système
|
- [ ] Génération d'un PDF en sortie du système
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user