diff --git a/AgentReact/agent.py b/AgentReact/agent.py index a21475d..7720f57 100644 --- a/AgentReact/agent.py +++ b/AgentReact/agent.py @@ -1,5 +1,6 @@ from langgraph.graph import START, END from langgraph.graph.state import CompiledStateGraph +from langgraph.checkpoint.memory import InMemorySaver from utils.nodes import call_to_LLM, should_continue, task_ended, BasicToolNode, tool_node from utils.state import getState @@ -26,7 +27,7 @@ def getGraph()->CompiledStateGraph: "no_tools":END }) - return workflow.compile() + return workflow.compile(checkpointer=InMemorySaver()) # TODO: Rempalcer par une vrai BDD de prod if __name__ == "__main__": # Affichage du graphe diff --git a/AgentReact/start.py b/AgentReact/start.py index 96e66a4..4020500 100644 --- a/AgentReact/start.py +++ b/AgentReact/start.py @@ -5,11 +5,18 @@ from langchain.messages import HumanMessage, SystemMessage, AIMessage, ToolMessa import mlflow from agent import getGraph +from utils.InterruptPayload import InterruptPayload +from utils.StreamGraph import streamGraph # MLFLOW mlflow.set_experiment("TEST PROJET") # VOIR AVEC LA COMMANDE "MLFLOW SERVER" mlflow.langchain.autolog() -out_state = getGraph().invoke({'messages':[HumanMessage("Observe la base de documents, et génère un rapport de stage à partir de celle-ci. Ecris le dans un fichier markdown.")]}) -for message in out_state['messages']: - message.pretty_print() \ No newline at end of file +initial_input = { + 'messages':[HumanMessage("Recherche 'Recette de Monster' sur internet")] + } + +config={"configurable": {"thread_id": 'yes'}} + +# Et je lance ! +streamGraph(initial_input, config, getGraph()) \ No newline at end of file diff --git a/AgentReact/utils/InterruptPayload.py b/AgentReact/utils/InterruptPayload.py index d76c236..c1a5818 100644 --- a/AgentReact/utils/InterruptPayload.py +++ b/AgentReact/utils/InterruptPayload.py @@ -14,6 +14,18 @@ class InterruptPayload(): self.__fields = fields self.__state = state + def get(self, key:str)->str: + """ + Récupérer une valeur passée dans la payload + + Args: + key (str): Clé de la valeur + + Returns: + str: Valeur, en String. Il faudra la reconvertir en int si besoin + """ + return self.__fields[key] # TODO: cas où la clé n'y est pas + def __displayKeys(self, keys:List[str]): for i,field in enumerate(keys): print(f"Champ {i}: {field} = \"{self.__fields[field]}\"\n") diff --git a/AgentReact/utils/StreamGraph.py b/AgentReact/utils/StreamGraph.py new file mode 100644 index 0000000..2f7194e --- /dev/null +++ b/AgentReact/utils/StreamGraph.py @@ -0,0 +1,33 @@ +from typing import Dict +from langgraph.graph.state import CompiledStateGraph +from langgraph.types import Command + +from .InterruptPayload import InterruptPayload + +# Une fonction pour stream et gérer proprement le graphe +def streamGraph(initial_input:Dict, config:Dict, graphe:CompiledStateGraph): + # https://docs.langchain.com/oss/python/langgraph/interrupts#stream-with-human-in-the-loop-hitl-interrupts + for mode, state in graphe.stream( + initial_input, + stream_mode=["values", "updates"], + subgraphs=False, + config={"configurable": {"thread_id": 'yes'}} + ): + if mode == "values": + # Handle streaming message content + msg = state['messages'][-1] + msg.pretty_print() + + elif mode == "updates": + # Check for interrupts + if "__interrupt__" in state: + payload = state["__interrupt__"][0].value + + payload = InterruptPayload.fromJSON(payload) # Chargement de la requête depuis sa version JSON + payload.humanDisplay() # L'utilisateur peut accepter/modifier/refuser ici + streamGraph(Command(resume=payload.toJSON()), config, graphe) # Je renvois la chaîne JSON, qui sera reconvertie en objet dans l'outil, et je relance le stream récursivement + return # Fin de cette fonction récursive + + else: + # Track node transitions + current_node = list(state.keys())[0] \ No newline at end of file diff --git a/AgentReact/utils/tools.py b/AgentReact/utils/tools.py index 1d9fef3..0069fd9 100644 --- a/AgentReact/utils/tools.py +++ b/AgentReact/utils/tools.py @@ -4,9 +4,11 @@ from tavily import TavilyClient from pathlib import Path from typing import List, Dict, Annotated import sys -from .StateElements.TodoElement import TodoElement +from langgraph.types import interrupt +from .StateElements.TodoElement import TodoElement from .VectorDatabase import VectorDatabase +from .InterruptPayload import InterruptPayload @tool def internet_search(query: str)->dict: @@ -18,7 +20,16 @@ def internet_search(query: str)->dict: Returns: dict: Retour de la recherche """ - return TavilyClient().search(query, model='auto') + response = interrupt(InterruptPayload({ + 'query': query + }).toJSON()) + + resp = InterruptPayload.fromJSON(response) # Je reforme mon objet depuis la string json + + if resp.isAccepted(): + return TavilyClient().search(resp.get("query"), model='auto') + else: + return {'error': "Utilisation de cet outil refusée par l'utilisateur"} @tool diff --git a/roadmap.md b/roadmap.md index 92a67ef..9911193 100644 --- a/roadmap.md +++ b/roadmap.md @@ -15,7 +15,7 @@ - [X] Développement des outils de l'agent - [X] Préparation des nœuds - [X] Branchement des nœuds entre-eux, **MVP** -- [ ] Human in the loop +- [X] Human in the loop - [ ] Amélioration du workflow ## Amélioration de l'agent