Cross-Encoder model for reranking

This example shows how to use a cross-encoder model to rerank a list of passages based on a query. This is useful for hybrid search that combines multiple retrieval results.

Server

python examples/rerank/server.py
# Copyright 2024 MOSEC Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from os import environ
from typing import List

from msgspec import Struct
from sentence_transformers import CrossEncoder

from mosec import Server, Worker
from mosec.mixin import TypedMsgPackMixin

DEFAULT_MODEL = "cross-encoder/ms-marco-MiniLM-L-6-v2"
WORKER_NUM = int(environ.get("WORKER_NUM", 1))


class Request(Struct, kw_only=True):
    query: str
    docs: List[str]


class Response(Struct, kw_only=True):
    scores: List[float]


class Encoder(TypedMsgPackMixin, Worker):
    def __init__(self):
        self.model_name = environ.get("MODEL_NAME", DEFAULT_MODEL)
        self.model = CrossEncoder(self.model_name)

    def forward(self, data: Request) -> Response:
        scores = self.model.predict([[data.query, doc] for doc in data.docs])
        return Response(scores=scores.tolist())


if __name__ == "__main__":
    server = Server()
    server.append_worker(Encoder, num=WORKER_NUM)
    server.run()

Client

python examples/rerank/client.py
# Copyright 2024 MOSEC Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from http import HTTPStatus

import httpx
import msgspec

req = {
    "query": "talk is cheap, show me the code",
    "docs": [
        "what a nice day",
        "life is short, use python",
        "early bird catches the worm",
    ],
}

resp = httpx.post(
    "http://127.0.0.1:8000/inference", content=msgspec.msgpack.encode(req)
)
if resp.status_code == HTTPStatus.OK:
    print(f"OK: {msgspec.msgpack.decode(resp.content)}")
else:
    print(f"err[{resp.status_code}] {resp.text}")