from __future__ import annotations import logging import uuid from collections.abc import Iterable, Iterator from typing import Annotated, Any from fastapi import APIRouter, Depends, Request from fastapi.responses import JSONResponse, Response, StreamingResponse from pydantic import ValidationError from nexus_claude_api.diagnostics import summarize_messages_request from nexus_claude_api.errors import NexusClaudeError, anthropic_error_response from nexus_claude_api.models import ( AnthropicMessagesRequest, CountTokensRequest, CountTokensResponse, ) from nexus_claude_api.nexus_client import NexusClient from nexus_claude_api.tokens import estimate_input_tokens from nexus_claude_api.translators.anthropic_to_bedrock import anthropic_to_bedrock_request from nexus_claude_api.translators.bedrock_to_anthropic import bedrock_to_anthropic_response from nexus_claude_api.translators.stream import ( bedrock_stream_to_anthropic_events, sse_frame, ) router = APIRouter() logger = logging.getLogger(__name__) def get_nexus_client(request: Request) -> NexusClient: return request.app.state.nexus_client def normalize_legacy_system_messages(raw: object) -> object: if not isinstance(raw, dict): return raw messages = raw.get("messages") if not isinstance(messages, list): return raw system_parts: list[object] = [] normalized_messages: list[object] = [] changed = False for message in messages: if isinstance(message, dict) and message.get("role") == "system": system_parts.append(message.get("content", "")) changed = True else: normalized_messages.append(message) if not changed: return raw normalized = dict(raw) normalized["messages"] = normalized_messages if "system" not in normalized and system_parts: normalized["system"] = system_parts[0] if len(system_parts) == 1 else system_parts return normalized def anthropic_sse_stream( stream: Iterable[dict[str, Any]], *, model: str, correlation_id: str, ) -> Iterator[str]: try: for event in bedrock_stream_to_anthropic_events(stream, model=model): yield sse_frame(event) except NexusClaudeError as exc: yield sse_frame( { "type": "error", "error": { "type": exc.error_type, "message": f"{exc.message} [correlation_id={correlation_id}]", "correlation_id": correlation_id, }, } ) except Exception: logger.exception( "anthropic_messages_stream_error correlation_id=%s", correlation_id, ) yield sse_frame( { "type": "error", "error": { "type": "api_error", "message": ( "Unexpected error while streaming response " f"[correlation_id={correlation_id}]" ), "correlation_id": correlation_id, }, } ) @router.post("/v1/messages", response_model=None) async def create_message( request: Request, client: Annotated[NexusClient, Depends(get_nexus_client)], ) -> Response: correlation_id = uuid.uuid4().hex try: raw = normalize_legacy_system_messages(await request.json()) payload = AnthropicMessagesRequest.model_validate(raw) summary = summarize_messages_request( payload, correlation_id=correlation_id, ) logger.info( "anthropic_messages_request %s", summary, extra=summary, ) bedrock_request = anthropic_to_bedrock_request(payload) if payload.stream: stream = client.converse_stream( bedrock_request, correlation_id=correlation_id, ) return StreamingResponse( anthropic_sse_stream( stream, model=payload.model, correlation_id=correlation_id, ), media_type="text/event-stream", ) response = client.converse( bedrock_request, correlation_id=correlation_id, ) anthropic_response = bedrock_to_anthropic_response( response, model=payload.model, ) return JSONResponse(content=anthropic_response.model_dump(exclude_none=True)) except ValidationError as exc: return anthropic_error_response( str(exc), status_code=400, correlation_id=correlation_id, ) except NexusClaudeError as exc: return anthropic_error_response( exc.message, status_code=exc.status_code, error_type=exc.error_type, correlation_id=correlation_id, ) @router.post("/v1/messages/count_tokens") async def count_tokens(request: Request) -> JSONResponse: try: raw = normalize_legacy_system_messages(await request.json()) payload = CountTokensRequest.model_validate(raw) response = CountTokensResponse(input_tokens=estimate_input_tokens(payload)) return JSONResponse(content=response.model_dump()) except ValidationError as exc: return anthropic_error_response(str(exc), status_code=400)