Cross-Encoder model for reranking¶
This example shows how to use a cross-encoder model to rerank a list of passages based on a query. This is useful for hybrid search that combines multiple retrieval results.
Server¶
python examples/rerank/server.py
# Copyright 2024 MOSEC Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from os import environ
from typing import List
from msgspec import Struct
from sentence_transformers import CrossEncoder # type: ignore
from mosec import Server, Worker
from mosec.mixin import TypedMsgPackMixin
DEFAULT_MODEL = "cross-encoder/ms-marco-MiniLM-L-6-v2"
WORKER_NUM = int(environ.get("WORKER_NUM", 1))
class Request(Struct, kw_only=True):
query: str
docs: List[str]
class Response(Struct, kw_only=True):
scores: List[float]
class Encoder(TypedMsgPackMixin, Worker):
def __init__(self):
self.model_name = environ.get("MODEL_NAME", DEFAULT_MODEL)
self.model = CrossEncoder(self.model_name)
def forward(self, data: Request) -> Response:
scores = self.model.predict([[data.query, doc] for doc in data.docs])
return Response(scores=scores.tolist())
if __name__ == "__main__":
server = Server()
server.append_worker(Encoder, num=WORKER_NUM)
server.run()
Client¶
python examples/rerank/client.py
# Copyright 2024 MOSEC Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from http import HTTPStatus
import httpx
import msgspec
req = {
"query": "talk is cheap, show me the code",
"docs": [
"what a nice day",
"life is short, use python",
"early bird catches the worm",
],
}
resp = httpx.post(
"http://127.0.0.1:8000/inference", content=msgspec.msgpack.encode(req)
)
if resp.status_code == HTTPStatus.OK:
print(f"OK: {msgspec.msgpack.decode(resp.content)}")
else:
print(f"err[{resp.status_code}] {resp.text}")