33 lines
1.5 KiB
Python
33 lines
1.5 KiB
Python
from typing import Dict
|
|
from langgraph.graph.state import CompiledStateGraph
|
|
from langgraph.types import Command
|
|
|
|
from .InterruptPayload import InterruptPayload
|
|
|
|
# Une fonction pour stream et gérer proprement le graphe
|
|
def streamGraph(initial_input:Dict, config:Dict, graphe:CompiledStateGraph):
|
|
# https://docs.langchain.com/oss/python/langgraph/interrupts#stream-with-human-in-the-loop-hitl-interrupts
|
|
for mode, state in graphe.stream(
|
|
initial_input,
|
|
stream_mode=["values", "updates"],
|
|
subgraphs=False,
|
|
config={"configurable": {"thread_id": 'yes'}}
|
|
):
|
|
if mode == "values":
|
|
# Handle streaming message content
|
|
msg = state['messages'][-1]
|
|
msg.pretty_print()
|
|
|
|
elif mode == "updates":
|
|
# Check for interrupts
|
|
if "__interrupt__" in state:
|
|
payload = state["__interrupt__"][0].value
|
|
|
|
payload = InterruptPayload.fromJSON(payload) # Chargement de la requête depuis sa version JSON
|
|
payload.humanDisplay() # L'utilisateur peut accepter/modifier/refuser ici
|
|
streamGraph(Command(resume=payload.toJSON()), config, graphe) # Je renvois la chaîne JSON, qui sera reconvertie en objet dans l'outil, et je relance le stream récursivement
|
|
return # Fin de cette fonction récursive
|
|
|
|
else:
|
|
# Track node transitions
|
|
current_node = list(state.keys())[0] |