Jax jitted inference¶
This example shows how to utilize the Jax framework to build a just-in-time (JIT) compiled inference server. You could install Jax following their official guide and you also need chex
to run this example (pip install -U chex
).
We use a single layer neural network for this minimal example. You could also experiment the speedup of JIT by setting the environment variable USE_JIT=true
and observe the latency difference. Note that in the __init__
of the worker we set the self.multi_examples
as a list of example inputs to warmup, because different batch sizes will trigger re-jitting when they are traced for the first time.
Server¶
USE_JIT=true python examples/jax_single_layer/server.py
jax_single_layer.py
# Copyright 2023 MOSEC Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Example: Simple jax jitted inference with a single layer classifier."""
import os
import time
from typing import List
import chex # type: ignore
import jax # type: ignore
import jax.numpy as jnp # type: ignore
from mosec import Server, ValidationError, Worker, get_logger
logger = get_logger()
INPUT_SIZE = 3
LATENT_SIZE = 16
OUTPUT_SIZE = 2
MAX_BATCH_SIZE = 8
USE_JIT = os.environ.get("USE_JIT", "false")
class JittedInference(Worker):
"""Sample Class."""
def __init__(self):
super().__init__()
key = jax.random.PRNGKey(42)
k_1, k_2 = jax.random.split(key)
self._layer1_w = jax.random.normal(k_1, (INPUT_SIZE, LATENT_SIZE))
self._layer1_b = jnp.zeros(LATENT_SIZE)
self._layer2_w = jax.random.normal(k_2, (LATENT_SIZE, OUTPUT_SIZE))
self._layer2_b = jnp.zeros(OUTPUT_SIZE)
# Enumerate all batch sizes for caching.
self.multi_examples = []
dummy_array = list(range(INPUT_SIZE))
for i in range(MAX_BATCH_SIZE):
self.multi_examples.append([{"array": dummy_array}] * (i + 1))
if USE_JIT == "true":
self.batch_forward = jax.jit(self._batch_forward)
else:
self.batch_forward = self._batch_forward
def _forward(self, x_single: jnp.ndarray) -> jnp.ndarray: # type: ignore
chex.assert_rank([x_single], [1])
h_1 = jnp.dot(self._layer1_w.T, x_single) + self._layer1_b
a_1 = jax.nn.relu(h_1)
h_2 = jnp.dot(self._layer2_w.T, a_1) + self._layer2_b
o_2 = jax.nn.softmax(h_2)
return jnp.argmax(o_2, axis=-1)
def _batch_forward(self, x_batch: jnp.ndarray) -> jnp.ndarray: # type: ignore
chex.assert_rank([x_batch], [2])
return jax.vmap(self._forward)(x_batch)
def forward(self, data: List[dict]) -> List[dict]:
time_start = time.perf_counter()
try:
input_array_raw = [ele["array"] for ele in data]
except KeyError as err:
raise ValidationError(f"cannot find key {err}") from err
input_array = jnp.array(input_array_raw)
output_array = self.batch_forward(input_array)
output_category = output_array.tolist()
elapse = time.perf_counter() - time_start
return [{"category": c, "elapse": elapse} for c in output_category]
if __name__ == "__main__":
server = Server()
server.append_worker(JittedInference, max_batch_size=MAX_BATCH_SIZE)
server.run()
Client¶
python examples/jax_single_layer/client.py
jax_single_layer_cli.py
# Copyright 2023 MOSEC Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Example: Client of the Jax server."""
import random
from http import HTTPStatus
import httpx
input_data = [random.randint(-99, 99), random.randint(-99, 99), random.randint(-99, 99)]
print("Client : sending data : ", input_data)
prediction = httpx.post(
"http://127.0.0.1:8000/inference",
json={"array": input_data},
)
if prediction.status_code == HTTPStatus.OK:
print(prediction.json())
else:
print(prediction.status_code, prediction.json())