79 lines
2.7 KiB
Python
79 lines
2.7 KiB
Python
|
|
from __future__ import annotations
|
||
|
|
|
||
|
|
from typing import Annotated
|
||
|
|
|
||
|
|
from fastapi import APIRouter, Depends, Request
|
||
|
|
from fastapi.responses import JSONResponse, Response, StreamingResponse
|
||
|
|
from pydantic import ValidationError
|
||
|
|
|
||
|
|
from nexus_claude_api.errors import NexusClaudeError, anthropic_error_response
|
||
|
|
from nexus_claude_api.models import (
|
||
|
|
AnthropicMessagesRequest,
|
||
|
|
CountTokensRequest,
|
||
|
|
CountTokensResponse,
|
||
|
|
)
|
||
|
|
from nexus_claude_api.nexus_client import NexusClient
|
||
|
|
from nexus_claude_api.tokens import estimate_input_tokens
|
||
|
|
from nexus_claude_api.translators.anthropic_to_bedrock import anthropic_to_bedrock_request
|
||
|
|
from nexus_claude_api.translators.bedrock_to_anthropic import bedrock_to_anthropic_response
|
||
|
|
from nexus_claude_api.translators.stream import (
|
||
|
|
bedrock_stream_to_anthropic_events,
|
||
|
|
sse_frame,
|
||
|
|
)
|
||
|
|
|
||
|
|
|
||
|
|
router = APIRouter()
|
||
|
|
|
||
|
|
|
||
|
|
def get_nexus_client(request: Request) -> NexusClient:
|
||
|
|
return request.app.state.nexus_client
|
||
|
|
|
||
|
|
|
||
|
|
@router.post("/v1/messages", response_model=None)
|
||
|
|
async def create_message(
|
||
|
|
request: Request,
|
||
|
|
client: Annotated[NexusClient, Depends(get_nexus_client)],
|
||
|
|
) -> Response:
|
||
|
|
try:
|
||
|
|
raw = await request.json()
|
||
|
|
payload = AnthropicMessagesRequest.model_validate(raw)
|
||
|
|
bedrock_request = anthropic_to_bedrock_request(payload)
|
||
|
|
if payload.stream:
|
||
|
|
stream = client.converse_stream(bedrock_request)
|
||
|
|
return StreamingResponse(
|
||
|
|
(
|
||
|
|
sse_frame(event)
|
||
|
|
for event in bedrock_stream_to_anthropic_events(
|
||
|
|
stream,
|
||
|
|
model=payload.model,
|
||
|
|
)
|
||
|
|
),
|
||
|
|
media_type="text/event-stream",
|
||
|
|
)
|
||
|
|
|
||
|
|
response = client.converse(bedrock_request)
|
||
|
|
anthropic_response = bedrock_to_anthropic_response(
|
||
|
|
response,
|
||
|
|
model=payload.model,
|
||
|
|
)
|
||
|
|
return JSONResponse(content=anthropic_response.model_dump(exclude_none=True))
|
||
|
|
except ValidationError as exc:
|
||
|
|
return anthropic_error_response(str(exc), status_code=400)
|
||
|
|
except NexusClaudeError as exc:
|
||
|
|
return anthropic_error_response(
|
||
|
|
exc.message,
|
||
|
|
status_code=exc.status_code,
|
||
|
|
error_type=exc.error_type,
|
||
|
|
)
|
||
|
|
|
||
|
|
|
||
|
|
@router.post("/v1/messages/count_tokens")
|
||
|
|
async def count_tokens(request: Request) -> JSONResponse:
|
||
|
|
try:
|
||
|
|
raw = await request.json()
|
||
|
|
payload = CountTokensRequest.model_validate(raw)
|
||
|
|
response = CountTokensResponse(input_tokens=estimate_input_tokens(payload))
|
||
|
|
return JSONResponse(content=response.model_dump())
|
||
|
|
except ValidationError as exc:
|
||
|
|
return anthropic_error_response(str(exc), status_code=400)
|