Compression

This example demonstrates how to use the --compression feature for segmentation tasks. We use the example from the Segment Anything Model 2. The request includes an image and its low resolution mask, the response is the final mask. Since there are lots of duplicate values in the mask, we can use gzip or zstd to compress it.

Server

python examples/segment/server.py --compression
segment.py
# Copyright 2023 MOSEC Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# refer to https://github.com/facebookresearch/sam2/blob/main/notebooks/image_predictor_example.ipynb

import numbin
import torch  # type: ignore
from sam2.sam2_image_predictor import SAM2ImagePredictor  # type: ignore

from mosec import Server, Worker, get_logger
from mosec.mixin import MsgpackMixin

logger = get_logger()
MIN_TF32_MAJOR = 8


class SegmentAnything(MsgpackMixin, Worker):
    def __init__(self):
        # select the device for computation
        if torch.cuda.is_available():
            device = torch.device("cuda")
        elif torch.backends.mps.is_available():
            device = torch.device("mps")
        else:
            device = torch.device("cpu")
        logger.info("using device: %s", device)

        self.predictor = SAM2ImagePredictor.from_pretrained(
            "facebook/sam2-hiera-large", device=device
        )

        if device.type == "cuda":
            # use bfloat16
            torch.autocast("cuda", dtype=torch.bfloat16).__enter__()
            # turn on tf32 for Ampere GPUs (https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices)
            if torch.cuda.get_device_properties(0).major >= MIN_TF32_MAJOR:
                torch.backends.cuda.matmul.allow_tf32 = True
                torch.backends.cudnn.allow_tf32 = True

    def forward(self, data: dict) -> bytes:
        with torch.inference_mode():
            self.predictor.set_image(numbin.loads(data["image"]))
            masks, _, _ = self.predictor.predict(
                point_coords=data["point_coords"],
                point_labels=data["labels"],
                mask_input=numbin.loads(data["mask"])[None, :, :],
                multimask_output=False,
            )
        return numbin.dumps(masks[0])


if __name__ == "__main__":
    server = Server()
    server.append_worker(SegmentAnything, num=1, max_batch_size=1)
    server.run()

Client

python examples/segment/client.py
segment.py
# Copyright 2023 MOSEC Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import gzip
from http import HTTPStatus
from io import BytesIO

import httpx
import msgpack  # type: ignore
import numbin
import numpy as np
from PIL import Image  # type: ignore

truck_image = Image.open(
    BytesIO(
        httpx.get(
            "https://raw.githubusercontent.com/facebookresearch/sam2/main/notebooks/images/truck.jpg"
        ).content
    )
)
array = np.array(truck_image.convert("RGB"))
# assume we have obtains the low resolution mask from the previous step
mask = np.zeros((256, 256))

resp = httpx.post(
    "http://127.0.0.1:8000/inference",
    content=gzip.compress(
        msgpack.packb(  # type: ignore
            {
                "image": numbin.dumps(array),
                "mask": numbin.dumps(mask),
                "labels": [1, 1],
                "point_coords": [[500, 375], [1125, 625]],
            }
        )
    ),
    headers={"Accept-Encoding": "gzip", "Content-Encoding": "gzip"},
)
assert resp.status_code == HTTPStatus.OK, resp.status_code
res = numbin.loads(msgpack.loads(resp.content))
assert res.shape == array.shape[:2], f"expect {array.shape[:2]}, got {res.shape}"