42 lines
1.4 KiB
Python
42 lines
1.4 KiB
Python
from langchain_mistralai import ChatMistralAI
|
|
from langgraph.graph import MessagesState
|
|
from langgraph.prebuilt import ToolNode
|
|
from langchain.chat_models import init_chat_model
|
|
from langgraph.graph import START, END
|
|
|
|
from .tools import getTools
|
|
|
|
# LLM principal
|
|
llm = ChatMistralAI( # LLM sans outils
|
|
model="mistral-large-latest",
|
|
temperature=0,
|
|
max_retries=2,
|
|
)
|
|
|
|
# NODES
|
|
def reponse_question(state: MessagesState):
|
|
"""Noeud qui réponds à la question, en s'aidant si besoin des outils à disposition"""
|
|
# Initialisation du LLM
|
|
model = llm.bind_tools(getTools())
|
|
|
|
# Appel du LLM
|
|
return {"messages": [model.invoke(state["messages"])]}
|
|
|
|
tool_node = ToolNode(tools=getTools()) # Node gérant les outils
|
|
|
|
# fonction de routage : Après reponse_question, si le LLM veut appeler un outil, on va au tool_node, sinon on termine
|
|
def should_continue(state: MessagesState):
|
|
"""
|
|
Use in the conditional_edge to route to the ToolNode if the last message
|
|
has tool calls. Otherwise, route to the end.
|
|
"""
|
|
if isinstance(state, list):
|
|
ai_message = state[-1]
|
|
elif messages := state.get("messages", []):
|
|
ai_message = messages[-1]
|
|
else:
|
|
raise ValueError(f"No messages found in input state to tool_edge: {state}")
|
|
|
|
if hasattr(ai_message, "tool_calls") and len(ai_message.tool_calls) > 0:
|
|
return "tools"
|
|
return END |