Human in the loop
Implémentation fonctionelle de l'HITL !
This commit is contained in:
@@ -1,5 +1,6 @@
|
||||
from langgraph.graph import START, END
|
||||
from langgraph.graph.state import CompiledStateGraph
|
||||
from langgraph.checkpoint.memory import InMemorySaver
|
||||
|
||||
from utils.nodes import call_to_LLM, should_continue, task_ended, BasicToolNode, tool_node
|
||||
from utils.state import getState
|
||||
@@ -26,7 +27,7 @@ def getGraph()->CompiledStateGraph:
|
||||
"no_tools":END
|
||||
})
|
||||
|
||||
return workflow.compile()
|
||||
return workflow.compile(checkpointer=InMemorySaver()) # TODO: Rempalcer par une vrai BDD de prod
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Affichage du graphe
|
||||
|
||||
@@ -5,11 +5,18 @@ from langchain.messages import HumanMessage, SystemMessage, AIMessage, ToolMessa
|
||||
import mlflow
|
||||
|
||||
from agent import getGraph
|
||||
from utils.InterruptPayload import InterruptPayload
|
||||
from utils.StreamGraph import streamGraph
|
||||
|
||||
# MLFLOW
|
||||
mlflow.set_experiment("TEST PROJET") # VOIR AVEC LA COMMANDE "MLFLOW SERVER"
|
||||
mlflow.langchain.autolog()
|
||||
|
||||
out_state = getGraph().invoke({'messages':[HumanMessage("Observe la base de documents, et génère un rapport de stage à partir de celle-ci. Ecris le dans un fichier markdown.")]})
|
||||
for message in out_state['messages']:
|
||||
message.pretty_print()
|
||||
initial_input = {
|
||||
'messages':[HumanMessage("Recherche 'Recette de Monster' sur internet")]
|
||||
}
|
||||
|
||||
config={"configurable": {"thread_id": 'yes'}}
|
||||
|
||||
# Et je lance !
|
||||
streamGraph(initial_input, config, getGraph())
|
||||
@@ -14,6 +14,18 @@ class InterruptPayload():
|
||||
self.__fields = fields
|
||||
self.__state = state
|
||||
|
||||
def get(self, key:str)->str:
|
||||
"""
|
||||
Récupérer une valeur passée dans la payload
|
||||
|
||||
Args:
|
||||
key (str): Clé de la valeur
|
||||
|
||||
Returns:
|
||||
str: Valeur, en String. Il faudra la reconvertir en int si besoin
|
||||
"""
|
||||
return self.__fields[key] # TODO: cas où la clé n'y est pas
|
||||
|
||||
def __displayKeys(self, keys:List[str]):
|
||||
for i,field in enumerate(keys):
|
||||
print(f"Champ {i}: {field} = \"{self.__fields[field]}\"\n")
|
||||
|
||||
33
AgentReact/utils/StreamGraph.py
Normal file
33
AgentReact/utils/StreamGraph.py
Normal file
@@ -0,0 +1,33 @@
|
||||
from typing import Dict
|
||||
from langgraph.graph.state import CompiledStateGraph
|
||||
from langgraph.types import Command
|
||||
|
||||
from .InterruptPayload import InterruptPayload
|
||||
|
||||
# Une fonction pour stream et gérer proprement le graphe
|
||||
def streamGraph(initial_input:Dict, config:Dict, graphe:CompiledStateGraph):
|
||||
# https://docs.langchain.com/oss/python/langgraph/interrupts#stream-with-human-in-the-loop-hitl-interrupts
|
||||
for mode, state in graphe.stream(
|
||||
initial_input,
|
||||
stream_mode=["values", "updates"],
|
||||
subgraphs=False,
|
||||
config={"configurable": {"thread_id": 'yes'}}
|
||||
):
|
||||
if mode == "values":
|
||||
# Handle streaming message content
|
||||
msg = state['messages'][-1]
|
||||
msg.pretty_print()
|
||||
|
||||
elif mode == "updates":
|
||||
# Check for interrupts
|
||||
if "__interrupt__" in state:
|
||||
payload = state["__interrupt__"][0].value
|
||||
|
||||
payload = InterruptPayload.fromJSON(payload) # Chargement de la requête depuis sa version JSON
|
||||
payload.humanDisplay() # L'utilisateur peut accepter/modifier/refuser ici
|
||||
streamGraph(Command(resume=payload.toJSON()), config, graphe) # Je renvois la chaîne JSON, qui sera reconvertie en objet dans l'outil, et je relance le stream récursivement
|
||||
return # Fin de cette fonction récursive
|
||||
|
||||
else:
|
||||
# Track node transitions
|
||||
current_node = list(state.keys())[0]
|
||||
@@ -4,9 +4,11 @@ from tavily import TavilyClient
|
||||
from pathlib import Path
|
||||
from typing import List, Dict, Annotated
|
||||
import sys
|
||||
from .StateElements.TodoElement import TodoElement
|
||||
from langgraph.types import interrupt
|
||||
|
||||
from .StateElements.TodoElement import TodoElement
|
||||
from .VectorDatabase import VectorDatabase
|
||||
from .InterruptPayload import InterruptPayload
|
||||
|
||||
@tool
|
||||
def internet_search(query: str)->dict:
|
||||
@@ -18,7 +20,16 @@ def internet_search(query: str)->dict:
|
||||
Returns:
|
||||
dict: Retour de la recherche
|
||||
"""
|
||||
return TavilyClient().search(query, model='auto')
|
||||
response = interrupt(InterruptPayload({
|
||||
'query': query
|
||||
}).toJSON())
|
||||
|
||||
resp = InterruptPayload.fromJSON(response) # Je reforme mon objet depuis la string json
|
||||
|
||||
if resp.isAccepted():
|
||||
return TavilyClient().search(resp.get("query"), model='auto')
|
||||
else:
|
||||
return {'error': "Utilisation de cet outil refusée par l'utilisateur"}
|
||||
|
||||
|
||||
@tool
|
||||
|
||||
@@ -15,7 +15,7 @@
|
||||
- [X] Développement des outils de l'agent
|
||||
- [X] Préparation des nœuds
|
||||
- [X] Branchement des nœuds entre-eux, **MVP**
|
||||
- [ ] Human in the loop
|
||||
- [X] Human in the loop
|
||||
- [ ] Amélioration du workflow
|
||||
|
||||
## Amélioration de l'agent
|
||||
|
||||
Reference in New Issue
Block a user