Agent_V2 #2

Merged
Kevin merged 12 commits from Agent_V2 into master 2026-02-09 19:25:11 +01:00
6 changed files with 187 additions and 7 deletions
Showing only changes of commit a9ff56c122 - Show all commits

View File

@@ -1,5 +1,6 @@
from langgraph.graph import START, END from langgraph.graph import START, END
from langgraph.graph.state import CompiledStateGraph from langgraph.graph.state import CompiledStateGraph
from langgraph.checkpoint.memory import InMemorySaver
from utils.nodes import call_to_LLM, should_continue, task_ended, BasicToolNode, tool_node from utils.nodes import call_to_LLM, should_continue, task_ended, BasicToolNode, tool_node
from utils.state import getState from utils.state import getState
@@ -26,7 +27,7 @@ def getGraph()->CompiledStateGraph:
"no_tools":END "no_tools":END
}) })
return workflow.compile() return workflow.compile(checkpointer=InMemorySaver()) # TODO: Rempalcer par une vrai BDD de prod
if __name__ == "__main__": if __name__ == "__main__":
# Affichage du graphe # Affichage du graphe

View File

@@ -5,11 +5,18 @@ from langchain.messages import HumanMessage, SystemMessage, AIMessage, ToolMessa
import mlflow import mlflow
from agent import getGraph from agent import getGraph
from utils.InterruptPayload import InterruptPayload
from utils.StreamGraph import streamGraph
# MLFLOW # MLFLOW
mlflow.set_experiment("TEST PROJET") # VOIR AVEC LA COMMANDE "MLFLOW SERVER" mlflow.set_experiment("TEST PROJET") # VOIR AVEC LA COMMANDE "MLFLOW SERVER"
mlflow.langchain.autolog() mlflow.langchain.autolog()
out_state = getGraph().invoke({'messages':[HumanMessage("Observe la base de documents, et génère un rapport de stage à partir de celle-ci. Ecris le dans un fichier markdown.")]}) initial_input = {
for message in out_state['messages']: 'messages':[HumanMessage("Recherche 'Recette de Monster' sur internet")]
message.pretty_print() }
config={"configurable": {"thread_id": 'yes'}}
# Et je lance !
streamGraph(initial_input, config, getGraph())

View File

@@ -14,6 +14,18 @@ class InterruptPayload():
self.__fields = fields self.__fields = fields
self.__state = state self.__state = state
def get(self, key:str)->str:
"""
Récupérer une valeur passée dans la payload
Args:
key (str): Clé de la valeur
Returns:
str: Valeur, en String. Il faudra la reconvertir en int si besoin
"""
return self.__fields[key] # TODO: cas où la clé n'y est pas
def __displayKeys(self, keys:List[str]): def __displayKeys(self, keys:List[str]):
for i,field in enumerate(keys): for i,field in enumerate(keys):
print(f"Champ {i}: {field} = \"{self.__fields[field]}\"\n") print(f"Champ {i}: {field} = \"{self.__fields[field]}\"\n")

View File

@@ -0,0 +1,33 @@
from typing import Dict
from langgraph.graph.state import CompiledStateGraph
from langgraph.types import Command
from .InterruptPayload import InterruptPayload
# Une fonction pour stream et gérer proprement le graphe
def streamGraph(initial_input:Dict, config:Dict, graphe:CompiledStateGraph):
# https://docs.langchain.com/oss/python/langgraph/interrupts#stream-with-human-in-the-loop-hitl-interrupts
for mode, state in graphe.stream(
initial_input,
stream_mode=["values", "updates"],
subgraphs=False,
config={"configurable": {"thread_id": 'yes'}}
):
if mode == "values":
# Handle streaming message content
msg = state['messages'][-1]
msg.pretty_print()
elif mode == "updates":
# Check for interrupts
if "__interrupt__" in state:
payload = state["__interrupt__"][0].value
payload = InterruptPayload.fromJSON(payload) # Chargement de la requête depuis sa version JSON
payload.humanDisplay() # L'utilisateur peut accepter/modifier/refuser ici
streamGraph(Command(resume=payload.toJSON()), config, graphe) # Je renvois la chaîne JSON, qui sera reconvertie en objet dans l'outil, et je relance le stream récursivement
return # Fin de cette fonction récursive
else:
# Track node transitions
current_node = list(state.keys())[0]

View File

@@ -4,9 +4,11 @@ from tavily import TavilyClient
from pathlib import Path from pathlib import Path
from typing import List, Dict, Annotated from typing import List, Dict, Annotated
import sys import sys
from .StateElements.TodoElement import TodoElement from langgraph.types import interrupt
from .StateElements.TodoElement import TodoElement
from .VectorDatabase import VectorDatabase from .VectorDatabase import VectorDatabase
from .InterruptPayload import InterruptPayload
@tool @tool
def internet_search(query: str)->dict: def internet_search(query: str)->dict:
@@ -18,7 +20,16 @@ def internet_search(query: str)->dict:
Returns: Returns:
dict: Retour de la recherche dict: Retour de la recherche
""" """
return TavilyClient().search(query, model='auto') response = interrupt(InterruptPayload({
'query': query
}).toJSON())
resp = InterruptPayload.fromJSON(response) # Je reforme mon objet depuis la string json
if resp.isAccepted():
return TavilyClient().search(resp.get("query"), model='auto')
else:
return {'error': "Utilisation de cet outil refusée par l'utilisateur"}
@tool @tool

View File

@@ -15,7 +15,7 @@
- [X] Développement des outils de l'agent - [X] Développement des outils de l'agent
- [X] Préparation des nœuds - [X] Préparation des nœuds
- [X] Branchement des nœuds entre-eux, **MVP** - [X] Branchement des nœuds entre-eux, **MVP**
- [ ] Human in the loop - [X] Human in the loop
- [ ] Amélioration du workflow - [ ] Amélioration du workflow
## Amélioration de l'agent ## Amélioration de l'agent