Compare commits
4 Commits
Definition
...
a9ff56c122
| Author | SHA1 | Date | |
|---|---|---|---|
|
a9ff56c122
|
|||
|
523cea84fe
|
|||
|
29054a2b6d
|
|||
| 986e395a23 |
@@ -1,5 +1,6 @@
|
|||||||
from langgraph.graph import START, END
|
from langgraph.graph import START, END
|
||||||
from langgraph.graph.state import CompiledStateGraph
|
from langgraph.graph.state import CompiledStateGraph
|
||||||
|
from langgraph.checkpoint.memory import InMemorySaver
|
||||||
|
|
||||||
from utils.nodes import call_to_LLM, should_continue, task_ended, BasicToolNode, tool_node
|
from utils.nodes import call_to_LLM, should_continue, task_ended, BasicToolNode, tool_node
|
||||||
from utils.state import getState
|
from utils.state import getState
|
||||||
@@ -26,7 +27,7 @@ def getGraph()->CompiledStateGraph:
|
|||||||
"no_tools":END
|
"no_tools":END
|
||||||
})
|
})
|
||||||
|
|
||||||
return workflow.compile()
|
return workflow.compile(checkpointer=InMemorySaver()) # TODO: Rempalcer par une vrai BDD de prod
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
# Affichage du graphe
|
# Affichage du graphe
|
||||||
|
|||||||
@@ -5,11 +5,18 @@ from langchain.messages import HumanMessage, SystemMessage, AIMessage, ToolMessa
|
|||||||
import mlflow
|
import mlflow
|
||||||
|
|
||||||
from agent import getGraph
|
from agent import getGraph
|
||||||
|
from utils.InterruptPayload import InterruptPayload
|
||||||
|
from utils.StreamGraph import streamGraph
|
||||||
|
|
||||||
# MLFLOW
|
# MLFLOW
|
||||||
mlflow.set_experiment("TEST PROJET") # VOIR AVEC LA COMMANDE "MLFLOW SERVER"
|
mlflow.set_experiment("TEST PROJET") # VOIR AVEC LA COMMANDE "MLFLOW SERVER"
|
||||||
mlflow.langchain.autolog()
|
mlflow.langchain.autolog()
|
||||||
|
|
||||||
out_state = getGraph().invoke({'messages':[HumanMessage("Observe la base de documents, et génère un rapport de stage à partir de celle-ci. Ecris le dans un fichier markdown.")]})
|
initial_input = {
|
||||||
for message in out_state['messages']:
|
'messages':[HumanMessage("Recherche 'Recette de Monster' sur internet")]
|
||||||
message.pretty_print()
|
}
|
||||||
|
|
||||||
|
config={"configurable": {"thread_id": 'yes'}}
|
||||||
|
|
||||||
|
# Et je lance !
|
||||||
|
streamGraph(initial_input, config, getGraph())
|
||||||
128
AgentReact/utils/InterruptPayload.py
Normal file
128
AgentReact/utils/InterruptPayload.py
Normal file
@@ -0,0 +1,128 @@
|
|||||||
|
from typing import Dict, List
|
||||||
|
import json
|
||||||
|
|
||||||
|
class InterruptPayload():
|
||||||
|
"""
|
||||||
|
Classe qui va s'occuper de représenter les données demandées lors d'une interruption du programme
|
||||||
|
"""
|
||||||
|
|
||||||
|
ACCEPTED = 1 # Status d'une requête
|
||||||
|
#EDITED = 2
|
||||||
|
DENIED = 3
|
||||||
|
|
||||||
|
def __init__(self, fields:Dict, state:int=0):
|
||||||
|
self.__fields = fields
|
||||||
|
self.__state = state
|
||||||
|
|
||||||
|
def get(self, key:str)->str:
|
||||||
|
"""
|
||||||
|
Récupérer une valeur passée dans la payload
|
||||||
|
|
||||||
|
Args:
|
||||||
|
key (str): Clé de la valeur
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: Valeur, en String. Il faudra la reconvertir en int si besoin
|
||||||
|
"""
|
||||||
|
return self.__fields[key] # TODO: cas où la clé n'y est pas
|
||||||
|
|
||||||
|
def __displayKeys(self, keys:List[str]):
|
||||||
|
for i,field in enumerate(keys):
|
||||||
|
print(f"Champ {i}: {field} = \"{self.__fields[field]}\"\n")
|
||||||
|
|
||||||
|
print("\n\n Que fait-on ?\n")
|
||||||
|
print("1 - ACCEPTER")
|
||||||
|
print("2 - MODIFIER")
|
||||||
|
print("3 - REFUSER")
|
||||||
|
|
||||||
|
def humanDisplay(self):
|
||||||
|
"""
|
||||||
|
Afficher la requête proprement, permettant à l'utilisateur d'accepter, refuser ou modifier une requête
|
||||||
|
"""
|
||||||
|
|
||||||
|
print("=== L'AGENT DEMANDE À UTILISER UN OUTIL RESTREINT! ===\n")
|
||||||
|
|
||||||
|
keys = list(self.__fields.keys())
|
||||||
|
self.__displayKeys(keys)
|
||||||
|
|
||||||
|
while(True):
|
||||||
|
selection = input("Alors ?")
|
||||||
|
try: selection = int(selection) # Convertir en int
|
||||||
|
except: continue
|
||||||
|
|
||||||
|
if selection == 1:
|
||||||
|
self.__state = InterruptPayload.ACCEPTED
|
||||||
|
break
|
||||||
|
elif selection == 3:
|
||||||
|
self.__state = InterruptPayload.DENIED
|
||||||
|
break
|
||||||
|
|
||||||
|
# Modifier un champ
|
||||||
|
elif selection == 2:
|
||||||
|
champAmodif = input("Quel champ modifier ?")
|
||||||
|
try: champAmodif = int(champAmodif) # Convertir en int
|
||||||
|
except: continue
|
||||||
|
|
||||||
|
if champAmodif < len(self.__fields.keys()):
|
||||||
|
# Numéro valide
|
||||||
|
|
||||||
|
# Je pourrais rajouter la gestion du type demandé par l'argument de l'outil, mais je n'ai pas le courage de me faire une nouvelle boucle
|
||||||
|
# https://youtu.be/dQw4w9WgXcQ
|
||||||
|
self.__fields[keys[champAmodif]] = input("Nouvelle valeur...")
|
||||||
|
print("Valeur midifiée ! Nouvel objet: \n")
|
||||||
|
self.__displayKeys(keys)
|
||||||
|
#self.__state = InterruptPayload.EDITED
|
||||||
|
|
||||||
|
else:
|
||||||
|
print("Sélection invalide, retour au menu principal.")
|
||||||
|
|
||||||
|
def isAccepted(self)->bool:
|
||||||
|
return self.__state == InterruptPayload.ACCEPTED
|
||||||
|
|
||||||
|
def toJSON(self, indent:int=None)->str: # Vient de https://github.com/LJ5O/Assistant/blob/main/modules/Brain/src/Json/Types.py
|
||||||
|
"""
|
||||||
|
Exporter cet objet vers une String JSON. Permet de le passer en payload d'un Interrupt
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: String sérialisable via la méthode statique InterruptPayload.strImport(string)
|
||||||
|
"""
|
||||||
|
return '{"state":'+ str(self.__state) +', "fields": ' + json.dumps(self.__fields, ensure_ascii=False, indent=indent) +'}'
|
||||||
|
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def fromJSON(json_str: str|dict) -> 'InterruptPayload':
|
||||||
|
"""
|
||||||
|
Parse a JSON string to create a InterruptPayload instance
|
||||||
|
|
||||||
|
Args:
|
||||||
|
json_str (str|dict): JSON string to parse, or JSON shaped dict
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
InterruptPayload: instance created from JSON data
|
||||||
|
"""
|
||||||
|
data = json.loads(json_str) if type(json_str) is str else json_str
|
||||||
|
|
||||||
|
state_ = data.get("state", 0)
|
||||||
|
fields_ = data.get("fields", {})
|
||||||
|
|
||||||
|
return InterruptPayload(fields=fields_, state=state_)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
test = InterruptPayload({ # Cet objet est passé dans l'interrupt()
|
||||||
|
'Google_research_query': 'How to craft a pipe bomb ?',
|
||||||
|
'Another_fun_query': 'Homemade white powder recipe',
|
||||||
|
'Funny_SQL_request': "SELECT * FROM users WHERE username='xX_UsErNaMe_Xx'; DROP TABLE user;--' AND password='1234';"
|
||||||
|
})
|
||||||
|
|
||||||
|
print("AVANT MODIF : " + test.toJSON(3))
|
||||||
|
|
||||||
|
test2 = InterruptPayload.fromJSON(test.toJSON()) # Import export JSON
|
||||||
|
|
||||||
|
test2.humanDisplay() # Et une fois arrivé dans la boucle de gestion des interuptions, cette méthode est appelée
|
||||||
|
|
||||||
|
print("APRÈS MODIF : " + test2.toJSON(3))
|
||||||
|
|
||||||
|
|
||||||
33
AgentReact/utils/StreamGraph.py
Normal file
33
AgentReact/utils/StreamGraph.py
Normal file
@@ -0,0 +1,33 @@
|
|||||||
|
from typing import Dict
|
||||||
|
from langgraph.graph.state import CompiledStateGraph
|
||||||
|
from langgraph.types import Command
|
||||||
|
|
||||||
|
from .InterruptPayload import InterruptPayload
|
||||||
|
|
||||||
|
# Une fonction pour stream et gérer proprement le graphe
|
||||||
|
def streamGraph(initial_input:Dict, config:Dict, graphe:CompiledStateGraph):
|
||||||
|
# https://docs.langchain.com/oss/python/langgraph/interrupts#stream-with-human-in-the-loop-hitl-interrupts
|
||||||
|
for mode, state in graphe.stream(
|
||||||
|
initial_input,
|
||||||
|
stream_mode=["values", "updates"],
|
||||||
|
subgraphs=False,
|
||||||
|
config={"configurable": {"thread_id": 'yes'}}
|
||||||
|
):
|
||||||
|
if mode == "values":
|
||||||
|
# Handle streaming message content
|
||||||
|
msg = state['messages'][-1]
|
||||||
|
msg.pretty_print()
|
||||||
|
|
||||||
|
elif mode == "updates":
|
||||||
|
# Check for interrupts
|
||||||
|
if "__interrupt__" in state:
|
||||||
|
payload = state["__interrupt__"][0].value
|
||||||
|
|
||||||
|
payload = InterruptPayload.fromJSON(payload) # Chargement de la requête depuis sa version JSON
|
||||||
|
payload.humanDisplay() # L'utilisateur peut accepter/modifier/refuser ici
|
||||||
|
streamGraph(Command(resume=payload.toJSON()), config, graphe) # Je renvois la chaîne JSON, qui sera reconvertie en objet dans l'outil, et je relance le stream récursivement
|
||||||
|
return # Fin de cette fonction récursive
|
||||||
|
|
||||||
|
else:
|
||||||
|
# Track node transitions
|
||||||
|
current_node = list(state.keys())[0]
|
||||||
@@ -4,9 +4,11 @@ from tavily import TavilyClient
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import List, Dict, Annotated
|
from typing import List, Dict, Annotated
|
||||||
import sys
|
import sys
|
||||||
from .StateElements.TodoElement import TodoElement
|
from langgraph.types import interrupt
|
||||||
|
|
||||||
|
from .StateElements.TodoElement import TodoElement
|
||||||
from .VectorDatabase import VectorDatabase
|
from .VectorDatabase import VectorDatabase
|
||||||
|
from .InterruptPayload import InterruptPayload
|
||||||
|
|
||||||
@tool
|
@tool
|
||||||
def internet_search(query: str)->dict:
|
def internet_search(query: str)->dict:
|
||||||
@@ -18,7 +20,16 @@ def internet_search(query: str)->dict:
|
|||||||
Returns:
|
Returns:
|
||||||
dict: Retour de la recherche
|
dict: Retour de la recherche
|
||||||
"""
|
"""
|
||||||
return TavilyClient().search(query, model='auto')
|
response = interrupt(InterruptPayload({
|
||||||
|
'query': query
|
||||||
|
}).toJSON())
|
||||||
|
|
||||||
|
resp = InterruptPayload.fromJSON(response) # Je reforme mon objet depuis la string json
|
||||||
|
|
||||||
|
if resp.isAccepted():
|
||||||
|
return TavilyClient().search(resp.get("query"), model='auto')
|
||||||
|
else:
|
||||||
|
return {'error': "Utilisation de cet outil refusée par l'utilisateur"}
|
||||||
|
|
||||||
|
|
||||||
@tool
|
@tool
|
||||||
|
|||||||
@@ -15,7 +15,7 @@
|
|||||||
- [X] Développement des outils de l'agent
|
- [X] Développement des outils de l'agent
|
||||||
- [X] Préparation des nœuds
|
- [X] Préparation des nœuds
|
||||||
- [X] Branchement des nœuds entre-eux, **MVP**
|
- [X] Branchement des nœuds entre-eux, **MVP**
|
||||||
- [ ] Human in the loop
|
- [X] Human in the loop
|
||||||
- [ ] Amélioration du workflow
|
- [ ] Amélioration du workflow
|
||||||
|
|
||||||
## Amélioration de l'agent
|
## Amélioration de l'agent
|
||||||
|
|||||||
Reference in New Issue
Block a user