Skip to content

pydantic_ai.embeddings

EmbeddingModel

Bases: ABC

Abstract class for a model.

Source code in pydantic_ai_slim/pydantic_ai/embeddings/base.py
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
class EmbeddingModel(ABC):
    """Abstract class for a model."""

    _settings: EmbeddingSettings | None = None

    def __init__(
        self,
        *,
        settings: EmbeddingSettings | None = None,
    ) -> None:
        """Initialize the model with optional settings and profile.

        Args:
            settings: Model-specific settings that will be used as defaults for this model.
        """
        self._settings = settings

    @property
    def settings(self) -> EmbeddingSettings | None:
        """Get the model settings."""
        return self._settings

    @property
    def base_url(self) -> str | None:
        """The base URL for the provider API, if available."""
        return None

    @property
    @abstractmethod
    def model_name(self) -> str:
        """The model name."""
        raise NotImplementedError()

    @property
    @abstractmethod
    def system(self) -> str:
        """The embedding model provider."""
        raise NotImplementedError()

    @abstractmethod
    async def embed(
        self, inputs: str | Sequence[str], *, input_type: EmbedInputType, settings: EmbeddingSettings | None = None
    ) -> EmbeddingResult:
        raise NotImplementedError

    def prepare_embed(
        self, inputs: str | Sequence[str], settings: EmbeddingSettings | None = None
    ) -> tuple[list[str], EmbeddingSettings]:
        """Prepare the inputs and settings for the embedding."""
        inputs = [inputs] if isinstance(inputs, str) else list(inputs)

        settings = merge_embedding_settings(self._settings, settings) or {}

        return inputs, settings

    async def max_input_tokens(self) -> int | None:
        """Get the maximum number of tokens that can be input to the model.

        `None` means unknown.
        """
        return None

    async def count_tokens(self, text: str) -> int:
        """Count the number of tokens in the text."""
        raise NotImplementedError

__init__

__init__(
    *, settings: EmbeddingSettings | None = None
) -> None

Initialize the model with optional settings and profile.

Parameters:

Name Type Description Default
settings EmbeddingSettings | None

Model-specific settings that will be used as defaults for this model.

None
Source code in pydantic_ai_slim/pydantic_ai/embeddings/base.py
13
14
15
16
17
18
19
20
21
22
23
def __init__(
    self,
    *,
    settings: EmbeddingSettings | None = None,
) -> None:
    """Initialize the model with optional settings and profile.

    Args:
        settings: Model-specific settings that will be used as defaults for this model.
    """
    self._settings = settings

settings property

settings: EmbeddingSettings | None

Get the model settings.

base_url property

base_url: str | None

The base URL for the provider API, if available.

model_name abstractmethod property

model_name: str

The model name.

system abstractmethod property

system: str

The embedding model provider.

prepare_embed

prepare_embed(
    inputs: str | Sequence[str],
    settings: EmbeddingSettings | None = None,
) -> tuple[list[str], EmbeddingSettings]

Prepare the inputs and settings for the embedding.

Source code in pydantic_ai_slim/pydantic_ai/embeddings/base.py
53
54
55
56
57
58
59
60
61
def prepare_embed(
    self, inputs: str | Sequence[str], settings: EmbeddingSettings | None = None
) -> tuple[list[str], EmbeddingSettings]:
    """Prepare the inputs and settings for the embedding."""
    inputs = [inputs] if isinstance(inputs, str) else list(inputs)

    settings = merge_embedding_settings(self._settings, settings) or {}

    return inputs, settings

max_input_tokens async

max_input_tokens() -> int | None

Get the maximum number of tokens that can be input to the model.

None means unknown.

Source code in pydantic_ai_slim/pydantic_ai/embeddings/base.py
63
64
65
66
67
68
async def max_input_tokens(self) -> int | None:
    """Get the maximum number of tokens that can be input to the model.

    `None` means unknown.
    """
    return None

count_tokens async

count_tokens(text: str) -> int

Count the number of tokens in the text.

Source code in pydantic_ai_slim/pydantic_ai/embeddings/base.py
70
71
72
async def count_tokens(self, text: str) -> int:
    """Count the number of tokens in the text."""
    raise NotImplementedError

InstrumentedEmbeddingModel dataclass

Bases: WrapperEmbeddingModel

Embedding model which wraps another model so that requests are instrumented with OpenTelemetry.

See the Debugging and Monitoring guide for more info.

Source code in pydantic_ai_slim/pydantic_ai/embeddings/instrumented.py
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
@dataclass(init=False)
class InstrumentedEmbeddingModel(WrapperEmbeddingModel):
    """Embedding model which wraps another model so that requests are instrumented with OpenTelemetry.

    See the [Debugging and Monitoring guide](https://ai.pydantic.dev/logfire/) for more info.
    """

    instrumentation_settings: InstrumentationSettings
    """Instrumentation settings for this model."""

    def __init__(
        self,
        wrapped: EmbeddingModel | str,
        options: InstrumentationSettings | None = None,
    ) -> None:
        super().__init__(wrapped)
        self.instrumentation_settings = options or InstrumentationSettings()

    async def embed(
        self, inputs: str | Sequence[str], *, input_type: EmbedInputType, settings: EmbeddingSettings | None = None
    ) -> EmbeddingResult:
        inputs, settings = self.prepare_embed(inputs, settings)
        with self._instrument(inputs, input_type, settings) as finish:
            result = await self.wrapped.embed(inputs, input_type=input_type, settings=settings)
            finish(result)
            return result

    @contextmanager
    def _instrument(
        self,
        inputs: list[str],
        input_type: EmbedInputType,
        settings: EmbeddingSettings | None,
    ) -> Iterator[Callable[[EmbeddingResult], None]]:
        operation = 'embeddings'
        span_name = f'{operation} {self.model_name}'

        inputs_count = len(inputs)

        attributes: dict[str, AttributeValue] = {
            'gen_ai.operation.name': operation,
            **self.model_attributes(self.wrapped),
            'input_type': input_type,
            'inputs_count': inputs_count,
        }

        if settings:
            attributes['embedding_settings'] = json.dumps(self.serialize_any(settings))

        if self.instrumentation_settings.include_content:
            attributes['inputs'] = json.dumps(inputs)

        attributes['logfire.json_schema'] = json.dumps(
            {
                'type': 'object',
                'properties': {
                    'input_type': {'type': 'string'},
                    'inputs_count': {'type': 'integer'},
                    'embedding_settings': {'type': 'object'},
                    **(
                        {'inputs': {'type': ['array']}, 'embeddings': {'type': 'array'}}
                        if self.instrumentation_settings.include_content
                        else {}
                    ),
                },
            }
        )

        record_metrics: Callable[[], None] | None = None
        try:
            with self.instrumentation_settings.tracer.start_as_current_span(span_name, attributes=attributes) as span:

                def finish(result: EmbeddingResult):
                    # Prepare metric recording closure first so metrics are recorded
                    # even if the span is not recording.
                    provider_name = attributes[GEN_AI_PROVIDER_NAME_ATTRIBUTE]
                    request_model = attributes[GEN_AI_REQUEST_MODEL_ATTRIBUTE]
                    response_model = result.model_name or request_model
                    price_calculation = None

                    def _record_metrics():
                        token_attributes = {
                            GEN_AI_PROVIDER_NAME_ATTRIBUTE: provider_name,
                            'gen_ai.operation.name': operation,
                            GEN_AI_REQUEST_MODEL_ATTRIBUTE: request_model,
                            'gen_ai.response.model': response_model,
                            'gen_ai.token.type': 'input',
                        }
                        tokens = result.usage.input_tokens or 0
                        if tokens:
                            self.instrumentation_settings.tokens_histogram.record(tokens, token_attributes)
                            if price_calculation is not None:
                                self.instrumentation_settings.cost_histogram.record(
                                    float(getattr(price_calculation, 'input_price', 0.0)),
                                    token_attributes,
                                )

                    nonlocal record_metrics
                    record_metrics = _record_metrics

                    if not span.is_recording():
                        return

                    attributes_to_set: dict[str, AttributeValue] = {
                        **result.usage.opentelemetry_attributes(),
                        'gen_ai.response.model': response_model,
                    }

                    try:
                        price_calculation = result.cost()
                    except LookupError:
                        # The cost of this provider/model is unknown, which is common.
                        pass
                    except Exception as e:
                        warnings.warn(
                            f'Failed to get cost from response: {type(e).__name__}: {e}', CostCalculationFailedWarning
                        )
                    else:
                        attributes_to_set['operation.cost'] = float(price_calculation.total_price)

                    embeddings = result.embeddings
                    if embeddings:
                        attributes_to_set['gen_ai.embeddings.dimension.count'] = len(embeddings[0])
                        if self.instrumentation_settings.include_content:
                            attributes['embeddings'] = json.dumps(embeddings)

                    if result.provider_response_id is not None:
                        attributes_to_set['gen_ai.response.id'] = result.provider_response_id

                    span.set_attributes(attributes_to_set)

                yield finish
        finally:
            if record_metrics:
                # Record metrics after the span finishes to avoid duplication.
                record_metrics()

    @staticmethod
    def model_attributes(model: EmbeddingModel) -> dict[str, AttributeValue]:
        attributes: dict[str, AttributeValue] = {
            GEN_AI_PROVIDER_NAME_ATTRIBUTE: model.system,
            GEN_AI_REQUEST_MODEL_ATTRIBUTE: model.model_name,
        }
        if base_url := model.base_url:
            try:
                parsed = urlparse(base_url)
            except Exception:  # pragma: no cover
                pass
            else:
                if parsed.hostname:  # pragma: no branch
                    attributes['server.address'] = parsed.hostname
                if parsed.port:  # pragma: no branch
                    attributes['server.port'] = parsed.port

        return attributes

    @staticmethod
    def serialize_any(value: Any) -> str:
        try:
            return ANY_ADAPTER.dump_python(value, mode='json')
        except Exception:
            try:
                return str(value)
            except Exception as e:
                return f'Unable to serialize: {e}'

instrumentation_settings instance-attribute

instrumentation_settings: InstrumentationSettings = (
    options or InstrumentationSettings()
)

Instrumentation settings for this model.

instrument_embedding_model

instrument_embedding_model(
    model: EmbeddingModel,
    instrument: InstrumentationSettings | bool,
) -> EmbeddingModel

Instrument an embedding model with OpenTelemetry/logfire.

Source code in pydantic_ai_slim/pydantic_ai/embeddings/instrumented.py
30
31
32
33
34
35
36
37
38
def instrument_embedding_model(model: EmbeddingModel, instrument: InstrumentationSettings | bool) -> EmbeddingModel:
    """Instrument an embedding model with OpenTelemetry/logfire."""
    if instrument and not isinstance(model, InstrumentedEmbeddingModel):
        if instrument is True:
            instrument = InstrumentationSettings()

        model = InstrumentedEmbeddingModel(model, instrument)

    return model

EmbeddingResult dataclass

The result of an embedding operation.

Source code in pydantic_ai_slim/pydantic_ai/embeddings/result.py
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
@dataclass
class EmbeddingResult:
    """The result of an embedding operation."""

    embeddings: Sequence[Sequence[float]]

    _: KW_ONLY

    inputs: Sequence[str]

    input_type: EmbedInputType

    model_name: str

    provider_name: str

    timestamp: datetime = field(default_factory=_now_utc)

    usage: RequestUsage = field(default_factory=RequestUsage)

    provider_details: dict[str, Any] | None = None

    provider_response_id: str | None = None

    def __getitem__(self, item: int | str) -> Sequence[float]:
        """Get the embedding for an input or input index."""
        if isinstance(item, str):
            item = self.inputs.index(item)

        return self.embeddings[item]

    def cost(self) -> genai_types.PriceCalculation:
        """Calculate the cost of the usage.

        Uses [`genai-prices`](https://github.com/pydantic/genai-prices).
        """
        assert self.model_name, 'Model name is required to calculate price'
        return calc_price(
            self.usage,
            self.model_name,
            provider_id=self.provider_name,
            genai_request_timestamp=self.timestamp,
        )

__getitem__

__getitem__(item: int | str) -> Sequence[float]

Get the embedding for an input or input index.

Source code in pydantic_ai_slim/pydantic_ai/embeddings/result.py
39
40
41
42
43
44
def __getitem__(self, item: int | str) -> Sequence[float]:
    """Get the embedding for an input or input index."""
    if isinstance(item, str):
        item = self.inputs.index(item)

    return self.embeddings[item]

cost

cost() -> PriceCalculation

Calculate the cost of the usage.

Uses genai-prices.

Source code in pydantic_ai_slim/pydantic_ai/embeddings/result.py
46
47
48
49
50
51
52
53
54
55
56
57
def cost(self) -> genai_types.PriceCalculation:
    """Calculate the cost of the usage.

    Uses [`genai-prices`](https://github.com/pydantic/genai-prices).
    """
    assert self.model_name, 'Model name is required to calculate price'
    return calc_price(
        self.usage,
        self.model_name,
        provider_id=self.provider_name,
        genai_request_timestamp=self.timestamp,
    )

EmbeddingSettings

Bases: TypedDict

Settings to configure an embedding model.

Here we include only settings which apply to multiple models / model providers, though not all of these settings are supported by all models.

Source code in pydantic_ai_slim/pydantic_ai/embeddings/settings.py
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
class EmbeddingSettings(TypedDict, total=False):
    """Settings to configure an embedding model.

    Here we include only settings which apply to multiple models / model providers,
    though not all of these settings are supported by all models.
    """

    dimensions: int
    """The number of dimensions the resulting output embeddings should have.

    Supported by:

    * OpenAI
    * Cohere
    """

    extra_headers: dict[str, str]
    """Extra headers to send to the model.

    Supported by:

    * OpenAI
    * Cohere
    """

    extra_body: object
    """Extra body to send to the model.

    Supported by:

    * OpenAI
    * Cohere
    """

dimensions instance-attribute

dimensions: int

The number of dimensions the resulting output embeddings should have.

Supported by:

  • OpenAI
  • Cohere

extra_headers instance-attribute

extra_headers: dict[str, str]

Extra headers to send to the model.

Supported by:

  • OpenAI
  • Cohere

extra_body instance-attribute

extra_body: object

Extra body to send to the model.

Supported by:

  • OpenAI
  • Cohere

merge_embedding_settings

merge_embedding_settings(
    base: EmbeddingSettings | None,
    overrides: EmbeddingSettings | None,
) -> EmbeddingSettings | None

Merge two sets of embedding settings, preferring the overrides.

A common use case is: merge_embedding_settings(, )

Source code in pydantic_ai_slim/pydantic_ai/embeddings/settings.py
39
40
41
42
43
44
45
46
47
48
49
50
def merge_embedding_settings(
    base: EmbeddingSettings | None, overrides: EmbeddingSettings | None
) -> EmbeddingSettings | None:
    """Merge two sets of embedding settings, preferring the overrides.

    A common use case is: merge_embedding_settings(<embedder settings>, <run settings>)
    """
    # Note: we may want merge recursively if/when we add non-primitive values
    if base and overrides:
        return base | overrides
    else:
        return base or overrides

WrapperEmbeddingModel dataclass

Bases: EmbeddingModel

Embedding model which wraps another embedding model.

Does nothing on its own, used as a base class.

Source code in pydantic_ai_slim/pydantic_ai/embeddings/wrapper.py
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
@dataclass(init=False)
class WrapperEmbeddingModel(EmbeddingModel):
    """Embedding model which wraps another embedding model.

    Does nothing on its own, used as a base class.
    """

    wrapped: EmbeddingModel
    """The underlying embedding model being wrapped."""

    def __init__(self, wrapped: EmbeddingModel | str):
        from . import infer_embedding_model

        super().__init__()
        self.wrapped = infer_embedding_model(wrapped) if isinstance(wrapped, str) else wrapped

    async def embed(
        self, inputs: str | Sequence[str], *, input_type: EmbedInputType, settings: EmbeddingSettings | None = None
    ) -> EmbeddingResult:
        return await self.wrapped.embed(inputs, input_type=input_type, settings=settings)

    @property
    def model_name(self) -> str:
        return self.wrapped.model_name

    @property
    def system(self) -> str:
        return self.wrapped.system

    @property
    def settings(self) -> EmbeddingSettings | None:
        """Get the settings from the wrapped embedding model."""
        return self.wrapped.settings

    @property
    def base_url(self) -> str | None:
        return self.wrapped.base_url

    def __getattr__(self, item: str):
        return getattr(self.wrapped, item)

wrapped instance-attribute

wrapped: EmbeddingModel = (
    infer_embedding_model(wrapped)
    if isinstance(wrapped, str)
    else wrapped
)

The underlying embedding model being wrapped.

settings property

settings: EmbeddingSettings | None

Get the settings from the wrapped embedding model.

KnownEmbeddingModelName module-attribute

KnownEmbeddingModelName = TypeAliasType(
    "KnownEmbeddingModelName",
    Literal[
        "openai:text-embedding-ada-002",
        "openai:text-embedding-3-small",
        "openai:text-embedding-3-large",
        "cohere:embed-v4.0",
        "cohere:embed-english-v3.0",
        "cohere:embed-english-light-v3.0",
        "cohere:embed-multilingual-v3.0",
        "cohere:embed-multilingual-light-v3.0",
    ],
)

Known model names that can be used with the model parameter of Embedder.

KnownEmbeddingModelName is provided as a concise way to specify an embedding model.

infer_embedding_model

infer_embedding_model(
    model: EmbeddingModel | KnownEmbeddingModelName | str,
    *,
    provider_factory: Callable[
        [str], Provider[Any]
    ] = infer_provider
) -> EmbeddingModel

Infer the model from the name.

Source code in pydantic_ai_slim/pydantic_ai/embeddings/__init__.py
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
def infer_embedding_model(
    model: EmbeddingModel | KnownEmbeddingModelName | str,
    *,
    provider_factory: Callable[[str], Provider[Any]] = infer_provider,
) -> EmbeddingModel:
    """Infer the model from the name."""
    if isinstance(model, EmbeddingModel):
        return model

    try:
        provider_name, model_name = model.split(':', maxsplit=1)
    except ValueError as e:
        raise ValueError('You must provide a provider prefix when specifying an embedding model name') from e

    provider = provider_factory(provider_name)

    model_kind = provider_name
    if model_kind.startswith('gateway/'):
        from ..providers.gateway import normalize_gateway_provider

        model_kind = normalize_gateway_provider(model_kind)

    if model_kind in (
        'openai',
        # For now, we assume that every chat and completions-compatible provider also
        # supports the embeddings endpoint, as at worst the user would get an `ModelHTTPError`.
        *get_args(OpenAIChatCompatibleProvider.__value__),
        *get_args(OpenAIResponsesCompatibleProvider.__value__),
    ):
        from .openai import OpenAIEmbeddingModel

        return OpenAIEmbeddingModel(model_name, provider=provider)
    elif model_kind == 'cohere':
        from .cohere import CohereEmbeddingModel

        return CohereEmbeddingModel(model_name, provider=provider)
    elif model_kind == 'sentence-transformers':
        from .sentence_transformers import SentenceTransformerEmbeddingModel

        return SentenceTransformerEmbeddingModel(model_name)
    else:
        raise UserError(f'Unknown embeddings model: {model}')  # pragma: no cover

Embedder dataclass

TODO: Docstring.

Source code in pydantic_ai_slim/pydantic_ai/embeddings/__init__.py
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
@dataclass(init=False)
class Embedder:
    """TODO: Docstring."""

    instrument: InstrumentationSettings | bool | None
    """Options to automatically instrument with OpenTelemetry.

    Set to `True` to use default instrumentation settings, which will use Logfire if it's configured.
    Set to an instance of [`InstrumentationSettings`][pydantic_ai.models.instrumented.InstrumentationSettings] to customize.
    If this isn't set, then the last value set by
    [`Embedder.instrument_all()`][pydantic_ai.embeddings.Embedder.instrument_all]
    will be used, which defaults to False.
    See the [Debugging and Monitoring guide](https://ai.pydantic.dev/logfire/) for more info.
    """

    _instrument_default: ClassVar[InstrumentationSettings | bool] = False

    def __init__(
        self,
        model: EmbeddingModel | KnownEmbeddingModelName | str,
        *,
        settings: EmbeddingSettings | None = None,
        defer_model_check: bool = True,
        instrument: InstrumentationSettings | bool | None = None,
    ) -> None:
        """Initialize an Embedder.

        Args:
            model: The embedding model to use - can be a model instance, model name, or string.
            settings: Optional embedding settings to use as defaults.
            defer_model_check: Whether to defer model validation until first use.
            instrument: OpenTelemetry instrumentation settings. Set to `True` to enable with defaults,
                or pass an `InstrumentationSettings` instance to customize. If `None`, uses the value
                from `Embedder.instrument_all()`.
        """
        self._model = model if defer_model_check else infer_embedding_model(model)
        self._settings = settings
        self.instrument = instrument

        self._override_model: ContextVar[EmbeddingModel | None] = ContextVar('_override_model', default=None)

    @staticmethod
    def instrument_all(instrument: InstrumentationSettings | bool = True) -> None:
        """Set the instrumentation options for all embedders where `instrument` is not set.

        Args:
            instrument: Instrumentation settings to use as the default. Set to `True` for default settings,
                `False` to disable, or pass an `InstrumentationSettings` instance to customize.
        """
        Embedder._instrument_default = instrument

    @property
    def model(self) -> EmbeddingModel | KnownEmbeddingModelName | str:
        return self._model

    @contextmanager
    def override(
        self,
        *,
        model: EmbeddingModel | KnownEmbeddingModelName | str | _utils.Unset = _utils.UNSET,
    ) -> Iterator[None]:
        if _utils.is_set(model):
            model_token = self._override_model.set(infer_embedding_model(model))
        else:
            model_token = None

        try:
            yield
        finally:
            if model_token is not None:
                self._override_model.reset(model_token)

    async def embed_query(
        self, query: str | Sequence[str], *, settings: EmbeddingSettings | None = None
    ) -> EmbeddingResult:
        return await self.embed(query, input_type='query', settings=settings)

    async def embed_documents(
        self, documents: str | Sequence[str], *, settings: EmbeddingSettings | None = None
    ) -> EmbeddingResult:
        return await self.embed(documents, input_type='document', settings=settings)

    async def embed(
        self, inputs: str | Sequence[str], *, input_type: EmbedInputType, settings: EmbeddingSettings | None = None
    ) -> EmbeddingResult:
        model = self._get_model()
        settings = merge_embedding_settings(self._settings, settings)
        return await model.embed(inputs, input_type=input_type, settings=settings)

    def embed_query_sync(
        self, query: str | Sequence[str], *, settings: EmbeddingSettings | None = None
    ) -> EmbeddingResult:
        return _utils.get_event_loop().run_until_complete(self.embed_query(query, settings=settings))

    def embed_documents_sync(
        self, documents: str | Sequence[str], *, settings: EmbeddingSettings | None = None
    ) -> EmbeddingResult:
        return _utils.get_event_loop().run_until_complete(self.embed_documents(documents, settings=settings))

    def embed_sync(
        self, inputs: str | Sequence[str], *, input_type: EmbedInputType, settings: EmbeddingSettings | None = None
    ) -> EmbeddingResult:
        return _utils.get_event_loop().run_until_complete(self.embed(inputs, input_type=input_type, settings=settings))

    async def max_input_tokens(self) -> int | None:
        model = self._get_model()
        return await model.max_input_tokens()

    def max_input_tokens_sync(self) -> int | None:
        return _utils.get_event_loop().run_until_complete(self.max_input_tokens())

    async def count_tokens(self, text: str) -> int:
        model = self._get_model()
        return await model.count_tokens(text)

    def count_tokens_sync(self, text: str) -> int:
        return _utils.get_event_loop().run_until_complete(self.count_tokens(text))

    def _get_model(self) -> EmbeddingModel:
        """Create a model configured for this embedder.

        Returns:
            The embedding model to use, with instrumentation applied if configured.
        """
        model_: EmbeddingModel
        if some_model := self._override_model.get():
            model_ = some_model
        else:
            model_ = self._model = infer_embedding_model(self.model)

        instrument = self.instrument
        if instrument is None:
            instrument = self._instrument_default

        return instrument_embedding_model(model_, instrument)

__init__

__init__(
    model: EmbeddingModel | KnownEmbeddingModelName | str,
    *,
    settings: EmbeddingSettings | None = None,
    defer_model_check: bool = True,
    instrument: InstrumentationSettings | bool | None = None
) -> None

Initialize an Embedder.

Parameters:

Name Type Description Default
model EmbeddingModel | KnownEmbeddingModelName | str

The embedding model to use - can be a model instance, model name, or string.

required
settings EmbeddingSettings | None

Optional embedding settings to use as defaults.

None
defer_model_check bool

Whether to defer model validation until first use.

True
instrument InstrumentationSettings | bool | None

OpenTelemetry instrumentation settings. Set to True to enable with defaults, or pass an InstrumentationSettings instance to customize. If None, uses the value from Embedder.instrument_all().

None
Source code in pydantic_ai_slim/pydantic_ai/embeddings/__init__.py
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
def __init__(
    self,
    model: EmbeddingModel | KnownEmbeddingModelName | str,
    *,
    settings: EmbeddingSettings | None = None,
    defer_model_check: bool = True,
    instrument: InstrumentationSettings | bool | None = None,
) -> None:
    """Initialize an Embedder.

    Args:
        model: The embedding model to use - can be a model instance, model name, or string.
        settings: Optional embedding settings to use as defaults.
        defer_model_check: Whether to defer model validation until first use.
        instrument: OpenTelemetry instrumentation settings. Set to `True` to enable with defaults,
            or pass an `InstrumentationSettings` instance to customize. If `None`, uses the value
            from `Embedder.instrument_all()`.
    """
    self._model = model if defer_model_check else infer_embedding_model(model)
    self._settings = settings
    self.instrument = instrument

    self._override_model: ContextVar[EmbeddingModel | None] = ContextVar('_override_model', default=None)

instrument instance-attribute

instrument: InstrumentationSettings | bool | None = (
    instrument
)

Options to automatically instrument with OpenTelemetry.

Set to True to use default instrumentation settings, which will use Logfire if it's configured. Set to an instance of InstrumentationSettings to customize. If this isn't set, then the last value set by Embedder.instrument_all() will be used, which defaults to False. See the Debugging and Monitoring guide for more info.

instrument_all staticmethod

instrument_all(
    instrument: InstrumentationSettings | bool = True,
) -> None

Set the instrumentation options for all embedders where instrument is not set.

Parameters:

Name Type Description Default
instrument InstrumentationSettings | bool

Instrumentation settings to use as the default. Set to True for default settings, False to disable, or pass an InstrumentationSettings instance to customize.

True
Source code in pydantic_ai_slim/pydantic_ai/embeddings/__init__.py
142
143
144
145
146
147
148
149
150
@staticmethod
def instrument_all(instrument: InstrumentationSettings | bool = True) -> None:
    """Set the instrumentation options for all embedders where `instrument` is not set.

    Args:
        instrument: Instrumentation settings to use as the default. Set to `True` for default settings,
            `False` to disable, or pass an `InstrumentationSettings` instance to customize.
    """
    Embedder._instrument_default = instrument

OpenAIEmbeddingModelName module-attribute

OpenAIEmbeddingModelName = str | EmbeddingModel

Possible OpenAI embeddings model names.

OpenAIEmbeddingSettings

Bases: EmbeddingSettings

Settings used for an OpenAI embedding model request.

Source code in pydantic_ai_slim/pydantic_ai/embeddings/openai.py
32
33
class OpenAIEmbeddingSettings(EmbeddingSettings, total=False):
    """Settings used for an OpenAI embedding model request."""

OpenAIEmbeddingModel dataclass

Bases: EmbeddingModel

OpenAI embedding model.

Source code in pydantic_ai_slim/pydantic_ai/embeddings/openai.py
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
@dataclass(init=False)
class OpenAIEmbeddingModel(EmbeddingModel):
    """OpenAI embedding model."""

    _model_name: OpenAIEmbeddingModelName = field(repr=False)
    _provider: Provider[AsyncOpenAI] = field(repr=False)

    def __init__(
        self,
        model_name: OpenAIEmbeddingModelName,
        *,
        provider: OpenAIEmbeddingsCompatibleProvider | Literal['openai'] | Provider[AsyncOpenAI] = 'openai',
        settings: EmbeddingSettings | None = None,
    ):
        """Initialize an OpenAI embedding model.

        Args:
            model_name: The name of the OpenAI model to use. List of model names
                available [here](https://platform.openai.com/docs/guides/embeddings#embedding-models).
            provider: The provider to use for authentication and API access. Can be either the string
                'openai' or an instance of `Provider[AsyncOpenAI]`. If not provided, a new provider will be
                created using the other parameters.
            settings: Model-specific settings that will be used as defaults for this model.
        """
        self._model_name = model_name

        if isinstance(provider, str):
            provider = infer_provider(provider)
        self._provider = provider
        self._client = provider.client

        super().__init__(settings=settings)

    @property
    def base_url(self) -> str:
        return str(self._client.base_url)

    @property
    def model_name(self) -> OpenAIEmbeddingModelName:
        """The embedding model name."""
        return self._model_name

    @property
    def system(self) -> str:
        """The embedding model provider."""
        return self._provider.name

    async def embed(
        self, inputs: str | Sequence[str], *, input_type: EmbedInputType, settings: EmbeddingSettings | None = None
    ) -> EmbeddingResult:
        inputs, settings = self.prepare_embed(inputs, settings)
        settings = cast(OpenAIEmbeddingSettings, settings)

        try:
            response = await self._client.embeddings.create(
                input=inputs,
                model=self.model_name,
                dimensions=settings.get('dimensions') or OMIT,
                extra_headers=settings.get('extra_headers'),
                extra_body=settings.get('extra_body'),
            )
        except APIStatusError as e:
            if (status_code := e.status_code) >= 400:
                raise ModelHTTPError(status_code=status_code, model_name=self.model_name, body=e.body) from e
            raise  # pragma: lax no cover
        except APIConnectionError as e:
            raise ModelAPIError(model_name=self.model_name, message=e.message) from e

        embeddings = [item.embedding for item in response.data]

        return EmbeddingResult(
            embeddings=embeddings,
            inputs=inputs,
            input_type=input_type,
            usage=_map_usage(response.usage, self.system, self.base_url, response.model),
            model_name=response.model,
            provider_name=self.system,
        )

    async def max_input_tokens(self) -> int | None:
        if self.system != 'openai':
            return None

        # https://platform.openai.com/docs/guides/embeddings#embedding-models
        return 8192

    async def count_tokens(self, text: str) -> int:
        if self.system != 'openai':
            raise UserError(
                'Counting tokens is not supported for non-OpenAI embedding models',
            )
        try:
            encoding = await _utils.run_in_executor(tiktoken.encoding_for_model, self.model_name)
        except KeyError as e:
            raise ValueError(
                f'The embedding model {self.model_name!r} is not supported by tiktoken',
            ) from e
        return len(encoding.encode(text))

__init__

__init__(
    model_name: OpenAIEmbeddingModelName,
    *,
    provider: (
        OpenAIEmbeddingsCompatibleProvider
        | Literal["openai"]
        | Provider[AsyncOpenAI]
    ) = "openai",
    settings: EmbeddingSettings | None = None
)

Initialize an OpenAI embedding model.

Parameters:

Name Type Description Default
model_name OpenAIEmbeddingModelName

The name of the OpenAI model to use. List of model names available here.

required
provider OpenAIEmbeddingsCompatibleProvider | Literal['openai'] | Provider[AsyncOpenAI]

The provider to use for authentication and API access. Can be either the string 'openai' or an instance of Provider[AsyncOpenAI]. If not provided, a new provider will be created using the other parameters.

'openai'
settings EmbeddingSettings | None

Model-specific settings that will be used as defaults for this model.

None
Source code in pydantic_ai_slim/pydantic_ai/embeddings/openai.py
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
def __init__(
    self,
    model_name: OpenAIEmbeddingModelName,
    *,
    provider: OpenAIEmbeddingsCompatibleProvider | Literal['openai'] | Provider[AsyncOpenAI] = 'openai',
    settings: EmbeddingSettings | None = None,
):
    """Initialize an OpenAI embedding model.

    Args:
        model_name: The name of the OpenAI model to use. List of model names
            available [here](https://platform.openai.com/docs/guides/embeddings#embedding-models).
        provider: The provider to use for authentication and API access. Can be either the string
            'openai' or an instance of `Provider[AsyncOpenAI]`. If not provided, a new provider will be
            created using the other parameters.
        settings: Model-specific settings that will be used as defaults for this model.
    """
    self._model_name = model_name

    if isinstance(provider, str):
        provider = infer_provider(provider)
    self._provider = provider
    self._client = provider.client

    super().__init__(settings=settings)

model_name property

The embedding model name.

system property

system: str

The embedding model provider.

LatestCohereEmbeddingModelNames module-attribute

LatestCohereEmbeddingModelNames = Literal[
    "embed-v4.0",
    "embed-english-v3.0",
    "embed-english-light-v3.0",
    "embed-multilingual-v3.0",
    "embed-multilingual-light-v3.0",
]

Latest Cohere embeddings models.

CohereEmbeddingModelName module-attribute

CohereEmbeddingModelName = (
    str | LatestCohereEmbeddingModelNames
)

Possible Cohere embeddings model names.

CohereEmbeddingSettings

Bases: EmbeddingSettings

Settings used for a Cohere embedding model request.

Source code in pydantic_ai_slim/pydantic_ai/embeddings/cohere.py
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
class CohereEmbeddingSettings(EmbeddingSettings, total=False):
    """Settings used for a Cohere embedding model request."""

    # ALL FIELDS MUST BE `cohere_` PREFIXED SO YOU CAN MERGE THEM WITH OTHER MODELS.

    cohere_max_tokens: int
    """The maximum number of tokens to generate before stopping."""

    cohere_input_type: CohereEmbedInputType
    """The input type to use for the embedding model. Overrides the `input_type` argument which only takes `query` and `document`."""

    cohere_truncate: V2EmbedRequestTruncate
    """The truncation strategy to use for the embedding model:

    - `NONE` (default): Do not truncate the input text and raise an error if the input text is too long.
    - `END`: Truncate the input text to the maximum number of tokens.
    - `START`: Truncate the start of the input text.
    """

cohere_max_tokens instance-attribute

cohere_max_tokens: int

The maximum number of tokens to generate before stopping.

cohere_input_type instance-attribute

cohere_input_type: EmbedInputType

The input type to use for the embedding model. Overrides the input_type argument which only takes query and document.

cohere_truncate instance-attribute

cohere_truncate: V2EmbedRequestTruncate

The truncation strategy to use for the embedding model:

  • NONE (default): Do not truncate the input text and raise an error if the input text is too long.
  • END: Truncate the input text to the maximum number of tokens.
  • START: Truncate the start of the input text.

CohereEmbeddingModel dataclass

Bases: EmbeddingModel

Cohere embedding model.

Source code in pydantic_ai_slim/pydantic_ai/embeddings/cohere.py
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
@dataclass(init=False)
class CohereEmbeddingModel(EmbeddingModel):
    """Cohere embedding model."""

    _model_name: CohereEmbeddingModelName = field(repr=False)
    _provider: Provider[AsyncClientV2] = field(repr=False)

    def __init__(
        self,
        model_name: CohereEmbeddingModelName,
        *,
        provider: Literal['cohere'] | Provider[AsyncClientV2] = 'cohere',
        settings: EmbeddingSettings | None = None,
    ):
        """Initialize an Cohere model.

        Args:
            model_name: The name of the Cohere model to use. List of model names
                available [here](https://docs.cohere.com/docs/cohere-embed).
            provider: The provider to use for authentication and API access. Can be either the string
                'cohere' or an instance of `CohereProvider`. If not provided, a new provider will be
                created using the other parameters.
            settings: Model-specific settings that will be used as defaults for this model.
        """
        self._model_name = model_name

        if isinstance(provider, str):
            provider = infer_provider(provider)
        self._provider = provider
        self._client = provider.client
        self._v1_client = provider.v1_client if isinstance(provider, CohereProvider) else None

        super().__init__(settings=settings)

    @property
    def base_url(self) -> str:
        """The base URL for the provider API, if available."""
        return self._provider.base_url

    @property
    def model_name(self) -> CohereEmbeddingModelName:
        """The embedding model name."""
        return self._model_name

    @property
    def system(self) -> str:
        """The embedding model provider."""
        return self._provider.name

    async def embed(
        self, inputs: str | Sequence[str], *, input_type: EmbedInputType, settings: EmbeddingSettings | None = None
    ) -> EmbeddingResult:
        inputs, settings = self.prepare_embed(inputs, settings)
        settings = cast(CohereEmbeddingSettings, settings)

        request_options = RequestOptions()
        if extra_headers := settings.get('extra_headers'):
            request_options['additional_headers'] = extra_headers
        if extra_body := settings.get('extra_body'):
            request_options['additional_body_parameters'] = cast(dict[str, Any], extra_body)

        cohere_input_type = settings.get(
            'cohere_input_type', 'search_document' if input_type == 'document' else 'search_query'
        )

        try:
            response = await self._client.embed(
                model=self.model_name,
                texts=inputs,
                output_dimension=settings.get('dimensions'),
                input_type=cohere_input_type,
                max_tokens=settings.get('cohere_max_tokens'),
                truncate=settings.get('cohere_truncate', 'NONE'),
                request_options=request_options,
            )
        except ApiError as e:
            if (status_code := e.status_code) and status_code >= 400:
                raise ModelHTTPError(status_code=status_code, model_name=self.model_name, body=e.body) from e
            raise ModelAPIError(model_name=self.model_name, message=str(e)) from e

        embeddings = response.embeddings.float_
        if embeddings is None:
            raise UnexpectedModelBehavior(
                'The Cohere embeddings response did not have an `embeddings` field holding a list of floats',
                str(response),
            )

        return EmbeddingResult(
            embeddings=embeddings,
            inputs=inputs,
            input_type=input_type,
            usage=_map_usage(response),
            model_name=self.model_name,
            provider_name=self.system,
            provider_response_id=response.id,
        )

    async def max_input_tokens(self) -> int | None:
        return _MAX_INPUT_TOKENS.get(self.model_name)

    async def count_tokens(self, text: str) -> int:
        if self._v1_client is None:
            raise NotImplementedError('Counting tokens requires the Cohere v1 client')
        try:
            result = await self._v1_client.tokenize(
                model=self.model_name,
                text=text,  # Has a max length of 65536 characters
                offline=False,
            )
        except ApiError as e:
            if (status_code := e.status_code) and status_code >= 400:
                raise ModelHTTPError(status_code=status_code, model_name=self.model_name, body=e.body) from e
            raise ModelAPIError(model_name=self.model_name, message=str(e)) from e

        return len(result.tokens)

__init__

__init__(
    model_name: CohereEmbeddingModelName,
    *,
    provider: (
        Literal["cohere"] | Provider[AsyncClientV2]
    ) = "cohere",
    settings: EmbeddingSettings | None = None
)

Initialize an Cohere model.

Parameters:

Name Type Description Default
model_name CohereEmbeddingModelName

The name of the Cohere model to use. List of model names available here.

required
provider Literal['cohere'] | Provider[AsyncClientV2]

The provider to use for authentication and API access. Can be either the string 'cohere' or an instance of CohereProvider. If not provided, a new provider will be created using the other parameters.

'cohere'
settings EmbeddingSettings | None

Model-specific settings that will be used as defaults for this model.

None
Source code in pydantic_ai_slim/pydantic_ai/embeddings/cohere.py
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
def __init__(
    self,
    model_name: CohereEmbeddingModelName,
    *,
    provider: Literal['cohere'] | Provider[AsyncClientV2] = 'cohere',
    settings: EmbeddingSettings | None = None,
):
    """Initialize an Cohere model.

    Args:
        model_name: The name of the Cohere model to use. List of model names
            available [here](https://docs.cohere.com/docs/cohere-embed).
        provider: The provider to use for authentication and API access. Can be either the string
            'cohere' or an instance of `CohereProvider`. If not provided, a new provider will be
            created using the other parameters.
        settings: Model-specific settings that will be used as defaults for this model.
    """
    self._model_name = model_name

    if isinstance(provider, str):
        provider = infer_provider(provider)
    self._provider = provider
    self._client = provider.client
    self._v1_client = provider.v1_client if isinstance(provider, CohereProvider) else None

    super().__init__(settings=settings)

base_url property

base_url: str

The base URL for the provider API, if available.

model_name property

The embedding model name.

system property

system: str

The embedding model provider.

SentenceTransformersEmbeddingSettings

Bases: EmbeddingSettings

Settings used for a Sentence-Transformers embedding model request.

All fields are sentence_transformers_-prefixed so settings can be merged across providers safely.

Source code in pydantic_ai_slim/pydantic_ai/embeddings/sentence_transformers.py
27
28
29
30
31
32
33
34
35
36
37
38
39
40
class SentenceTransformersEmbeddingSettings(EmbeddingSettings, total=False):
    """Settings used for a Sentence-Transformers embedding model request.

    All fields are `sentence_transformers_`-prefixed so settings can be merged across providers safely.
    """

    sentence_transformers_device: str
    """Device to run inference on, e.g. "cpu", "cuda", "cuda:0", "mps"."""

    sentence_transformers_normalize_embeddings: bool
    """Whether to L2-normalize embeddings. Mirrors `normalize_embeddings` in SentenceTransformer.encode."""

    sentence_transformers_batch_size: int
    """Batch size to use during encoding."""

sentence_transformers_device instance-attribute

sentence_transformers_device: str

Device to run inference on, e.g. "cpu", "cuda", "cuda:0", "mps".

sentence_transformers_normalize_embeddings instance-attribute

sentence_transformers_normalize_embeddings: bool

Whether to L2-normalize embeddings. Mirrors normalize_embeddings in SentenceTransformer.encode.

sentence_transformers_batch_size instance-attribute

sentence_transformers_batch_size: int

Batch size to use during encoding.

SentenceTransformerEmbeddingModel dataclass

Bases: EmbeddingModel

Local embeddings using sentence-transformers models.

Example models include "all-MiniLM-L6-v2" and many others hosted on Hugging Face.

Source code in pydantic_ai_slim/pydantic_ai/embeddings/sentence_transformers.py
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
@dataclass(init=False)
class SentenceTransformerEmbeddingModel(EmbeddingModel):
    """Local embeddings using `sentence-transformers` models.

    Example models include "all-MiniLM-L6-v2" and many others hosted on Hugging Face.
    """

    _model_name: str = field(repr=False)
    _model: SentenceTransformer | None = field(repr=False, default=None)

    def __init__(self, model: SentenceTransformer | str, *, settings: EmbeddingSettings | None = None) -> None:
        """Initialize a Sentence-Transformers embedding model.

        Args:
            model: The model name or local path to load with `SentenceTransformer`, or a `SentenceTransformer` instance.
            settings: Model-specific settings that will be used as defaults for this model.
        """
        if isinstance(model, str):
            self._model_name = model
        else:
            self._model = deepcopy(model)
            self._model_name = model.model_card_data.model_id or model.model_card_data.base_model or 'unknown'

        super().__init__(settings=settings)

    @property
    def base_url(self) -> str | None:
        """No base URL — runs locally."""
        return None

    @property
    def model_name(self) -> str:
        """The embedding model name."""
        return self._model_name

    @property
    def system(self) -> str:
        """The embedding model provider/system identifier."""
        return 'sentence-transformers'

    async def embed(
        self, inputs: str | Sequence[str], *, input_type: EmbedInputType, settings: EmbeddingSettings | None = None
    ) -> EmbeddingResult:
        inputs, settings = self.prepare_embed(inputs, settings)
        settings = cast(SentenceTransformersEmbeddingSettings, settings)

        device = settings.get('sentence_transformers_device', None)
        normalize = settings.get('sentence_transformers_normalize_embeddings', False)
        batch_size = settings.get('sentence_transformers_batch_size', None)

        model = await self._get_model()
        encode_func = model.encode_query if input_type == 'query' else model.encode_document  # type: ignore[reportUnknownReturnType]

        np_embeddings: np.ndarray[Any, float] = await _utils.run_in_executor(  # type: ignore[reportAssignmentType]
            encode_func,  # type: ignore[reportArgumentType]
            inputs,
            show_progress_bar=False,
            convert_to_numpy=True,
            convert_to_tensor=False,
            device=device,
            normalize_embeddings=normalize,
            **{'batch_size': batch_size} if batch_size is not None else {},  # type: ignore[reportArgumentType]
        )
        embeddings = np_embeddings.tolist()  # type: ignore[reportAttributeAccessIssue]

        return EmbeddingResult(
            embeddings=embeddings,  # type: ignore[reportUnknownArgumentType]
            inputs=inputs,
            input_type=input_type,
            model_name=self.model_name,
            provider_name=self.system,
        )

    async def max_input_tokens(self) -> int | None:
        model = await self._get_model()
        return model.get_max_seq_length()

    async def count_tokens(self, text: str) -> int:
        model = await self._get_model()
        result: dict[str, torch.Tensor] = await _utils.run_in_executor(
            model.tokenize,  # type: ignore[reportArgumentType]
            [text],
        )
        if 'input_ids' not in result or not isinstance(result['input_ids'], torch.Tensor):
            raise UnexpectedModelBehavior(
                'The SentenceTransformers tokenizer output did not have an `input_ids` field holding a tensor',
                str(result),
            )
        return len(result['input_ids'][0])

    async def _get_model(self) -> SentenceTransformer:
        if self._model is None:
            # This may download the model from Hugging Face, so we do it in a thread
            self._model = await _utils.run_in_executor(SentenceTransformer, self.model_name)
        return self._model

__init__

__init__(
    model: SentenceTransformer | str,
    *,
    settings: EmbeddingSettings | None = None
) -> None

Initialize a Sentence-Transformers embedding model.

Parameters:

Name Type Description Default
model SentenceTransformer | str

The model name or local path to load with SentenceTransformer, or a SentenceTransformer instance.

required
settings EmbeddingSettings | None

Model-specific settings that will be used as defaults for this model.

None
Source code in pydantic_ai_slim/pydantic_ai/embeddings/sentence_transformers.py
53
54
55
56
57
58
59
60
61
62
63
64
65
66
def __init__(self, model: SentenceTransformer | str, *, settings: EmbeddingSettings | None = None) -> None:
    """Initialize a Sentence-Transformers embedding model.

    Args:
        model: The model name or local path to load with `SentenceTransformer`, or a `SentenceTransformer` instance.
        settings: Model-specific settings that will be used as defaults for this model.
    """
    if isinstance(model, str):
        self._model_name = model
    else:
        self._model = deepcopy(model)
        self._model_name = model.model_card_data.model_id or model.model_card_data.base_model or 'unknown'

    super().__init__(settings=settings)

base_url property

base_url: str | None

No base URL — runs locally.

model_name property

model_name: str

The embedding model name.

system property

system: str

The embedding model provider/system identifier.