OpenAI compatible embedding service#
This example shows how to create an embedding service that is compatible with the OpenAI API.
In this example, we use the embedding model from HuggingFace LeaderBoard.
Server#
EMB_MODEL=thenlper/gte-base python examples/embedding/server.py
# Copyright 2023 MOSEC Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""OpenAI compatible embedding server."""
import base64
import os
from typing import List, Union
import numpy as np
import torch # type: ignore
import torch.nn.functional as F # type: ignore
import transformers # type: ignore
from llmspec import EmbeddingData, EmbeddingRequest, EmbeddingResponse, TokenUsage
from mosec import ClientError, Runtime, Server, Worker
DEFAULT_MODEL = "thenlper/gte-base"
class Embedding(Worker):
def __init__(self):
self.model_name = os.environ.get("EMB_MODEL", DEFAULT_MODEL)
self.tokenizer = transformers.AutoTokenizer.from_pretrained(self.model_name)
self.model = transformers.AutoModel.from_pretrained(self.model_name)
self.device = (
torch.cuda.current_device() if torch.cuda.is_available() else "cpu"
)
self.model = self.model.to(self.device)
self.model.eval()
def get_embedding_with_token_count(
self, sentences: Union[str, List[Union[str, List[int]]]]
):
# Mean Pooling - Take attention mask into account for correct averaging
def mean_pooling(model_output, attention_mask):
# First element of model_output contains all token embeddings
token_embeddings = model_output[0]
input_mask_expanded = (
attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
)
return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(
input_mask_expanded.sum(1), min=1e-9
)
# Tokenize sentences
# TODO: support `List[List[int]]` input
encoded_input = self.tokenizer(
sentences, padding=True, truncation=True, return_tensors="pt"
)
inputs = encoded_input.to(self.device)
token_count = inputs["attention_mask"].sum(dim=1).tolist()[0]
# Compute token embeddings
model_output = self.model(**inputs)
# Perform pooling
sentence_embeddings = mean_pooling(model_output, inputs["attention_mask"])
# Normalize embeddings
sentence_embeddings = F.normalize(sentence_embeddings, p=2, dim=1)
return token_count, sentence_embeddings
def deserialize(self, data: bytes) -> EmbeddingRequest:
return EmbeddingRequest.from_bytes(data)
def serialize(self, data: EmbeddingResponse) -> bytes:
return data.to_json()
def forward(self, data: EmbeddingRequest) -> EmbeddingResponse:
if data.model != self.model_name:
raise ClientError(
f"the requested model {data.model} is not supported by "
f"this worker {self.model_name}"
)
token_count, embeddings = self.get_embedding_with_token_count(data.input)
embeddings = embeddings.detach()
if self.device != "cpu":
embeddings = embeddings.cpu()
embeddings = embeddings.numpy()
if data.encoding_format == "base64":
embeddings = [
base64.b64encode(emb.astype(np.float32).tobytes()).decode("utf-8")
for emb in embeddings
]
else:
embeddings = [emb.tolist() for emb in embeddings]
resp = EmbeddingResponse(
data=[
EmbeddingData(embedding=emb, index=i)
for i, emb in enumerate(embeddings)
],
model=self.model_name,
usage=TokenUsage(
prompt_tokens=token_count,
# No completions performed, only embeddings generated.
completion_tokens=0,
total_tokens=token_count,
),
)
return resp
if __name__ == "__main__":
server = Server()
emb = Runtime(Embedding)
server.register_runtime(
{
"/v1/embeddings": [emb],
"/embeddings": [emb],
}
)
server.run()
Client#
EMB_MODEL=thenlper/gte-base python examples/embedding/client.py
# Copyright 2023 MOSEC Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""OpenAI embedding client example."""
import os
from openai import Client
DEFAULT_MODEL = "thenlper/gte-base"
client = Client(api_key="fake", base_url="http://127.0.0.1:8000/")
emb = client.embeddings.create(
model=os.environ.get("EMB_MODEL", DEFAULT_MODEL),
input="Hello world!",
)
print(emb.data[0].embedding) # type: ignore