Stable Diffusion¶
This example provides a demo service for stable diffusion. You can develop this in the container environment by using envd: envd up -p examples/stable_diffusion
.
You should be able to try this demo under the mosec/examples/stable_diffusion/
directory.
Server¶
envd build -t sd:serving
docker run --rm --gpus all -p 8000:8000 sd:serving
# Copyright 2023 MOSEC Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from io import BytesIO
from typing import List
import torch # type: ignore
from diffusers import StableDiffusionPipeline # type: ignore
from mosec import Server, Worker, get_logger
from mosec.mixin import MsgpackMixin
logger = get_logger()
class StableDiffusion(MsgpackMixin, Worker):
def __init__(self):
self.pipe = StableDiffusionPipeline.from_pretrained(
"sd-legacy/stable-diffusion-v1-5",
torch_dtype=torch.float16,
)
self.pipe.enable_model_cpu_offload()
self.example = ["useless example prompt"] * 4 # warmup (bs=4)
def forward(self, data: List[str]) -> List[memoryview]:
logger.debug("generate images for %s", data)
res = self.pipe(data) # type: ignore
logger.debug("NSFW: %s", res[1])
images = []
for img in res[0]: # type: ignore
dummy_file = BytesIO()
img.save(dummy_file, format="JPEG") # type: ignore
images.append(dummy_file.getbuffer())
return images
if __name__ == "__main__":
server = Server()
server.append_worker(StableDiffusion, num=1, max_batch_size=4, max_wait_time=10)
server.run()
python server.py --timeout 30000
Client¶
python client.py --prompt "a cute cat site on the basketball"
# Copyright 2023 MOSEC Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
from http import HTTPStatus
import httpx
import msgpack # type: ignore
parser = argparse.ArgumentParser(
prog="stable diffusion client demo",
)
parser.add_argument(
"-p", "--prompt", default="a photo of an astronaut riding a horse on mars"
)
parser.add_argument(
"-o", "--output", default="stable_diffusion_result.jpg", help="output filename"
)
parser.add_argument(
"--port",
default=8000,
type=int,
help="service port",
)
args = parser.parse_args()
resp = httpx.post(
f"http://127.0.0.1:{args.port}/inference",
content=msgpack.packb(args.prompt),
timeout=httpx.Timeout(20),
)
if resp.status_code == HTTPStatus.OK:
data = msgpack.unpackb(resp.content)
with open(args.output, "wb") as f:
f.write(data)
else:
print(f"ERROR: <{resp.status_code}> {resp.text}")