封装一个文本片段召回的函数
函数中的 retriever
是 langchain 提供的用于文本召回的类。在RAG简单实现的演示中已经做过介绍,此处不再赘述。
def Retriever(query):
return retriever.get_relevant_documents(query)
实现问题拆解的提示词
为了让大模型执行分解问题的任务,设计以下提示词模板
1from langchain_core.prompts import ChatPromptTemplate
2from langchain_openai import ChatOpenAI
3
4system = """
5 You are an expert at converting user questions into database queries. \
6 You have access to a database of documents about financial reports. \
7
8 Perform query decomposition. Given a user question, break it down into distinct sub questions that \
9 you need to answer in order to answer the original question. Focus on creating retrievable queries without adding any processing steps like calculations. \
10
11 If there are acronyms or words you are not familiar with, do not try to rephrase them.
12
13 Ensure that your responses strictly adhere to the format of the following example.
14
15 example:
16 question: "What's the difference between LangChain agents and LangGraph?"
17 output format:
18 [Retriever("What's the difference between LangChain agents and LangGraph?"),
19 Retriever("What are LangChain agents"),
20 Retriever("What is LangGraph")]
21"""
- 在提示词中,我们加入了一个案例,以及输出的格式要求。
- 案例中的输出格式,是方便我们使用 eval 函数进行解析,利用我们上述定义的 Retriever 函数进行文档召回。
问题分解测试
prompt = ChatPromptTemplate.from_messages(
[
("system", system),
("human", "{question}"),
]
)
query_analyzer = prompt | model
querys = query_analyzer.invoke({"question": "腾讯2023年收入比2022年高多少?"}).content
模型返回的结果如下,成功分解为三个子问题,分别询问2023年和2022年的腾讯收入,以及两者的差值。
'[Retriever("What was Tencent\'s income in 2023?"),\nRetriever("What was Tencent\'s income in 2022?"),\nRetriever("What is the difference between Tencent\'s income in 2023 and 2022?")]'
子问题召回
由于我们要求模型输出的格式,是方便我们使用 eval 函数进行解析执行的,以下命令即可完成文档召回:
由于是对三个子问题进行文档召回,因此我们获得的是一个双层列表,总共召回了 9 个文本片段。
len(docs): 3
len(docs[0]): 3
答案生成
最后将召回的文本通过格式化后提供给大模型进行答案生成,先定义一个格式化函数
def format_docs(docs):
if type(docs[0]) == list:
flattened_docs = [doc for sublist in docs for doc in sublist]
else:
flattened_docs = docs
doc_list = [doc.page_content for doc in flattened_docs]
seen = set()
unique_list = [x for x in doc_list if x not in seen and (seen.add(x) or True)]
return "\n\n".join(doc for doc in unique_list)
- 函数中 if 判断的作用是当多个子问题召回时,结果是一个双层列表,我们需要将其展平。
- 由于多个子问题的召回结果可能会有重复的文本片段,我们需要对其进行去重。
context = format_docs(docs)
rag_chain = (
{"context": RunnablePassthrough(), "question": RunnablePassthrough()}
| custom_rag_prompt
| model
| StrOutputParser()
)
rag_chain.invoke({'context': context, 'question': "腾讯2023年收入比2022年高多少"})
最终得到的答案如下,可以看到模型成功回答了我们的问题。
AIMessage(content='腾讯2023年的收入比2022年高10,242万元人民币。', response_metadata={'finish_reason': 'stop', 'logprobs': None})