36 lines
1.6 KiB
Python
36 lines
1.6 KiB
Python
from typing import Dict
|
|
from langgraph.graph.state import CompiledStateGraph
|
|
from langgraph.types import Command
|
|
|
|
from .InterruptPayload import InterruptPayload
|
|
|
|
# Une fonction pour stream et gérer proprement le graphe
|
|
def streamGraph(initial_input:Dict, config:Dict, graphe:CompiledStateGraph, lastMsgIndex=0):
|
|
# https://docs.langchain.com/oss/python/langgraph/interrupts#stream-with-human-in-the-loop-hitl-interrupts
|
|
for mode, state in graphe.stream(
|
|
initial_input,
|
|
stream_mode=["values", "updates"],
|
|
subgraphs=False,
|
|
config={"configurable": {"thread_id": 'yes'}}
|
|
):
|
|
if mode == "values":
|
|
# Handle streaming message content
|
|
i=0
|
|
for msg in state['messages'][lastMsgIndex:]: # Permet de gérer plusieurs nouveaux messages d'un coup
|
|
msg.pretty_print()
|
|
i+=1
|
|
lastMsgIndex+=i
|
|
|
|
elif mode == "updates":
|
|
# Check for interrupts
|
|
if "__interrupt__" in state:
|
|
payload = state["__interrupt__"][0].value
|
|
|
|
payload = InterruptPayload.fromJSON(payload) # Chargement de la requête depuis sa version JSON
|
|
payload.humanDisplay() # L'utilisateur peut accepter/modifier/refuser ici
|
|
streamGraph(Command(resume=payload.toJSON()), config, graphe, lastMsgIndex) # Je renvois la chaîne JSON, qui sera reconvertie en objet dans l'outil, et je relance le stream récursivement
|
|
return # Fin de cette fonction récursive
|
|
|
|
else:
|
|
# Track node transitions
|
|
current_node = list(state.keys())[0] |