diff --git a/AgentReact/agent.py b/AgentReact/agent.py index 1de10f6..1cbef3f 100644 --- a/AgentReact/agent.py +++ b/AgentReact/agent.py @@ -16,6 +16,7 @@ def getGraph()->CompiledStateGraph: workflow = getState() # State prêt à utiliser # Définition des sommets du graphe + workflow.add_node(user_prompt) workflow.add_node(call_to_LLM) workflow.add_node(preparation_docs) workflow.add_node(inject_preparation_prompt) @@ -27,8 +28,9 @@ def getGraph()->CompiledStateGraph: workflow.add_edge("inject_preparation_prompt", "preparation_docs") workflow.add_conditional_edges("preparation_docs", should_continue, { "tools":"weekly_report_tools", - "no_tools":"call_to_LLM" + "no_tools":"user_prompt" }) + workflow.add_edge("user_prompt", "call_to_LLM") #workflow.set_entry_point("call_to_LLM") workflow.add_edge("weekly_report_tools", "preparation_docs") diff --git a/AgentReact/utils/InterruptPayload.py b/AgentReact/utils/InterruptPayload.py index c1a5818..609acbd 100644 --- a/AgentReact/utils/InterruptPayload.py +++ b/AgentReact/utils/InterruptPayload.py @@ -10,9 +10,21 @@ class InterruptPayload(): #EDITED = 2 DENIED = 3 - def __init__(self, fields:Dict, state:int=0): + TOOL_CALL = 999 + USER_PROMPT = 998 + + def __init__(self, fields:Dict, state:int=0, payload_type:int=TOOL_CALL): + """ + Créer unne nouvelle instance de payload pour interrupt() + + Args: + fields (Dict): Un dictionnaire d'arguments pour un call d'outil, ou {'prompt':str} pour une requête de prompt + state (int, optional): État de la requête. Defaults to 0. Définit en variables statiques de l'objet. + payload_type (int, optional): Type d'interuption, appel d'outil ou requête humaine. Defaults to TOOL_CALL. Définit en variables statiques de l'objet. + """ self.__fields = fields self.__state = state + self.__type = payload_type def get(self, key:str)->str: """ @@ -39,7 +51,22 @@ class InterruptPayload(): """ Afficher la requête proprement, permettant à l'utilisateur d'accepter, refuser ou modifier une requête """ + if self.__type == InterruptPayload.USER_PROMPT: # C'est une demande de prompt humain + self.__human_prompt_display() + else: # C'est un appel d'outil + self.__tool_query_display() + def __human_prompt_display(self): + print("=== L'AGENT DEMANDE DES CONSIGNES! ===\n") + + print("Veuillez saisir un prompt pour l'agent...\n") + prompt = input("Prompt...") + + self.__fields = {'prompt': prompt} + print("\nMerci, l'exécution va reprendre.\n") + print("======") + + def __tool_query_display(self): print("=== L'AGENT DEMANDE À UTILISER UN OUTIL RESTREINT! ===\n") keys = list(self.__fields.keys()) @@ -86,7 +113,7 @@ class InterruptPayload(): Returns: str: String sérialisable via la méthode statique InterruptPayload.strImport(string) """ - return '{"state":'+ str(self.__state) +', "fields": ' + json.dumps(self.__fields, ensure_ascii=False, indent=indent) +'}' + return '{"state":'+ str(self.__state) +', "type": '+str(self.__type)+', "fields": ' + json.dumps(self.__fields, ensure_ascii=False, indent=indent) +'}' @staticmethod @@ -104,8 +131,9 @@ class InterruptPayload(): state_ = data.get("state", 0) fields_ = data.get("fields", {}) + type_ = data.get("type", InterruptPayload.TOOL_CALL) - return InterruptPayload(fields=fields_, state=state_) + return InterruptPayload(fields=fields_, state=state_, payload_type=type_) diff --git a/AgentReact/utils/nodes.py b/AgentReact/utils/nodes.py index a2fb619..352beb7 100644 --- a/AgentReact/utils/nodes.py +++ b/AgentReact/utils/nodes.py @@ -33,6 +33,26 @@ def preparation_docs(state: CustomState): return {'messages': model.invoke(state['messages'])} +def user_prompt(state: CustomState): + """ Dans ce nœud, l'utilisateur peut écrire un HumanMessage pour l'IA """ + + messages = [msg for msg in state['messages']] # Je récupère la liste des messages + + sys_message = SystemMessage("Salut") # TODO: Anti-injections + user_message = HumanMessage( + InterruptPayload.fromJSON( + interrupt( + InterruptPayload({'prompt':''}, payload_type=InterruptPayload.USER_PROMPT).toJSON() + ) + ).get("prompt") + ) # Récupérer un prompt + + messages.append(sys_message) # Rajout des nouveaux messages dans le système + messages.append(user_message) + + return {'messages': messages}# Je passe unen liste, devrait écraser tous les messages précédent au lieu d'ajouter à la liste du State + + def call_to_LLM(state: MessagesState): """Noeud qui s'occupe de gérer les appels au LLM""" # Initialisation du LLM diff --git a/imgs/agent.png b/imgs/agent.png index e24f8c4..2dfc650 100644 Binary files a/imgs/agent.png and b/imgs/agent.png differ