2026-06-26 17:02:21 +08:00
|
|
|
from __future__ import annotations
|
|
|
|
|
|
2026-06-26 22:36:09 +08:00
|
|
|
import logging
|
|
|
|
|
import uuid
|
2026-06-27 14:32:52 +08:00
|
|
|
from collections.abc import Iterable, Iterator
|
|
|
|
|
from typing import Annotated, Any
|
2026-06-26 17:02:21 +08:00
|
|
|
|
|
|
|
|
from fastapi import APIRouter, Depends, Request
|
|
|
|
|
from fastapi.responses import JSONResponse, Response, StreamingResponse
|
|
|
|
|
from pydantic import ValidationError
|
|
|
|
|
|
2026-06-26 22:36:09 +08:00
|
|
|
from nexus_claude_api.diagnostics import summarize_messages_request
|
2026-06-26 17:02:21 +08:00
|
|
|
from nexus_claude_api.errors import NexusClaudeError, anthropic_error_response
|
|
|
|
|
from nexus_claude_api.models import (
|
|
|
|
|
AnthropicMessagesRequest,
|
|
|
|
|
CountTokensRequest,
|
|
|
|
|
CountTokensResponse,
|
|
|
|
|
)
|
|
|
|
|
from nexus_claude_api.nexus_client import NexusClient
|
|
|
|
|
from nexus_claude_api.tokens import estimate_input_tokens
|
|
|
|
|
from nexus_claude_api.translators.anthropic_to_bedrock import anthropic_to_bedrock_request
|
|
|
|
|
from nexus_claude_api.translators.bedrock_to_anthropic import bedrock_to_anthropic_response
|
|
|
|
|
from nexus_claude_api.translators.stream import (
|
|
|
|
|
bedrock_stream_to_anthropic_events,
|
|
|
|
|
sse_frame,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
router = APIRouter()
|
2026-06-26 22:36:09 +08:00
|
|
|
logger = logging.getLogger(__name__)
|
2026-06-26 17:02:21 +08:00
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_nexus_client(request: Request) -> NexusClient:
|
|
|
|
|
return request.app.state.nexus_client
|
|
|
|
|
|
|
|
|
|
|
2026-06-26 22:36:09 +08:00
|
|
|
def normalize_legacy_system_messages(raw: object) -> object:
|
|
|
|
|
if not isinstance(raw, dict):
|
|
|
|
|
return raw
|
|
|
|
|
messages = raw.get("messages")
|
|
|
|
|
if not isinstance(messages, list):
|
|
|
|
|
return raw
|
|
|
|
|
|
|
|
|
|
system_parts: list[object] = []
|
|
|
|
|
normalized_messages: list[object] = []
|
|
|
|
|
changed = False
|
|
|
|
|
for message in messages:
|
|
|
|
|
if isinstance(message, dict) and message.get("role") == "system":
|
|
|
|
|
system_parts.append(message.get("content", ""))
|
|
|
|
|
changed = True
|
|
|
|
|
else:
|
|
|
|
|
normalized_messages.append(message)
|
|
|
|
|
|
|
|
|
|
if not changed:
|
|
|
|
|
return raw
|
|
|
|
|
|
|
|
|
|
normalized = dict(raw)
|
|
|
|
|
normalized["messages"] = normalized_messages
|
|
|
|
|
if "system" not in normalized and system_parts:
|
|
|
|
|
normalized["system"] = system_parts[0] if len(system_parts) == 1 else system_parts
|
|
|
|
|
return normalized
|
|
|
|
|
|
|
|
|
|
|
2026-06-27 14:32:52 +08:00
|
|
|
def anthropic_sse_stream(
|
|
|
|
|
stream: Iterable[dict[str, Any]],
|
|
|
|
|
*,
|
|
|
|
|
model: str,
|
|
|
|
|
correlation_id: str,
|
|
|
|
|
) -> Iterator[str]:
|
|
|
|
|
try:
|
|
|
|
|
for event in bedrock_stream_to_anthropic_events(stream, model=model):
|
|
|
|
|
yield sse_frame(event)
|
|
|
|
|
except NexusClaudeError as exc:
|
|
|
|
|
yield sse_frame(
|
|
|
|
|
{
|
|
|
|
|
"type": "error",
|
|
|
|
|
"error": {
|
|
|
|
|
"type": exc.error_type,
|
|
|
|
|
"message": f"{exc.message} [correlation_id={correlation_id}]",
|
|
|
|
|
"correlation_id": correlation_id,
|
|
|
|
|
},
|
|
|
|
|
}
|
|
|
|
|
)
|
|
|
|
|
except Exception:
|
|
|
|
|
logger.exception(
|
|
|
|
|
"anthropic_messages_stream_error correlation_id=%s",
|
|
|
|
|
correlation_id,
|
|
|
|
|
)
|
|
|
|
|
yield sse_frame(
|
|
|
|
|
{
|
|
|
|
|
"type": "error",
|
|
|
|
|
"error": {
|
|
|
|
|
"type": "api_error",
|
|
|
|
|
"message": (
|
|
|
|
|
"Unexpected error while streaming response "
|
|
|
|
|
f"[correlation_id={correlation_id}]"
|
|
|
|
|
),
|
|
|
|
|
"correlation_id": correlation_id,
|
|
|
|
|
},
|
|
|
|
|
}
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
2026-06-26 17:02:21 +08:00
|
|
|
@router.post("/v1/messages", response_model=None)
|
|
|
|
|
async def create_message(
|
|
|
|
|
request: Request,
|
|
|
|
|
client: Annotated[NexusClient, Depends(get_nexus_client)],
|
|
|
|
|
) -> Response:
|
2026-06-26 22:36:09 +08:00
|
|
|
correlation_id = uuid.uuid4().hex
|
2026-06-26 17:02:21 +08:00
|
|
|
try:
|
2026-06-26 22:36:09 +08:00
|
|
|
raw = normalize_legacy_system_messages(await request.json())
|
2026-06-26 17:02:21 +08:00
|
|
|
payload = AnthropicMessagesRequest.model_validate(raw)
|
2026-06-26 22:36:09 +08:00
|
|
|
summary = summarize_messages_request(
|
|
|
|
|
payload,
|
|
|
|
|
correlation_id=correlation_id,
|
|
|
|
|
)
|
|
|
|
|
logger.info(
|
|
|
|
|
"anthropic_messages_request %s",
|
|
|
|
|
summary,
|
|
|
|
|
extra=summary,
|
|
|
|
|
)
|
2026-06-26 17:02:21 +08:00
|
|
|
bedrock_request = anthropic_to_bedrock_request(payload)
|
|
|
|
|
if payload.stream:
|
2026-06-26 22:36:09 +08:00
|
|
|
stream = client.converse_stream(
|
|
|
|
|
bedrock_request,
|
|
|
|
|
correlation_id=correlation_id,
|
|
|
|
|
)
|
2026-06-26 17:02:21 +08:00
|
|
|
return StreamingResponse(
|
2026-06-27 14:32:52 +08:00
|
|
|
anthropic_sse_stream(
|
|
|
|
|
stream,
|
|
|
|
|
model=payload.model,
|
|
|
|
|
correlation_id=correlation_id,
|
2026-06-26 17:02:21 +08:00
|
|
|
),
|
|
|
|
|
media_type="text/event-stream",
|
|
|
|
|
)
|
|
|
|
|
|
2026-06-26 22:36:09 +08:00
|
|
|
response = client.converse(
|
|
|
|
|
bedrock_request,
|
|
|
|
|
correlation_id=correlation_id,
|
|
|
|
|
)
|
2026-06-26 17:02:21 +08:00
|
|
|
anthropic_response = bedrock_to_anthropic_response(
|
|
|
|
|
response,
|
|
|
|
|
model=payload.model,
|
|
|
|
|
)
|
|
|
|
|
return JSONResponse(content=anthropic_response.model_dump(exclude_none=True))
|
|
|
|
|
except ValidationError as exc:
|
2026-06-26 22:36:09 +08:00
|
|
|
return anthropic_error_response(
|
|
|
|
|
str(exc),
|
|
|
|
|
status_code=400,
|
|
|
|
|
correlation_id=correlation_id,
|
|
|
|
|
)
|
2026-06-26 17:02:21 +08:00
|
|
|
except NexusClaudeError as exc:
|
|
|
|
|
return anthropic_error_response(
|
|
|
|
|
exc.message,
|
|
|
|
|
status_code=exc.status_code,
|
|
|
|
|
error_type=exc.error_type,
|
2026-06-26 22:36:09 +08:00
|
|
|
correlation_id=correlation_id,
|
2026-06-26 17:02:21 +08:00
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@router.post("/v1/messages/count_tokens")
|
|
|
|
|
async def count_tokens(request: Request) -> JSONResponse:
|
|
|
|
|
try:
|
2026-06-26 22:36:09 +08:00
|
|
|
raw = normalize_legacy_system_messages(await request.json())
|
2026-06-26 17:02:21 +08:00
|
|
|
payload = CountTokensRequest.model_validate(raw)
|
|
|
|
|
response = CountTokensResponse(input_tokens=estimate_input_tokens(payload))
|
|
|
|
|
return JSONResponse(content=response.model_dump())
|
|
|
|
|
except ValidationError as exc:
|
|
|
|
|
return anthropic_error_response(str(exc), status_code=400)
|