Compare commits
3 Commits
1f96b9a408
...
633726b2a0
| Author | SHA1 | Date | |
|---|---|---|---|
|
633726b2a0
|
|||
|
1c2f0728ea
|
|||
|
14b8664106
|
@@ -1,8 +1,9 @@
|
|||||||
from langgraph.graph import START, END
|
from langgraph.graph import START, END
|
||||||
from langgraph.graph.state import CompiledStateGraph
|
from langgraph.graph.state import CompiledStateGraph
|
||||||
|
|
||||||
from utils.nodes import reponse_question, tool_node, should_continue
|
from utils.nodes import call_to_LLM, should_continue, task_ended, BasicToolNode, tool_node
|
||||||
from utils.state import getState
|
from utils.state import getState
|
||||||
|
from utils.tools import getTools
|
||||||
|
|
||||||
def getGraph()->CompiledStateGraph:
|
def getGraph()->CompiledStateGraph:
|
||||||
"""
|
"""
|
||||||
@@ -14,15 +15,15 @@ def getGraph()->CompiledStateGraph:
|
|||||||
workflow = getState() # State prêt à utiliser
|
workflow = getState() # State prêt à utiliser
|
||||||
|
|
||||||
# Définition des sommets du graphe
|
# Définition des sommets du graphe
|
||||||
workflow.add_node(reponse_question)
|
workflow.add_node(call_to_LLM)
|
||||||
workflow.add_node("tool_node", tool_node) # N'est pas une fonction, mais une classe instanciée, je dois précisier le nom du node
|
workflow.add_node("tool_node", tool_node)# BasicToolNode(tools=getTools())) # N'est pas une fonction, mais une classe instanciée, je dois précisier le nom du node
|
||||||
|
|
||||||
# Arrêtes
|
# Arrêtes
|
||||||
workflow.set_entry_point("reponse_question")
|
workflow.set_entry_point("call_to_LLM")
|
||||||
workflow.add_edge("tool_node", "reponse_question")
|
workflow.add_edge("tool_node", "call_to_LLM")
|
||||||
workflow.add_conditional_edges("reponse_question", should_continue, {
|
workflow.add_conditional_edges("call_to_LLM", should_continue, {
|
||||||
"tools":"tool_node",
|
"tools":"tool_node",
|
||||||
END:END
|
"no_tools":END
|
||||||
})
|
})
|
||||||
|
|
||||||
return workflow.compile()
|
return workflow.compile()
|
||||||
|
|||||||
@@ -10,6 +10,6 @@ from agent import getGraph
|
|||||||
mlflow.set_experiment("TEST PROJET") # VOIR AVEC LA COMMANDE "MLFLOW SERVER"
|
mlflow.set_experiment("TEST PROJET") # VOIR AVEC LA COMMANDE "MLFLOW SERVER"
|
||||||
mlflow.langchain.autolog()
|
mlflow.langchain.autolog()
|
||||||
|
|
||||||
out_state = getGraph().invoke({'messages':[HumanMessage("What's the price for bitcoin ?")]})
|
out_state = getGraph().invoke({'messages':[HumanMessage("Observe la base de documents, et génère un rapport de stage à partir de celle-ci. Ecris le dans un fichier markdown.")]})
|
||||||
for message in out_state['messages']:
|
for message in out_state['messages']:
|
||||||
message.pretty_print()
|
message.pretty_print()
|
||||||
@@ -3,43 +3,21 @@ from langchain_chroma import Chroma # TODO plus tard, ramplacer par PG Vector
|
|||||||
import sys
|
import sys
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
# Permet de garder ChromaDB en mémoire.
|
base_dir:str = Path(sys.argv[0]).resolve().parent.as_posix() # Récupérer le chemin vers le point d'entrée du programme
|
||||||
# Cette classe est un Singleton, il n'y en aura qu'une seule et unique instance à tout moment
|
bdd_path:str = base_dir + "/../chroma_db/"
|
||||||
# https://refactoring.guru/design-patterns/singleton
|
|
||||||
class VectorDatabase:
|
|
||||||
instance = None
|
|
||||||
|
|
||||||
def __new__(cls): # Selon https://www.geeksforgeeks.org/python/singleton-pattern-in-python-a-complete-guide/
|
EMBEDDINGS = HuggingFaceEmbeddings(model_name="intfloat/multilingual-e5-large", model_kwargs={"trust_remote_code": True})
|
||||||
if cls.instance is None:
|
CHROMA = Chroma(
|
||||||
cls.instance = super().__new__(cls)
|
|
||||||
# J'initialise les attributs à None ici, permet de tester si la classe a déjà été init une première fois ou non
|
|
||||||
cls.instance.__embeddings = None
|
|
||||||
cls.instance.__chroma = None
|
|
||||||
return cls.instance
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
if self.__embeddings is not None: return
|
|
||||||
|
|
||||||
base_dir:str = Path(sys.argv[0]).resolve().parent.as_posix() # Récupérer le chemin vers le point d'entrée du programme
|
|
||||||
bdd_path:str = base_dir + "/chroma_db/"
|
|
||||||
|
|
||||||
self.__embeddings = HuggingFaceEmbeddings(model_name="jinaai/jina-embeddings-v3", model_kwargs={"trust_remote_code": True})
|
|
||||||
self.__chroma = Chroma(
|
|
||||||
persist_directory=bdd_path,
|
persist_directory=bdd_path,
|
||||||
embedding_function=self.__embeddings
|
embedding_function=EMBEDDINGS
|
||||||
)
|
)
|
||||||
|
|
||||||
def getChroma(self)->Chroma:
|
class VectorDatabase: # Classe pour récupérer la BDD
|
||||||
return self.__chroma
|
|
||||||
|
|
||||||
def getEmbeddings(self)->'Embeddings Hugging Face':
|
@staticmethod
|
||||||
return self.__embeddings
|
def getChroma()->Chroma:
|
||||||
|
return CHROMA
|
||||||
|
|
||||||
if __name__ == "__main__":
|
@staticmethod
|
||||||
|
def getEmbeddings()->'Embeddings Hugging Face':
|
||||||
test1 = VectorDatabase()
|
return EMBEDDINGS
|
||||||
print('TEST 1 INIT')
|
|
||||||
test2 = VectorDatabase()
|
|
||||||
|
|
||||||
print(test1 is test2)
|
|
||||||
assert test1 is test2
|
|
||||||
@@ -14,21 +14,18 @@ llm = ChatMistralAI( # LLM sans outils
|
|||||||
)
|
)
|
||||||
|
|
||||||
# NODES
|
# NODES
|
||||||
def reponse_question(state: MessagesState):
|
def call_to_LLM(state: MessagesState):
|
||||||
"""Noeud qui réponds à la question, en s'aidant si besoin des outils à disposition"""
|
"""Noeud qui s'occupe de gérer les appels au LLM"""
|
||||||
# Initialisation du LLM
|
# Initialisation du LLM
|
||||||
model = llm.bind_tools(getTools())
|
model = llm.bind_tools(getTools())
|
||||||
|
|
||||||
# Appel du LLM
|
# Appel du LLM
|
||||||
return {"messages": [model.invoke(state["messages"])]}
|
return {"messages": [model.invoke(state["messages"])]}
|
||||||
|
|
||||||
tool_node = ToolNode(tools=getTools()) # Node gérant les outils
|
# fonction de routage : Après reponse_question, si le LLM veut appeler un outil, on va au tool_node
|
||||||
|
|
||||||
# fonction de routage : Après reponse_question, si le LLM veut appeler un outil, on va au tool_node, sinon on termine
|
|
||||||
def should_continue(state: MessagesState):
|
def should_continue(state: MessagesState):
|
||||||
"""
|
"""
|
||||||
Use in the conditional_edge to route to the ToolNode if the last message
|
Vérifier s'il y a un appel aux outils dans le dernier message
|
||||||
has tool calls. Otherwise, route to the end.
|
|
||||||
"""
|
"""
|
||||||
if isinstance(state, list):
|
if isinstance(state, list):
|
||||||
ai_message = state[-1]
|
ai_message = state[-1]
|
||||||
@@ -39,4 +36,48 @@ def should_continue(state: MessagesState):
|
|||||||
|
|
||||||
if hasattr(ai_message, "tool_calls") and len(ai_message.tool_calls) > 0:
|
if hasattr(ai_message, "tool_calls") and len(ai_message.tool_calls) > 0:
|
||||||
return "tools"
|
return "tools"
|
||||||
return END
|
return "no_tools"
|
||||||
|
|
||||||
|
def task_ended(state: MessagesState):
|
||||||
|
"""
|
||||||
|
Vérifier si l'agent a terminé son cycle, ou s'il faut le relancer
|
||||||
|
"""
|
||||||
|
if isinstance(state, list):
|
||||||
|
ai_message = state[-1]
|
||||||
|
elif messages := state.get("messages", []):
|
||||||
|
ai_message = messages[-1]
|
||||||
|
else:
|
||||||
|
raise ValueError(f"No messages found in input state to tool_edge: {state}")
|
||||||
|
|
||||||
|
if "terminé" in ai_message.content.lower():
|
||||||
|
return END
|
||||||
|
return "continue"
|
||||||
|
|
||||||
|
tool_node = ToolNode(tools=getTools())
|
||||||
|
|
||||||
|
|
||||||
|
class BasicToolNode: # De mon ancien projet, https://github.com/LJ5O/Assistant/blob/main/modules/Brain/src/LLM/graph/nodes/BasicToolNode.py
|
||||||
|
"""A node that runs the tools requested in the last AIMessage."""
|
||||||
|
|
||||||
|
def __init__(self, tools: list) -> None:
|
||||||
|
self.tools_by_name = {tool.name: tool for tool in tools}
|
||||||
|
|
||||||
|
def __call__(self, inputs: dict):
|
||||||
|
if messages := inputs.get("messages", []):
|
||||||
|
message = messages[-1]
|
||||||
|
else:
|
||||||
|
raise ValueError("No message found in input")
|
||||||
|
outputs = []
|
||||||
|
for tool_call in message.tool_calls:
|
||||||
|
#print(tool_call["args"])
|
||||||
|
tool_result = self.tools_by_name[tool_call["name"]].invoke(
|
||||||
|
tool_call["args"]
|
||||||
|
)
|
||||||
|
outputs.append(
|
||||||
|
ToolMessage(
|
||||||
|
content=json.dumps(tool_result),
|
||||||
|
name=tool_call["name"],
|
||||||
|
tool_call_id=tool_call["id"],
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return {"messages": outputs}
|
||||||
@@ -48,13 +48,13 @@ def write_file(file_path:str, content: str, append:bool=True) -> str:
|
|||||||
return f"Erreur lors de l'écriture: {str(e)}"
|
return f"Erreur lors de l'écriture: {str(e)}"
|
||||||
|
|
||||||
@tool
|
@tool
|
||||||
def editTodo(index:int, state:int, state: Annotated[dict, InjectedState])->bool: # https://stackoverflow.com/a/79525434
|
def editTodo(index:int, todoState:int, state: Annotated[dict, InjectedState])->bool: # https://stackoverflow.com/a/79525434
|
||||||
"""
|
"""
|
||||||
Modifier l'état d'une tâche (TODO)
|
Modifier l'état d'une tâche (TODO)
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
index (int): Index de la tâche à modifier, en commançant à 0 pour la première tâche.
|
index (int): Index de la tâche à modifier, en commançant à 0 pour la première tâche.
|
||||||
state (int): Nouvel état. 0 pour "non commencé, 1 pour "en cours", 2 pour "complété"
|
todoState (int): Nouvel état. 0 pour "non commencé, 1 pour "en cours", 2 pour "complété"
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
bool: Réussite de l'opération, ou non.
|
bool: Réussite de l'opération, ou non.
|
||||||
@@ -63,7 +63,7 @@ def editTodo(index:int, state:int, state: Annotated[dict, InjectedState])->bool:
|
|||||||
# Erreur, l'index est trop grand
|
# Erreur, l'index est trop grand
|
||||||
return False
|
return False
|
||||||
|
|
||||||
state["todo"][index].state = state # Modification de l'état de cette tâche
|
state["todo"][index].state = todoState # Modification de l'état de cette tâche
|
||||||
|
|
||||||
# Toutes les tâches complétées ?
|
# Toutes les tâches complétées ?
|
||||||
found = False
|
found = False
|
||||||
@@ -168,9 +168,9 @@ def search_in_files(query:str, state: Annotated[dict, InjectedState])->str:
|
|||||||
Returns:
|
Returns:
|
||||||
str: Échantillons de documents correspondants, concaténés en une seule chaîne de caractères.
|
str: Échantillons de documents correspondants, concaténés en une seule chaîne de caractères.
|
||||||
"""
|
"""
|
||||||
bdd = VectorDatabase() # Récupère l'unique instance de cette BDD, c'est un SIngleton
|
bdd = VectorDatabase.getChroma() # Récupère l'unique instance de cette BDD, c'est un SIngleton
|
||||||
|
|
||||||
retrieved_docs = bdd.getChroma().similarity_search(query, k=5) # 5 documents
|
retrieved_docs = bdd.similarity_search(query, k=5) # 5 documents
|
||||||
|
|
||||||
# Conversion des documents en texte
|
# Conversion des documents en texte
|
||||||
docs_content = "\n".join(
|
docs_content = "\n".join(
|
||||||
|
|||||||
@@ -43,7 +43,7 @@ print("===")
|
|||||||
# Création du modèle d'embeddings
|
# Création du modèle d'embeddings
|
||||||
# https://docs.langchain.com/oss/python/integrations/text_embedding/huggingfacehub
|
# https://docs.langchain.com/oss/python/integrations/text_embedding/huggingfacehub
|
||||||
# https://huggingface.co/jinaai/jina-clip-v2
|
# https://huggingface.co/jinaai/jina-clip-v2
|
||||||
embeddings = HuggingFaceEmbeddings(model_name="jinaai/jina-embeddings-v3", model_kwargs={"trust_remote_code": True})
|
embeddings = HuggingFaceEmbeddings(model_name="intfloat/multilingual-e5-large", model_kwargs={"trust_remote_code": True})
|
||||||
|
|
||||||
# Stockage des embeddings dans ChromaDB dans un dossier local "chroma_db"
|
# Stockage des embeddings dans ChromaDB dans un dossier local "chroma_db"
|
||||||
vectorstore = Chroma.from_documents(documents=chunks,embedding=embeddings, persist_directory=base_dir.as_posix()+"/chroma_db/",) # https://docs.langchain.com/oss/python/integrations/vectorstores/chroma
|
vectorstore = Chroma.from_documents(documents=chunks,embedding=embeddings, persist_directory=base_dir.as_posix()+"/chroma_db/",) # https://docs.langchain.com/oss/python/integrations/vectorstores/chroma
|
||||||
|
|||||||
@@ -13,8 +13,10 @@
|
|||||||
## Mise en place de l'agent
|
## Mise en place de l'agent
|
||||||
- [X] Préparation du `State`
|
- [X] Préparation du `State`
|
||||||
- [X] Développement des outils de l'agent
|
- [X] Développement des outils de l'agent
|
||||||
- [ ] Préparation des nœuds
|
- [X] Préparation des nœuds
|
||||||
- [ ] Branchement des nœuds entre-eux
|
- [X] Branchement des nœuds entre-eux, **MVP**
|
||||||
|
- [ ] Human in the loop
|
||||||
|
- [ ] Amélioration du workflow
|
||||||
|
|
||||||
## Amélioration de l'agent
|
## Amélioration de l'agent
|
||||||
- [ ] Cross-encoding sur la sortie du **RAG**
|
- [ ] Cross-encoding sur la sortie du **RAG**
|
||||||
|
|||||||
Reference in New Issue
Block a user