# Jax jitted inference This example shows how to utilize the [Jax framework](https://github.com/google/jax) to build a just-in-time (JIT) compiled inference server. You could install Jax following their official guide and you also need `chex` to run this example (`pip install -U chex`). We use a single layer neural network for this minimal example. You could also experiment the speedup of JIT by setting the environment variable `USE_JIT=true` and observe the latency difference. Note that in the `__init__` of the worker we set the `self.multi_examples` as a list of example inputs to warmup, because different batch sizes will trigger re-jitting when they are traced for the first time. ## Server ```shell USE_JIT=true python examples/jax_single_layer/server.py ```
jax_single_layer.py ```{include} ../../../examples/jax_single_layer/server.py :code: python ```
## Client ```shell python examples/jax_single_layer/client.py ```
jax_single_layer_cli.py ```{include} ../../../examples/jax_single_layer/client.py :code: python ```