fastapi自定义类中间件

81次阅读
没有评论

因为最开始python我用的是django,特别喜欢他的中间件的写法。

就是请求和响应分别进行处理。

最近用fastapi开发了一个小应用,想要对前端的一些小加密进行校验,直接使用@app.middleware("http") 写起来感觉太占地方,也不方便整理。研究了一下写了一个类示例。

使用的是app.add_middleware(XXXMiddleware)的方式,记录一下,防止下次再忘了。

直接上代码:

下面是中间件的实际代码

from typing import Optional
from fastapi import Request, Response
from fastapi.responses import JSONResponse
from .base import BaseHttpMiddleware
import json
import time

from loguru import logger


class CryptoMiddleware(BaseHttpMiddleware):
    """
    加密中间件
    """

    async def process_request(self, request: Request) -> Optional[Response]:
        """
        处理请求
        """
        request.state.start_time = round(time.time() * 1000)

    async def process_response(
        self, request: Request, response: Response
    ) -> Optional[Response]:
        """
        处理响应
        """
        end_time = round(time.time() * 1000)
        logger.info(f"请求处理时间: {end_time - request.state.start_time} 毫秒")
        response.headers["X-process-time"] = f"{end_time - request.state.start_time}ms"

        # 非json响应直接放给下面的中间件处理
        if "application/json" not in response.headers["content-type"]:
            return None

        # 将字节数据转为dict格式
        content = await self.parse_stream_response(response)

        # 解析JSON
        try:
            json_content = json.loads(content)
            logger.info(f"JSON响应内容: {json_content}")

            # 创建新的响应对象,确保Content-Length正确
            return JSONResponse(
                content=json_content,
                status_code=response.status_code,
                headers=dict(response.headers),
            )

        except json.JSONDecodeError:
            logger.warning("响应不是有效的JSON")
            return Response(
                content=content,
                status_code=response.status_code,
                headers=dict(response.headers),
            )

这里导入了一个我写的类,方便后续的使用,基类如下:

from fastapi import Request, Response
from starlette.middleware.base import BaseHTTPMiddleware as FastAppBaseHTTPMiddleware
from typing import Optional


class BaseHttpMiddleware(FastAppBaseHTTPMiddleware):
    """
    基础中间件
    """

    async def dispatch(self, request: Request, call_next) -> Response:
        # 处理请求
        request_response = await self.process_request(request)

        if request_response:
            return request_response

        # 处理响应
        response = await call_next(request)

        # 处理响应
        response_response = await self.process_response(request, response)
        if response_response:
            return response_response

        return response

    async def process_request(self, request: Request) -> Optional[Response]:
        """
        处理请求
        """
        pass

    async def process_response(
        self, request: Request, response: Response
    ) -> Optional[Response]:
        """
        处理响应
        """
        pass

    async def parse_stream_response(self, response: Response):
        """
        解析流式响应
        """
        # 读取响应内容
        body = []
        async for chunk in response.body_iterator:
            body.append(chunk)
        content = b''.join(body)
        return content.decode('utf-8')

如代码所示,我先继承了starletteBaseHTTPMiddleware类,主要是这个我感觉写起来也方便。

当然fastapi支持多种格式的中间件写法。这是fastapi文档的说明:https://fastapi.tiangolo.com/zh/tutorial/middleware/

fastapi的BaseHTTPMiddleware会调用类dispatch方法,这个方法基类没有定义,需要我们自己实现。

所以我写了一个dispatch函数,分别调用process_request方法和 process_response 方法,传入了,重点在于这两个方法的返回值,他们的返回值需要是一个Response对象,如果返回了Response对象,则会停止后续的操作,直接返回结果到前端。不返回的话则会继续向下走。

在这两个方法里面直接修改request和response对象可以直接影响到后续的结果,所以不需要返回结果。我的例子是添加了一个计算程序允许时间的计算。在请求进来的时候记录一下时间,响应的时候记录一下时间并且加入到headers里面。

需要注意的是如果需要设计修改响应体的内容,需要额外的处理一下,因为我们收到的response是一个StreamReponse对象,无法直接获取结果,需要遍历body_iterator来获取结果拼接在一起。所以我封装了一个函数parse_stream_response用来解析,方便后续使用。注意解析之前需要判断好是否有解析的需求,按需调用。

因为body_iterator遍历之后就没有值了,所以直接返回response值是空的,无法正常返回,所以当我们读取了数据之后需要重新给response对象赋值,可以直接构建response返回,也可以直接给response的body对象赋值。这两个的区别就是直接返回则不会后续的操作,重新赋值还能继续往下走。body接收的参数是一个byte对象,可以直接赋值,也可以调用函数的render函数传入一个字符串赋值。

下面是fastapi 的response对象的构建方法,可以看到body属性的值的由来

class Response:
    media_type = None
    charset = "utf-8"

    def __init__(
        self,
        content: Any = None,
        status_code: int = 200,
        headers: Mapping[str, str] | None = None,
        media_type: str | None = None,
        background: BackgroundTask | None = None,
    ) -> None:
        self.status_code = status_code
        if media_type is not None:
            self.media_type = media_type
        self.background = background
        self.body = self.render(content)
        self.init_headers(headers)

    def render(self, content: Any) -> bytes | memoryview:
        if content is None:
            return b""
        if isinstance(content, (bytes, memoryview)):
            return content
        return content.encode(self.charset)  # type: ignore

此时可能会出现headers['content-length']报错说个内容的长度不一致,需要重新计算body的长度给content-length赋值,也可以重新调用一下init_headers方法,详情自己读源码问ai。

ok记录结束!!!

正文完
 0
评论(没有评论)