FastAPI 异步架构与依赖注入体系

FastAPI 是现代 Python 微服务开发的首选框架。其架构设计围绕三个核心支柱:基于 Starlette 的高性能异步 HTTP 引擎、基于 Pydantic 的自动请求验证和序列化、以及类型提示驱动的依赖注入系统。

from fastapi import FastAPI, Depends, HTTPException, Request
from contextlib import asynccontextmanager
from typing import AsyncGenerator, Annotated
import asyncio

class Database:
    def __init__(self, dsn: str):
        self.dsn = dsn
        self.pool = None

    async def connect(self):
        self.pool = {"connections": 10}
        print(f"Connected to {self.dsn}")

    async def disconnect(self):
        self.pool = None

    async def fetch(self, query: str):
        await asyncio.sleep(0.001)
        return [{"id": 1, "name": "test"}]

class Cache:
    def __init__(self, redis_url: str):
        self.redis_url = redis_url

    async def connect(self):
        print(f"Cache connected to {self.redis_url}")

    async def disconnect(self):
        print("Cache disconnected")


class AppState:
    def __init__(self):
        self.db = Database("postgresql://localhost/myapp")
        self.cache = Cache("redis://localhost")
        self.metrics = {"requests": 0, "errors": 0}

_app_state: AppState | None = None


async def get_app_state() -> AppState:
    global _app_state
    if _app_state is None:
        raise RuntimeError("App not initialized")
    return _app_state


async def get_db(state: Annotated[AppState, Depends(get_app_state)]) -> Database:
    return state.db


async def get_cache(state: Annotated[AppState, Depends(get_app_state)]) -> Cache:
    return state.cache


@asynccontextmanager
async def lifespan(app: FastAPI) -> AsyncGenerator[AppState, None]:
    global _app_state
    _app_state = AppState()
    await _app_state.db.connect()
    await _app_state.cache.connect()
    print("Application startup complete")

    yield _app_state

    await _app_state.cache.disconnect()
    await _app_state.db.disconnect()
    _app_state = None
    print("Application shutdown complete")


app = FastAPI(title="Microservice API", lifespan=lifespan)


@app.middleware("http")
async def metrics_middleware(request: Request, call_next):
    state = await get_app_state()
    state.metrics["requests"] += 1
    try:
        response = await call_next(request)
        return response
    except Exception as e:
        state.metrics["errors"] += 1
        raise


@app.get("/health")
async def health_check(state: Annotated[AppState, Depends(get_app_state)]):
    return {"status": "healthy", "metrics": state.metrics}


@app.get("/users/{user_id}")
async def get_user(
    user_id: int,
    db: Annotated[Database, Depends(get_db)],
    cache: Annotated[Cache, Depends(get_cache)]
):
    cache_key = f"user:{user_id}"
    result = await db.fetch(f"SELECT * FROM users WHERE id = {user_id}")
    if not result:
        raise HTTPException(status_code=404, detail="User not found")
    return {"data": result[0]}
重点提示:
  • Annotated + Depends 是 FastAPI 推荐的现代依赖声明方式
  • 生命周期钩子(lifespan)替代了旧的 startup/shutdown 事件
  • 依赖可以嵌套形成依赖树,FastAPI 自动处理缓存

gRPC 流式通信与双向 RPC 模式

gRPC 基于 HTTP/2 和 Protocol Buffers,提供高效的二进制序列化、流式通信和强类型契约。支持四种通信模式:简单 RPC、服务端流式、客户端流式、双向流式。

# gRPC 服务端实现
import asyncio
import time
from concurrent import futures
import grpc
from generated import service_pb2, service_pb2_grpc


class OrderServiceServicer(service_pb2_grpc.OrderServiceServicer):
    def CreateOrder(self, request, context):
        order_id = f"ORD-{int(time.time() * 1000)}"
        total = sum(item.price * item.quantity for item in request.items)
        order = service_pb2.Order(
            order_id=order_id,
            user_id=request.user_id,
            total_amount=total,
            status="PENDING"
        )
        return order

    def StreamOrderUpdates(self, request, context):
        """服务端流式 RPC"""
        order_id = request.order_id
        statuses = [
            ("PENDING", "Order received"),
            ("PROCESSING", "Payment processing"),
            ("CONFIRMED", "Payment confirmed"),
            ("SHIPPING", "Preparing shipment"),
            ("DELIVERED", "Order delivered")
        ]
        for status, message in statuses:
            yield service_pb2.OrderUpdate(
                order_id=order_id,
                status=status,
                message=message
            )
            time.sleep(1)

    def Chat(self, request_iterator, context):
        """双向流式 RPC"""
        for message in request_iterator:
            yield service_pb2.ChatMessage(
                sender_id="bot",
                message=f"Echo: {message.message}"
            )


def serve():
    server = grpc.server(
        futures.ThreadPoolExecutor(max_workers=10),
        options=[
            ('grpc.max_send_message_length', 50 * 1024 * 1024),
        ]
    )
    service_pb2_grpc.add_OrderServiceServicer_to_server(
        OrderServiceServicer(), server
    )
    server.add_insecure_port('[::]:50051')
    server.start()
    server.wait_for_termination()


# 异步 gRPC 客户端
class OrderServiceClient:
    def __init__(self, target: str = "localhost:50051"):
        self.channel = grpc.aio.insecure_channel(target)
        self.stub = service_pb2_grpc.OrderServiceStub(self.channel)

    async def create_order(self, user_id: str, items: list):
        request = service_pb2.CreateOrderRequest(
            user_id=user_id,
            items=[service_pb2.OrderItem(**item) for item in items]
        )
        return await self.stub.CreateOrder(request)

    async def stream_order_updates(self, order_id: str):
        request = service_pb2.OrderIdRequest(order_id=order_id)
        async for update in self.stub.StreamOrderUpdates(request):
            yield update

    async def close(self):
        await self.channel.close()
重点提示:
  • gRPC 使用 HTTP/2 多路复用,单个连接可处理多个 RPC 调用
  • 流式 RPC 支持背压控制,可根据处理能力调整发送速率
  • 生产环境应使用 TLS 加密和客户端证书认证

服务发现、负载均衡与熔断降级

服务注册中心解决服务位置动态变化问题;负载均衡算法决定流量分发;熔断器防止故障级联扩散。

import asyncio
import aiohttp
import random
from typing import List, Optional, Dict
from dataclasses import dataclass, field
from enum import Enum, auto


@dataclass
class ServiceInstance:
    service_name: str
    instance_id: str
    host: str
    port: int
    weight: int = 1

    @property
    def address(self) -> str:
        return f"{self.host}:{self.port}"


class ServiceRegistry:
    def __init__(self, registry_url: str):
        self.registry_url = registry_url
        self._local_cache: Dict[str, List[ServiceInstance]] = {}

    async def discover(self, service_name: str) -> List[ServiceInstance]:
        if service_name in self._local_cache:
            return self._local_cache[service_name]

        async with aiohttp.ClientSession() as session:
            async with session.get(
                f"{self.registry_url}/services/{service_name}"
            ) as resp:
                data = await resp.json()
                instances = [
                    ServiceInstance(
                        service_name=service_name,
                        instance_id=i["id"],
                        host=i["host"],
                        port=i["port"],
                        weight=i.get("weight", 1)
                    )
                    for i in data.get("instances", [])
                ]
                self._local_cache[service_name] = instances
                return instances


class LoadBalancer:
    def __init__(self, strategy: str = "round_robin"):
        self.strategy = strategy
        self._index = 0

    def select(self, instances: List[ServiceInstance]) -> Optional[ServiceInstance]:
        if not instances:
            return None
        if self.strategy == "round_robin":
            idx = self._index % len(instances)
            self._index += 1
            return instances[idx]
        elif self.strategy == "random":
            return random.choice(instances)
        elif self.strategy == "weighted":
            total = sum(i.weight for i in instances)
            r = random.randint(1, total)
            for inst in instances:
                r -= inst.weight
                if r <= 0:
                    return inst
        return instances[0]


class CircuitState(Enum):
    CLOSED = auto()
    OPEN = auto()
    HALF_OPEN = auto()


class CircuitBreaker:
    """熔断器模式实现"""

    def __init__(self, name: str, failure_threshold: int = 5,
                 recovery_timeout: float = 30.0):
        self.name = name
        self.failure_threshold = failure_threshold
        self.recovery_timeout = recovery_timeout
        self._state = CircuitState.CLOSED
        self._failure_count = 0
        self._last_failure_time: Optional[float] = None
        self._lock = asyncio.Lock()

    async def call(self, func, *args, **kwargs):
        async with self._lock:
            if self._state == CircuitState.OPEN:
                if time.time() - self._last_failure_time >= self.recovery_timeout:
                    self._state = CircuitState.HALF_OPEN
                else:
                    raise CircuitBreakerOpen(f"Circuit {self.name} is OPEN")

        try:
            result = await func(*args, **kwargs)
            await self._on_success()
            return result
        except Exception as e:
            await self._on_failure()
            raise

    async def _on_success(self):
        async with self._lock:
            if self._state == CircuitState.HALF_OPEN:
                self._state = CircuitState.CLOSED
                self._failure_count = 0
            else:
                self._failure_count = max(0, self._failure_count - 1)

    async def _on_failure(self):
        async with self._lock:
            self._failure_count += 1
            self._last_failure_time = time.time()
            if self._failure_count >= self.failure_threshold:
                self._state = CircuitState.OPEN
                print(f"Circuit {self.name} OPEN")


class CircuitBreakerOpen(Exception):
    pass
重点提示:
  • 服务注册中心应支持 TTL 健康检查,自动剔除无心跳实例
  • 一致性哈希适用于有状态服务,避免请求重新分配导致缓存失效
  • 熔断器应支持半开状态,自动检测下游服务恢复情况

分布式链路追踪与可观测性

链路追踪通过为每个请求分配 Trace ID,在服务间传递 Span 上下文,使得跨服务调用链可视化。OpenTelemetry 是当前标准实现。

import time
import uuid
from contextvars import ContextVar
from typing import Optional, Dict, Any
from dataclasses import dataclass, field


trace_context: ContextVar[Optional["Span"]] = ContextVar(
    "trace_context", default=None
)


@dataclass
class Span:
    trace_id: str
    span_id: str
    parent_id: Optional[str] = None
    name: str = ""
    start_time: float = field(default_factory=time.time)
    end_time: Optional[float] = None
    tags: Dict[str, Any] = field(default_factory=dict)
    status: str = "OK"

    def finish(self, status: str = "OK"):
        self.end_time = time.time()
        self.status = status

    def set_tag(self, key: str, value: Any):
        self.tags[key] = value


class Tracer:
    def __init__(self, service_name: str):
        self.service_name = service_name

    def start_span(self, name: str, parent: Optional[Span] = None) -> Span:
        trace_id = parent.trace_id if parent else uuid.uuid4().hex[:16]
        span = Span(
            trace_id=trace_id,
            span_id=uuid.uuid4().hex[:8],
            parent_id=parent.span_id if parent else None,
            name=name,
        )
        span.set_tag("service", self.service_name)
        return span

    def finish_span(self, span: Span):
        span.finish()
        duration = (span.end_time - span.start_time) * 1000
        print(f"[TRACE] {span.name}: {duration:.2f}ms")


class MetricsCollector:
    def __init__(self, prefix: str = "app"):
        self.prefix = prefix
        self._counters: Dict[str, int] = {}
        self._histograms: Dict[str, List[float]] = {}

    def increment(self, name: str, value: int = 1):
        full_name = f"{self.prefix}_{name}"
        self._counters[full_name] = self._counters.get(full_name, 0) + value

    def histogram(self, name: str, value: float):
        full_name = f"{self.prefix}_{name}"
        if full_name not in self._histograms:
            self._histograms[full_name] = []
        self._histograms[full_name].append(value)

    def get_metrics(self) -> Dict:
        return {
            "counters": self._counters,
            "histograms": {k: {"count": len(v), "avg": sum(v)/len(v)}
                          for k, v in self._histograms.items()},
        }


# 可观测性中间件
class ObservabilityMiddleware:
    def __init__(self, tracer: Tracer, metrics: MetricsCollector):
        self.tracer = tracer
        self.metrics = metrics

    async def __call__(self, request, call_next):
        trace_id = request.headers.get("X-Trace-ID") or uuid.uuid4().hex[:16]
        span = self.tracer.start_span(f"{request.method} {request.url.path}")
        span.set_tag("trace_id", trace_id)

        start_time = time.time()
        try:
            response = await call_next(request)
            span.set_tag("status_code", response.status_code)
            span.finish("OK")
            return response
        except Exception as e:
            span.set_tag("error", str(e))
            span.finish("ERROR")
            self.metrics.increment("errors_total")
            raise
        finally:
            duration = time.time() - start_time
            self.metrics.histogram("request_duration_seconds", duration)
            self.tracer.finish_span(span)
            response.headers["X-Trace-ID"] = trace_id
重点提示:
  • OpenTelemetry SDK 可统一收集 traces、metrics 和 logs
  • Trace ID 应在所有服务间传递,通过 HTTP 头或 gRPC 元数据
  • 高流量场景应使用概率采样(如 1%)避免数据量过大

实战:完整的微服务通信框架

import asyncio
import aiohttp
from typing import Dict
from dataclasses import dataclass
import uuid


@dataclass
class ServiceConfig:
    name: str
    host: str = "0.0.0.0"
    http_port: int = 8000
    grpc_port: int = 50051
    registry_url: str = "http://localhost:8500"


class MicroserviceFramework:
    """微服务框架核心"""

    def __init__(self, config: ServiceConfig):
        self.config = config
        self.registry = ServiceRegistry(config.registry_url)
        self.load_balancer = LoadBalancer(strategy="round_robin")
        self.circuit_breakers: Dict[str, CircuitBreaker] = {}
        self.tracer = Tracer(config.name)
        self.metrics = MetricsCollector(config.name)

    async def call_service(self, service_name: str, endpoint: str,
                          method: str = "GET", data: dict = None) -> dict:
        """带熔断保护的服务调用"""
        if service_name not in self.circuit_breakers:
            self.circuit_breakers[service_name] = CircuitBreaker(service_name)

        cb = self.circuit_breakers[service_name]
        span = self.tracer.start_span(f"call.{service_name}.{endpoint}")

        async def _do_request():
            instances = await self.registry.discover(service_name)
            if not instances:
                raise Exception(f"Service {service_name} not found")
            instance = self.load_balancer.select(instances)
            url = f"http://{instance.address}{endpoint}"

            async with aiohttp.ClientSession() as session:
                async with session.request(method, url, json=data) as resp:
                    return await resp.json()

        try:
            result = await cb.call(_do_request)
            span.finish("OK")
            self.metrics.increment("calls_success")
            return result
        except CircuitBreakerOpen:
            span.finish("CIRCUIT_OPEN")
            self.metrics.increment("calls_circuit_open")
            raise
        except Exception as e:
            span.finish("ERROR")
            self.metrics.increment("calls_error")
            raise
        finally:
            self.tracer.finish_span(span)

    async def start(self):
        """启动服务"""
        instance = ServiceInstance(
            service_name=self.config.name,
            instance_id=f"{self.config.name}-{uuid.uuid4().hex[:8]}",
            host=self.config.host,
            port=self.config.http_port
        )
        await self.registry.register(instance)
        print(f"Service {self.config.name} started")


# 使用示例
async def main():
    config = ServiceConfig(name="order-service", http_port=8001)
    framework = MicroserviceFramework(config)

    # 调用用户服务
    user = await framework.call_service(
        "user-service", "/users/123"
    )
    print(f"User data: {user}")

    await framework.start()


if __name__ == "__main__":
    asyncio.run(main())
重点提示:
  • 生产环境应使用成熟框架如 FastAPI + Nameko 或 Dapr sidecar 模式
  • Service Mesh(Istio、Linkerd)可在基础设施层实现上述能力
  • 服务配置应支持热更新,避免重启才能修改超时等参数

架构决策总结

Python 微服务架构设计需在开发效率和运维复杂度间取得平衡。关键决策点:

  • 通信协议:内部服务优先 gRPC,外部网关使用 REST + GraphQL
  • 服务发现:K8s 环境使用 DNS + Headless Service,非 K8s 使用 Consul/Nacos
  • 可观测性:OpenTelemetry 统一采集,Prometheus + Grafana 监控,Jaeger 链路追踪
  • 部署模式:容器化部署,Kubernetes 编排,GitOps 工作流

Python 在微服务生态中的优势在于开发速度快、生态丰富。通过合理的架构设计和对异步编程模型的深入理解,Python 完全可以支撑高并发、高可用的生产级微服务体系。