C'est moche, bancal, et mal foutu, mais ça compile et ça crache un rapport de stage dans un fichier
83 lines
2.7 KiB
Python
83 lines
2.7 KiB
Python
from langchain_mistralai import ChatMistralAI
|
|
from langgraph.graph import MessagesState
|
|
from langgraph.prebuilt import ToolNode
|
|
from langchain.chat_models import init_chat_model
|
|
from langgraph.graph import START, END
|
|
|
|
from .tools import getTools
|
|
|
|
# LLM principal
|
|
llm = ChatMistralAI( # LLM sans outils
|
|
model="mistral-large-latest",
|
|
temperature=0,
|
|
max_retries=2,
|
|
)
|
|
|
|
# NODES
|
|
def call_to_LLM(state: MessagesState):
|
|
"""Noeud qui s'occupe de gérer les appels au LLM"""
|
|
# Initialisation du LLM
|
|
model = llm.bind_tools(getTools())
|
|
|
|
# Appel du LLM
|
|
return {"messages": [model.invoke(state["messages"])]}
|
|
|
|
# fonction de routage : Après reponse_question, si le LLM veut appeler un outil, on va au tool_node
|
|
def should_continue(state: MessagesState):
|
|
"""
|
|
Vérifier s'il y a un appel aux outils dans le dernier message
|
|
"""
|
|
if isinstance(state, list):
|
|
ai_message = state[-1]
|
|
elif messages := state.get("messages", []):
|
|
ai_message = messages[-1]
|
|
else:
|
|
raise ValueError(f"No messages found in input state to tool_edge: {state}")
|
|
|
|
if hasattr(ai_message, "tool_calls") and len(ai_message.tool_calls) > 0:
|
|
return "tools"
|
|
return "no_tools"
|
|
|
|
def task_ended(state: MessagesState):
|
|
"""
|
|
Vérifier si l'agent a terminé son cycle, ou s'il faut le relancer
|
|
"""
|
|
if isinstance(state, list):
|
|
ai_message = state[-1]
|
|
elif messages := state.get("messages", []):
|
|
ai_message = messages[-1]
|
|
else:
|
|
raise ValueError(f"No messages found in input state to tool_edge: {state}")
|
|
|
|
if "terminé" in ai_message.content.lower():
|
|
return END
|
|
return "continue"
|
|
|
|
tool_node = ToolNode(tools=getTools())
|
|
|
|
|
|
class BasicToolNode: # De mon ancien projet, https://github.com/LJ5O/Assistant/blob/main/modules/Brain/src/LLM/graph/nodes/BasicToolNode.py
|
|
"""A node that runs the tools requested in the last AIMessage."""
|
|
|
|
def __init__(self, tools: list) -> None:
|
|
self.tools_by_name = {tool.name: tool for tool in tools}
|
|
|
|
def __call__(self, inputs: dict):
|
|
if messages := inputs.get("messages", []):
|
|
message = messages[-1]
|
|
else:
|
|
raise ValueError("No message found in input")
|
|
outputs = []
|
|
for tool_call in message.tool_calls:
|
|
#print(tool_call["args"])
|
|
tool_result = self.tools_by_name[tool_call["name"]].invoke(
|
|
tool_call["args"]
|
|
)
|
|
outputs.append(
|
|
ToolMessage(
|
|
content=json.dumps(tool_result),
|
|
name=tool_call["name"],
|
|
tool_call_id=tool_call["id"],
|
|
)
|
|
)
|
|
return {"messages": outputs} |