Agent_V2 #2
1
.gitignore
vendored
1
.gitignore
vendored
@@ -5,6 +5,7 @@ mlflow.db
|
|||||||
# Par sécurité
|
# Par sécurité
|
||||||
documents_projet/
|
documents_projet/
|
||||||
chroma_db/
|
chroma_db/
|
||||||
|
AgentReact/rapports_resumes/
|
||||||
|
|
||||||
# Python
|
# Python
|
||||||
__pycache__/
|
__pycache__/
|
||||||
@@ -2,7 +2,7 @@ from langgraph.graph import START, END
|
|||||||
from langgraph.graph.state import CompiledStateGraph
|
from langgraph.graph.state import CompiledStateGraph
|
||||||
from langgraph.checkpoint.memory import InMemorySaver
|
from langgraph.checkpoint.memory import InMemorySaver
|
||||||
|
|
||||||
from utils.nodes import call_to_LLM, should_continue, task_ended, BasicToolNode, tool_node
|
from utils.nodes import call_to_LLM, should_continue, task_ended, BasicToolNode, tool_node, preparation_docs, weekly_report_tools
|
||||||
from utils.state import getState
|
from utils.state import getState
|
||||||
from utils.tools import getTools
|
from utils.tools import getTools
|
||||||
|
|
||||||
@@ -17,10 +17,19 @@ def getGraph()->CompiledStateGraph:
|
|||||||
|
|
||||||
# Définition des sommets du graphe
|
# Définition des sommets du graphe
|
||||||
workflow.add_node(call_to_LLM)
|
workflow.add_node(call_to_LLM)
|
||||||
|
workflow.add_node(preparation_docs)
|
||||||
workflow.add_node("tool_node", tool_node)# BasicToolNode(tools=getTools())) # N'est pas une fonction, mais une classe instanciée, je dois précisier le nom du node
|
workflow.add_node("tool_node", tool_node)# BasicToolNode(tools=getTools())) # N'est pas une fonction, mais une classe instanciée, je dois précisier le nom du node
|
||||||
|
workflow.add_node("weekly_report_tools", weekly_report_tools)
|
||||||
|
|
||||||
# Arrêtes
|
# Arrêtes
|
||||||
workflow.set_entry_point("call_to_LLM")
|
workflow.set_entry_point("preparation_docs")
|
||||||
|
workflow.add_conditional_edges("preparation_docs", should_continue, {
|
||||||
|
"tools":"weekly_report_tools",
|
||||||
|
"no_tools":"call_to_LLM"
|
||||||
|
})
|
||||||
|
|
||||||
|
#workflow.set_entry_point("call_to_LLM")
|
||||||
|
workflow.add_edge("weekly_report_tools", "preparation_docs")
|
||||||
workflow.add_edge("tool_node", "call_to_LLM")
|
workflow.add_edge("tool_node", "call_to_LLM")
|
||||||
workflow.add_conditional_edges("call_to_LLM", should_continue, {
|
workflow.add_conditional_edges("call_to_LLM", should_continue, {
|
||||||
"tools":"tool_node",
|
"tools":"tool_node",
|
||||||
|
|||||||
@@ -13,7 +13,7 @@ mlflow.set_experiment("TEST PROJET") # VOIR AVEC LA COMMANDE "MLFLOW SERVER"
|
|||||||
mlflow.langchain.autolog()
|
mlflow.langchain.autolog()
|
||||||
|
|
||||||
initial_input = {
|
initial_input = {
|
||||||
'messages':[HumanMessage("Recherche 'Recette de Monster' sur internet")]
|
'messages':[SystemMessage("Salut")]
|
||||||
}
|
}
|
||||||
|
|
||||||
config={"configurable": {"thread_id": 'yes'}}
|
config={"configurable": {"thread_id": 'yes'}}
|
||||||
|
|||||||
@@ -5,7 +5,7 @@ from langgraph.types import Command
|
|||||||
from .InterruptPayload import InterruptPayload
|
from .InterruptPayload import InterruptPayload
|
||||||
|
|
||||||
# Une fonction pour stream et gérer proprement le graphe
|
# Une fonction pour stream et gérer proprement le graphe
|
||||||
def streamGraph(initial_input:Dict, config:Dict, graphe:CompiledStateGraph):
|
def streamGraph(initial_input:Dict, config:Dict, graphe:CompiledStateGraph, lastMsgIndex=0):
|
||||||
# https://docs.langchain.com/oss/python/langgraph/interrupts#stream-with-human-in-the-loop-hitl-interrupts
|
# https://docs.langchain.com/oss/python/langgraph/interrupts#stream-with-human-in-the-loop-hitl-interrupts
|
||||||
for mode, state in graphe.stream(
|
for mode, state in graphe.stream(
|
||||||
initial_input,
|
initial_input,
|
||||||
@@ -15,8 +15,11 @@ def streamGraph(initial_input:Dict, config:Dict, graphe:CompiledStateGraph):
|
|||||||
):
|
):
|
||||||
if mode == "values":
|
if mode == "values":
|
||||||
# Handle streaming message content
|
# Handle streaming message content
|
||||||
msg = state['messages'][-1]
|
i=0
|
||||||
msg.pretty_print()
|
for msg in state['messages'][lastMsgIndex:]: # Permet de gérer plusieurs nouveaux messages d'un coup
|
||||||
|
msg.pretty_print()
|
||||||
|
i+=1
|
||||||
|
lastMsgIndex+=i
|
||||||
|
|
||||||
elif mode == "updates":
|
elif mode == "updates":
|
||||||
# Check for interrupts
|
# Check for interrupts
|
||||||
@@ -25,7 +28,7 @@ def streamGraph(initial_input:Dict, config:Dict, graphe:CompiledStateGraph):
|
|||||||
|
|
||||||
payload = InterruptPayload.fromJSON(payload) # Chargement de la requête depuis sa version JSON
|
payload = InterruptPayload.fromJSON(payload) # Chargement de la requête depuis sa version JSON
|
||||||
payload.humanDisplay() # L'utilisateur peut accepter/modifier/refuser ici
|
payload.humanDisplay() # L'utilisateur peut accepter/modifier/refuser ici
|
||||||
streamGraph(Command(resume=payload.toJSON()), config, graphe) # Je renvois la chaîne JSON, qui sera reconvertie en objet dans l'outil, et je relance le stream récursivement
|
streamGraph(Command(resume=payload.toJSON()), config, graphe, lastMsgIndex) # Je renvois la chaîne JSON, qui sera reconvertie en objet dans l'outil, et je relance le stream récursivement
|
||||||
return # Fin de cette fonction récursive
|
return # Fin de cette fonction récursive
|
||||||
|
|
||||||
else:
|
else:
|
||||||
|
|||||||
@@ -3,8 +3,10 @@ from langgraph.graph import MessagesState
|
|||||||
from langgraph.prebuilt import ToolNode
|
from langgraph.prebuilt import ToolNode
|
||||||
from langchain.chat_models import init_chat_model
|
from langchain.chat_models import init_chat_model
|
||||||
from langgraph.graph import START, END
|
from langgraph.graph import START, END
|
||||||
|
from langchain.messages import HumanMessage, AIMessage, SystemMessage
|
||||||
|
|
||||||
from .tools import getTools
|
from .tools import getTools, getWeeklyReportTools
|
||||||
|
from .state import CustomState
|
||||||
|
|
||||||
# LLM principal
|
# LLM principal
|
||||||
llm = ChatMistralAI( # LLM sans outils
|
llm = ChatMistralAI( # LLM sans outils
|
||||||
@@ -14,6 +16,22 @@ llm = ChatMistralAI( # LLM sans outils
|
|||||||
)
|
)
|
||||||
|
|
||||||
# NODES
|
# NODES
|
||||||
|
def preparation_docs(state: CustomState):
|
||||||
|
"""Noeud en charge de préparer les résumés pour chaque semaine des rapports, et la liste des outils et méthodes utilisées"""
|
||||||
|
model = llm.bind_tools(getWeeklyReportTools()) # LLM en charge de générer des rapports hebdomadaires sur le stage
|
||||||
|
print(len(state['messages']))
|
||||||
|
messages = [m for m in state['messages']] # Tous les messages du stage
|
||||||
|
|
||||||
|
if 'documentsGenerationStarted' not in state.keys():
|
||||||
|
# Si ce noeud en est à son premier lancement, je lui donne la consigne de départ
|
||||||
|
messages.append(HumanMessage("Ton but est de lire les fichiers présents dans la base de données en utilisant l'outil 'search_in_files',\
|
||||||
|
afin de générer des rapports sur chaque semaine du stage qui y est décrit. Pour enregistrer chaque semaine du stage, utilise l'outil 'write_week_report'.\
|
||||||
|
Une fois terminé, fais une liste de tous les outils, logiciels, méthodes, entreprises, techniques, ect.. utilisés,\
|
||||||
|
et fais en une liste avec quelques descriptions que tu devras enregistrer avec l'outil 'write_library_tools_details_on_internship'."))
|
||||||
|
|
||||||
|
messages.append(model.invoke(messages)) # Invocation LLM
|
||||||
|
return {'messages': messages, 'documentsGenerationStarted':True} # Je passe une liste de messages, ce qui ne devrait pas ajouter un message mais redéfinir toute la liste
|
||||||
|
|
||||||
def call_to_LLM(state: MessagesState):
|
def call_to_LLM(state: MessagesState):
|
||||||
"""Noeud qui s'occupe de gérer les appels au LLM"""
|
"""Noeud qui s'occupe de gérer les appels au LLM"""
|
||||||
# Initialisation du LLM
|
# Initialisation du LLM
|
||||||
@@ -53,6 +71,7 @@ def task_ended(state: MessagesState):
|
|||||||
return END
|
return END
|
||||||
return "continue"
|
return "continue"
|
||||||
|
|
||||||
|
weekly_report_tools = ToolNode(tools=getWeeklyReportTools())
|
||||||
tool_node = ToolNode(tools=getTools())
|
tool_node = ToolNode(tools=getTools())
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -10,6 +10,8 @@ class CustomState(MessagesState):
|
|||||||
ragQuery: str # Requête envoyée au RAG, pour le cross-encodeur
|
ragQuery: str # Requête envoyée au RAG, pour le cross-encodeur
|
||||||
ragDocuments: List[str] # Documents retrouvés par le RAG, pour le cross-encodeur
|
ragDocuments: List[str] # Documents retrouvés par le RAG, pour le cross-encodeur
|
||||||
|
|
||||||
|
documentsGenerationStarted:bool# Permet d'indiquer que la consigne de génération des documents a été envoyée
|
||||||
|
|
||||||
# TODO: Ajouter la source des documents sélectionnés pour la fin du rapport ?
|
# TODO: Ajouter la source des documents sélectionnés pour la fin du rapport ?
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -263,4 +263,4 @@ def getWeeklyReportTools()->List['Tools']:
|
|||||||
"""
|
"""
|
||||||
Récupérer la liste des tools, POUR LE LLM EN CHARGE DE FAIRE LES RAPPORTS DE CHAQUE SEMAINE
|
Récupérer la liste des tools, POUR LE LLM EN CHARGE DE FAIRE LES RAPPORTS DE CHAQUE SEMAINE
|
||||||
"""
|
"""
|
||||||
return [write_week_report, write_library_tools_details_on_internship, internet_search]
|
return [write_week_report, write_library_tools_details_on_internship, internet_search, search_in_files]
|
||||||
Reference in New Issue
Block a user