Files
flask_rulebase_serve/app/blueprints/vlm/rotutes.py
2026-04-22 13:35:40 +08:00

886 lines
30 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from datetime import datetime
from functools import lru_cache
import os
import re
import time
import uuid
from flask import Blueprint, current_app, json, request, jsonify
from flask_jwt_extended import get_jwt_identity, jwt_required
import requests
from sqlalchemy import and_, func
from app import db
from app.models import BagFile, BagStatus, Fst, User, VlmFilter
from app import create_app
from app.utils.get_vlm_token import Demo
from app.utils.driving_tree import vlm_data
vlm_bp = Blueprint("vlm", __name__)
def extract_datetime_from_filename(bag_filename):
"""
从文件名中提取日期和时间。
返回包含年月日时分秒的字符串、datetime对象以及视频URL。
"""
match = re.search(r"_(\d{8})-(\d{6})_", bag_filename)
if not match:
raise ValueError("文件名格式错误,无法提取时间信息")
date_part, time_part = match.groups()
# 提取年月日时分秒各部分
year = date_part[:4]
month = date_part[4:6]
day = date_part[6:8]
hour = time_part[:2]
minute = time_part[2:4]
second = time_part[4:6]
# 年月日时分秒整合到一个变量(字符串格式)
datetime_full = f"{year}-{month}-{day} {hour}:{minute}:{second}"
# 原有视频路径相关逻辑保持不变
bag_filename = bag_filename.replace(".bag", "")
video_final_path = f"{bag_filename}- Wide.mp4"
video_cos = f"momenta/videos/{year}/{month}/{day}/{video_final_path}"
video_url = (
"https://data-miningc01-1318950322.cos.ap-shanghai-adc.myqcloud.com/"
+ video_cos
)
# 返回整合后的完整日期时间字符串、datetime对象和视频URL
return datetime_full, video_url
@lru_cache(maxsize=1)
def get_cached_token():
TOKEN_EXPIRE_SECONDS = 7200
"""获取缓存的Token如果过期则重新获取固定有效期"""
# 获取Token时记录当前时间作为获取时间
token_info = Demo().getAuthToken()
token_info["fetch_time"] = time.time() # 手动添加获取时间
# 检查是否过期(当前时间 - 获取时间 > 有效期)
if time.time() - token_info["fetch_time"] >= TOKEN_EXPIRE_SECONDS - 60:
# 提前60秒刷新
get_cached_token.cache_clear()
new_token = Demo().getAuthToken()
new_token["fetch_time"] = time.time() # 记录新的获取时间
return new_token
return token_info
@vlm_bp.route("/insert-csv", methods=["POST"])
@jwt_required()
def insert_db_by_csv():
try:
# 1. 获取上传的CSV文件
if "file" not in request.files:
return jsonify({"code": 400, "message": "未上传CSV文件"}), 400
file = request.files["file"]
if file.filename == "":
return jsonify({"code": 400, "message": "未选择文件"}), 400
# 验证文件类型
if not file.filename.endswith(".csv"):
return jsonify({"code": 400, "message": "仅支持CSV格式文件"}), 400
# 2. 直接从form中获取表单字段无需解析JSON
# 验证必要的表单参数
required_fields = ["level1Tag", "status"]
form_params = {}
for field in required_fields:
value = request.form.get(field)
if not value:
return (
jsonify({"code": 400, "message": f"缺少必要的表单参数: {field}"}),
400,
)
form_params[field] = value
# 额外获取可选的表单字段
form_params["radioGroup"] = request.form.get("radioGroup", "")
form_params["fileKey"] = request.form.get("fileKey", "")
# 转换level1Tag为整数
try:
level1_tag = int(form_params["level1Tag"])
except ValueError:
return jsonify({"code": 400, "message": "level1Tag必须是整数"}), 400
# 验证status是否有效
status_str = form_params["status"]
try:
status = BagStatus[status_str] # 假设使用枚举类
except KeyError:
return (
jsonify({"code": 400, "message": f"无效的status值: {status_str}"}),
400,
)
# 替换bag_status
if form_params["radioGroup"] == "annotation":
bag_status = 0
if form_params["radioGroup"] == "qa":
bag_status = 1
# 插入fst_version字段
# 获取当前时间
current_time = datetime.now()
# 格式化年月日时,用下划线连接
formatted_time = f"{current_time.year}-{current_time.month:02d}-{current_time.day:02d}-{current_time.hour:02d}"
fst_version = "vlm_" + get_jwt_identity() + "_" + formatted_time
# 3. 解析CSV文件内容
import csv
from io import StringIO
# from datetime import datetime
# 读取CSV内容
csv_content = file.stream.read().decode("utf-8")
csv_file = StringIO(csv_content)
csv_reader = csv.DictReader(csv_file)
# 验证CSV表头
if "bag_name" not in csv_reader.fieldnames:
return (
jsonify({"code": 400, "message": "CSV文件缺少必要的'bag_name'"}),
400,
)
# 4. 组合CSV数据与表单数据插入数据库
inserted_count = 0
skipped_count = 0
error_records = []
for row_num, row in enumerate(
csv_reader, start=2
): # 行号从2开始表头为第1行
bag_name = row.get("bag_name", "").strip()
if not bag_name:
skipped_count += 1
error_records.append(
{"row": row_num, "reason": "bag_name为空", "content": row}
)
continue
# 从文件名提取时间和视频URL复用你的函数
datetime_full, video_url = extract_datetime_from_filename(bag_name)
# 组合数据CSV的bag_name + 表单的level1Tag、status等
new_bag = BagFile(
file_name=bag_name,
level1_tag_id=level1_tag, # 表单字段
status=status, # 表单字段
capture_datetime=datetime_full,
create_time=datetime.now(),
update_time=datetime.now(),
bag_status=bag_status,
sync_status="SYNC_NOT_READY",
user_id=get_jwt_identity(), # 当前用户ID
video_url=video_url,
fst_version=fst_version,
)
db.session.add(new_bag)
inserted_count += 1
# 提交事务
db.session.commit()
# 5. 返回处理结果
return jsonify(
{
"code": 200,
"message": f"处理完成,成功插入 {inserted_count} 条记录,跳过 {skipped_count} 条记录",
"data": {
"inserted_count": inserted_count,
"skipped_count": skipped_count,
"errors": error_records if skipped_count > 0 else None,
},
}
)
except Exception as e:
db.session.rollback()
print(f"【上传CSV接口异常】: {str(e)}")
return jsonify({"code": 500, "message": f"处理失败:{str(e)}"}), 500
@vlm_bp.route("/get-models", methods=["GET"])
def get_models():
result = {
"code": 200,
"data": [
{
"id": 3,
"name": "图片模型Base",
"url": "http://10.0.220.110:20080",
"type": 0,
"remark": "没有finetune的初始模型\nhttp://10.0.220.110:20080",
"createUserId": 3,
"status": 1,
"createdAt": "2024-11-21T07:36:34.337Z",
"updatedAt": "2025-04-10T08:00:03.727Z",
},
{
"id": 4,
"name": "图片模型FT4",
"url": "http://10.0.220.226:20081",
"type": 0,
"remark": "第4轮微调模型\nhttp://10.0.220.226:20081",
"createUserId": 3,
"status": 1,
"createdAt": "2025-04-25T07:22:06.986Z",
"updatedAt": "2025-05-09T07:00:22.667Z",
},
],
"success": True,
"message": "success",
}
# token = get_cached_token()
# API_URL = "http://10.0.220.110/api/app/model/models-server?modelType=0"
# headers = {"Authorization": f"Bearer {token}", "Content-Type": "application/json"}
# response = requests.get(url=API_URL, headers=headers)
# result = response.json()
return jsonify({"code": 200, "data": result})
@vlm_bp.route("/get-datasets", methods=["GET"])
def get_datasets():
# result = {
# "code": 200,
# "data": [
# {
# "id": 99,
# "name": "Momenta",
# "path": None,
# "type": 0,
# "status": 3,
# "remark": "这是数据集",
# "index_name": "0_33f0ba5e-d00d-4033-b2d4-3a491a347f25",
# "create_user_id": None,
# "is_delete": 0,
# "createdAt": "2025-04-25T09:10:33.072Z",
# "updatedAt": "2025-05-19T15:03:33.366Z",
# "file_count": None,
# },
# {
# "id": 107,
# "name": "MB",
# "path": None,
# "type": 0,
# "status": 3,
# "remark": "这是数据集",
# "index_name": "0_abcb005e-6782-4b91-877e-4b1cebcf52e7",
# "create_user_id": None,
# "is_delete": 0,
# "createdAt": "2025-05-15T09:49:23.056Z",
# "updatedAt": "2025-05-19T11:35:21.063Z",
# "file_count": None,
# },
# {
# "id": 124,
# "name": "MB_DC_2508",
# "path": None,
# "type": 0,
# "status": 3,
# "remark": "这是数据集",
# "index_name": "0_35caf344-a572-40d1-bc31-e426b785eaa2",
# "create_user_id": None,
# "is_delete": 0,
# "createdAt": "2025-08-29T01:30:34.199Z",
# "updatedAt": "2025-09-12T07:45:58.658Z",
# "file_count": None,
# },
# {
# "id": 125,
# "name": "MB_DC_2504_06",
# "path": None,
# "type": 0,
# "status": 3,
# "remark": "这是数据集",
# "index_name": "0_46805e78-738e-439d-9b65-2b5a9cdd5ee5",
# "create_user_id": None,
# "is_delete": 0,
# "createdAt": "2025-09-04T09:54:12.411Z",
# "updatedAt": "2025-09-11T19:37:03.399Z",
# "file_count": None,
# },
# {
# "id": 126,
# "name": "MB_DC_2507",
# "path": None,
# "type": 0,
# "status": 3,
# "remark": "这是数据集",
# "index_name": "0_14b37c68-cf81-4c5b-b232-43323c7501e7",
# "create_user_id": None,
# "is_delete": 0,
# "createdAt": "2025-09-11T01:47:57.010Z",
# "updatedAt": "2025-09-12T03:45:47.002Z",
# "file_count": None,
# },
# ],
# "success": True,
# "message": "success",
# }
token = get_cached_token()
API_URL = "http://10.0.220.110/api/app/model/datasets-list?modelId=4"
headers = {"Authorization": f"Bearer {token}", "Content-Type": "application/json"}
response = requests.get(url=API_URL, headers=headers)
result = response.json()
return jsonify({"code": 200, "data": result})
@vlm_bp.route("/get-search", methods=["POST"])
def get_search_list():
# result = vlm_data
API_URL = "http://10.0.220.110/api/app/search/4"
data = request.json
result = requests.post(url=API_URL, json=data)
result=result.json()
return jsonify({"code": 200, "data": result})
@vlm_bp.route("/get-alltags", methods=["GET"])
def get_alltags():
try:
# 查询所有标签
all_tags = Fst.query.all()
# 转换为所需的结构
result = []
for tag in all_tags:
tag_data = {
"id": tag.id,
"name": tag.name,
"level": tag.level,
"parent_id": tag.parent_id,
}
result.append(tag_data)
# 返回JSON响应
return jsonify(result), 200
except Exception as e:
# 错误处理
return jsonify({"error": str(e)}), 500
@vlm_bp.route("/insert-vlm-filter", methods=["POST"])
@jwt_required()
def add_vlm_filter():
"""
新增VlmFilter记录
请求参数:
{
"level1": {
"id": 407,
"name": "TOLL_STATION",
"level": 1
},
"level2": {
"id": 616,
"name": "TOLL_STATION_LEAVE",
"level": 2
},
"level3": {
"id": 579,
"name": "TOLL_STATION_LEAVE_PASSING_NO_LANE_MARKING_SQUARE_AFTER_ROD",
"level": 3
},
"level4": null,
"bagname": "PL162802_event_hmi_console_event_20250827-182132_0.bag",
"comment": "12345",
"status": 0,
"video_url":"",
}
"""
try:
# 1. 获取当前用户ID从JWT中提取
current_user_id = get_jwt_identity()
# 2. 解析请求参数
params = request.json
# 3. 验证必要参数
required_fields = ["level1", "level2", "level3", "bagname", "status"]
if not all(key in params for key in required_fields):
return (
jsonify(
{
"code": 400,
"message": f"缺少必要参数: {', '.join(required_fields)}",
}
),
400,
)
# 4. 提取参数值
level1 = params["level1"]
level2 = params["level2"]
level3 = params["level3"]
level4 = params.get("level4")
bagname = params["bagname"]
comment = params.get("comment", "")
status = params["status"]
video_url = params["video_url"]
# 5. 提取各级标签ID
level1_tag_id = level1.get("id") if level1 else None
level2_tag_id = level2.get("id") if level2 else None
level3_tag_id = level3.get("id") if level3 else None
level4_tag_id = level4.get("id") if level4 else None
# 6. 验证标签层级关系和存在性
# 验证一级标签
if level1_tag_id is not None:
level1_tag = Fst.query.get(level1_tag_id)
if not level1_tag:
return (
jsonify(
{
"code": 400,
"message": f"一级标签不存在ID: {level1_tag_id}",
}
),
400,
)
# 验证二级标签
if level2_tag_id is not None:
level2_tag = Fst.query.get(level2_tag_id)
if not level2_tag:
return (
jsonify(
{
"code": 400,
"message": f"二级标签不存在ID: {level2_tag_id}",
}
),
400,
)
if not level1_tag_id:
return (
jsonify(
{"code": 400, "message": "二级标签存在时,一级标签不能为空"}
),
400,
)
if level2_tag.parent_id != level1_tag_id:
return (
jsonify(
{
"code": 400,
"message": f"二级标签的父级ID{level2_tag.parent_id}与一级标签ID{level1_tag_id})不匹配",
}
),
400,
)
# 验证三级标签
if level3_tag_id is not None:
level3_tag = Fst.query.get(level3_tag_id)
if not level3_tag:
return (
jsonify(
{
"code": 400,
"message": f"三级标签不存在ID: {level3_tag_id}",
}
),
400,
)
if not level2_tag_id:
return (
jsonify(
{"code": 400, "message": "三级标签存在时,二级标签不能为空"}
),
400,
)
if level3_tag.parent_id != level2_tag_id:
return (
jsonify(
{
"code": 400,
"message": f"三级标签的父级ID{level3_tag.parent_id}与二级标签ID{level2_tag_id})不匹配",
}
),
400,
)
# 验证四级标签
if level4_tag_id is not None:
level4_tag = Fst.query.get(level4_tag_id)
if not level4_tag:
return (
jsonify(
{
"code": 400,
"message": f"四级标签不存在ID: {level4_tag_id}",
}
),
400,
)
if not level3_tag_id:
return (
jsonify(
{"code": 400, "message": "四级标签存在时,三级标签不能为空"}
),
400,
)
if level4_tag.parent_id != level3_tag_id:
return (
jsonify(
{
"code": 400,
"message": f"四级标签的父级ID{level4_tag.parent_id}与三级标签ID{level3_tag_id})不匹配",
}
),
400,
)
# 7. 创建新的VlmFilter记录
new_vlm = VlmFilter(
bag_name=bagname,
level1_tag_id=level1_tag_id,
level2_tag_id=level2_tag_id,
level3_tag_id=level3_tag_id,
level4_tag_id=level4_tag_id,
comment=comment,
user_id=current_user_id,
# collection_time=datetime.now(),
create_time=datetime.now(),
status=status,
video_url=video_url,
)
# 8. 添加到数据库并提交
db.session.add(new_vlm)
db.session.commit()
# 9. 返回成功响应
return (
jsonify(
{
"code": 200,
"message": "数据插入成功",
"data": {
"id": new_vlm.id,
"bag_name": new_vlm.bag_name,
"level1_tag_id": new_vlm.level1_tag_id,
"level2_tag_id": new_vlm.level2_tag_id,
"level3_tag_id": new_vlm.level3_tag_id,
"level4_tag_id": new_vlm.level4_tag_id,
"user_id": new_vlm.user_id,
"create_time": new_vlm.create_time.strftime(
"%Y-%m-%d %H:%M:%S"
),
"status": new_vlm.status,
},
}
),
200,
)
except Exception as e:
db.session.rollback()
return jsonify({"code": 500, "message": f"服务器错误: {str(e)}"}), 500
@vlm_bp.route("/get-vlm-filter-list", methods=["POST"])
@jwt_required()
def get_vlm_filter_list():
try:
# 1. 解析请求参数
params = request.get_json() or {}
# 提取分页参数,设置默认值
page = params.get("page", 1)
per_page = params.get("per_page", 20)
# 提取过滤参数
bag_name = params.get("bag_name", "").strip()
level1_tag = params.get("level1_tag", "")
start_datetime = params.get("start_datetime", "").strip()
end_datetime = params.get("end_datetime", "").strip()
user_id = params.get("user_id")
# 2. 构建基础查询
query = VlmFilter.query
# 3. 动态添加过滤条件
conditions = []
# 固定条件只展示status=0的数据核心新增
conditions.append(VlmFilter.status == 0)
# 袋名模糊查询
if bag_name:
conditions.append(VlmFilter.bag_name.like(f"%{bag_name}%"))
# 一级标签查询
if level1_tag:
try:
level1_id = int(level1_tag)
conditions.append(VlmFilter.level1_tag_id == level1_id)
except ValueError:
print(f"【过滤条件】level1_tag转换失败非整数: {level1_tag}")
# 用户ID过滤
if user_id:
try:
user_id_int = int(user_id)
conditions.append(VlmFilter.user_id == user_id_int)
except ValueError:
print(f"【过滤条件】user_id转换失败非整数: {user_id}")
# 时间范围查询按create_time
if start_datetime:
try:
start_dt = datetime.fromisoformat(start_datetime)
conditions.append(VlmFilter.create_time >= start_dt)
except ValueError:
print(f"【过滤条件】开始时间格式错误: {start_datetime}")
if end_datetime:
try:
end_dt = datetime.fromisoformat(end_datetime)
conditions.append(VlmFilter.create_time <= end_dt)
except ValueError:
print(f"【过滤条件】结束时间格式错误: {end_datetime}")
# 应用所有条件包含固定的status=0
if conditions:
query = query.filter(and_(*conditions))
# 4. 排序按create_time降序
query = query.order_by(VlmFilter.create_time.desc())
# 5. 分页查询
pagination = query.paginate(page=page, per_page=per_page, error_out=False)
total_items = pagination.total
total_pages = pagination.pages
# 6. 序列化结果
result_list = []
for item in pagination.items:
# 获取标签名称
level1_name = (
Fst.query.get(item.level1_tag_id).name if item.level1_tag_id else None
)
level2_name = (
Fst.query.get(item.level2_tag_id).name if item.level2_tag_id else None
)
level3_name = (
Fst.query.get(item.level3_tag_id).name if item.level3_tag_id else None
)
level4_name = (
Fst.query.get(item.level4_tag_id).name if item.level4_tag_id else None
)
# 获取用户名
username = User.query.get(item.user_id).username if item.user_id else None
result_list.append(
{
"id": item.id,
"bag_name": item.bag_name,
"level1_tag_id": item.level1_tag_id,
"level1_tag_name": level1_name,
"level2_tag_id": item.level2_tag_id,
"level2_tag_name": level2_name,
"level3_tag_id": item.level3_tag_id,
"level3_tag_name": level3_name,
"level4_tag_id": item.level4_tag_id,
"level4_tag_name": level4_name,
"video_url": item.video_url,
"comment": item.comment,
"user_id": item.user_id,
"username": username,
"collection_time": (
item.collection_time.isoformat()
if item.collection_time
else None
),
"create_time": (
item.create_time.isoformat() if item.create_time else None
),
"init_label": item.init_label,
"status": item.status, # 此处返回的status始终为0
}
)
# 7. 返回响应
return jsonify(
{
"code": 200,
"data": result_list,
"total": total_items,
"page": page,
"pages": total_pages,
"message": "查询成功仅返回status=0的数据",
}
)
except Exception as e:
print(f"【接口异常】: {str(e)}")
return (
jsonify(
{
"code": 500,
"data": [],
"total": 0,
"page": page if "page" in locals() else 1,
"pages": 0,
"message": f"查询出错:{str(e)}",
}
),
500,
)
@vlm_bp.route("/send-loacldb", methods=["POST"])
@jwt_required()
def send_filter_vlm_localdb():
try:
# 1. 解析请求参数
params = request.get_json() or {}
# 验证必要参数
if "data" not in params or not isinstance(params["data"], list):
return (
jsonify({"code": 400, "message": "缺少必要参数data或data不是数组"}),
400,
)
if "add_status" not in params:
return jsonify({"code": 400, "message": "缺少必要参数add_status"}), 400
# 提取参数
data_list = params["data"]
add_status = params["add_status"]
label_status = params["label_status"]
# 验证add_status是否为整数
try:
add_status = int(add_status)
except ValueError:
return jsonify({"code": 400, "message": "add_status必须是整数"}), 400
# 初始化统计变量
inserted_count = 0
updated_count = 0
error_records = []
# 获取当前用户ID
current_user_id = get_jwt_identity()
current_time = datetime.now()
# 获取当前时间
current_time = datetime.now()
# 格式化年月日时,用下划线连接
formatted_time = f"{current_time.year}-{current_time.month:02d}-{current_time.day:02d}-{current_time.hour:02d}"
fst_version = "send_" + get_jwt_identity() + "_" + formatted_time
# 2. 循环处理每条数据
for idx, item in enumerate(data_list):
try:
# print("item",item)
# 验证必要字段
if "bag_name" not in item or not item["bag_name"]:
error_records.append(
{
"index": idx,
"reason": "缺少bag_name或bag_name为空",
"data": item,
}
)
continue
bag_name=item["bag_name"]
datetime_full, video_url = extract_datetime_from_filename(bag_name)
# 创建新的BagFile记录
new_bag = BagFile(
file_name=bag_name,
capture_datetime=datetime_full,
level1_tag_id=item.get("level1_tag_id"),
level2_tag_id=item.get("level2_tag_id"),
level3_tag_id=item.get("level3_tag_id"),
level4_tag_id=item.get("level4_tag_id"),
video_url=item.get("video_url"),
comment1=item.get("comment"),
user_id=item.get("user_id") or current_user_id,
create_time=(
datetime.fromisoformat(item["create_time"])
if item.get("create_time")
else current_time
),
update_time=current_time,
fst_version=fst_version,
bag_status=add_status, # 0是标注1是质检
sync_status="SYNC_NOT_READY",
status=BagStatus[label_status],
)
print("new_bag", new_bag)
# 添加到会话
db.session.add(new_bag)
inserted_count += 1
# 3. 更新vlm_filter表
# 假设根据bag_name和id进行匹配
vlm_filter = VlmFilter.query.filter_by(
id=item.get("id"), bag_name=item["bag_name"]
).first()
if vlm_filter:
vlm_filter.status = 1 # 1是入库
vlm_filter.init_label = add_status # 0是标注1是质检
updated_count += 1
else:
error_records.append(
{
"index": idx,
"reason": f"未找到对应的vlm_filter记录 (id: {item.get('id')}, bag_name: {item['bag_name']})",
"data": item,
}
)
except Exception as e:
error_records.append(
{"index": idx, "reason": f"处理错误: {str(e)}", "data": item}
)
# 4. 提交事务
db.session.commit()
# 5. 返回结果
return jsonify(
{
"code": 200,
"message": f"批量处理完成,成功插入 {inserted_count} 条记录,{len(error_records)} 条记录处理失败",
"data": {
"inserted_count": inserted_count,
"updated_count": updated_count,
"error_count": len(error_records),
"errors": error_records if error_records else None,
},
}
)
except Exception as e:
db.session.rollback()
print(f"【批量插入接口异常】: {str(e)}")
return jsonify({"code": 500, "message": f"处理失败:{str(e)}"}), 500