from __future__ import annotations import logging import uuid from typing import Annotated from fastapi import APIRouter, Depends, Request from fastapi.responses import JSONResponse, Response, StreamingResponse from pydantic import ValidationError from nexus_claude_api.diagnostics import summarize_messages_request from nexus_claude_api.errors import NexusClaudeError, anthropic_error_response from nexus_claude_api.models import ( AnthropicMessagesRequest, CountTokensRequest, CountTokensResponse, ) from nexus_claude_api.nexus_client import NexusClient from nexus_claude_api.tokens import estimate_input_tokens from nexus_claude_api.translators.anthropic_to_bedrock import anthropic_to_bedrock_request from nexus_claude_api.translators.bedrock_to_anthropic import bedrock_to_anthropic_response from nexus_claude_api.translators.stream import ( bedrock_stream_to_anthropic_events, sse_frame, ) router = APIRouter() logger = logging.getLogger(__name__) def get_nexus_client(request: Request) -> NexusClient: return request.app.state.nexus_client def normalize_legacy_system_messages(raw: object) -> object: if not isinstance(raw, dict): return raw messages = raw.get("messages") if not isinstance(messages, list): return raw system_parts: list[object] = [] normalized_messages: list[object] = [] changed = False for message in messages: if isinstance(message, dict) and message.get("role") == "system": system_parts.append(message.get("content", "")) changed = True else: normalized_messages.append(message) if not changed: return raw normalized = dict(raw) normalized["messages"] = normalized_messages if "system" not in normalized and system_parts: normalized["system"] = system_parts[0] if len(system_parts) == 1 else system_parts return normalized @router.post("/v1/messages", response_model=None) async def create_message( request: Request, client: Annotated[NexusClient, Depends(get_nexus_client)], ) -> Response: correlation_id = uuid.uuid4().hex try: raw = normalize_legacy_system_messages(await request.json()) payload = AnthropicMessagesRequest.model_validate(raw) summary = summarize_messages_request( payload, correlation_id=correlation_id, ) logger.info( "anthropic_messages_request %s", summary, extra=summary, ) bedrock_request = anthropic_to_bedrock_request(payload) if payload.stream: stream = client.converse_stream( bedrock_request, correlation_id=correlation_id, ) return StreamingResponse( ( sse_frame(event) for event in bedrock_stream_to_anthropic_events( stream, model=payload.model, ) ), media_type="text/event-stream", ) response = client.converse( bedrock_request, correlation_id=correlation_id, ) anthropic_response = bedrock_to_anthropic_response( response, model=payload.model, ) return JSONResponse(content=anthropic_response.model_dump(exclude_none=True)) except ValidationError as exc: return anthropic_error_response( str(exc), status_code=400, correlation_id=correlation_id, ) except NexusClaudeError as exc: return anthropic_error_response( exc.message, status_code=exc.status_code, error_type=exc.error_type, correlation_id=correlation_id, ) @router.post("/v1/messages/count_tokens") async def count_tokens(request: Request) -> JSONResponse: try: raw = normalize_legacy_system_messages(await request.json()) payload = CountTokensRequest.model_validate(raw) response = CountTokensResponse(input_tokens=estimate_input_tokens(payload)) return JSONResponse(content=response.model_dump()) except ValidationError as exc: return anthropic_error_response(str(exc), status_code=400)