Source code for langchain.chains.openai_functions.qa_with_structure

from typing import Any, List, Optional, Type, Union, cast

from langchain_core.language_models import BaseLanguageModel
from langchain_core.messages import HumanMessage, SystemMessage
from langchain_core.output_parsers import BaseLLMOutputParser
from langchain_core.output_parsers.openai_functions import (
    OutputFunctionsParser,
    PydanticOutputFunctionsParser,
)
from langchain_core.prompts import PromptTemplate
from langchain_core.prompts.chat import ChatPromptTemplate, HumanMessagePromptTemplate
from langchain_core.pydantic_v1 import BaseModel, Field
from langchain_core.utils.pydantic import is_basemodel_subclass

from langchain.chains.llm import LLMChain
from langchain.chains.openai_functions.utils import get_llm_kwargs


[docs] class AnswerWithSources(BaseModel): """An answer to the question, with sources.""" answer: str = Field(..., description="Answer to the question that was asked") sources: List[str] = Field( ..., description="List of sources used to answer the question" )
[docs] def create_qa_with_structure_chain( llm: BaseLanguageModel, schema: Union[dict, Type[BaseModel]], output_parser: str = "base", prompt: Optional[Union[PromptTemplate, ChatPromptTemplate]] = None, verbose: bool = False, ) -> LLMChain: """Create a question answering chain that returns an answer with sources based on schema. Args: llm: Language model to use for the chain. schema: Pydantic schema to use for the output. output_parser: Output parser to use. Should be one of `pydantic` or `base`. Default to `base`. prompt: Optional prompt to use for the chain. Returns: """ if output_parser == "pydantic": if not (isinstance(schema, type) and is_basemodel_subclass(schema)): raise ValueError( "Must provide a pydantic class for schema when output_parser is " "'pydantic'." ) _output_parser: BaseLLMOutputParser = PydanticOutputFunctionsParser( pydantic_schema=schema ) elif output_parser == "base": _output_parser = OutputFunctionsParser() else: raise ValueError( f"Got unexpected output_parser: {output_parser}. " f"Should be one of `pydantic` or `base`." ) if isinstance(schema, type) and is_basemodel_subclass(schema): schema_dict = cast(dict, schema.schema()) else: schema_dict = cast(dict, schema) function = { "name": schema_dict["title"], "description": schema_dict["description"], "parameters": schema_dict, } llm_kwargs = get_llm_kwargs(function) messages = [ SystemMessage( content=( "You are a world class algorithm to answer " "questions in a specific format." ) ), HumanMessage(content="Answer question using the following context"), HumanMessagePromptTemplate.from_template("{context}"), HumanMessagePromptTemplate.from_template("Question: {question}"), HumanMessage(content="Tips: Make sure to answer in the correct format"), ] prompt = prompt or ChatPromptTemplate(messages=messages) # type: ignore[arg-type, call-arg] chain = LLMChain( llm=llm, prompt=prompt, llm_kwargs=llm_kwargs, output_parser=_output_parser, verbose=verbose, ) return chain
[docs] def create_qa_with_sources_chain( llm: BaseLanguageModel, verbose: bool = False, **kwargs: Any ) -> LLMChain: """Create a question answering chain that returns an answer with sources. Args: llm: Language model to use for the chain. verbose: Whether to print the details of the chain **kwargs: Keyword arguments to pass to `create_qa_with_structure_chain`. Returns: Chain (LLMChain) that can be used to answer questions with citations. """ return create_qa_with_structure_chain( llm, AnswerWithSources, verbose=verbose, **kwargs )