102 lines
4.2 KiB
Python
102 lines
4.2 KiB
Python
from langchain_mistralai import ChatMistralAI
|
|
from langgraph.graph import MessagesState
|
|
from langgraph.prebuilt import ToolNode
|
|
from langchain.chat_models import init_chat_model
|
|
from langgraph.graph import START, END
|
|
from langchain.messages import HumanMessage, AIMessage, SystemMessage
|
|
|
|
from .tools import getTools, getWeeklyReportTools
|
|
from .state import CustomState
|
|
|
|
# LLM principal
|
|
llm = ChatMistralAI( # LLM sans outils
|
|
model="mistral-large-latest",
|
|
temperature=0,
|
|
max_retries=2,
|
|
)
|
|
|
|
# NODES
|
|
def preparation_docs(state: CustomState):
|
|
"""Noeud en charge de préparer les résumés pour chaque semaine des rapports, et la liste des outils et méthodes utilisées"""
|
|
model = llm.bind_tools(getWeeklyReportTools()) # LLM en charge de générer des rapports hebdomadaires sur le stage
|
|
print(len(state['messages']))
|
|
messages = [m for m in state['messages']] # Tous les messages du stage
|
|
|
|
if 'documentsGenerationStarted' not in state.keys():
|
|
# Si ce noeud en est à son premier lancement, je lui donne la consigne de départ
|
|
messages.append(HumanMessage("Ton but est de lire les fichiers présents dans la base de données en utilisant l'outil 'search_in_files',\
|
|
afin de générer des rapports sur chaque semaine du stage qui y est décrit. Pour enregistrer chaque semaine du stage, utilise l'outil 'write_week_report'.\
|
|
Une fois terminé, fais une liste de tous les outils, logiciels, méthodes, entreprises, techniques, ect.. utilisés,\
|
|
et fais en une liste avec quelques descriptions que tu devras enregistrer avec l'outil 'write_library_tools_details_on_internship'."))
|
|
|
|
messages.append(model.invoke(messages)) # Invocation LLM
|
|
return {'messages': messages, 'documentsGenerationStarted':True} # Je passe une liste de messages, ce qui ne devrait pas ajouter un message mais redéfinir toute la liste
|
|
|
|
def call_to_LLM(state: MessagesState):
|
|
"""Noeud qui s'occupe de gérer les appels au LLM"""
|
|
# Initialisation du LLM
|
|
model = llm.bind_tools(getTools())
|
|
|
|
# Appel du LLM
|
|
return {"messages": [model.invoke(state["messages"])]}
|
|
|
|
# fonction de routage : Après reponse_question, si le LLM veut appeler un outil, on va au tool_node
|
|
def should_continue(state: MessagesState):
|
|
"""
|
|
Vérifier s'il y a un appel aux outils dans le dernier message
|
|
"""
|
|
if isinstance(state, list):
|
|
ai_message = state[-1]
|
|
elif messages := state.get("messages", []):
|
|
ai_message = messages[-1]
|
|
else:
|
|
raise ValueError(f"No messages found in input state to tool_edge: {state}")
|
|
|
|
if hasattr(ai_message, "tool_calls") and len(ai_message.tool_calls) > 0:
|
|
return "tools"
|
|
return "no_tools"
|
|
|
|
def task_ended(state: MessagesState):
|
|
"""
|
|
Vérifier si l'agent a terminé son cycle, ou s'il faut le relancer
|
|
"""
|
|
if isinstance(state, list):
|
|
ai_message = state[-1]
|
|
elif messages := state.get("messages", []):
|
|
ai_message = messages[-1]
|
|
else:
|
|
raise ValueError(f"No messages found in input state to tool_edge: {state}")
|
|
|
|
if "terminé" in ai_message.content.lower():
|
|
return END
|
|
return "continue"
|
|
|
|
weekly_report_tools = ToolNode(tools=getWeeklyReportTools())
|
|
tool_node = ToolNode(tools=getTools())
|
|
|
|
|
|
class BasicToolNode: # De mon ancien projet, https://github.com/LJ5O/Assistant/blob/main/modules/Brain/src/LLM/graph/nodes/BasicToolNode.py
|
|
"""A node that runs the tools requested in the last AIMessage."""
|
|
|
|
def __init__(self, tools: list) -> None:
|
|
self.tools_by_name = {tool.name: tool for tool in tools}
|
|
|
|
def __call__(self, inputs: dict):
|
|
if messages := inputs.get("messages", []):
|
|
message = messages[-1]
|
|
else:
|
|
raise ValueError("No message found in input")
|
|
outputs = []
|
|
for tool_call in message.tool_calls:
|
|
#print(tool_call["args"])
|
|
tool_result = self.tools_by_name[tool_call["name"]].invoke(
|
|
tool_call["args"]
|
|
)
|
|
outputs.append(
|
|
ToolMessage(
|
|
content=json.dumps(tool_result),
|
|
name=tool_call["name"],
|
|
tool_call_id=tool_call["id"],
|
|
)
|
|
)
|
|
return {"messages": outputs} |