OpenAI compatible embedding service¶
This example shows how to create an embedding service that is compatible with the OpenAI API.
In this example, we use the embedding model from Hugging Face LeaderBoard.
Server¶
EMB_MODEL=thenlper/gte-base python examples/embedding/server.py
# Copyright 2023 MOSEC Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""OpenAI compatible embedding server."""
import base64
import os
from typing import List, Union
import numpy as np
import torch # type: ignore
import torch.nn.functional as F # type: ignore
import transformers # type: ignore
from llmspec import EmbeddingData, EmbeddingRequest, EmbeddingResponse, TokenUsage
from mosec import ClientError, Runtime, Server, Worker
DEFAULT_MODEL = "thenlper/gte-base"
class Embedding(Worker):
def __init__(self):
self.model_name = os.environ.get("EMB_MODEL", DEFAULT_MODEL)
self.tokenizer = transformers.AutoTokenizer.from_pretrained(self.model_name)
self.model = transformers.AutoModel.from_pretrained(self.model_name)
self.device = (
torch.cuda.current_device() if torch.cuda.is_available() else "cpu"
)
self.model = self.model.to(self.device)
self.model.eval()
def get_embedding_with_token_count(
self, sentences: Union[str, List[Union[str, List[int]]]]
):
# Mean Pooling - Take attention mask into account for correct averaging
def mean_pooling(model_output, attention_mask):
# First element of model_output contains all token embeddings
token_embeddings = model_output[0]
input_mask_expanded = (
attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
)
return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(
input_mask_expanded.sum(1), min=1e-9
)
# Tokenize sentences
# TODO: support `List[List[int]]` input
encoded_input = self.tokenizer(
sentences, padding=True, truncation=True, return_tensors="pt"
)
inputs = encoded_input.to(self.device)
token_count = inputs["attention_mask"].sum(dim=1).tolist()[0]
# Compute token embeddings
model_output = self.model(**inputs)
# Perform pooling
sentence_embeddings = mean_pooling(model_output, inputs["attention_mask"])
# Normalize embeddings
sentence_embeddings = F.normalize(sentence_embeddings, p=2, dim=1)
return token_count, sentence_embeddings
def deserialize(self, data: bytes) -> EmbeddingRequest:
return EmbeddingRequest.from_bytes(data)
def serialize(self, data: EmbeddingResponse) -> bytes:
return data.to_json()
def forward(self, data: EmbeddingRequest) -> EmbeddingResponse:
if data.model != self.model_name:
raise ClientError(
f"the requested model {data.model} is not supported by "
f"this worker {self.model_name}"
)
token_count, embeddings = self.get_embedding_with_token_count(data.input)
embeddings = embeddings.detach()
if self.device != "cpu":
embeddings = embeddings.cpu()
embeddings = embeddings.numpy()
if data.encoding_format == "base64":
embeddings = [
base64.b64encode(emb.astype(np.float32).tobytes()).decode("utf-8")
for emb in embeddings
]
else:
embeddings = [emb.tolist() for emb in embeddings]
resp = EmbeddingResponse(
data=[
EmbeddingData(embedding=emb, index=i)
for i, emb in enumerate(embeddings)
],
model=self.model_name,
usage=TokenUsage(
prompt_tokens=token_count,
# No completions performed, only embeddings generated.
completion_tokens=0,
total_tokens=token_count,
),
)
return resp
if __name__ == "__main__":
server = Server()
emb = Runtime(Embedding)
server.register_runtime(
{
"/v1/embeddings": [emb],
"/embeddings": [emb],
}
)
server.run()
Client¶
EMB_MODEL=thenlper/gte-base python examples/embedding/client.py
# Copyright 2023 MOSEC Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""OpenAI embedding client example."""
import os
from openai import Client
DEFAULT_MODEL = "thenlper/gte-base"
client = Client(api_key="fake", base_url="http://127.0.0.1:8000/")
emb = client.embeddings.create(
model=os.environ.get("EMB_MODEL", DEFAULT_MODEL),
input="Hello world!",
)
print(emb.data[0].embedding) # type: ignore