Source code for langchain_nvidia_ai_endpoints.reranking

from __future__ import annotations

import os
from typing import Any, Dict, Generator, List, Literal, Optional, Sequence

from langchain_core.callbacks.manager import Callbacks
from langchain_core.documents import Document
from langchain_core.documents.compressor import BaseDocumentCompressor
from langchain_core.pydantic_v1 import BaseModel, Field, PrivateAttr, root_validator

from langchain_nvidia_ai_endpoints._common import _NVIDIAClient
from langchain_nvidia_ai_endpoints._statics import Model


[docs] class Ranking(BaseModel): index: int logit: float
[docs] class NVIDIARerank(BaseDocumentCompressor): """ LangChain Document Compressor that uses the NVIDIA NeMo Retriever Reranking API. """ class Config: validate_assignment = True _client: _NVIDIAClient = PrivateAttr(_NVIDIAClient) _default_batch_size: int = 32 _default_model_name: str = "nv-rerank-qa-mistral-4b:1" _default_base_url: str = "https://integrate.api.nvidia.com/v1" base_url: str = Field( description="Base url for model listing an invocation", ) top_n: int = Field(5, ge=0, description="The number of documents to return.") model: Optional[str] = Field(description="The model to use for reranking.") truncate: Optional[Literal["NONE", "END"]] = Field( description=( "Truncate input text if it exceeds the model's maximum token length. " "Default is model dependent and is likely to raise error if an " "input is too long." ), ) max_batch_size: int = Field( _default_batch_size, ge=1, description="The maximum batch size." ) _base_url_var = "NVIDIA_BASE_URL" @root_validator(pre=True) def _validate_base_url(cls, values: Dict[str, Any]) -> Dict[str, Any]: values["base_url"] = ( values.get(cls._base_url_var.lower()) or values.get("base_url") or os.getenv(cls._base_url_var) or cls._default_base_url ) return values def __init__(self, **kwargs: Any): """ Create a new NVIDIARerank document compressor. This class provides access to a NVIDIA NIM for reranking. By default, it connects to a hosted NIM, but can be configured to connect to a local NIM using the `base_url` parameter. An API key is required to connect to the hosted NIM. Args: model (str): The model to use for reranking. nvidia_api_key (str): The API key to use for connecting to the hosted NIM. api_key (str): Alternative to nvidia_api_key. base_url (str): The base URL of the NIM to connect to. truncate (str): "NONE", "END", truncate input text if it exceeds the model's context length. Default is model dependent and is likely to raise an error if an input is too long. API Key: - The recommended way to provide the API key is through the `NVIDIA_API_KEY` environment variable. """ super().__init__(**kwargs) self._client = _NVIDIAClient( base_url=self.base_url, model_name=self.model, default_hosted_model_name=self._default_model_name, api_key=kwargs.get("nvidia_api_key", kwargs.get("api_key", None)), infer_path="{base_url}/ranking", cls=self.__class__.__name__, ) # todo: only store the model in one place # the model may be updated to a newer name during initialization self.model = self._client.model_name @property def available_models(self) -> List[Model]: """ Get a list of available models that work with NVIDIARerank. """ return self._client.get_available_models(self.__class__.__name__)
[docs] @classmethod def get_available_models( cls, **kwargs: Any, ) -> List[Model]: """ Get a list of available models that work with NVIDIARerank. """ return cls(**kwargs).available_models
# todo: batching when len(documents) > endpoint's max batch size def _rank(self, documents: List[str], query: str) -> List[Ranking]: payload = { "model": self.model, "query": {"text": query}, "passages": [{"text": passage} for passage in documents], } if self.truncate: payload["truncate"] = self.truncate response = self._client.get_req(payload=payload) if response.status_code != 200: response.raise_for_status() # todo: handle errors rankings = response.json()["rankings"] # todo: callback support return [Ranking(**ranking) for ranking in rankings[: self.top_n]]
[docs] def compress_documents( self, documents: Sequence[Document], query: str, callbacks: Optional[Callbacks] = None, ) -> Sequence[Document]: """ Compress documents using the NVIDIA NeMo Retriever Reranking microservice API. Args: documents: A sequence of documents to compress. query: The query to use for compressing the documents. callbacks: Callbacks to run during the compression process. Returns: A sequence of compressed documents. """ if len(documents) == 0 or self.top_n < 1: return [] def batch(ls: list, size: int) -> Generator[List[Document], None, None]: for i in range(0, len(ls), size): yield ls[i : i + size] doc_list = list(documents) results = [] for doc_batch in batch(doc_list, self.max_batch_size): rankings = self._rank( query=query, documents=[d.page_content for d in doc_batch] ) for ranking in rankings: assert ( 0 <= ranking.index < len(doc_batch) ), "invalid response from server: index out of range" doc = doc_batch[ranking.index] doc.metadata["relevance_score"] = ranking.logit results.append(doc) # if we batched, we need to sort the results if len(doc_list) > self.max_batch_size: results.sort(key=lambda x: x.metadata["relevance_score"], reverse=True) return results[: self.top_n]