PyTorch Examples¶
Here are some out-of-the-box model servers powered by mosec for PyTorch users. We use the version 1.9.0 in the following examples.
Natural Language Processing¶
Natural language processing model servers usually receive text data and make predictions ranging from text classification, question answering to translation and text generation.
Sentiment Analysis¶
This server receives a string and predicts how positive its content is. We build the model server based on Transformers of version 4.11.0.
We show how to customize the deserialize
method of the ingress stage (Preprocess
) and the serialize
method of the egress stage (Inference
). In this way, we can enjoy the high flexibility, directly reading data bytes from request body and writing the results into response body.
Note that in a stage that enables batching (e.g. Inference
in this example), its worker’s forward
method deals with a list of data, while its serialize
and deserialize
methods only need to manipulate individual datum.
Server¶
python distil_bert_server_pytorch.py
distil_bert_server_pytorch.py
# Copyright 2022 MOSEC Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Example: Mosec with Pytorch Distil BERT."""
from typing import Any, List
import torch # type: ignore
from transformers import ( # type: ignore
AutoModelForSequenceClassification,
AutoTokenizer,
)
from mosec import Server, Worker, get_logger
logger = get_logger()
# type alias
Returns = Any
INFERENCE_BATCH_SIZE = 32
INFERENCE_WORKER_NUM = 1
class Preprocess(Worker):
"""Preprocess BERT on current setup."""
def __init__(self):
super().__init__()
self.tokenizer = AutoTokenizer.from_pretrained(
"distilbert-base-uncased-finetuned-sst-2-english"
)
def deserialize(self, data: bytes) -> str:
# Override `deserialize` for the *first* stage;
# `data` is the raw bytes from the request body
return data.decode()
def forward(self, data: str) -> Returns:
tokens = self.tokenizer.encode(data, add_special_tokens=True)
return tokens
class Inference(Worker):
"""Pytorch Inference class"""
resp_mime_type = "text/plain"
def __init__(self):
super().__init__()
self.device = (
torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
)
logger.info("using computing device: %s", self.device)
self.model = AutoModelForSequenceClassification.from_pretrained(
"distilbert-base-uncased-finetuned-sst-2-english"
)
self.model.eval()
self.model.to(self.device)
# Overwrite self.example for warmup
self.example = [
[101, 2023, 2003, 1037, 8403, 4937, 999, 102] * 5 # make sentence longer
] * INFERENCE_BATCH_SIZE
def forward(self, data: List[Returns]) -> List[str]:
tensors = [torch.tensor(token) for token in data]
with torch.no_grad():
result = self.model(
torch.nn.utils.rnn.pad_sequence(tensors, batch_first=True).to(
self.device
)
)[0]
scores = result.softmax(dim=1).cpu().tolist()
return [f"positive={p}" for (_, p) in scores]
def serialize(self, data: str) -> bytes:
# Override `serialize` for the *last* stage;
# `data` is the string from the `forward` output
return data.encode()
if __name__ == "__main__":
server = Server()
server.append_worker(Preprocess, num=2 * INFERENCE_WORKER_NUM)
server.append_worker(
Inference, max_batch_size=INFERENCE_BATCH_SIZE, num=INFERENCE_WORKER_NUM
)
server.run()
Client¶
echo 'i bought this product for many times, highly recommend' | http POST :8000/inference
Computer Vision¶
Computer vision model servers usually receive images or links to the images (downloading from the link becomes an I/O workload then), feed the preprocessed image data into the model and extract information like categories, bounding boxes and pixel labels as results.
Image Recognition¶
This server receives an image and classify it according to the ImageNet categorization. We specifically use ResNet as an image classifier and build a model service based on it. Nevertheless, this file serves as the starter code for any kind of image recognition model server.
We enable multiprocessing for Preprocess
stage, so that it can produce enough tasks for Inference
stage to do batch inference, which better exploits the GPU computing power. More interestingly, we also started multiple model by setting the number of worker for Inference
stage to 2. This is because a single model hardly fully occupy the GPU memory or utilization. Multiple models running on the same device in parallel can further increase our service throughput.
When instantiating the Server
, we enable plasma_shm
, which utilizes the pyarrow.plasma
as a shared memory data store for IPC. This could benefit the data transfer, especially when the data is large (preprocessed image data in this case). Note that you need to use pip install -U pyarrow==11
to install necessary dependencies.
We also demonstrate how to customized validation on the data content through this example. In the forward
method of the Preprocess
worker, we firstly check the key of the input, then try to decode the str and load it into array. If any of these steps fails, we raise the ValidationError
. The status will be finally returned to our clients as HTTP 422.
Server¶
python examples/resnet50_msgpack/server.py
resnet50_server_msgpack.py
# Copyright 2022 MOSEC Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Example: Sample Resnet server."""
from io import BytesIO
from typing import List
from urllib.request import urlretrieve
import numpy as np # type: ignore
import torch # type: ignore
import torchvision # type: ignore
from PIL import Image # type: ignore
from torchvision import transforms # type: ignore
from mosec import Server, ValidationError, Worker, get_logger
from mosec.mixin import MsgpackMixin
logger = get_logger()
INFERENCE_BATCH_SIZE = 16
class Preprocess(MsgpackMixin, Worker):
"""Sample Preprocess worker"""
def __init__(self) -> None:
super().__init__()
trans = torch.nn.Sequential(
transforms.Resize((256, 256)),
transforms.CenterCrop(224),
transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
)
self.transform = torch.jit.script(trans) # type: ignore
def forward(self, data: dict):
# Customized validation for input key and field content; raise
# ValidationError so that the client can get 422 as http status
try:
image = Image.open(BytesIO(data["image"]))
except KeyError as err:
raise ValidationError(f"cannot find key {err}") from err
except Exception as err:
raise ValidationError(f"cannot decode as image data: {err}") from err
tensor = transforms.ToTensor()(image)
data = self.transform(tensor) # type: ignore
return data
class Inference(Worker):
"""Sample Inference worker"""
def __init__(self):
super().__init__()
self.device = (
torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
)
logger.info("using computing device: %s", self.device)
self.model = torchvision.models.resnet50(pretrained=True)
self.model.eval()
self.model.to(self.device)
# Overwrite self.example for warmup
self.example = [
np.zeros((3, 244, 244), dtype=np.float32)
] * INFERENCE_BATCH_SIZE
def forward(self, data: List[np.ndarray]) -> List[int]:
logger.info("processing batch with size: %d", len(data))
with torch.no_grad():
batch = torch.stack([torch.tensor(arr, device=self.device) for arr in data])
output = self.model(batch)
top1 = torch.argmax(output, dim=1)
return top1.cpu().tolist()
class Postprocess(MsgpackMixin, Worker):
"""Sample Postprocess worker"""
def __init__(self):
super().__init__()
logger.info("loading categories file...")
local_filename, _ = urlretrieve(
"https://raw.githubusercontent.com/pytorch/hub/master/imagenet_classes.txt"
)
with open(local_filename, encoding="utf8") as file:
self.categories = list(map(lambda x: x.strip(), file.readlines()))
def forward(self, data: int) -> dict:
return {"category": self.categories[data]}
if __name__ == "__main__":
server = Server()
server.append_worker(Preprocess, num=4)
server.append_worker(Inference, num=2, max_batch_size=INFERENCE_BATCH_SIZE)
server.append_worker(Postprocess, num=1)
server.run()
Client¶
python examples/resnet50_msgpack/client.py
resnet50_client_msgpack.py
# Copyright 2022 MOSEC Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Example: Sample Resnet client."""
from http import HTTPStatus
import httpx
import msgpack # type: ignore
dog_bytes = httpx.get(
"https://raw.githubusercontent.com/pytorch/hub/master/images/dog.jpg"
).content
prediction = httpx.post(
"http://127.0.0.1:8000/inference",
content=msgpack.packb({"image": dog_bytes}),
)
if prediction.status_code == HTTPStatus.OK:
print(msgpack.unpackb(prediction.content))
else:
print(prediction.status_code, prediction.content)