Stable Diffusion#

This example provides a demo service for stable diffusion. You can develop this in the container environment by using envd: envd up -p examples/stable_diffusion.

You should be able to try this demo under the mosec/examples/stable_diffusion/ directory.

Server#

envd build -t sd:serving
docker run --rm --gpus all -p 8000:8000 sd:serving
# Copyright 2023 MOSEC Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from io import BytesIO
from typing import List

import torch  # type: ignore
from diffusers import StableDiffusionPipeline  # type: ignore

from mosec import Server, Worker, get_logger
from mosec.mixin import MsgpackMixin

logger = get_logger()


class StableDiffusion(MsgpackMixin, Worker):
    def __init__(self):
        self.pipe = StableDiffusionPipeline.from_pretrained(
            "runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16
        )
        device = "cuda" if torch.cuda.is_available() else "cpu"
        self.pipe = self.pipe.to(device)  # type: ignore
        self.example = ["useless example prompt"] * 4  # warmup (bs=4)

    def forward(self, data: List[str]) -> List[memoryview]:
        logger.debug("generate images for %s", data)
        res = self.pipe(data)  # type: ignore
        logger.debug("NSFW: %s", res[1])
        images = []
        for img in res[0]:  # type: ignore
            dummy_file = BytesIO()
            img.save(dummy_file, format="JPEG")  # type: ignore
            images.append(dummy_file.getbuffer())
        return images


if __name__ == "__main__":
    server = Server()
    server.append_worker(StableDiffusion, num=1, max_batch_size=4, max_wait_time=10)
    server.run()
python server.py --timeout 30000

Client#

python client.py --prompt "a cute cat site on the basketball"
# Copyright 2023 MOSEC Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import argparse
from http import HTTPStatus

import httpx
import msgpack  # type: ignore

parser = argparse.ArgumentParser(
    prog="stable diffusion client demo",
)
parser.add_argument(
    "-p", "--prompt", default="a photo of an astronaut riding a horse on mars"
)
parser.add_argument(
    "-o", "--output", default="stable_diffusion_result.jpg", help="output filename"
)
parser.add_argument(
    "--port",
    default=8000,
    type=int,
    help="service port",
)


args = parser.parse_args()
resp = httpx.post(
    f"http://127.0.0.1:{args.port}/inference",
    content=msgpack.packb(args.prompt),
    timeout=httpx.Timeout(20),
)
if resp.status_code == HTTPStatus.OK:
    data = msgpack.unpackb(resp.content)
    with open(args.output, "wb") as f:
        f.write(data)
else:
    print(f"ERROR: <{resp.status_code}> {resp.text}")