diff --git a/.gitignore b/.gitignore index e6f64b5..1c00b2f 100644 --- a/.gitignore +++ b/.gitignore @@ -5,6 +5,7 @@ mlflow.db # Par sécurité documents_projet/ chroma_db/ +AgentReact/rapports_resumes/ # Python __pycache__/ \ No newline at end of file diff --git a/AgentReact/agent.py b/AgentReact/agent.py index 7720f57..5f24eef 100644 --- a/AgentReact/agent.py +++ b/AgentReact/agent.py @@ -2,7 +2,7 @@ from langgraph.graph import START, END from langgraph.graph.state import CompiledStateGraph from langgraph.checkpoint.memory import InMemorySaver -from utils.nodes import call_to_LLM, should_continue, task_ended, BasicToolNode, tool_node +from utils.nodes import call_to_LLM, should_continue, task_ended, BasicToolNode, tool_node, preparation_docs, weekly_report_tools from utils.state import getState from utils.tools import getTools @@ -17,10 +17,19 @@ def getGraph()->CompiledStateGraph: # Définition des sommets du graphe workflow.add_node(call_to_LLM) + workflow.add_node(preparation_docs) workflow.add_node("tool_node", tool_node)# BasicToolNode(tools=getTools())) # N'est pas une fonction, mais une classe instanciée, je dois précisier le nom du node + workflow.add_node("weekly_report_tools", weekly_report_tools) # Arrêtes - workflow.set_entry_point("call_to_LLM") + workflow.set_entry_point("preparation_docs") + workflow.add_conditional_edges("preparation_docs", should_continue, { + "tools":"weekly_report_tools", + "no_tools":"call_to_LLM" + }) + + #workflow.set_entry_point("call_to_LLM") + workflow.add_edge("weekly_report_tools", "preparation_docs") workflow.add_edge("tool_node", "call_to_LLM") workflow.add_conditional_edges("call_to_LLM", should_continue, { "tools":"tool_node", diff --git a/AgentReact/start.py b/AgentReact/start.py index 4020500..aba9b12 100644 --- a/AgentReact/start.py +++ b/AgentReact/start.py @@ -13,7 +13,7 @@ mlflow.set_experiment("TEST PROJET") # VOIR AVEC LA COMMANDE "MLFLOW SERVER" mlflow.langchain.autolog() initial_input = { - 'messages':[HumanMessage("Recherche 'Recette de Monster' sur internet")] + 'messages':[SystemMessage("Salut")] } config={"configurable": {"thread_id": 'yes'}} diff --git a/AgentReact/utils/StreamGraph.py b/AgentReact/utils/StreamGraph.py index 2f7194e..4e287d7 100644 --- a/AgentReact/utils/StreamGraph.py +++ b/AgentReact/utils/StreamGraph.py @@ -5,7 +5,7 @@ from langgraph.types import Command from .InterruptPayload import InterruptPayload # Une fonction pour stream et gérer proprement le graphe -def streamGraph(initial_input:Dict, config:Dict, graphe:CompiledStateGraph): +def streamGraph(initial_input:Dict, config:Dict, graphe:CompiledStateGraph, lastMsgIndex=0): # https://docs.langchain.com/oss/python/langgraph/interrupts#stream-with-human-in-the-loop-hitl-interrupts for mode, state in graphe.stream( initial_input, @@ -15,8 +15,11 @@ def streamGraph(initial_input:Dict, config:Dict, graphe:CompiledStateGraph): ): if mode == "values": # Handle streaming message content - msg = state['messages'][-1] - msg.pretty_print() + i=0 + for msg in state['messages'][lastMsgIndex:]: # Permet de gérer plusieurs nouveaux messages d'un coup + msg.pretty_print() + i+=1 + lastMsgIndex+=i elif mode == "updates": # Check for interrupts @@ -25,7 +28,7 @@ def streamGraph(initial_input:Dict, config:Dict, graphe:CompiledStateGraph): payload = InterruptPayload.fromJSON(payload) # Chargement de la requête depuis sa version JSON payload.humanDisplay() # L'utilisateur peut accepter/modifier/refuser ici - streamGraph(Command(resume=payload.toJSON()), config, graphe) # Je renvois la chaîne JSON, qui sera reconvertie en objet dans l'outil, et je relance le stream récursivement + streamGraph(Command(resume=payload.toJSON()), config, graphe, lastMsgIndex) # Je renvois la chaîne JSON, qui sera reconvertie en objet dans l'outil, et je relance le stream récursivement return # Fin de cette fonction récursive else: diff --git a/AgentReact/utils/nodes.py b/AgentReact/utils/nodes.py index 5171f9a..6513bae 100644 --- a/AgentReact/utils/nodes.py +++ b/AgentReact/utils/nodes.py @@ -3,8 +3,10 @@ from langgraph.graph import MessagesState from langgraph.prebuilt import ToolNode from langchain.chat_models import init_chat_model from langgraph.graph import START, END +from langchain.messages import HumanMessage, AIMessage, SystemMessage -from .tools import getTools +from .tools import getTools, getWeeklyReportTools +from .state import CustomState # LLM principal llm = ChatMistralAI( # LLM sans outils @@ -14,6 +16,22 @@ llm = ChatMistralAI( # LLM sans outils ) # NODES +def preparation_docs(state: CustomState): + """Noeud en charge de préparer les résumés pour chaque semaine des rapports, et la liste des outils et méthodes utilisées""" + model = llm.bind_tools(getWeeklyReportTools()) # LLM en charge de générer des rapports hebdomadaires sur le stage + print(len(state['messages'])) + messages = [m for m in state['messages']] # Tous les messages du stage + + if 'documentsGenerationStarted' not in state.keys(): + # Si ce noeud en est à son premier lancement, je lui donne la consigne de départ + messages.append(HumanMessage("Ton but est de lire les fichiers présents dans la base de données en utilisant l'outil 'search_in_files',\ + afin de générer des rapports sur chaque semaine du stage qui y est décrit. Pour enregistrer chaque semaine du stage, utilise l'outil 'write_week_report'.\ + Une fois terminé, fais une liste de tous les outils, logiciels, méthodes, entreprises, techniques, ect.. utilisés,\ + et fais en une liste avec quelques descriptions que tu devras enregistrer avec l'outil 'write_library_tools_details_on_internship'.")) + + messages.append(model.invoke(messages)) # Invocation LLM + return {'messages': messages, 'documentsGenerationStarted':True} # Je passe une liste de messages, ce qui ne devrait pas ajouter un message mais redéfinir toute la liste + def call_to_LLM(state: MessagesState): """Noeud qui s'occupe de gérer les appels au LLM""" # Initialisation du LLM @@ -53,6 +71,7 @@ def task_ended(state: MessagesState): return END return "continue" +weekly_report_tools = ToolNode(tools=getWeeklyReportTools()) tool_node = ToolNode(tools=getTools()) diff --git a/AgentReact/utils/state.py b/AgentReact/utils/state.py index 3224440..e91964a 100644 --- a/AgentReact/utils/state.py +++ b/AgentReact/utils/state.py @@ -10,6 +10,8 @@ class CustomState(MessagesState): ragQuery: str # Requête envoyée au RAG, pour le cross-encodeur ragDocuments: List[str] # Documents retrouvés par le RAG, pour le cross-encodeur + documentsGenerationStarted:bool# Permet d'indiquer que la consigne de génération des documents a été envoyée + # TODO: Ajouter la source des documents sélectionnés pour la fin du rapport ? diff --git a/AgentReact/utils/tools.py b/AgentReact/utils/tools.py index 082b612..354e4ec 100644 --- a/AgentReact/utils/tools.py +++ b/AgentReact/utils/tools.py @@ -263,4 +263,4 @@ def getWeeklyReportTools()->List['Tools']: """ Récupérer la liste des tools, POUR LE LLM EN CHARGE DE FAIRE LES RAPPORTS DE CHAQUE SEMAINE """ - return [write_week_report, write_library_tools_details_on_internship, internet_search] \ No newline at end of file + return [write_week_report, write_library_tools_details_on_internship, internet_search, search_in_files] \ No newline at end of file diff --git a/agent.png b/agent.png index c920050..1a00776 100644 Binary files a/agent.png and b/agent.png differ