Files
nexus-claude-api/src/nexus_claude_api/routes/messages.py

170 lines
5.4 KiB
Python

from __future__ import annotations
import logging
import uuid
from collections.abc import Iterable, Iterator
from typing import Annotated, Any
from fastapi import APIRouter, Depends, Request
from fastapi.responses import JSONResponse, Response, StreamingResponse
from pydantic import ValidationError
from nexus_claude_api.diagnostics import summarize_messages_request
from nexus_claude_api.errors import NexusClaudeError, anthropic_error_response
from nexus_claude_api.models import (
AnthropicMessagesRequest,
CountTokensRequest,
CountTokensResponse,
)
from nexus_claude_api.nexus_client import NexusClient
from nexus_claude_api.tokens import estimate_input_tokens
from nexus_claude_api.translators.anthropic_to_bedrock import anthropic_to_bedrock_request
from nexus_claude_api.translators.bedrock_to_anthropic import bedrock_to_anthropic_response
from nexus_claude_api.translators.stream import (
bedrock_stream_to_anthropic_events,
sse_frame,
)
router = APIRouter()
logger = logging.getLogger(__name__)
def get_nexus_client(request: Request) -> NexusClient:
return request.app.state.nexus_client
def normalize_legacy_system_messages(raw: object) -> object:
if not isinstance(raw, dict):
return raw
messages = raw.get("messages")
if not isinstance(messages, list):
return raw
system_parts: list[object] = []
normalized_messages: list[object] = []
changed = False
for message in messages:
if isinstance(message, dict) and message.get("role") == "system":
system_parts.append(message.get("content", ""))
changed = True
else:
normalized_messages.append(message)
if not changed:
return raw
normalized = dict(raw)
normalized["messages"] = normalized_messages
if "system" not in normalized and system_parts:
normalized["system"] = system_parts[0] if len(system_parts) == 1 else system_parts
return normalized
def anthropic_sse_stream(
stream: Iterable[dict[str, Any]],
*,
model: str,
correlation_id: str,
) -> Iterator[str]:
try:
for event in bedrock_stream_to_anthropic_events(stream, model=model):
yield sse_frame(event)
except NexusClaudeError as exc:
yield sse_frame(
{
"type": "error",
"error": {
"type": exc.error_type,
"message": f"{exc.message} [correlation_id={correlation_id}]",
"correlation_id": correlation_id,
},
}
)
except Exception:
logger.exception(
"anthropic_messages_stream_error correlation_id=%s",
correlation_id,
)
yield sse_frame(
{
"type": "error",
"error": {
"type": "api_error",
"message": (
"Unexpected error while streaming response "
f"[correlation_id={correlation_id}]"
),
"correlation_id": correlation_id,
},
}
)
@router.post("/v1/messages", response_model=None)
async def create_message(
request: Request,
client: Annotated[NexusClient, Depends(get_nexus_client)],
) -> Response:
correlation_id = uuid.uuid4().hex
try:
raw = normalize_legacy_system_messages(await request.json())
payload = AnthropicMessagesRequest.model_validate(raw)
summary = summarize_messages_request(
payload,
correlation_id=correlation_id,
)
logger.info(
"anthropic_messages_request %s",
summary,
extra=summary,
)
bedrock_request = anthropic_to_bedrock_request(payload)
if payload.stream:
stream = client.converse_stream(
bedrock_request,
correlation_id=correlation_id,
)
return StreamingResponse(
anthropic_sse_stream(
stream,
model=payload.model,
correlation_id=correlation_id,
),
media_type="text/event-stream",
)
response = client.converse(
bedrock_request,
correlation_id=correlation_id,
)
anthropic_response = bedrock_to_anthropic_response(
response,
model=payload.model,
)
return JSONResponse(content=anthropic_response.model_dump(exclude_none=True))
except ValidationError as exc:
return anthropic_error_response(
str(exc),
status_code=400,
correlation_id=correlation_id,
)
except NexusClaudeError as exc:
return anthropic_error_response(
exc.message,
status_code=exc.status_code,
error_type=exc.error_type,
correlation_id=correlation_id,
)
@router.post("/v1/messages/count_tokens")
async def count_tokens(request: Request) -> JSONResponse:
try:
raw = normalize_legacy_system_messages(await request.json())
payload = CountTokensRequest.model_validate(raw)
response = CountTokensResponse(input_tokens=estimate_input_tokens(payload))
return JSONResponse(content=response.model_dump())
except ValidationError as exc:
return anthropic_error_response(str(exc), status_code=400)