from __future__ import annotations from typing import Annotated from fastapi import APIRouter, Depends, Request from fastapi.responses import JSONResponse, Response, StreamingResponse from pydantic import ValidationError from nexus_claude_api.errors import NexusClaudeError, anthropic_error_response from nexus_claude_api.models import ( AnthropicMessagesRequest, CountTokensRequest, CountTokensResponse, ) from nexus_claude_api.nexus_client import NexusClient from nexus_claude_api.tokens import estimate_input_tokens from nexus_claude_api.translators.anthropic_to_bedrock import anthropic_to_bedrock_request from nexus_claude_api.translators.bedrock_to_anthropic import bedrock_to_anthropic_response from nexus_claude_api.translators.stream import ( bedrock_stream_to_anthropic_events, sse_frame, ) router = APIRouter() def get_nexus_client(request: Request) -> NexusClient: return request.app.state.nexus_client @router.post("/v1/messages", response_model=None) async def create_message( request: Request, client: Annotated[NexusClient, Depends(get_nexus_client)], ) -> Response: try: raw = await request.json() payload = AnthropicMessagesRequest.model_validate(raw) bedrock_request = anthropic_to_bedrock_request(payload) if payload.stream: stream = client.converse_stream(bedrock_request) return StreamingResponse( ( sse_frame(event) for event in bedrock_stream_to_anthropic_events( stream, model=payload.model, ) ), media_type="text/event-stream", ) response = client.converse(bedrock_request) anthropic_response = bedrock_to_anthropic_response( response, model=payload.model, ) return JSONResponse(content=anthropic_response.model_dump(exclude_none=True)) except ValidationError as exc: return anthropic_error_response(str(exc), status_code=400) except NexusClaudeError as exc: return anthropic_error_response( exc.message, status_code=exc.status_code, error_type=exc.error_type, ) @router.post("/v1/messages/count_tokens") async def count_tokens(request: Request) -> JSONResponse: try: raw = await request.json() payload = CountTokensRequest.model_validate(raw) response = CountTokensResponse(input_tokens=estimate_input_tokens(payload)) return JSONResponse(content=response.model_dump()) except ValidationError as exc: return anthropic_error_response(str(exc), status_code=400)