Source code for langchain_core.tools.retriever

from __future__ import annotations

from functools import partial
from typing import Optional

from langchain_core.callbacks import Callbacks
from langchain_core.prompts import (
    BasePromptTemplate,
    PromptTemplate,
    aformat_document,
    format_document,
)
from langchain_core.pydantic_v1 import BaseModel, Field
from langchain_core.retrievers import BaseRetriever
from langchain_core.tools.simple import Tool


[docs] class RetrieverInput(BaseModel): """Input to the retriever.""" query: str = Field(description="query to look up in retriever")
def _get_relevant_documents( query: str, retriever: BaseRetriever, document_prompt: BasePromptTemplate, document_separator: str, callbacks: Callbacks = None, ) -> str: docs = retriever.invoke(query, config={"callbacks": callbacks}) return document_separator.join( format_document(doc, document_prompt) for doc in docs ) async def _aget_relevant_documents( query: str, retriever: BaseRetriever, document_prompt: BasePromptTemplate, document_separator: str, callbacks: Callbacks = None, ) -> str: docs = await retriever.ainvoke(query, config={"callbacks": callbacks}) return document_separator.join( [await aformat_document(doc, document_prompt) for doc in docs] )
[docs] def create_retriever_tool( retriever: BaseRetriever, name: str, description: str, *, document_prompt: Optional[BasePromptTemplate] = None, document_separator: str = "\n\n", ) -> Tool: """Create a tool to do retrieval of documents. Args: retriever: The retriever to use for the retrieval name: The name for the tool. This will be passed to the language model, so should be unique and somewhat descriptive. description: The description for the tool. This will be passed to the language model, so should be descriptive. document_prompt: The prompt to use for the document. Defaults to None. document_separator: The separator to use between documents. Defaults to "\n\n". Returns: Tool class to pass to an agent. """ document_prompt = document_prompt or PromptTemplate.from_template("{page_content}") func = partial( _get_relevant_documents, retriever=retriever, document_prompt=document_prompt, document_separator=document_separator, ) afunc = partial( _aget_relevant_documents, retriever=retriever, document_prompt=document_prompt, document_separator=document_separator, ) return Tool( name=name, description=description, func=func, coroutine=afunc, args_schema=RetrieverInput, )