Source code for mosec.mixin.typed_worker

# Copyright 2023 MOSEC Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""MOSEC type validation mixin."""

# pylint: disable=import-outside-toplevel

from typing import Any, Dict, Optional, Tuple

from mosec import get_logger
from mosec.errors import ValidationError
from mosec.utils import ParseTarget, parse_func_type
from mosec.worker import Worker

logger = get_logger()


[docs] class TypedMsgPackMixin(Worker): """Enable request type validation with `msgspec` and serde with `msgpack`.""" # pylint: disable=no-self-use resp_mime_type = "application/msgpack" _input_typ: Optional[type] = None
[docs] def deserialize(self, data: Any) -> Any: """Deserialize and validate request with msgspec.""" import msgspec if not self._input_typ: self._input_typ = parse_func_type(self.forward, ParseTarget.INPUT) if not issubclass(self._input_typ, msgspec.Struct): # skip other annotation type return super().deserialize(data) try: return msgspec.msgpack.decode(data, type=self._input_typ) except msgspec.ValidationError as err: raise ValidationError(err) from err
[docs] def serialize(self, data: Any) -> bytes: """Serialize with `msgpack`.""" import msgspec return msgspec.msgpack.encode(data)
[docs] @classmethod def get_forward_json_schema( cls, target: ParseTarget, ref_template: str ) -> Tuple[Dict[str, Any], Dict[str, Any]]: """Get the JSON schema of the forward function.""" import msgspec schema: Dict[str, Any] comp_schema: Dict[str, Any] schema, comp_schema = {}, {} typ = parse_func_type(cls.forward, target) try: (schema,), comp_schema = msgspec.json.schema_components( [typ], ref_template=ref_template ) except TypeError as err: logger.warning( "Failed to generate JSON schema for %s: %s", cls.__name__, err ) return schema, comp_schema