【AI Agent系列】【LangGraph】1. 进阶实战:给你的 LangGraph 加入条件分支(Conditional edges)

慈云数据 2024-05-09 技术支持 38 0
  • 大家好,我是同学小张,日常分享AI知识和实战案例
  • 欢迎 点赞 + 关注 👏,持续学习,持续干货输出。
  • +v: jasper_8017 一起交流💬,一起进步💪。
  • 微信公众号也可搜【同学小张】 🙏

    本站文章一览:

    在这里插入图片描述


    书接上文(【AI Agent系列】【LangGraph】0. 快速上手:协同LangChain,LangGraph帮你用图结构轻松构建多智能体应用),前面我们了解了 LangGraph 的概念和基本构造方法,今天我们来看下 LangGraph 构造中的进阶用法:给边加个条件 - 条件分支(Conditional edges)。

    文章目录

    • 1. 完整代码及运行
    • 2. 代码详解
      • 2.1 add_conditional_edges
      • 2.2 条件 router
      • 2.3 各node的定义
      • 2.4 总体流程

        LangGraph 构造的是个图的数据结构,有节点(node) 和边(edge),那它的边也可以是带条件的。如何给边加入条件呢?可以通过 add_conditional_edges 函数添加带条件的边。

        1. 完整代码及运行

        废话不多说,先上完整代码,和运行结果。先跑起来看看效果再说。

        from langchain_openai import ChatOpenAI
        from langchain_core.messages import HumanMessage, BaseMessage
        from langgraph.graph import END, MessageGraph
        import json
        from langchain_core.messages import ToolMessage
        from langchain_core.tools import tool
        from langchain_core.utils.function_calling import convert_to_openai_tool
        from typing import List
        @tool
        def multiply(first_number: int, second_number: int):
            """Multiplies two numbers together."""
            return first_number * second_number
        model = ChatOpenAI(temperature=0)
        model_with_tools = model.bind(tools=[convert_to_openai_tool(multiply)])
        graph = MessageGraph()
        def invoke_model(state: List[BaseMessage]):
            return model_with_tools.invoke(state)
        graph.add_node("oracle", invoke_model)
        def invoke_tool(state: List[BaseMessage]):
            tool_calls = state[-1].additional_kwargs.get("tool_calls", [])
            multiply_call = None
            for tool_call in tool_calls:
                if tool_call.get("function").get("name") == "multiply":
                    multiply_call = tool_call
            if multiply_call is None:
                raise Exception("No adder input found.")
            res = multiply.invoke(
                json.loads(multiply_call.get("function").get("arguments"))
            )
            return ToolMessage(
                tool_call_id=multiply_call.get("id"),
                content=res
            )
        graph.add_node("multiply", invoke_tool)
        graph.add_edge("multiply", END)
        graph.set_entry_point("Oracle")
        def router(state: List[BaseMessage]):
            tool_calls = state[-1].additional_kwargs.get("tool_calls", [])
            if len(tool_calls):
                return "multiply"
            else:
                return "end"
        graph.add_conditional_edges("oracle", router, {
            "multiply": "multiply",
            "end": END,
        })
        runnable = graph.compile()
        response = runnable.invoke(HumanMessage("What is 123 * 456?"))
        print(response)
        

        运行结果如下:

        在这里插入图片描述

        2. 代码详解

        下面对上面的代码进行详细解释。

        2.1 add_conditional_edges

        首先,我们知道了可以通过 add_conditional_edges 来对边进行条件添加。这部分代码如下:

        graph.add_conditional_edges("oracle", router, {
            "multiply": "multiply",
            "end": END,
        })
        

        add_conditional_edges接收三个参数:

        • 第一个为这条边的第一个node的名称
        • 第二个为这条边的条件
        • 第三个为条件返回结果的映射(根据条件结果映射到相应的node)

          如上面的代码,意思就是往 “oracle” node上添加边,这个node有两条边,一条是往“multiply” node上走,一条是往“END”上走。怎么决定往哪个方向去:条件是 router(后面解释),如果 router 返回的是“multiply”,则往“multiply”方向走,如果 router 返回的是 “end”,则走“END”。

          来看下这个函数的源码:

          def add_conditional_edges(
              self,
              start_key: str,
              condition: Callable[..., str],
              conditional_edge_mapping: Optional[Dict[str, str]] = None,
          ) -> None:
              if self.compiled:
                  logger.warning(
                      "Adding an edge to a graph that has already been compiled. This will "
                      "not be reflected in the compiled graph."
                  )
              if start_key not in self.nodes:
                  raise ValueError(f"Need to add_node `{start_key}` first")
              if iscoroutinefunction(condition):
                  raise ValueError("Condition cannot be a coroutine function")
              if conditional_edge_mapping and set(
                  conditional_edge_mapping.values()
              ).difference([END]).difference(self.nodes):
                  raise ValueError(
                      f"Missing nodes which are in conditional edge mapping. Mapping "
                      f"contains possible destinations: "
                      f"{list(conditional_edge_mapping.values())}. Possible nodes are "
                      f"{list(self.nodes.keys())}."
                  )
              self.branches[start_key].append(Branch(condition, conditional_edge_mapping))
          

          重点是这一句:self.branches[start_key].append(Branch(condition, conditional_edge_mapping)),给当前node添加分支Branch。

          2.2 条件 router

          条件代码如下:判断执行结果中是否有 tool_calls 参数,如果有,则返回"multiply",没有,则返回“end”。

          def router(state: List[BaseMessage]):
              tool_calls = state[-1].additional_kwargs.get("tool_calls", [])
              if len(tool_calls):
                  return "multiply"
              else:
                  return "end"
          

          2.3 各node的定义

          (1)起始node:oracle

          @tool
          def multiply(first_number: int, second_number: int):
              """Multiplies two numbers together."""
              return first_number * second_number
          model = ChatOpenAI(temperature=0)
          model_with_tools = model.bind(tools=[convert_to_openai_tool(multiply)])
          graph = MessageGraph()
          def invoke_model(state: List[BaseMessage]):
              return model_with_tools.invoke(state)
          graph.add_node("oracle", invoke_model)
          

          这个node是一个带有Tools 的 ChatOpenAI。在LangChain中使用Tools的详细教程请看这篇文章:【AI大模型应用开发】【LangChain系列】5. 实战LangChain的智能体Agents模块。简单解释就是:这个node的执行结果,将返回是否应该使用绑定的Tools。

          (2)multiply

          def invoke_tool(state: List[BaseMessage]):
              tool_calls = state[-1].additional_kwargs.get("tool_calls", [])
              multiply_call = None
              for tool_call in tool_calls:
                  if tool_call.get("function").get("name") == "multiply":
                      multiply_call = tool_call
              if multiply_call is None:
                  raise Exception("No adder input found.")
              res = multiply.invoke(
                  json.loads(multiply_call.get("function").get("arguments"))
              )
              return ToolMessage(
                  tool_call_id=multiply_call.get("id"),
                  content=res
              )
          graph.add_node("multiply", invoke_tool)
          

          这个node的作用就是执行Tools。

          2.4 总体流程

          在这里插入图片描述

          如果觉得本文对你有帮助,麻烦点个赞和关注呗 ~~~


          • 大家好,我是 同学小张,日常分享AI知识和实战案例
          • 欢迎 点赞 + 关注 👏,持续学习,持续干货输出。
          • +v: jasper_8017 一起交流💬,一起进步💪。
          • 微信公众号也可搜【同学小张】 🙏

            本站文章一览:

            在这里插入图片描述

微信扫一扫加客服

微信扫一扫加客服

点击启动AI问答
Draggable Icon