FastAPI 异步架构与依赖注入体系
FastAPI 是现代 Python 微服务开发的首选框架。其架构设计围绕三个核心支柱:基于 Starlette 的高性能异步 HTTP 引擎、基于 Pydantic 的自动请求验证和序列化、以及类型提示驱动的依赖注入系统。
from fastapi import FastAPI, Depends, HTTPException, Request
from contextlib import asynccontextmanager
from typing import AsyncGenerator, Annotated
import asyncio
class Database:
def __init__(self, dsn: str):
self.dsn = dsn
self.pool = None
async def connect(self):
self.pool = {"connections": 10}
print(f"Connected to {self.dsn}")
async def disconnect(self):
self.pool = None
async def fetch(self, query: str):
await asyncio.sleep(0.001)
return [{"id": 1, "name": "test"}]
class Cache:
def __init__(self, redis_url: str):
self.redis_url = redis_url
async def connect(self):
print(f"Cache connected to {self.redis_url}")
async def disconnect(self):
print("Cache disconnected")
class AppState:
def __init__(self):
self.db = Database("postgresql://localhost/myapp")
self.cache = Cache("redis://localhost")
self.metrics = {"requests": 0, "errors": 0}
_app_state: AppState | None = None
async def get_app_state() -> AppState:
global _app_state
if _app_state is None:
raise RuntimeError("App not initialized")
return _app_state
async def get_db(state: Annotated[AppState, Depends(get_app_state)]) -> Database:
return state.db
async def get_cache(state: Annotated[AppState, Depends(get_app_state)]) -> Cache:
return state.cache
@asynccontextmanager
async def lifespan(app: FastAPI) -> AsyncGenerator[AppState, None]:
global _app_state
_app_state = AppState()
await _app_state.db.connect()
await _app_state.cache.connect()
print("Application startup complete")
yield _app_state
await _app_state.cache.disconnect()
await _app_state.db.disconnect()
_app_state = None
print("Application shutdown complete")
app = FastAPI(title="Microservice API", lifespan=lifespan)
@app.middleware("http")
async def metrics_middleware(request: Request, call_next):
state = await get_app_state()
state.metrics["requests"] += 1
try:
response = await call_next(request)
return response
except Exception as e:
state.metrics["errors"] += 1
raise
@app.get("/health")
async def health_check(state: Annotated[AppState, Depends(get_app_state)]):
return {"status": "healthy", "metrics": state.metrics}
@app.get("/users/{user_id}")
async def get_user(
user_id: int,
db: Annotated[Database, Depends(get_db)],
cache: Annotated[Cache, Depends(get_cache)]
):
cache_key = f"user:{user_id}"
result = await db.fetch(f"SELECT * FROM users WHERE id = {user_id}")
if not result:
raise HTTPException(status_code=404, detail="User not found")
return {"data": result[0]}
重点提示:
- Annotated + Depends 是 FastAPI 推荐的现代依赖声明方式
- 生命周期钩子(lifespan)替代了旧的 startup/shutdown 事件
- 依赖可以嵌套形成依赖树,FastAPI 自动处理缓存
gRPC 流式通信与双向 RPC 模式
gRPC 基于 HTTP/2 和 Protocol Buffers,提供高效的二进制序列化、流式通信和强类型契约。支持四种通信模式:简单 RPC、服务端流式、客户端流式、双向流式。
# gRPC 服务端实现
import asyncio
import time
from concurrent import futures
import grpc
from generated import service_pb2, service_pb2_grpc
class OrderServiceServicer(service_pb2_grpc.OrderServiceServicer):
def CreateOrder(self, request, context):
order_id = f"ORD-{int(time.time() * 1000)}"
total = sum(item.price * item.quantity for item in request.items)
order = service_pb2.Order(
order_id=order_id,
user_id=request.user_id,
total_amount=total,
status="PENDING"
)
return order
def StreamOrderUpdates(self, request, context):
"""服务端流式 RPC"""
order_id = request.order_id
statuses = [
("PENDING", "Order received"),
("PROCESSING", "Payment processing"),
("CONFIRMED", "Payment confirmed"),
("SHIPPING", "Preparing shipment"),
("DELIVERED", "Order delivered")
]
for status, message in statuses:
yield service_pb2.OrderUpdate(
order_id=order_id,
status=status,
message=message
)
time.sleep(1)
def Chat(self, request_iterator, context):
"""双向流式 RPC"""
for message in request_iterator:
yield service_pb2.ChatMessage(
sender_id="bot",
message=f"Echo: {message.message}"
)
def serve():
server = grpc.server(
futures.ThreadPoolExecutor(max_workers=10),
options=[
('grpc.max_send_message_length', 50 * 1024 * 1024),
]
)
service_pb2_grpc.add_OrderServiceServicer_to_server(
OrderServiceServicer(), server
)
server.add_insecure_port('[::]:50051')
server.start()
server.wait_for_termination()
# 异步 gRPC 客户端
class OrderServiceClient:
def __init__(self, target: str = "localhost:50051"):
self.channel = grpc.aio.insecure_channel(target)
self.stub = service_pb2_grpc.OrderServiceStub(self.channel)
async def create_order(self, user_id: str, items: list):
request = service_pb2.CreateOrderRequest(
user_id=user_id,
items=[service_pb2.OrderItem(**item) for item in items]
)
return await self.stub.CreateOrder(request)
async def stream_order_updates(self, order_id: str):
request = service_pb2.OrderIdRequest(order_id=order_id)
async for update in self.stub.StreamOrderUpdates(request):
yield update
async def close(self):
await self.channel.close()
重点提示:
- gRPC 使用 HTTP/2 多路复用,单个连接可处理多个 RPC 调用
- 流式 RPC 支持背压控制,可根据处理能力调整发送速率
- 生产环境应使用 TLS 加密和客户端证书认证
服务发现、负载均衡与熔断降级
服务注册中心解决服务位置动态变化问题;负载均衡算法决定流量分发;熔断器防止故障级联扩散。
import asyncio
import aiohttp
import random
from typing import List, Optional, Dict
from dataclasses import dataclass, field
from enum import Enum, auto
@dataclass
class ServiceInstance:
service_name: str
instance_id: str
host: str
port: int
weight: int = 1
@property
def address(self) -> str:
return f"{self.host}:{self.port}"
class ServiceRegistry:
def __init__(self, registry_url: str):
self.registry_url = registry_url
self._local_cache: Dict[str, List[ServiceInstance]] = {}
async def discover(self, service_name: str) -> List[ServiceInstance]:
if service_name in self._local_cache:
return self._local_cache[service_name]
async with aiohttp.ClientSession() as session:
async with session.get(
f"{self.registry_url}/services/{service_name}"
) as resp:
data = await resp.json()
instances = [
ServiceInstance(
service_name=service_name,
instance_id=i["id"],
host=i["host"],
port=i["port"],
weight=i.get("weight", 1)
)
for i in data.get("instances", [])
]
self._local_cache[service_name] = instances
return instances
class LoadBalancer:
def __init__(self, strategy: str = "round_robin"):
self.strategy = strategy
self._index = 0
def select(self, instances: List[ServiceInstance]) -> Optional[ServiceInstance]:
if not instances:
return None
if self.strategy == "round_robin":
idx = self._index % len(instances)
self._index += 1
return instances[idx]
elif self.strategy == "random":
return random.choice(instances)
elif self.strategy == "weighted":
total = sum(i.weight for i in instances)
r = random.randint(1, total)
for inst in instances:
r -= inst.weight
if r <= 0:
return inst
return instances[0]
class CircuitState(Enum):
CLOSED = auto()
OPEN = auto()
HALF_OPEN = auto()
class CircuitBreaker:
"""熔断器模式实现"""
def __init__(self, name: str, failure_threshold: int = 5,
recovery_timeout: float = 30.0):
self.name = name
self.failure_threshold = failure_threshold
self.recovery_timeout = recovery_timeout
self._state = CircuitState.CLOSED
self._failure_count = 0
self._last_failure_time: Optional[float] = None
self._lock = asyncio.Lock()
async def call(self, func, *args, **kwargs):
async with self._lock:
if self._state == CircuitState.OPEN:
if time.time() - self._last_failure_time >= self.recovery_timeout:
self._state = CircuitState.HALF_OPEN
else:
raise CircuitBreakerOpen(f"Circuit {self.name} is OPEN")
try:
result = await func(*args, **kwargs)
await self._on_success()
return result
except Exception as e:
await self._on_failure()
raise
async def _on_success(self):
async with self._lock:
if self._state == CircuitState.HALF_OPEN:
self._state = CircuitState.CLOSED
self._failure_count = 0
else:
self._failure_count = max(0, self._failure_count - 1)
async def _on_failure(self):
async with self._lock:
self._failure_count += 1
self._last_failure_time = time.time()
if self._failure_count >= self.failure_threshold:
self._state = CircuitState.OPEN
print(f"Circuit {self.name} OPEN")
class CircuitBreakerOpen(Exception):
pass
重点提示:
- 服务注册中心应支持 TTL 健康检查,自动剔除无心跳实例
- 一致性哈希适用于有状态服务,避免请求重新分配导致缓存失效
- 熔断器应支持半开状态,自动检测下游服务恢复情况
分布式链路追踪与可观测性
链路追踪通过为每个请求分配 Trace ID,在服务间传递 Span 上下文,使得跨服务调用链可视化。OpenTelemetry 是当前标准实现。
import time
import uuid
from contextvars import ContextVar
from typing import Optional, Dict, Any
from dataclasses import dataclass, field
trace_context: ContextVar[Optional["Span"]] = ContextVar(
"trace_context", default=None
)
@dataclass
class Span:
trace_id: str
span_id: str
parent_id: Optional[str] = None
name: str = ""
start_time: float = field(default_factory=time.time)
end_time: Optional[float] = None
tags: Dict[str, Any] = field(default_factory=dict)
status: str = "OK"
def finish(self, status: str = "OK"):
self.end_time = time.time()
self.status = status
def set_tag(self, key: str, value: Any):
self.tags[key] = value
class Tracer:
def __init__(self, service_name: str):
self.service_name = service_name
def start_span(self, name: str, parent: Optional[Span] = None) -> Span:
trace_id = parent.trace_id if parent else uuid.uuid4().hex[:16]
span = Span(
trace_id=trace_id,
span_id=uuid.uuid4().hex[:8],
parent_id=parent.span_id if parent else None,
name=name,
)
span.set_tag("service", self.service_name)
return span
def finish_span(self, span: Span):
span.finish()
duration = (span.end_time - span.start_time) * 1000
print(f"[TRACE] {span.name}: {duration:.2f}ms")
class MetricsCollector:
def __init__(self, prefix: str = "app"):
self.prefix = prefix
self._counters: Dict[str, int] = {}
self._histograms: Dict[str, List[float]] = {}
def increment(self, name: str, value: int = 1):
full_name = f"{self.prefix}_{name}"
self._counters[full_name] = self._counters.get(full_name, 0) + value
def histogram(self, name: str, value: float):
full_name = f"{self.prefix}_{name}"
if full_name not in self._histograms:
self._histograms[full_name] = []
self._histograms[full_name].append(value)
def get_metrics(self) -> Dict:
return {
"counters": self._counters,
"histograms": {k: {"count": len(v), "avg": sum(v)/len(v)}
for k, v in self._histograms.items()},
}
# 可观测性中间件
class ObservabilityMiddleware:
def __init__(self, tracer: Tracer, metrics: MetricsCollector):
self.tracer = tracer
self.metrics = metrics
async def __call__(self, request, call_next):
trace_id = request.headers.get("X-Trace-ID") or uuid.uuid4().hex[:16]
span = self.tracer.start_span(f"{request.method} {request.url.path}")
span.set_tag("trace_id", trace_id)
start_time = time.time()
try:
response = await call_next(request)
span.set_tag("status_code", response.status_code)
span.finish("OK")
return response
except Exception as e:
span.set_tag("error", str(e))
span.finish("ERROR")
self.metrics.increment("errors_total")
raise
finally:
duration = time.time() - start_time
self.metrics.histogram("request_duration_seconds", duration)
self.tracer.finish_span(span)
response.headers["X-Trace-ID"] = trace_id
重点提示:
- OpenTelemetry SDK 可统一收集 traces、metrics 和 logs
- Trace ID 应在所有服务间传递,通过 HTTP 头或 gRPC 元数据
- 高流量场景应使用概率采样(如 1%)避免数据量过大
实战:完整的微服务通信框架
import asyncio
import aiohttp
from typing import Dict
from dataclasses import dataclass
import uuid
@dataclass
class ServiceConfig:
name: str
host: str = "0.0.0.0"
http_port: int = 8000
grpc_port: int = 50051
registry_url: str = "http://localhost:8500"
class MicroserviceFramework:
"""微服务框架核心"""
def __init__(self, config: ServiceConfig):
self.config = config
self.registry = ServiceRegistry(config.registry_url)
self.load_balancer = LoadBalancer(strategy="round_robin")
self.circuit_breakers: Dict[str, CircuitBreaker] = {}
self.tracer = Tracer(config.name)
self.metrics = MetricsCollector(config.name)
async def call_service(self, service_name: str, endpoint: str,
method: str = "GET", data: dict = None) -> dict:
"""带熔断保护的服务调用"""
if service_name not in self.circuit_breakers:
self.circuit_breakers[service_name] = CircuitBreaker(service_name)
cb = self.circuit_breakers[service_name]
span = self.tracer.start_span(f"call.{service_name}.{endpoint}")
async def _do_request():
instances = await self.registry.discover(service_name)
if not instances:
raise Exception(f"Service {service_name} not found")
instance = self.load_balancer.select(instances)
url = f"http://{instance.address}{endpoint}"
async with aiohttp.ClientSession() as session:
async with session.request(method, url, json=data) as resp:
return await resp.json()
try:
result = await cb.call(_do_request)
span.finish("OK")
self.metrics.increment("calls_success")
return result
except CircuitBreakerOpen:
span.finish("CIRCUIT_OPEN")
self.metrics.increment("calls_circuit_open")
raise
except Exception as e:
span.finish("ERROR")
self.metrics.increment("calls_error")
raise
finally:
self.tracer.finish_span(span)
async def start(self):
"""启动服务"""
instance = ServiceInstance(
service_name=self.config.name,
instance_id=f"{self.config.name}-{uuid.uuid4().hex[:8]}",
host=self.config.host,
port=self.config.http_port
)
await self.registry.register(instance)
print(f"Service {self.config.name} started")
# 使用示例
async def main():
config = ServiceConfig(name="order-service", http_port=8001)
framework = MicroserviceFramework(config)
# 调用用户服务
user = await framework.call_service(
"user-service", "/users/123"
)
print(f"User data: {user}")
await framework.start()
if __name__ == "__main__":
asyncio.run(main())
重点提示:
- 生产环境应使用成熟框架如 FastAPI + Nameko 或 Dapr sidecar 模式
- Service Mesh(Istio、Linkerd)可在基础设施层实现上述能力
- 服务配置应支持热更新,避免重启才能修改超时等参数
架构决策总结
Python 微服务架构设计需在开发效率和运维复杂度间取得平衡。关键决策点:
- 通信协议:内部服务优先 gRPC,外部网关使用 REST + GraphQL
- 服务发现:K8s 环境使用 DNS + Headless Service,非 K8s 使用 Consul/Nacos
- 可观测性:OpenTelemetry 统一采集,Prometheus + Grafana 监控,Jaeger 链路追踪
- 部署模式:容器化部署,Kubernetes 编排,GitOps 工作流
Python 在微服务生态中的优势在于开发速度快、生态丰富。通过合理的架构设计和对异步编程模型的深入理解,Python 完全可以支撑高并发、高可用的生产级微服务体系。