Agent_V2 #2
1
.gitignore
vendored
1
.gitignore
vendored
@@ -5,6 +5,7 @@ mlflow.db
|
||||
# Par sécurité
|
||||
documents_projet/
|
||||
chroma_db/
|
||||
AgentReact/rapports_resumes/
|
||||
|
||||
# Python
|
||||
__pycache__/
|
||||
@@ -2,7 +2,7 @@ from langgraph.graph import START, END
|
||||
from langgraph.graph.state import CompiledStateGraph
|
||||
from langgraph.checkpoint.memory import InMemorySaver
|
||||
|
||||
from utils.nodes import call_to_LLM, should_continue, task_ended, BasicToolNode, tool_node
|
||||
from utils.nodes import call_to_LLM, should_continue, task_ended, BasicToolNode, tool_node, preparation_docs, weekly_report_tools
|
||||
from utils.state import getState
|
||||
from utils.tools import getTools
|
||||
|
||||
@@ -17,10 +17,19 @@ def getGraph()->CompiledStateGraph:
|
||||
|
||||
# Définition des sommets du graphe
|
||||
workflow.add_node(call_to_LLM)
|
||||
workflow.add_node(preparation_docs)
|
||||
workflow.add_node("tool_node", tool_node)# BasicToolNode(tools=getTools())) # N'est pas une fonction, mais une classe instanciée, je dois précisier le nom du node
|
||||
workflow.add_node("weekly_report_tools", weekly_report_tools)
|
||||
|
||||
# Arrêtes
|
||||
workflow.set_entry_point("call_to_LLM")
|
||||
workflow.set_entry_point("preparation_docs")
|
||||
workflow.add_conditional_edges("preparation_docs", should_continue, {
|
||||
"tools":"weekly_report_tools",
|
||||
"no_tools":"call_to_LLM"
|
||||
})
|
||||
|
||||
#workflow.set_entry_point("call_to_LLM")
|
||||
workflow.add_edge("weekly_report_tools", "preparation_docs")
|
||||
workflow.add_edge("tool_node", "call_to_LLM")
|
||||
workflow.add_conditional_edges("call_to_LLM", should_continue, {
|
||||
"tools":"tool_node",
|
||||
|
||||
@@ -13,7 +13,7 @@ mlflow.set_experiment("TEST PROJET") # VOIR AVEC LA COMMANDE "MLFLOW SERVER"
|
||||
mlflow.langchain.autolog()
|
||||
|
||||
initial_input = {
|
||||
'messages':[HumanMessage("Recherche 'Recette de Monster' sur internet")]
|
||||
'messages':[SystemMessage("Salut")]
|
||||
}
|
||||
|
||||
config={"configurable": {"thread_id": 'yes'}}
|
||||
|
||||
@@ -5,7 +5,7 @@ from langgraph.types import Command
|
||||
from .InterruptPayload import InterruptPayload
|
||||
|
||||
# Une fonction pour stream et gérer proprement le graphe
|
||||
def streamGraph(initial_input:Dict, config:Dict, graphe:CompiledStateGraph):
|
||||
def streamGraph(initial_input:Dict, config:Dict, graphe:CompiledStateGraph, lastMsgIndex=0):
|
||||
# https://docs.langchain.com/oss/python/langgraph/interrupts#stream-with-human-in-the-loop-hitl-interrupts
|
||||
for mode, state in graphe.stream(
|
||||
initial_input,
|
||||
@@ -15,8 +15,11 @@ def streamGraph(initial_input:Dict, config:Dict, graphe:CompiledStateGraph):
|
||||
):
|
||||
if mode == "values":
|
||||
# Handle streaming message content
|
||||
msg = state['messages'][-1]
|
||||
msg.pretty_print()
|
||||
i=0
|
||||
for msg in state['messages'][lastMsgIndex:]: # Permet de gérer plusieurs nouveaux messages d'un coup
|
||||
msg.pretty_print()
|
||||
i+=1
|
||||
lastMsgIndex+=i
|
||||
|
||||
elif mode == "updates":
|
||||
# Check for interrupts
|
||||
@@ -25,7 +28,7 @@ def streamGraph(initial_input:Dict, config:Dict, graphe:CompiledStateGraph):
|
||||
|
||||
payload = InterruptPayload.fromJSON(payload) # Chargement de la requête depuis sa version JSON
|
||||
payload.humanDisplay() # L'utilisateur peut accepter/modifier/refuser ici
|
||||
streamGraph(Command(resume=payload.toJSON()), config, graphe) # Je renvois la chaîne JSON, qui sera reconvertie en objet dans l'outil, et je relance le stream récursivement
|
||||
streamGraph(Command(resume=payload.toJSON()), config, graphe, lastMsgIndex) # Je renvois la chaîne JSON, qui sera reconvertie en objet dans l'outil, et je relance le stream récursivement
|
||||
return # Fin de cette fonction récursive
|
||||
|
||||
else:
|
||||
|
||||
@@ -3,8 +3,10 @@ from langgraph.graph import MessagesState
|
||||
from langgraph.prebuilt import ToolNode
|
||||
from langchain.chat_models import init_chat_model
|
||||
from langgraph.graph import START, END
|
||||
from langchain.messages import HumanMessage, AIMessage, SystemMessage
|
||||
|
||||
from .tools import getTools
|
||||
from .tools import getTools, getWeeklyReportTools
|
||||
from .state import CustomState
|
||||
|
||||
# LLM principal
|
||||
llm = ChatMistralAI( # LLM sans outils
|
||||
@@ -14,6 +16,22 @@ llm = ChatMistralAI( # LLM sans outils
|
||||
)
|
||||
|
||||
# NODES
|
||||
def preparation_docs(state: CustomState):
|
||||
"""Noeud en charge de préparer les résumés pour chaque semaine des rapports, et la liste des outils et méthodes utilisées"""
|
||||
model = llm.bind_tools(getWeeklyReportTools()) # LLM en charge de générer des rapports hebdomadaires sur le stage
|
||||
print(len(state['messages']))
|
||||
messages = [m for m in state['messages']] # Tous les messages du stage
|
||||
|
||||
if 'documentsGenerationStarted' not in state.keys():
|
||||
# Si ce noeud en est à son premier lancement, je lui donne la consigne de départ
|
||||
messages.append(HumanMessage("Ton but est de lire les fichiers présents dans la base de données en utilisant l'outil 'search_in_files',\
|
||||
afin de générer des rapports sur chaque semaine du stage qui y est décrit. Pour enregistrer chaque semaine du stage, utilise l'outil 'write_week_report'.\
|
||||
Une fois terminé, fais une liste de tous les outils, logiciels, méthodes, entreprises, techniques, ect.. utilisés,\
|
||||
et fais en une liste avec quelques descriptions que tu devras enregistrer avec l'outil 'write_library_tools_details_on_internship'."))
|
||||
|
||||
messages.append(model.invoke(messages)) # Invocation LLM
|
||||
return {'messages': messages, 'documentsGenerationStarted':True} # Je passe une liste de messages, ce qui ne devrait pas ajouter un message mais redéfinir toute la liste
|
||||
|
||||
def call_to_LLM(state: MessagesState):
|
||||
"""Noeud qui s'occupe de gérer les appels au LLM"""
|
||||
# Initialisation du LLM
|
||||
@@ -53,6 +71,7 @@ def task_ended(state: MessagesState):
|
||||
return END
|
||||
return "continue"
|
||||
|
||||
weekly_report_tools = ToolNode(tools=getWeeklyReportTools())
|
||||
tool_node = ToolNode(tools=getTools())
|
||||
|
||||
|
||||
|
||||
@@ -10,6 +10,8 @@ class CustomState(MessagesState):
|
||||
ragQuery: str # Requête envoyée au RAG, pour le cross-encodeur
|
||||
ragDocuments: List[str] # Documents retrouvés par le RAG, pour le cross-encodeur
|
||||
|
||||
documentsGenerationStarted:bool# Permet d'indiquer que la consigne de génération des documents a été envoyée
|
||||
|
||||
# TODO: Ajouter la source des documents sélectionnés pour la fin du rapport ?
|
||||
|
||||
|
||||
|
||||
@@ -263,4 +263,4 @@ def getWeeklyReportTools()->List['Tools']:
|
||||
"""
|
||||
Récupérer la liste des tools, POUR LE LLM EN CHARGE DE FAIRE LES RAPPORTS DE CHAQUE SEMAINE
|
||||
"""
|
||||
return [write_week_report, write_library_tools_details_on_internship, internet_search]
|
||||
return [write_week_report, write_library_tools_details_on_internship, internet_search, search_in_files]
|
||||
Reference in New Issue
Block a user