Files
Projet-Agent-IA/AgentReact/utils/nodes.py
2026-02-08 22:21:51 +01:00

125 lines
5.0 KiB
Python
Raw Blame History

This file contains invisible Unicode characters

This file contains invisible Unicode characters that are indistinguishable to humans but may be processed differently by a computer. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from langchain_mistralai import ChatMistralAI
from langgraph.graph import MessagesState
from langgraph.prebuilt import ToolNode
from langchain.chat_models import init_chat_model
from langgraph.graph import START, END
from langchain.messages import HumanMessage, AIMessage, SystemMessage
from langgraph.types import interrupt
from .tools import getTools, getWeeklyReportTools
from .state import CustomState
from .InterruptPayload import InterruptPayload
# LLM principal
llm = ChatMistralAI( # LLM sans outils
model="mistral-large-latest",
temperature=0,
max_retries=2,
)
# NODES
def inject_preparation_prompt(state: CustomState):
""" Noeud qui vise juste à insérer le message indiquant au LLM comment travailler sur les résumés de comptes-rendus """
return {'messages': HumanMessage(
"Ton but est de lire les fichiers présents dans la base de données en utilisant l'outil 'search_in_files',\
afin de générer des rapports sur chaque semaine du stage qui y est décrit. Il y a 25 semaines, tu dois toutes les résumer,\
avec des détails et des informations complémentaires.\
Pour enregistrer chaque semaine du stage, utilise l'outil 'write_week_report'. Tu DOIS les enregistrer avec cet outil.\
Une fois terminé, fais une liste de tous les outils, logiciels, méthodes, entreprises, techniques, ect.. utilisés,\
et fais en une liste avec quelques descriptions que tu devras enregistrer avec l'outil 'write_library_tools_details_on_internship'."
)}
def preparation_docs(state: CustomState):
"""Noeud en charge de préparer les résumés pour chaque semaine des rapports, et la liste des outils et méthodes utilisées"""
model = llm.bind_tools(getWeeklyReportTools()) # LLM en charge de générer des rapports hebdomadaires sur le stage
return {'messages': model.invoke(state['messages'])}
def user_prompt(state: CustomState):
""" Dans ce nœud, l'utilisateur peut écrire un HumanMessage pour l'IA """
messages = [msg for msg in state['messages']] # Je récupère la liste des messages
sys_message = SystemMessage("Salut") # TODO: Anti-injections
user_message = HumanMessage(
InterruptPayload.fromJSON(
interrupt(
InterruptPayload({'prompt':''}, payload_type=InterruptPayload.USER_PROMPT).toJSON()
)
).get("prompt")
) # Récupérer un prompt
messages.append(sys_message) # Rajout des nouveaux messages dans le système
messages.append(user_message)
return {'messages': messages}# Je passe unen liste, devrait écraser tous les messages précédent au lieu d'ajouter à la liste du State
def LLM_central(state: MessagesState):
"""Noeud qui s'occupe de gérer les appels au LLM"""
# Initialisation du LLM
model = llm.bind_tools(getTools())
# Appel du LLM
return {"messages": [model.invoke(state["messages"])]}
# fonction de routage : Après reponse_question, si le LLM veut appeler un outil, on va au tool_node
def should_continue(state: MessagesState):
"""
Vérifier s'il y a un appel aux outils dans le dernier message
"""
if isinstance(state, list):
ai_message = state[-1]
elif messages := state.get("messages", []):
ai_message = messages[-1]
else:
raise ValueError(f"No messages found in input state to tool_edge: {state}")
if hasattr(ai_message, "tool_calls") and len(ai_message.tool_calls) > 0:
return "tools"
return "no_tools"
def task_ended(state: MessagesState):
"""
Vérifier si l'agent a terminé son cycle, ou s'il faut le relancer
"""
if isinstance(state, list):
ai_message = state[-1]
elif messages := state.get("messages", []):
ai_message = messages[-1]
else:
raise ValueError(f"No messages found in input state to tool_edge: {state}")
if "terminé" in ai_message.content.lower():
return END
return "continue"
weekly_report_tools = ToolNode(tools=getWeeklyReportTools())
tool_node = ToolNode(tools=getTools())
class BasicToolNode: # De mon ancien projet, https://github.com/LJ5O/Assistant/blob/main/modules/Brain/src/LLM/graph/nodes/BasicToolNode.py
"""A node that runs the tools requested in the last AIMessage."""
def __init__(self, tools: list) -> None:
self.tools_by_name = {tool.name: tool for tool in tools}
def __call__(self, inputs: dict):
if messages := inputs.get("messages", []):
message = messages[-1]
else:
raise ValueError("No message found in input")
outputs = []
for tool_call in message.tool_calls:
#print(tool_call["args"])
tool_result = self.tools_by_name[tool_call["name"]].invoke(
tool_call["args"]
)
outputs.append(
ToolMessage(
content=json.dumps(tool_result),
name=tool_call["name"],
tool_call_id=tool_call["id"],
)
)
return {"messages": outputs}