"""Embeddings Components Derived from NVEModel/Embeddings"""
import os
import warnings
from typing import Any, Dict, List, Literal, Optional
from langchain_core.embeddings import Embeddings
from langchain_core.outputs.llm_result import LLMResult
from langchain_core.pydantic_v1 import (
BaseModel,
Field,
PrivateAttr,
root_validator,
validator,
)
from langchain_nvidia_ai_endpoints._common import _NVIDIAClient
from langchain_nvidia_ai_endpoints._statics import Model
from langchain_nvidia_ai_endpoints.callbacks import usage_callback_var
[docs]
class NVIDIAEmbeddings(BaseModel, Embeddings):
"""
Client to NVIDIA embeddings models.
Fields:
- model: str, the name of the model to use
- truncate: "NONE", "START", "END", truncate input text if it exceeds the model's
maximum token length. Default is "NONE", which raises an error if an input is
too long.
"""
class Config:
validate_assignment = True
_client: _NVIDIAClient = PrivateAttr(_NVIDIAClient)
_default_model_name: str = "NV-Embed-QA"
_default_max_batch_size: int = 50
_default_base_url: str = "https://integrate.api.nvidia.com/v1"
base_url: str = Field(
description="Base url for model listing an invocation",
)
model: Optional[str] = Field(description="Name of the model to invoke")
truncate: Literal["NONE", "START", "END"] = Field(
default="NONE",
description=(
"Truncate input text if it exceeds the model's maximum token length. "
"Default is 'NONE', which raises an error if an input is too long."
),
)
max_batch_size: int = Field(default=_default_max_batch_size)
model_type: Optional[Literal["passage", "query"]] = Field(
None, description="(DEPRECATED) The type of text to be embedded."
)
_base_url_var = "NVIDIA_BASE_URL"
@root_validator(pre=True)
def _validate_base_url(cls, values: Dict[str, Any]) -> Dict[str, Any]:
values["base_url"] = (
values.get(cls._base_url_var.lower())
or values.get("base_url")
or os.getenv(cls._base_url_var)
or cls._default_base_url
)
return values
def __init__(self, **kwargs: Any):
"""
Create a new NVIDIAEmbeddings embedder.
This class provides access to a NVIDIA NIM for embedding. By default, it
connects to a hosted NIM, but can be configured to connect to a local NIM
using the `base_url` parameter. An API key is required to connect to the
hosted NIM.
Args:
model (str): The model to use for embedding.
nvidia_api_key (str): The API key to use for connecting to the hosted NIM.
api_key (str): Alternative to nvidia_api_key.
base_url (str): The base URL of the NIM to connect to.
Format for base URL is http://host:port
trucate (str): "NONE", "START", "END", truncate input text if it exceeds
the model's context length. Default is "NONE", which raises
an error if an input is too long.
API Key:
- The recommended way to provide the API key is through the `NVIDIA_API_KEY`
environment variable.
"""
super().__init__(**kwargs)
self._client = _NVIDIAClient(
base_url=self.base_url,
model_name=self.model,
default_hosted_model_name=self._default_model_name,
api_key=kwargs.get("nvidia_api_key", kwargs.get("api_key", None)),
infer_path="{base_url}/embeddings",
cls=self.__class__.__name__,
)
# todo: only store the model in one place
# the model may be updated to a newer name during initialization
self.model = self._client.model_name
# todo: remove when nvolveqa_40k is removed from MODEL_TABLE
if "model" in kwargs and kwargs["model"] in [
"playground_nvolveqa_40k",
"nvolveqa_40k",
]:
warnings.warn(
'Setting truncate="END" for nvolveqa_40k backward compatibility'
)
self.truncate = "END"
@validator("model_type")
def _validate_model_type(
cls, v: Optional[Literal["passage", "query"]]
) -> Optional[Literal["passage", "query"]]:
if v:
warnings.warn(
"Warning: `model_type` is deprecated and will be removed "
"in a future release. Please use `embed_query` or "
"`embed_documents` appropriately."
)
return v
@property
def available_models(self) -> List[Model]:
"""
Get a list of available models that work with NVIDIAEmbeddings.
"""
return self._client.get_available_models(self.__class__.__name__)
[docs]
@classmethod
def get_available_models(
cls,
**kwargs: Any,
) -> List[Model]:
"""
Get a list of available models that work with NVIDIAEmbeddings.
"""
return cls(**kwargs).available_models
def _embed(
self, texts: List[str], model_type: Literal["passage", "query"]
) -> List[List[float]]:
"""Embed a single text entry to either passage or query type"""
# API Catalog API -
# input: str | list[str] -- char limit depends on model
# model: str -- model name, e.g. NV-Embed-QA
# encoding_format: "float" | "base64"
# input_type: "query" | "passage"
# user: str -- ignored
# truncate: "NONE" | "START" | "END" -- default "NONE", error raised if
# an input is too long
payload = {
"input": texts,
"model": self.model,
"encoding_format": "float",
"input_type": model_type,
}
if self.truncate:
payload["truncate"] = self.truncate
response = self._client.get_req(
payload=payload,
)
response.raise_for_status()
result = response.json()
data = result.get("data", result)
if not isinstance(data, list):
raise ValueError(f"Expected data with a list of embeddings. Got: {data}")
embedding_list = [(res["embedding"], res["index"]) for res in data]
self._invoke_callback_vars(result)
return [x[0] for x in sorted(embedding_list, key=lambda x: x[1])]
[docs]
def embed_query(self, text: str) -> List[float]:
"""Input pathway for query embeddings."""
return self._embed([text], model_type=self.model_type or "query")[0]
[docs]
def embed_documents(self, texts: List[str]) -> List[List[float]]:
"""Input pathway for document embeddings."""
if not isinstance(texts, list) or not all(
isinstance(text, str) for text in texts
):
raise ValueError(f"`texts` must be a list of strings, given: {repr(texts)}")
all_embeddings = []
for i in range(0, len(texts), self.max_batch_size):
batch = texts[i : i + self.max_batch_size]
all_embeddings.extend(
self._embed(batch, model_type=self.model_type or "passage")
)
return all_embeddings
def _invoke_callback_vars(self, response: dict) -> None:
"""Invoke the callback context variables if there are any."""
callback_vars = [
usage_callback_var.get(),
]
llm_output = {**response, "model_name": self.model}
result = LLMResult(generations=[[]], llm_output=llm_output)
for cb_var in callback_vars:
if cb_var:
cb_var.on_llm_end(result)