@@ -288,299 +288,27 @@ def _prepare_response_content(
288288 """
289289 Prepares the response content for serialization.
290290 """
291- if isinstance (res , BaseModel ):
292- return _model_dump (
291+ if isinstance (res , BaseModel ): # pragma: no cover
292+ return _model_dump ( # pragma: no cover
293293 res ,
294294 by_alias = True ,
295295 exclude_unset = exclude_unset ,
296296 exclude_defaults = exclude_defaults ,
297297 exclude_none = exclude_none ,
298298 )
299- elif isinstance (res , list ):
300- return [
299+ elif isinstance (res , list ): # pragma: no cover
300+ return [ # pragma: no cover
301301 self ._prepare_response_content (item , exclude_unset = exclude_unset , exclude_defaults = exclude_defaults )
302302 for item in res
303303 ]
304- elif isinstance (res , dict ):
305- return {
304+ elif isinstance (res , dict ): # pragma: no cover
305+ return { # pragma: no cover
306306 k : self ._prepare_response_content (v , exclude_unset = exclude_unset , exclude_defaults = exclude_defaults )
307307 for k , v in res .items ()
308308 }
309- elif dataclasses .is_dataclass (res ):
310- return dataclasses .asdict (res ) # type: ignore[arg-type]
311- return res
312-
313-
314- class OpenAPIValidationMiddleware (BaseMiddlewareHandler ):
315- """
316- OpenAPIValidationMiddleware is a middleware that validates the request against the OpenAPI schema defined by the
317- Lambda handler. It also validates the response against the OpenAPI schema defined by the Lambda handler. It
318- should not be used directly, but rather through the `enable_validation` parameter of the `ApiGatewayResolver`.
319-
320- Example
321- --------
322-
323- ```python
324- from pydantic import BaseModel
325-
326- from aws_lambda_powertools.event_handler.api_gateway import (
327- APIGatewayRestResolver,
328- )
329-
330- class Todo(BaseModel):
331- name: str
332-
333- app = APIGatewayRestResolver(enable_validation=True)
334-
335- @app.get("/todos")
336- def get_todos(): list[Todo]:
337- return [Todo(name="hello world")]
338- ```
339- """
340-
341- def __init__ (
342- self ,
343- validation_serializer : Callable [[Any ], str ] | None = None ,
344- has_response_validation_error : bool = False ,
345- ):
346- """
347- Initialize the OpenAPIValidationMiddleware.
348-
349- Parameters
350- ----------
351- validation_serializer : Callable, optional
352- Optional serializer to use when serializing the response for validation.
353- Use it when you have a custom type that cannot be serialized by the default jsonable_encoder.
354-
355- has_response_validation_error: bool, optional
356- Optional flag used to distinguish between payload and validation errors.
357- By setting this flag to True, ResponseValidationError will be raised if response could not be validated.
358- """
359- self ._validation_serializer = validation_serializer
360- self ._has_response_validation_error = has_response_validation_error
361-
362- def handler (self , app : EventHandlerInstance , next_middleware : NextMiddleware ) -> Response :
363- logger .debug ("OpenAPIValidationMiddleware handler" )
364-
365- route : Route = app .context ["_route" ]
366-
367- values : dict [str , Any ] = {}
368- errors : list [Any ] = []
369-
370- # Process path values, which can be found on the route_args
371- path_values , path_errors = _request_params_to_args (
372- route .dependant .path_params ,
373- app .context ["_route_args" ],
374- )
375-
376- # Normalize query values before validate this
377- query_string = _normalize_multi_query_string_with_param (
378- app .current_event .resolved_query_string_parameters ,
379- route .dependant .query_params ,
380- )
381-
382- # Process query values
383- query_values , query_errors = _request_params_to_args (
384- route .dependant .query_params ,
385- query_string ,
386- )
387-
388- # Normalize header values before validate this
389- headers = _normalize_multi_header_values_with_param (
390- app .current_event .resolved_headers_field ,
391- route .dependant .header_params ,
392- )
393-
394- # Process header values
395- header_values , header_errors = _request_params_to_args (
396- route .dependant .header_params ,
397- headers ,
398- )
399-
400- values .update (path_values )
401- values .update (query_values )
402- values .update (header_values )
403- errors += path_errors + query_errors + header_errors
404-
405- # Process the request body, if it exists
406- if route .dependant .body_params :
407- (body_values , body_errors ) = _request_body_to_args (
408- required_params = route .dependant .body_params ,
409- received_body = self ._get_body (app ),
410- )
411- values .update (body_values )
412- errors .extend (body_errors )
413-
414- if errors :
415- # Raise the validation errors
416- raise RequestValidationError (_normalize_errors (errors ))
417- else :
418- # Re-write the route_args with the validated values, and call the next middleware
419- app .context ["_route_args" ] = values
420-
421- # Call the handler by calling the next middleware
422- response = next_middleware (app )
423-
424- # Process the response
425- return self ._handle_response (route = route , response = response )
426-
427- def _handle_response (self , * , route : Route , response : Response ):
428- # Process the response body if it exists
429- if response .body and response .is_json ():
430- response .body = self ._serialize_response (
431- field = route .dependant .return_param ,
432- response_content = response .body ,
433- has_route_custom_response_validation = route .custom_response_validation_http_code is not None ,
434- )
435-
436- return response
437-
438- def _serialize_response (
439- self ,
440- * ,
441- field : ModelField | None = None ,
442- response_content : Any ,
443- include : IncEx | None = None ,
444- exclude : IncEx | None = None ,
445- by_alias : bool = True ,
446- exclude_unset : bool = False ,
447- exclude_defaults : bool = False ,
448- exclude_none : bool = False ,
449- has_route_custom_response_validation : bool = False ,
450- ) -> Any :
451- """
452- Serialize the response content according to the field type.
453- """
454- if field :
455- errors : list [dict [str , Any ]] = []
456- value = _validate_field (field = field , value = response_content , loc = ("response" ,), existing_errors = errors )
457- if errors :
458- # route-level validation must take precedence over app-level
459- if has_route_custom_response_validation :
460- raise ResponseValidationError (
461- errors = _normalize_errors (errors ),
462- body = response_content ,
463- source = "route" ,
464- )
465- if self ._has_response_validation_error :
466- raise ResponseValidationError (errors = _normalize_errors (errors ), body = response_content , source = "app" )
467-
468- raise RequestValidationError (errors = _normalize_errors (errors ), body = response_content )
469-
470- if hasattr (field , "serialize" ):
471- return field .serialize (
472- value ,
473- include = include ,
474- exclude = exclude ,
475- by_alias = by_alias ,
476- exclude_unset = exclude_unset ,
477- exclude_defaults = exclude_defaults ,
478- exclude_none = exclude_none ,
479- )
480- return jsonable_encoder (
481- value ,
482- include = include ,
483- exclude = exclude ,
484- by_alias = by_alias ,
485- exclude_unset = exclude_unset ,
486- exclude_defaults = exclude_defaults ,
487- exclude_none = exclude_none ,
488- custom_serializer = self ._validation_serializer ,
489- )
490- else :
491- # Just serialize the response content returned from the handler.
492- return jsonable_encoder (response_content , custom_serializer = self ._validation_serializer )
493-
494- def _prepare_response_content (
495- self ,
496- res : Any ,
497- * ,
498- exclude_unset : bool ,
499- exclude_defaults : bool = False ,
500- exclude_none : bool = False ,
501- ) -> Any :
502- """
503- Prepares the response content for serialization.
504- """
505- if isinstance (res , BaseModel ):
506- return _model_dump (
507- res ,
508- by_alias = True ,
509- exclude_unset = exclude_unset ,
510- exclude_defaults = exclude_defaults ,
511- exclude_none = exclude_none ,
512- )
513- elif isinstance (res , list ):
514- return [
515- self ._prepare_response_content (item , exclude_unset = exclude_unset , exclude_defaults = exclude_defaults )
516- for item in res
517- ]
518- elif isinstance (res , dict ):
519- return {
520- k : self ._prepare_response_content (v , exclude_unset = exclude_unset , exclude_defaults = exclude_defaults )
521- for k , v in res .items ()
522- }
523- elif dataclasses .is_dataclass (res ):
524- return dataclasses .asdict (res ) # type: ignore[arg-type]
525- return res
526-
527- def _get_body (self , app : EventHandlerInstance ) -> dict [str , Any ]:
528- """
529- Get the request body from the event, and parse it according to content type.
530- """
531- content_type = app .current_event .headers .get ("content-type" , "" ).strip ()
532-
533- # Handle JSON content
534- if not content_type or content_type .startswith (APPLICATION_JSON_CONTENT_TYPE ):
535- return self ._parse_json_data (app )
536-
537- # Handle URL-encoded form data
538- elif content_type .startswith (APPLICATION_FORM_CONTENT_TYPE ):
539- return self ._parse_form_data (app )
540-
541- else :
542- raise NotImplementedError ("Only JSON body or Form() are supported" )
543-
544- def _parse_json_data (self , app : EventHandlerInstance ) -> dict [str , Any ]:
545- """Parse JSON data from the request body."""
546- try :
547- return app .current_event .json_body
548- except json .JSONDecodeError as e :
549- raise RequestValidationError (
550- [
551- {
552- "type" : "json_invalid" ,
553- "loc" : ("body" , e .pos ),
554- "msg" : "JSON decode error" ,
555- "input" : {},
556- "ctx" : {"error" : e .msg },
557- },
558- ],
559- body = e .doc ,
560- ) from e
561-
562- def _parse_form_data (self , app : EventHandlerInstance ) -> dict [str , Any ]:
563- """Parse URL-encoded form data from the request body."""
564- try :
565- body = app .current_event .decoded_body or ""
566- # parse_qs returns dict[str, list[str]], but we want dict[str, str] for single values
567- parsed = parse_qs (body , keep_blank_values = True )
568-
569- result : dict [str , Any ] = {key : values [0 ] if len (values ) == 1 else values for key , values in parsed .items ()}
570- return result
571-
572- except Exception as e : # pragma: no cover
573- raise RequestValidationError ( # pragma: no cover
574- [
575- {
576- "type" : "form_invalid" ,
577- "loc" : ("body" ,),
578- "msg" : "Form data parsing error" ,
579- "input" : {},
580- "ctx" : {"error" : str (e )},
581- },
582- ],
583- ) from e
309+ elif dataclasses .is_dataclass (res ): # pragma: no cover
310+ return dataclasses .asdict (res ) # type: ignore[arg-type] # pragma: no cover
311+ return res # pragma: no cover
584312
585313
586314def _request_params_to_args (
0 commit comments