Files
nexus-claude-api/src/nexus_claude_api/routes/messages.py

170 lines
5.4 KiB
Python
Raw Normal View History

2026-06-26 17:02:21 +08:00
from __future__ import annotations
import logging
import uuid
from collections.abc import Iterable, Iterator
from typing import Annotated, Any
2026-06-26 17:02:21 +08:00
from fastapi import APIRouter, Depends, Request
from fastapi.responses import JSONResponse, Response, StreamingResponse
from pydantic import ValidationError
from nexus_claude_api.diagnostics import summarize_messages_request
2026-06-26 17:02:21 +08:00
from nexus_claude_api.errors import NexusClaudeError, anthropic_error_response
from nexus_claude_api.models import (
AnthropicMessagesRequest,
CountTokensRequest,
CountTokensResponse,
)
from nexus_claude_api.nexus_client import NexusClient
from nexus_claude_api.tokens import estimate_input_tokens
from nexus_claude_api.translators.anthropic_to_bedrock import anthropic_to_bedrock_request
from nexus_claude_api.translators.bedrock_to_anthropic import bedrock_to_anthropic_response
from nexus_claude_api.translators.stream import (
bedrock_stream_to_anthropic_events,
sse_frame,
)
router = APIRouter()
logger = logging.getLogger(__name__)
2026-06-26 17:02:21 +08:00
def get_nexus_client(request: Request) -> NexusClient:
return request.app.state.nexus_client
def normalize_legacy_system_messages(raw: object) -> object:
if not isinstance(raw, dict):
return raw
messages = raw.get("messages")
if not isinstance(messages, list):
return raw
system_parts: list[object] = []
normalized_messages: list[object] = []
changed = False
for message in messages:
if isinstance(message, dict) and message.get("role") == "system":
system_parts.append(message.get("content", ""))
changed = True
else:
normalized_messages.append(message)
if not changed:
return raw
normalized = dict(raw)
normalized["messages"] = normalized_messages
if "system" not in normalized and system_parts:
normalized["system"] = system_parts[0] if len(system_parts) == 1 else system_parts
return normalized
def anthropic_sse_stream(
stream: Iterable[dict[str, Any]],
*,
model: str,
correlation_id: str,
) -> Iterator[str]:
try:
for event in bedrock_stream_to_anthropic_events(stream, model=model):
yield sse_frame(event)
except NexusClaudeError as exc:
yield sse_frame(
{
"type": "error",
"error": {
"type": exc.error_type,
"message": f"{exc.message} [correlation_id={correlation_id}]",
"correlation_id": correlation_id,
},
}
)
except Exception:
logger.exception(
"anthropic_messages_stream_error correlation_id=%s",
correlation_id,
)
yield sse_frame(
{
"type": "error",
"error": {
"type": "api_error",
"message": (
"Unexpected error while streaming response "
f"[correlation_id={correlation_id}]"
),
"correlation_id": correlation_id,
},
}
)
2026-06-26 17:02:21 +08:00
@router.post("/v1/messages", response_model=None)
async def create_message(
request: Request,
client: Annotated[NexusClient, Depends(get_nexus_client)],
) -> Response:
correlation_id = uuid.uuid4().hex
2026-06-26 17:02:21 +08:00
try:
raw = normalize_legacy_system_messages(await request.json())
2026-06-26 17:02:21 +08:00
payload = AnthropicMessagesRequest.model_validate(raw)
summary = summarize_messages_request(
payload,
correlation_id=correlation_id,
)
logger.info(
"anthropic_messages_request %s",
summary,
extra=summary,
)
2026-06-26 17:02:21 +08:00
bedrock_request = anthropic_to_bedrock_request(payload)
if payload.stream:
stream = client.converse_stream(
bedrock_request,
correlation_id=correlation_id,
)
2026-06-26 17:02:21 +08:00
return StreamingResponse(
anthropic_sse_stream(
stream,
model=payload.model,
correlation_id=correlation_id,
2026-06-26 17:02:21 +08:00
),
media_type="text/event-stream",
)
response = client.converse(
bedrock_request,
correlation_id=correlation_id,
)
2026-06-26 17:02:21 +08:00
anthropic_response = bedrock_to_anthropic_response(
response,
model=payload.model,
)
return JSONResponse(content=anthropic_response.model_dump(exclude_none=True))
except ValidationError as exc:
return anthropic_error_response(
str(exc),
status_code=400,
correlation_id=correlation_id,
)
2026-06-26 17:02:21 +08:00
except NexusClaudeError as exc:
return anthropic_error_response(
exc.message,
status_code=exc.status_code,
error_type=exc.error_type,
correlation_id=correlation_id,
2026-06-26 17:02:21 +08:00
)
@router.post("/v1/messages/count_tokens")
async def count_tokens(request: Request) -> JSONResponse:
try:
raw = normalize_legacy_system_messages(await request.json())
2026-06-26 17:02:21 +08:00
payload = CountTokensRequest.model_validate(raw)
response = CountTokensResponse(input_tokens=estimate_input_tokens(payload))
return JSONResponse(content=response.model_dump())
except ValidationError as exc:
return anthropic_error_response(str(exc), status_code=400)