第一次

This commit is contained in:
ZhuJW
2026-04-16 15:44:32 +08:00
commit 5a98242f2f
171 changed files with 42954 additions and 0 deletions

View File

@@ -0,0 +1,439 @@
#!/usr/bin/env python3
import os
import sys
import time
import argparse
import threading
import queue
import subprocess
import shutil
import logging
from datetime import datetime
from prometheus_client import start_http_server, Counter, Gauge, Summary, Histogram
# —— 常量 & 路径 —— #
BASE = os.getcwd()
INPUT_ROOT = os.path.join(BASE, "input")
OUTPUT_ROOT = os.path.join(BASE, "output")
EMPTY_DIR = os.path.join(BASE, "empty")
LOG_DIR = os.path.join(BASE, "logs")
EXCLUDE_SUBDIRS = ["struct_infos/*"]
DOCKER_IMAGE = (
"artifact.swfcn.i.mercedes-benz.com/swfcn_docker/perception-dnn/od-net:prod_v0.2"
)
BATCH_SIZE = 10
MAX_LOCAL = 20
MAX_RETRIES = 3
RETRY_DELAY_S = 2
METRICS_PORT = 8001
SENTINEL = (None, None)
# —— 日志配置 —— #
os.makedirs(LOG_DIR, exist_ok=True)
logger = logging.getLogger("pipeline")
logger.setLevel(logging.INFO)
h_info = logging.FileHandler(os.path.join(LOG_DIR, "pipeline.log"), encoding="utf-8")
h_err = logging.FileHandler(os.path.join(LOG_DIR, "error_tasks.log"), encoding="utf-8")
fmt = logging.Formatter("%(asctime)s %(levelname)s [%(threadName)s] %(message)s")
h_info.setFormatter(fmt)
h_err.setFormatter(fmt)
h_err.setLevel(logging.ERROR)
logger.addHandler(h_info)
logger.addHandler(h_err)
# —— Prometheus 指标 —— #
DL_TOTAL = Counter("pipeline_download_total", "下载尝试总数")
DL_FAIL = Counter("pipeline_download_failures", "下载失败总数")
DL_RETRY = Counter("pipeline_download_retries", "下载重试总数")
PR_TOTAL = Counter("pipeline_process_total", "处理尝试总数")
PR_FAIL = Counter("pipeline_process_failures", "处理失败总数")
PR_RETRY = Counter("pipeline_process_retries", "处理重试总数")
UP_TOTAL = Counter("pipeline_upload_total", "上传尝试总数")
UP_FAIL = Counter("pipeline_upload_failures", "上传失败总数")
UP_RETRY = Counter("pipeline_upload_retries", "上传重试总数")
DL_DUR = Summary("pipeline_download_duration_seconds", "单批下载耗时秒")
PR_DUR = Summary("pipeline_process_duration_seconds", "单批处理耗时秒")
UP_DUR = Summary("pipeline_upload_duration_seconds", "单批上传耗时秒")
BATCH_SIZE_HIST = Histogram(
"pipeline_batch_size",
"单批任务中文件夹数量分布",
buckets=[1, 10, 20, 50, 100, 200, 500],
)
FILE_DL_DUR = Histogram(
"pipeline_file_download_duration_seconds", "单文件/单目录下载耗时分布"
)
BATCH_OUT_FILES = Gauge(
"pipeline_batch_output_subfolder_count", "单批处理后 output 下指定子文件夹数"
)
Q_BATCH = Gauge("pipeline_queue_batches", "待下载批次数")
Q_PROC = Gauge("pipeline_queue_processing", "待处理批次数")
Q_UP = Gauge("pipeline_queue_uploading", "待上传批次数")
LOCAL_COUNT = Gauge("pipeline_local_subdir_count", "当前本地 input 子文件夹总数")
# —— 队列 & 控制 —— #
batch_q = queue.Queue()
proc_q = queue.Queue()
up_q = queue.Queue()
# —— 全局计数 & 锁,用于减少磁盘扫描 —— #
_local_counter = 0
_counter_lock = threading.Lock()
_downloaded_per_batch = {} # batch_id -> 成功下载的子目录数量
def incr_local(n=1, batch_id=None):
global _local_counter
with _counter_lock:
_local_counter += n
if batch_id:
_downloaded_per_batch.setdefault(batch_id, 0)
_downloaded_per_batch[batch_id] += n
def decr_local_batch(batch_id):
global _local_counter
with _counter_lock:
n = _downloaded_per_batch.pop(batch_id, 0)
_local_counter -= n
if _local_counter < 0:
_local_counter = 0
def get_local_count():
with _counter_lock:
return _local_counter
# —— 子目录限流 —— #
def count_local_subdirs_in_batch(batch_dir):
# 只统计当前 batch 下直接子目录
if not os.path.isdir(batch_dir):
return 0
return sum(
1 for e in os.listdir(batch_dir) if os.path.isdir(os.path.join(batch_dir, e))
)
# —— 统一 subprocess + 重试 —— #
def run(cmd, timeout=None):
"""
统一调用 subprocess记录日志并返回 (code, timed_out, output).
"""
logger.info("RUN: %s", " ".join(cmd))
try:
p = subprocess.Popen(
cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True
)
start = time.time()
output = []
timed_out = False
for line in p.stdout:
output.append(line)
if timeout and (time.time() - start) > timeout:
p.kill()
timed_out = True
break
code = p.wait()
out_str = "".join(output)
if code != 0:
logger.warning("CMD 返回非零(%d): %s", code, out_str.strip())
return code, timed_out, out_str
except Exception as e:
logger.exception("RUN 异常: %s", e)
return -1, False, ""
def with_retry(tag, func, *args):
"""
重试包装func 必须返回 (code, timed_out, output)
"""
for i in range(1, MAX_RETRIES + 1):
code, timed_out, out = func(*args)
if code == 0:
return True
if timed_out:
logger.error("%s 超时,不再重试", tag)
break
# 统计重试
if tag.startswith("DL["):
DL_RETRY.inc()
if tag.startswith("PR["):
PR_RETRY.inc()
if tag.startswith("UP["):
UP_RETRY.inc()
logger.warning("%s 重试 %d/%d, output: %s", tag, i, MAX_RETRIES, out.strip())
time.sleep(RETRY_DELAY_S)
logger.error("%s 最终失败", tag)
return False
# —— 删除软连接 —— #
def delete_symlinks(root_path):
"""
调用 find 一次性删除所有软链接,效率更高。
"""
logger.info("删除软连接: %s", root_path)
code, _, out = run(
["sudo", "find", root_path, "-type", "l", "-delete"], timeout=120
)
if code != 0:
logger.warning("删除软连接失败: %s", out.strip())
# —— 下载阶段 —— #
@DL_DUR.time()
def do_download(batch_id, remote_paths, batch_timeout):
if batch_id is None:
proc_q.put(SENTINEL)
return
DL_TOTAL.inc()
start = time.time()
in_dir = os.path.join(INPUT_ROOT, batch_id)
os.makedirs(in_dir, exist_ok=True)
logger.info("DL[%s] 开始, targets=%s", batch_id, remote_paths)
# 本地同 batch 子目录限流
while _local_counter >= MAX_LOCAL:
logger.warning("DL[%s] 本地子目录≥%dsleep 5min", batch_id, MAX_LOCAL)
time.sleep(300)
exclude_flags = []
for sub in EXCLUDE_SUBDIRS:
exclude_flags += ["--ignore", f"'{sub}'"]
success_count = 0
for remote in remote_paths:
# 检查 lidar_gt_pandar128
lidar_sub = remote.rstrip("/") + "/lidar_gt_pandar128"
code, _, out = run(["coscmd", "list", lidar_sub])
if code != 0 or not out.strip():
logger.info("DL[%s] 跳过无数据: %s", batch_id, lidar_sub)
continue
elapsed = time.time() - start
if elapsed > batch_timeout:
logger.error("DL[%s] 超时,停止下载", batch_id)
DL_FAIL.inc()
break
f_start = time.time()
local_bag = os.path.join(in_dir, os.path.basename(remote))
cmd = ["coscmd", "-s", "download", "-r", remote, local_bag] + exclude_flags
ok = with_retry(
f"DL[{batch_id}]", lambda c: run(c, timeout=batch_timeout - elapsed), cmd
)
FILE_DL_DUR.observe(time.time() - f_start)
if ok:
success_count += 1
incr_local(1, batch_id) # 成功下载一个子目录
else:
DL_FAIL.inc()
logger.info("DL[%s] 完成,成功 %d", batch_id, success_count)
proc_q.put((batch_id, (in_dir, remote_paths)))
# —— 处理阶段 —— #
@PR_DUR.time()
def do_process(batch_id, data, batch_timeout):
if batch_id is None:
up_q.put(SENTINEL)
return
in_dir, remote_paths = data
PR_TOTAL.inc()
start = time.time()
out_dir = os.path.join(OUTPUT_ROOT, batch_id)
os.makedirs(out_dir, exist_ok=True)
logger.info("PR[%s] 开始", batch_id)
# 统一用 run() 运行 Docker
docker_cmd = [
"docker",
"run",
"--rm",
"--gpus",
"all",
"--shm-size=16g",
"-v",
f"{in_dir}:/input",
"-v",
f"{out_dir}:/output",
DOCKER_IMAGE,
"bash",
"-c",
"cd /code/projects/od_net/lidarnet/ && "
"source ../.venv/bin/activate && "
"python pipeline.py --input_dir=/input --output_dir=/output",
]
ok = with_retry(
f"PR[{batch_id}]", lambda c: run(c, timeout=batch_timeout), docker_cmd
)
elapsed = time.time() - start
if ok:
logger.info("PR[%s] 成功 耗时 %.1fs", batch_id, elapsed)
else:
PR_FAIL.inc()
logger.error("PR[%s] 失败 耗时 %.1fs", batch_id, elapsed)
# 统计并清理输出子目录
existing = []
for name in os.listdir(out_dir):
path = os.path.join(out_dir, name)
if os.path.isdir(path) and name.endswith(".dir"):
existing.append(path)
BATCH_OUT_FILES.set(len(existing))
logger.info("PR[%s] 输出子目录: %s", batch_id, existing)
# for folder in existing:
# for item in os.listdir(folder):
# p = os.path.join(folder, item)
# if item != "object_tracking":
# # 非 object_tracking 全删
# if os.path.isdir(p):
# shutil.rmtree(p, ignore_errors=True)
# else:
# os.remove(p)
# else:
# # 重命名 object_tracking
# newp = os.path.join(folder, "object_auto_labeling")
# os.rename(p, newp)
shutil.rmtree(in_dir, ignore_errors=True)
up_q.put((batch_id, (out_dir, remote_paths)))
# —— 上传阶段 —— #
@UP_DUR.time()
def do_upload(batch_id, data, batch_timeout):
if batch_id is None:
return
out_dir, remote_paths = data
UP_TOTAL.inc()
start = time.time()
logger.info("UP[%s] 开始", batch_id)
# 删除中间产物
for sub in ("logs", "intermedia_products"):
shutil.rmtree(os.path.join(out_dir, sub), ignore_errors=True)
# 删除所有软连接
delete_symlinks(out_dir)
# 只上传 object_tracking -> remote/aaa
for remote in remote_paths:
base = remote.rstrip("/")
dest = f"{base}/derived/object_auto_labeling"
for root, dirs, _ in os.walk(out_dir):
if "object_tracking" in dirs:
src = os.path.join(root, "object_tracking")
elapsed = time.time() - start
if elapsed > batch_timeout:
logger.error("UP[%s] 超时", batch_id)
UP_FAIL.inc()
return
ok = with_retry(
f"UP[{batch_id}]",
lambda c: run(c, timeout=batch_timeout - elapsed),
["coscmd", "-s", "upload", "-r", src, dest],
)
if not ok:
UP_FAIL.inc()
logger.info("UP[%s] 上传完成", batch_id)
# 清理本地输出 & 输入目录
for cmd in [
["sudo", "rsync", "-a", "--delete", f"{EMPTY_DIR}/", f"{out_dir}/"],
]:
code, _, _ = run(cmd, timeout=60)
if code != 0:
logger.warning("UP[%s] rsync 失败: %s", batch_id, cmd)
shutil.rmtree(out_dir, ignore_errors=True)
# 新增:删除 input 下对应 batch释放磁盘
in_dir = os.path.join(INPUT_ROOT, batch_id)
shutil.rmtree(in_dir, ignore_errors=True)
decr_local_batch(batch_id)
logger.info("UP[%s] 本地清理完成", batch_id)
# —— Worker & 主流程 —— #
def worker(q, fn, timeout):
while True:
bid, data = q.get()
try:
fn(bid, data, timeout)
except Exception:
logger.exception("阶段异常batch=%s", bid)
finally:
q.task_done()
if bid is None:
break
def main():
p = argparse.ArgumentParser()
p.add_argument("--tasks-file", required=True, help="每行一个 COS 目录路径")
p.add_argument("--batch-size", type=int, default=BATCH_SIZE)
p.add_argument("--batch-timeout", type=int, default=3600)
args = p.parse_args()
T = args.batch_timeout
# 确保基础目录
for d in (INPUT_ROOT, OUTPUT_ROOT, EMPTY_DIR, LOG_DIR):
os.makedirs(d, exist_ok=True)
# 读取并分批
lines = [
line.strip() for line in open(args.tasks_file, encoding="utf-8") if line.strip()
]
for idx in range(0, len(lines), args.batch_size):
blk = lines[idx : idx + args.batch_size]
BATCH_SIZE_HIST.observe(len(blk))
bid = (
datetime.utcnow().strftime("%Y%m%dT%H%M%SZ")
+ f"_{idx // args.batch_size + 1}"
)
batch_q.put((bid, blk))
batch_q.put(SENTINEL)
# 启动监控
start_http_server(METRICS_PORT)
logger.info("Metrics 服务启动,端口 %d", METRICS_PORT)
# 启动线程
threads = [
threading.Thread(
target=worker, args=(batch_q, do_download, T), name="DL-Worker"
),
threading.Thread(target=worker, args=(proc_q, do_process, T), name="PR-Worker"),
threading.Thread(target=worker, args=(up_q, do_upload, T), name="UP-Worker"),
]
for t in threads:
t.start()
# 主线程仅更新队列长度和本地目录计数,无全盘扫描
while batch_q.unfinished_tasks or proc_q.unfinished_tasks or up_q.unfinished_tasks:
Q_BATCH.set(batch_q.qsize())
Q_PROC.set(proc_q.qsize())
Q_UP.set(up_q.qsize())
LOCAL_COUNT.set(get_local_count())
time.sleep(1)
for t in threads:
t.join()
logger.info("所有批次完成")
sys.exit(0)
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,229 @@
#!/usr/bin/env python3
# cleanup_bags_psql.py
"""
批量清理 bag_list 中 fst_indexed=FALSE 且日期早于 N 个月的记录:
1. 删除 COS 目录
2. 删除关联 6 张表
3. 标记 bag_list.is_deleted=TRUE
4. 失败自动重试 3 次,失败明细写 delete_failed.log
5. 最终生成 CSV 并打印统计
"""
import csv
import logging
import re
import threading
from concurrent.futures import ThreadPoolExecutor, as_completed
from datetime import datetime, timedelta
import psycopg2
import requests
from qcloud_cos import CosConfig, CosS3Client
from tenacity import (
retry,
stop_after_attempt,
wait_exponential,
retry_if_exception_type,
)
# 如果原来用 ConfigManager保持不动这里为了单文件可跑直接用 os.environ 兜底
import os
# ===================== 配置 =====================
DB_HOST = os.getenv("ROOT_DB_HOST", "localhost")
DB_PORT = int(os.getenv("ROOT_DB_PORT", 5432))
DB_NAME = os.getenv("ROOT_DB_NAME", "default_dbname")
DB_USER = os.getenv("ROOT_DB_USER", "default_user")
DB_PASSWORD = os.getenv("ROOT_DB_PASSWD", "default_password")
COS_REGION = os.getenv("COS_CLIENT_REGION", "ap-guangzhou")
COS_SECRET_ID = os.getenv("COS_CLIENT_SECRET_ID", "default_id")
COS_SECRET_KEY = os.getenv("COS_CLIENT_SECRET_KEY", "default_key")
COS_BUCKET = os.getenv("COS_BUCKET", "b-perception-e2e-1318950322")
BASE_URL = os.getenv("ROOT_DB_API", "http://localhost")
PROJECT_IDS = (1, 2)
MONTHS = 6
MAX_WORKERS = 10
RETRY_TIMES = 3
CSV_FILE = f"deleted_bags_{datetime.now():%Y%m%d_%H%M%S}.csv"
DB_CONF = dict(
host=DB_HOST, port=DB_PORT, user=DB_USER, password=DB_PASSWORD, dbname=DB_NAME
)
# ===================== 日志 =====================
logging.basicConfig(
filename=f"cleanup_{datetime.now():%Y%m%d}.log",
level=logging.INFO,
format="%(asctime)s [%(levelname)s] %(message)s",
)
# 失败明细日志
failure_handler = logging.FileHandler("delete_failed.log", encoding="utf-8")
failure_handler.setFormatter(logging.Formatter("%(asctime)s,%(message)s"))
logger_failure = logging.getLogger("failure")
logger_failure.setLevel(logging.INFO)
logger_failure.addHandler(failure_handler)
# ===================== COS 客户端 =====================
cos = CosS3Client(
CosConfig(Region=COS_REGION, SecretId=COS_SECRET_ID, SecretKey=COS_SECRET_KEY)
)
# ===================== 统计 =====================
stats_lock = threading.Lock()
stats = {"total": 0, "success": 0, "fail": 0}
def inc_stat(key: str):
with stats_lock:
stats[key] += 1
# ===================== 数据库工具 =====================
def get_conn():
return psycopg2.connect(**DB_CONF)
def fetch_candidates() -> list[tuple[int, str]]:
"""返回 [(id, name), ...]"""
sql = """
SELECT id, name
FROM bag_list
WHERE project_id = ANY(%s)
AND fst_indexed = FALSE
AND is_deleted = FALSE
"""
with get_conn() as conn:
with conn.cursor() as cur:
cur.execute(sql, (PROJECT_IDS,))
return cur.fetchall()
# ===================== 日期过滤 =====================
DATE_RE = re.compile(r"_(\d{4})(\d{2})(\d{2})-(\d{2})(\d{2})(\d{2})_")
def need_delete(name: str) -> bool:
days = 30 * MONTHS
m = DATE_RE.search(name)
if not m:
return False
dt = datetime.strptime("".join(m.groups()), "%Y%m%d%H%M%S")
return dt < datetime.utcnow() - timedelta(days=days)
# ===================== 调接口拿 COS 路径 =====================
def get_pangu_detail(bag_name: str):
url = f"{BASE_URL}/api/bags/pangu/detail"
resp = requests.get(url, params={"bagName": bag_name}, timeout=10)
resp.raise_for_status()
data = resp.json()
return data["dataPath"], data["rawPath"]
# ===================== COS 批量删前缀 =====================
def cos_delete_prefix(prefix: str):
if not prefix:
return
marker = ""
while True:
resp = cos.list_objects(Bucket=COS_BUCKET, Prefix=prefix, Marker=marker)
contents = resp.get("Contents", [])
if not contents:
break
keys = [{"Key": obj["Key"]} for obj in contents]
cos.delete_objects(Bucket=COS_BUCKET, Delete={"Objects": keys})
if resp.get("IsTruncated") == "false":
break
marker = resp["NextMarker"]
# ===================== 重试包装 =====================
@retry(
stop=stop_after_attempt(RETRY_TIMES),
wait=wait_exponential(multiplier=2, min=4, max=30),
retry=retry_if_exception_type((
requests.exceptions.RequestException,
psycopg2.OperationalError,
)),
reraise=True,
)
def _do_cleanup(bag_id: int, bag_name: str):
# 1. 拿路径并删 COS
data_path, raw_path = get_pangu_detail(bag_name)
cos_delete_prefix(data_path)
cos_delete_prefix(raw_path)
# 2. 一个事务删关联表
with get_conn() as conn:
with conn.cursor() as cur:
cur.execute(
"""
DELETE FROM bag_lifecycle WHERE bag_name = %s;
DELETE FROM secondary_pangu WHERE bag_id = %s;
DELETE FROM secondary_minerva WHERE bag_id = %s;
DELETE FROM bag_topic WHERE bag_id = %s;
DELETE FROM bag_reserved_tag WHERE bag_id = %s;
DELETE FROM fst_bag WHERE bag_id = %s;
UPDATE bag_list SET is_deleted = TRUE WHERE id = %s;
""",
(bag_name, bag_id, bag_id, bag_id, bag_id, bag_id, bag_id),
)
conn.commit()
def process_one(bag_id: int, bag_name: str) -> tuple[bool, int, str]:
"""外层统一入口:负责统计、重试、失败落盘"""
inc_stat("total")
try:
_do_cleanup(bag_id, bag_name)
inc_stat("success")
logging.info("Cleaned bag_id=%d name=%s", bag_id, bag_name)
return True, bag_id, bag_name
except Exception as e:
inc_stat("fail")
# 失败明细落盘
logger_failure.info("%d,%s,%s", bag_id, bag_name, str(e).replace(",", ";"))
logging.error("Failed bag_id=%d name=%s error=%s", bag_id, bag_name, e)
return False, bag_id, bag_name
# ===================== 主流程 =====================
def main():
candidates = fetch_candidates()
logging.info("Fetched %d candidate bags", len(candidates))
to_delete = [(bid, nm) for bid, nm in candidates if need_delete(nm)]
logging.info("Need delete %d bags", len(to_delete))
if not to_delete:
logging.info("Nothing to delete, exit.")
return
results = []
with ThreadPoolExecutor(max_workers=MAX_WORKERS) as pool:
futures = [pool.submit(process_one, bid, nm) for bid, nm in to_delete]
for f in as_completed(futures):
results.append(f.result())
# 写 CSV
with open(CSV_FILE, "w", newline="", encoding="utf-8") as f:
writer = csv.writer(f)
writer.writerow(["id", "name", "success"])
for ok, bid, nm in results:
writer.writerow([bid, nm, ok])
# 最终汇总
logging.info(
"=== Done: Total=%d Success=%d Fail=%d. CSV=%s FailureDetail=delete_failed.log ===",
stats["total"],
stats["success"],
stats["fail"],
CSV_FILE,
)
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,270 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
import csv
import sys
import shutil
import logging
from logging.handlers import TimedRotatingFileHandler
from pathlib import Path
import subprocess
import argparse
from concurrent.futures import ProcessPoolExecutor, as_completed
from tenacity import retry, stop_after_attempt, wait_fixed, retry_if_exception_type
import psycopg2
from psycopg2 import OperationalError
from subprocess import CalledProcessError
from tqdm import tqdm
from fst_data_pipeline.core.config_manager import ConfigManager
# ==== 1. 全局目录管理 ====
ROOT_DIR = Path(__file__).parent.resolve()
LOG_DIR = ROOT_DIR / "logs"
DOWNLOAD_DIR = ROOT_DIR / "downloads"
VIDEO_DIR = ROOT_DIR / "videos"
for d in (LOG_DIR, DOWNLOAD_DIR, VIDEO_DIR):
d.mkdir(parents=True, exist_ok=True)
# ==== 2. 日志配置 ====
logger = logging.getLogger("bag_pipeline")
logger.setLevel(logging.DEBUG)
ch = logging.StreamHandler(sys.stdout)
ch.setLevel(logging.INFO)
ch.setFormatter(logging.Formatter("[%(levelname)s] %(message)s"))
fh = TimedRotatingFileHandler(
LOG_DIR / "pipeline.log", when="midnight", backupCount=7, encoding="utf-8"
)
fh.setLevel(logging.DEBUG)
fh.setFormatter(
logging.Formatter("%(asctime)s %(levelname)s [%(processName)s] %(message)s")
)
logger.addHandler(ch)
logger.addHandler(fh)
# ==== 3. 数据库连接工厂 ====
config = ConfigManager()
DB_HOST = config.get("ROOT_DB_HOST", "localhost")
DB_PORT = config.get_int("ROOT_DB_PORT", 5432)
DB_NAME = config.get("ROOT_DB_NAME", "default_dbname")
DB_USER = config.get("ROOT_DB_USER", "default_user")
DB_PASSWORD = config.get("ROOT_DB_PASSWD", "default_password")
DB_CFG = dict(
host=DB_HOST, port=DB_PORT, dbname=DB_NAME, user=DB_USER, password=DB_PASSWORD
)
def get_conn():
return psycopg2.connect(**DB_CFG)
def read_and_validate_csv(path):
required = {"bag name", "FST 1st node", "FST 2nd node", "FST 3rd node"}
records = []
seen_bags = set()
with open(path, encoding="utf-8", newline="") as f:
reader = csv.DictReader(f)
headers = set(reader.fieldnames or [])
if not required.issubset(headers):
missing = required - headers
raise ValueError(f"CSV 缺少必要列: {missing}")
for i, row in enumerate(reader, start=2):
bag = row["bag name"].strip()
if not bag:
raise ValueError(f"{i}行 bag name 为空")
if bag in seen_bags:
raise ValueError(f'{i}行 bag 重复: "{bag}"')
seen_bags.add(bag)
nodes = []
for col in ("FST 1st node", "FST 2nd node", "FST 3rd node"):
v = row.get(col, "").strip()
if v:
nodes.append(v)
if not nodes:
raise ValueError(f"{i}行没有任何 FST 节点")
records.append({"bag": bag, "nodes": nodes, "first": nodes[0]})
return records
@retry(
stop=stop_after_attempt(3),
wait=wait_fixed(5),
retry=retry_if_exception_type((OperationalError, CalledProcessError)),
)
def process_record(rec) -> bool:
bag, first = rec["bag"], rec["first"]
logger.info(f"[{bag}] 开始处理")
# 5.1 查 data_path
try:
conn = get_conn()
with conn.cursor() as cur:
cur.execute(
"SELECT data_path FROM main_pangu WHERE bag_path LIKE %s LIMIT 1",
(f"%{bag}%",),
)
row = cur.fetchone()
conn.close()
except Exception as e:
logger.error(f"[{bag}] 查询 data_path 出错: {e}")
raise
if not row:
logger.error(f"[{bag}] main_pangu 未找到记录,跳过")
return False
data_path = row[0]
# 5.2 下载图片
dst_dir = DOWNLOAD_DIR / bag / "camera_front_wide"
dst_dir.mkdir(parents=True, exist_ok=True)
src_path = data_path.partition("/")[2] + "/camera_front_wide/"
cmd_dl = ["coscmd", "download", "-r", src_path, str(dst_dir)]
try:
subprocess.run(cmd_dl, check=True, stdout=subprocess.DEVNULL)
except CalledProcessError as e:
logger.error(f"[{bag}] 下载失败: {e}")
raise
# 5.3 生成 list.txt
imgs = sorted(
p.name for p in dst_dir.iterdir() if p.suffix.lower() in (".jpg", ".png")
)
sampled = imgs[::3]
if not imgs:
logger.error(f"[{bag}] 未找到任何图片,跳过")
shutil.rmtree(DOWNLOAD_DIR / bag, ignore_errors=True)
return False
(dst_dir / "list.txt").write_text(
"\n".join(f'file "{fn}"' for fn in sampled), encoding="utf-8"
)
# 5.4 ffmpeg 合成
out_dir = VIDEO_DIR / first
out_dir.mkdir(parents=True, exist_ok=True)
out_mp4 = out_dir / f"{bag}.mp4"
cmd_ff = [
"ffmpeg",
"-y",
"-f",
"concat",
"-safe",
"0",
"-i",
"list.txt",
"-vf",
"scale=640:360",
"-c:v",
"libx264",
"-pix_fmt",
"yuv420p",
str(out_mp4),
]
try:
subprocess.run(cmd_ff, cwd=dst_dir, check=True)
except CalledProcessError as e:
logger.error(f"[{bag}] 合成失败: {e}")
raise
# 5.5 更新数据库
try:
conn = get_conn()
with conn:
with conn.cursor() as cur:
# 5.5.1 bag_list
cur.execute(
"SELECT id,fst_indexed FROM bag_list "
"WHERE name=%s AND is_deleted=FALSE AND is_active_data=TRUE",
(bag,),
)
r = cur.fetchone()
if not r:
logger.error(f"[{bag}] bag_list 未找到,跳过更新")
conn.rollback()
return False
bag_id, fst_indexed = r
if not fst_indexed:
cur.execute(
"UPDATE bag_list SET fst_indexed=TRUE WHERE id=%s", (bag_id,)
)
# 5.5.2 逐级关联 fst
parent = None
for name in rec["nodes"]:
cur.execute("SELECT id,parent_id FROM fst WHERE name=%s", (name,))
fr = cur.fetchone()
if not fr:
logger.error(f"[{bag}] FST 节点不存在 '{name}',跳过")
conn.rollback()
return False
fid, actual_parent = fr
if parent is not None and actual_parent != parent:
logger.error(
f"[{bag}] 节点 '{name}' 父级校验失败 "
f"(actual={actual_parent}!=expect={parent}),跳过"
)
conn.rollback()
return False
# ---- 确保没插入过才 update bag_sum ----
cur.execute(
"SELECT 1 FROM fst_bag WHERE bag_id=%s AND fst_node_id=%s",
(bag_id, fid),
)
if not cur.fetchone():
cur.execute(
"INSERT INTO fst_bag(bag_id,fst_node_id) VALUES(%s,%s)",
(bag_id, fid),
)
cur.execute(
"UPDATE fst SET bag_sum = bag_sum + 1 WHERE id=%s", (fid,)
)
parent = fid
except OperationalError as e:
logger.error(f"[{bag}] 更新DB连接错误: {e}")
raise
except Exception:
logger.exception(f"[{bag}] 更新DB时未知错误跳过")
return False
finally:
if conn:
conn.close()
# 5.6 清理 & 成功日志
shutil.rmtree(DOWNLOAD_DIR / bag, ignore_errors=True)
logger.info(f"[{bag}] 处理成功")
return True
def main():
parser = argparse.ArgumentParser(description="批量下载→合成→更新DB")
parser.add_argument("csv_path", help="CSV 文件路径")
parser.add_argument("-w", "--workers", type=int, default=4, help="并行 worker 数")
args = parser.parse_args()
try:
records = read_and_validate_csv(args.csv_path)
except Exception as e:
logger.error(f"CSV 校验失败: {e}")
sys.exit(1)
total = len(records)
logger.info(f"{total} 条,启动 {args.workers} 并行 worker")
success = skipped = failed = 0
with ProcessPoolExecutor(max_workers=args.workers) as exe:
fut2bag = {exe.submit(process_record, rec): rec["bag"] for rec in records}
for fut in tqdm(as_completed(fut2bag), total=total, desc="进度"):
bag = fut2bag[fut]
try:
ok = fut.result()
if ok:
success += 1
else:
skipped += 1
except Exception as e:
logger.error(f"[{bag}] 最终失败: {e}")
failed += 1
logger.info(f"结束:{success} 成功,{skipped} 跳过,{failed} 失败,共 {total}")
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,317 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
-----------------------------------------------------------------------------------
功能概述:
把数据库里的rosbag的经纬度降采样, 提取, 画图生成每周采集覆盖度在SD map上的分布情况
使用示例:
配置说明:
-----------------------------------------------------------------------------------
"""
import os
import re
import csv
import time
import math
import logging
import threading
import argparse
from io import StringIO
from http.server import HTTPServer, SimpleHTTPRequestHandler
from concurrent.futures import ThreadPoolExecutor, as_completed
import numpy as np
import psycopg2
from shapely import wkb
import folium
from folium import TileLayer, FeatureGroup, LayerControl
from branca.element import Element
from branca.colormap import linear
from qcloud_cos import CosConfig, CosS3Client
from tqdm import tqdm
from fst_data_pipeline.core.config_manager import ConfigManager
# ====== 配置区 ======
config = ConfigManager()
DB_HOST = config.get("ROOT_DB_HOST", "localhost")
DB_PORT = config.get_int("ROOT_DB_PORT", 5432)
DB_NAME = config.get("ROOT_DB_NAME", "default_dbname")
DB_USER = config.get("ROOT_DB_USER", "default_user")
DB_PASSWORD = config.get("ROOT_DB_PASSWD", "default_password")
COS_REGION = config.get("COS_CLIENT_REGION", "default_region")
COS_SECRET_ID = config.get("COS_CLIENT_SECRET_ID", "default_id")
COS_SECRET_KEY = config.get("COS_CLIENT_SECRET_KEY", "default_key")
# ===== 日志配置 =====
logging.basicConfig(
level=logging.DEBUG,
format="%(asctime)s %(levelname)s [%(threadName)s] %(message)s",
datefmt="%Y-%m-%dT%H:%M:%S",
)
logger = logging.getLogger("gnss_map")
# ===== 静态文件服务 (带 CORS) =====
class CORSHandler(SimpleHTTPRequestHandler):
def end_headers(self):
self.send_header("Access-Control-Allow-Origin", "*")
super().end_headers()
def serve_map(port: int, web_dir: str):
os.chdir(web_dir)
logger.info("Serving static files from %s", web_dir)
HTTPServer(("0.0.0.0", port), CORSHandler).serve_forever()
# ===== 工具haversine 计算距离 =====
def haversine(lat1, lon1, lat2, lon2):
R = 6371000.0
φ1, φ2 = math.radians(lat1), math.radians(lat2)
= math.radians(lat2 - lat1)
= math.radians(lon2 - lon1)
a = math.sin( / 2) ** 2 + math.cos(φ1) * math.cos(φ2) * math.sin( / 2) ** 2
return 2 * R * math.asin(math.sqrt(a))
# ===== 单 bag 处理:下载→去重→下采样 → 返回 (name, pts_wkt) =====
def process_bag(name, cos, bucket, prefix):
logger.debug("Start processing bag %s", name)
key = f"{prefix}/{name}.dir/raw_gnss.csv"
text = None
for attempt in range(1, 4):
try:
resp = cos.get_object(Bucket=bucket, Key=key)
text = resp["Body"].get_raw_stream().read().decode("utf-8")
logger.debug("[%s] downloaded %d bytes", name, len(text))
break
except Exception as e:
logger.warning("[%s] download failed (attempt %d): %s", name, attempt, e)
time.sleep(1)
if not text:
logger.error("[%s] all download attempts failed → skip", name)
return None
rows = list(csv.DictReader(StringIO(text)))
unique = []
for r in rows:
try:
lat, lon, alt = float(r["lat"]), float(r["lon"]), float(r.get("alt", 0))
except AttributeError:
continue
if not any(haversine(lat, lon, plat, plon) < 1.0 for plat, plon, _ in unique):
unique.append((lat, lon, alt))
logger.debug("[%s] unique points: %d", name, len(unique))
if not unique:
return None
# 下采样最多30点
if len(unique) > 30:
idx = np.linspace(0, len(unique) - 1, 30, dtype=int)
sampled = [unique[i] for i in idx]
else:
sampled = unique
logger.info("[%s] sampled %d points", name, len(sampled))
# 仅拼接 ST_MakePoint(...) 逗号列表,留给 SQL 拼 ARRAY[]
pts_wkt = ",".join(f"ST_MakePoint({lon},{lat},{alt})" for lat, lon, alt in sampled)
return name, pts_wkt
# ===== 主流程 =====
def main():
p = argparse.ArgumentParser()
p.add_argument("-p", "--map-port", type=int, default=8000)
p.add_argument("-u", "--tile-url", required=True)
p.add_argument("-w", "--workers", type=int, default=8)
args = p.parse_args()
# COS client
cos = CosS3Client(
CosConfig(
Region=COS_REGION,
SecretId=COS_SECRET_ID,
SecretKey=COS_SECRET_KEY,
)
)
BUCKET = config.get("bucket", "b-perception-e2e-1318950322")
PREFIX = config.get("prefix", "mb_raw_rosbag_decode_dirs")
# DB 配置
DB_CONF = dict(
host=DB_HOST, port=DB_PORT, dbname=DB_NAME, user=DB_USER, password=DB_PASSWORD
)
# 1. 取待处理 bag
logger.info("Fetch unprocessed bags from DB")
with psycopg2.connect(**DB_CONF) as conn:
with conn.cursor() as cur:
cur.execute(
"""
SELECT bl.name
FROM bag_list bl
LEFT JOIN geometry_info gi ON gi.rosbag_name=bl.name
WHERE bl.is_decoded=TRUE
AND gi.rosbag_name IS NULL
"""
)
names = [r[0] for r in cur.fetchall()]
logger.info("To process: %d bags", len(names))
# 2. 并行下载 & 处理
results = []
with ThreadPoolExecutor(max_workers=args.workers) as ex:
futures = {ex.submit(process_bag, n, cos, BUCKET, PREFIX): n for n in names}
for f in tqdm(as_completed(futures), total=len(names), desc="Processing"):
n = futures[f]
try:
r = f.result()
if r:
results.append(r)
except Exception as e:
logger.error("[%s] error: %s", n, e)
logger.info("Valid processed bags: %d", len(results))
if not results:
logger.warning("No new data → skip insertion, proceed to visualization")
# 3. 批量插入 geometry_info
if results:
logger.info("Batch inserting into geometry_info")
with psycopg2.connect(**DB_CONF) as conn:
with conn.cursor() as cur:
batch = 50
for i in range(0, len(results), batch):
chunk = results[i : i + batch]
# 手工拼 VALUES 列表
vals = []
for name, pts in chunk:
esc = name.replace("'", "''")
vals.append(f"('{esc}', ARRAY[{pts}]::geometry(PointZ)[])")
sql = f"""
INSERT INTO geometry_info(rosbag_name,gnss_downsampled_points)
VALUES {",".join(vals)}
ON CONFLICT(rosbag_name) DO NOTHING
"""
cur.execute(sql)
conn.commit()
logger.info(
"Inserted batch %d~%d", i + 1, min(i + batch, len(results))
)
logger.info("Batch insert done, total: %d", len(results))
# 4. 从 DB 读取所有 geometry
logger.info("Load all geometries from DB")
with psycopg2.connect(**DB_CONF) as conn:
with conn.cursor() as cur:
cur.execute(
"SELECT rosbag_name, gnss_downsampled_points FROM geometry_info"
)
rows = cur.fetchall()
by_bag, all_pts = {}, []
for name, arr in tqdm(rows, desc="Loading geometry"):
pts = []
for hw in arr.strip("{}").split(","):
h = re.sub(r"[^0-9A-Fa-f]", "", hw)
p = wkb.loads(bytes.fromhex(h))
pts.append((p.y, p.x))
all_pts.append((p.y, p.x))
if pts:
by_bag[name] = pts
if all_pts:
lats = [p[0] for p in all_pts]
lons = [p[1] for p in all_pts]
center = (sum(lats) / len(lats), sum(lons) / len(lons))
else:
logger.warning("No points in DB → default center (0,0)")
center = (0.0, 0.0)
logger.info("Map center: %s", center)
# 5. Folium 地图 & 两个图层(视图)
m = folium.Map(
location=center,
zoom_start=6,
tiles=None,
prefer_canvas=True,
control_scale=True,
zoom_control=True,
)
TileLayer(tiles=args.tile_url, overlay=True, control=False, attr=" ").add_to(m)
header = m.get_root().header
for tag in [
'<link rel="stylesheet" href="leaflet.css"/>',
'<link rel="stylesheet" href="MarkerCluster.css"/>',
'<link rel="stylesheet" href="MarkerCluster.Default.css"/>',
'<link rel="stylesheet" href="leaflet.awesome-markers.css"/>',
'<link rel="stylesheet" href="leaflet.awesome.rotate.min.css"/>',
'<link rel="stylesheet" href="bootstrap.min.css"/>',
'<link rel="stylesheet" href="all.min.css"/>',
'<script src="jquery-1.12.4.min.js"></script>',
'<script src="bootstrap.bundle.min.js"></script>',
'<script src="leaflet.js"></script>',
'<script src="leaflet.markercluster.js"></script>',
'<script src="leaflet.awesome-markers.js"></script>',
]:
header.add_child(Element(tag))
# 轨迹视图
fg1 = FeatureGroup(name="Trajectory", overlay=False, show=True)
for pts in by_bag.values():
folium.PolyLine(pts, color="#3388ff", weight=1, opacity=0.7).add_to(fg1)
m.add_child(fg1)
# 聚类点视图
cnt = {}
for lat, lon in all_pts:
cnt[(lat, lon)] = cnt.get((lat, lon), 0) + 1
if cnt:
mi, ma = min(cnt.values()), max(cnt.values())
cmap = linear.YlOrRd_09.scale(mi, ma)
cmap.caption = "Point Count"
else:
cmap = None
fg2 = FeatureGroup(name="Cluster", overlay=False, show=False)
for (lat, lon), c in cnt.items():
color = cmap(c) if cmap else "#000000"
folium.CircleMarker(
location=(lat, lon),
radius=1 + min(c, 4),
color=color,
fill=True,
fill_color=color,
fill_opacity=0.7,
popup=f"count: {c}",
).add_to(fg2)
m.add_child(fg2)
if cmap:
m.add_child(cmap)
LayerControl(collapsed=False).add_to(m)
# 6. 保存并启动静态服务
script_dir = os.path.dirname(__file__)
html = os.path.join(script_dir, "map.html")
m.save(html)
logger.info("Saved map → %s", html)
threading.Thread(
target=serve_map,
args=(args.map_port, script_dir),
daemon=True,
name="HTTP-Server",
).start()
logger.info("Server running at http://127.0.0.1:%d/map.html", args.map_port)
threading.Event().wait()
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,293 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""_summary_
✅ 批量读取 ROS bag 文件列表(一个 txt 文件里每行是文件路径)
✅ 根据文件名解析出车辆、时间、bag 名称
✅ 自动:
在数据库里插入 / 更新 bag_list(记录名称与 project_id)
插入 / 更新 main_pangu(记录 bag 的元信息)
插入 / 更新 secondary_pangu(记录 COS 上的文件、目录存在性及链接)
读取 meta_info.txt 文件里的 topic 列表, 更新 topic_list 与 bag_topic 关联表
✅ 使用腾讯云 COS (对象存储) API 检查文件、目录是否存在, 并提供文件访问链接
✅ 支持多线程并发处理, 加快大量 bag 的同步速度
✅ 自动维护数据库连接池 (psycopg2 ThreadedConnectionPool)
✅ 日志文件记录所有操作, 包括成功、跳过、失败, 便于后续排查
"""
import os
import re
import tempfile
import logging
import argparse
from datetime import datetime
from concurrent.futures import ThreadPoolExecutor, as_completed
from psycopg2.pool import ThreadedConnectionPool
from tqdm import tqdm
from qcloud_cos import CosConfig, CosS3Client
from qcloud_cos.cos_exception import CosServiceError
from fst_data_pipeline.core.config_manager import ConfigManager
# ====== 配置区 ======
config = ConfigManager()
DB_HOST = config.get("ROOT_DB_HOST", "localhost")
DB_PORT = config.get_int("ROOT_DB_PORT", 5432)
DB_NAME = config.get("ROOT_DB_NAME", "default_dbname")
DB_USER = config.get("ROOT_DB_USER", "default_user")
DB_PASSWORD = config.get("ROOT_DB_PASSWD", "default_password")
COS_REGION = config.get("COS_CLIENT_REGION", "default_region")
COS_SECRET_ID = config.get("COS_CLIENT_SECRET_ID", "default_id")
COS_SECRET_KEY = config.get("COS_CLIENT_SECRET_KEY", "default_key")
BUCKET = config.get("COS_BUCKET", "b-perception-e2e-1318950322")
config.require("COS_PREFIX_DECODE")
PREFIX = config.get("COS_PREFIX_DECODE", "mb_raw_rosbag_decode_dirs")
DOMAIN = f"https://{BUCKET}.cos.{COS_REGION}.myqcloud.com"
MIN_CONN, MAX_CONN = 1, 10
MAX_WORKERS = 8 # 并发线程数, 可调
# ====== 日志 ======
LOGFILE = f"{datetime.now():%Y-%m-%d}.log"
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s %(levelname)s %(message)s",
handlers=[logging.StreamHandler(), logging.FileHandler(LOGFILE, mode="a")],
)
logger = logging.getLogger()
# ====== COS 客户端 ======
cos = CosS3Client(
CosConfig(Region=COS_REGION, SecretId=COS_SECRET_ID, SecretKey=COS_SECRET_KEY)
)
# ====== 全局 DB 连接池 ======
pool = ThreadedConnectionPool(
MIN_CONN, MAX_CONN, host=DB_HOST, dbname=DB_NAME, user=DB_USER, password=DB_PASSWORD
)
# ====== 工具函数 ======
def get_db():
"""从池里取一个 conn, 设置 autocommit=True"""
conn = pool.getconn()
conn.autocommit = True
return conn, conn.cursor()
def release_db(conn, cur):
cur.close()
pool.putconn(conn)
def parse_rosbag_line(line):
raw = line.strip()
name = os.path.basename(raw).split("/")[0]
if not name.endswith(".bag"):
name += ".bag"
m = re.search(r"_(\d{4})(\d{2})(\d{2})-(\d{2})(\d{2})(\d{2})_", name)
if not m:
raise ValueError(f"无法解析 rosbag 名称/时间: {name}")
yy, mm, dd, HH, MN, SS = m.groups()
v = name[:8]
dt = datetime.strptime(f"{yy}-{mm}-{dd} {HH}:{MN}:{SS}", "%Y-%m-%d %H:%M:%S")
return name, {
"name": name,
"vehicle": v,
"datetime": dt,
"path": f"{DOMAIN}/mb_cuct_data_collection/{name}",
"data_path": f"{DOMAIN}/{PREFIX}{name}.dir",
}
def download_txt(name):
key = f"{name}.meta_info.txt"
try:
obj = cos.get_object(Bucket=BUCKET, Key=key)
except CosServiceError:
return None
fd, tmp = tempfile.mkstemp(prefix=name + "_", suffix=".txt")
os.close(fd)
with open(tmp, "wb") as f:
for chunk in obj["Body"].get_raw_stream():
f.write(chunk)
return tmp
def parse_rosbag_info(fp):
tops = []
on = False
with open(fp, "r", encoding="utf-8") as f:
for L in f:
if L.startswith("topics:"):
on = True
continue
if on:
s = L.strip()
if not s:
break
p = s.split()
if len(p) > 3:
tops.append((p[0], " ".join(p[3:]).rstrip(":")))
return tops
# ==== 读取已成功处理的 bag_name 列表 ====
def load_done(logfile):
done = set()
if not os.path.exists(logfile):
return done
pat = re.compile(r"SUCCESS\s+(.+\.bag)\b")
with open(logfile, "r", encoding="utf-8") as f:
for L in f:
m = pat.search(L)
if m:
done.add(m.group(1))
return done
# ==== 单条处理逻辑 ====
def process_one(line):
try:
bag_name, ent = parse_rosbag_line(line)
except ValueError as e:
logger.error(e)
return
conn, cur = get_db()
try:
# 1) bag_list
cur.execute(
"INSERT INTO bag_list(name,project_id) VALUES(%s,%s) "
"ON CONFLICT(name) DO UPDATE SET update_time=NOW() "
"RETURNING id",
(bag_name, 1),
)
bag_id = cur.fetchone()[0]
# 2) main_pangu
cur.execute(
"INSERT INTO main_pangu(bag_id,name,vehicle,datetime,path,data_path) "
"VALUES(%s,%s,%s,%s,%s,%s) "
"ON CONFLICT(bag_id) DO UPDATE SET "
"name=COALESCE(main_pangu.name,EXCLUDED.name),"
"vehicle=COALESCE(main_pangu.vehicle,EXCLUDED.vehicle),"
"datetime=COALESCE(main_pangu.datetime,EXCLUDED.datetime),"
"path=COALESCE(main_pangu.path,EXCLUDED.path),"
"data_path=COALESCE(main_pangu.data_path,EXCLUDED.data_path)",
(
bag_id,
ent["name"],
ent["vehicle"],
ent["datetime"],
ent["path"],
ent["data_path"],
),
)
# 3) secondary_pangu
file_fs = ["raw_gnss.csv", "raw_imu.csv", "ego_motion.csv", "vehicle_wheel.csv"]
dir_fs = [
"camera_front_wide",
"lidar_gt_pandar128",
"object_lidar_gt_pandar128_manual",
"lidar_fd_multi_scan_raw",
"camera_fisheye_left",
"camera_fisheye_right",
"calibration",
]
data = {}
for f in file_fs:
key = f"{PREFIX}{bag_name}.dir/{f}"
try:
cos.head_object(Bucket=BUCKET, Key=key)
data[f] = f"{DOMAIN}/{key}"
except CosServiceError:
data[f] = None
for d in dir_fs:
pfx = f"{PREFIX}{bag_name}.dir/{d}/"
try:
r = cos.list_objects_v2(Bucket=BUCKET, Prefix=pfx, MaxKeys=1)
data[d] = f"{DOMAIN}/{pfx}" if r.get("Contents") else None
except CosServiceError:
data[d] = None
data["opt"] = {}
cols = ",".join(["bag_id"] + file_fs + dir_fs + ["opt"])
vals = [bag_id] + [data[x] for x in file_fs + dir_fs] + [data["opt"]]
upd = (
",".join(
f"{x}=COALESCE(secondary_pangu.{x},EXCLUDED.{x})"
for x in file_fs + dir_fs
)
+ ",opt=EXCLUDED.opt"
)
cur.execute(
f"INSERT INTO secondary_pangu({cols}) VALUES({','.join(['%s'] * len(vals))}) "
f"ON CONFLICT(bag_id) DO UPDATE SET {upd}",
vals,
)
# 4) topics + 关联
txt = download_txt(bag_name)
if txt:
tops = parse_rosbag_info(txt)
os.remove(txt)
for nm, ty in tops:
cur.execute(
"INSERT INTO topic_list(name,type) VALUES(%s,%s) "
"ON CONFLICT(name) DO UPDATE "
"SET type=COALESCE(topic_list.type,EXCLUDED.type),update_time=NOW() "
"RETURNING id",
(nm, ty),
)
tid = cur.fetchone()[0]
cur.execute(
"INSERT INTO bag_topic(bag_id,topic_id) VALUES(%s,%s) "
"ON CONFLICT DO NOTHING",
(bag_id, tid),
)
logger.info("SUCCESS %s", bag_name)
except Exception:
logger.exception("FAIL %s", bag_name)
finally:
release_db(conn, cur)
# ==== 主入口 ====
def main(txtfile, threads):
done = load_done(LOGFILE)
lines = [L for L in open(txtfile, "r") if L.strip()]
tasks = []
with ThreadPoolExecutor(max_workers=threads) as ex:
for L in lines:
try:
nm = os.path.basename(L.strip())
if not nm.endswith(".bag"):
nm += ".bag"
except FileExistsError:
continue
if nm in done:
logger.info("SKIP %s", nm)
continue
tasks.append(ex.submit(process_one, L))
for _ in tqdm(as_completed(tasks), total=len(tasks), desc="Bags"):
pass
if __name__ == "__main__":
p = argparse.ArgumentParser()
p.add_argument("file", help="ROS bag 列表 txt")
p.add_argument("--threads", type=int, default=MAX_WORKERS)
args = p.parse_args()
main(args.file, args.threads)

View File

@@ -0,0 +1,84 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
import psycopg2
from qcloud_cos import CosConfig, CosS3Client
from fst_data_pipeline.core.config_manager import ConfigManager
# ====== 配置区 ======
config = ConfigManager()
DB_HOST = config.get("ROOT_DB_HOST", "localhost")
DB_PORT = config.get_int("ROOT_DB_PORT", 5432)
DB_NAME = config.get("ROOT_DB_NAME", "default_dbname")
DB_USER = config.get("ROOT_DB_USER", "default_user")
DB_PASSWORD = config.get("ROOT_DB_PASSWD", "default_password")
COS_REGION = config.get("COS_CLIENT_REGION", "default_region")
COS_SECRET_ID = config.get("COS_CLIENT_SECRET_ID", "default_id")
COS_SECRET_KEY = config.get("COS_CLIENT_SECRET_KEY", "default_key")
BUCKET = config.get("COS_BUCKET", "b-perception-e2e-1318950322")
config.require("COS_PREFIX_DECODE")
PREFIX = config.get("COS_PREFIX_DECODE", "mb_raw_rosbag_decode_dirs")
DOMAIN = f"https://{BUCKET}.cos.{COS_REGION}.myqcloud.com"
# 输出缺失列表的文件
MISSING_FILE = "missing.txt"
def fetch_paths():
conn = psycopg2.connect(
host=DB_HOST, port=DB_PORT, dbname=DB_NAME, user=DB_USER, password=DB_PASSWORD
)
with conn.cursor() as cur:
cur.execute("SELECT data_path FROM bag_list;")
paths = [row[0] for row in cur.fetchall()]
conn.close()
return paths
def check_dirs(paths):
cfg = CosConfig(
Secret_id=COS_SECRET_ID, Secret_key=COS_SECRET_KEY, Region=COS_REGION
)
client = CosS3Client(cfg)
exist, missing = [], []
for p in paths:
prefix = p.rstrip("/") + "/lidar_gt_pandar128/"
resp = client.list_objects(Bucket=BUCKET, Prefix=prefix, MaxKeys=1)
if resp.get("Contents"):
exist.append(p)
else:
missing.append(p)
return exist, missing
def write_missing(missing, filepath):
with open(filepath, "w", encoding="utf-8") as f:
for p in missing:
f.write(p + "\n")
def main():
paths = fetch_paths()
exist, missing = check_dirs(paths)
print(f"总路径数: {len(paths)}")
print(f"存在 lidar_gt_pandar128 的: {len(exist)}")
for p in exist:
print(" [OK] ", p)
print(f"\n缺失 lidar_gt_pandar128 的: {len(missing)}")
for p in missing:
print(" [MISS]", p)
# 将缺失列表写入文件
write_missing(missing, MISSING_FILE)
print(f"\n已将 {len(missing)} 条缺失记录写入 `{MISSING_FILE}`")
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,354 @@
#!/usr/bin/env python3
"""
四种互斥模式:
(默认) —— 全量扫描
file --tasks-file—— 指定文件
service —— HTTP 服务
check —— fst_bag 关联并推送缺失
"""
from __future__ import annotations
import argparse
import json
import logging
import os
import subprocess
import time
from concurrent.futures import ThreadPoolExecutor
from pathlib import Path
from typing import List, Tuple, Dict, Any
import psycopg2
import requests
from flask import Flask, jsonify, request
from qcloud_cos import CosConfig, CosS3Client
from fst_data_pipeline.core.config_manager import ConfigManager
# ---------- 配置 ----------
cfg = ConfigManager()
DB = {
"host": cfg.get("ROOT_DB_HOST"),
"port": cfg.get_int("ROOT_DB_PORT"),
"dbname": cfg.get("ROOT_DB_NAME"),
"user": cfg.get("ROOT_DB_USER"),
"password": cfg.get("ROOT_DB_PASSWD"),
}
COS_CFG = CosConfig(
Region=cfg.get("COS_CLIENT_REGION"),
SecretId=cfg.get("COS_CLIENT_SECRET"),
SecretKey=cfg.get("COS_CLIENT_KEY"),
)
BUCKET = cfg.get("COS_BUCKET")
LD_READY = cfg.get("LD_READY_URL", "http://ld-checker/ready")
OD_READY = cfg.get("OD_READY_URL", "http://od-checker/ready")
LD_NOTIFY = cfg.get("LD_NOTIFY_URL", "http://ld-checker/notify")
OD_NOTIFY = cfg.get("OD_NOTIFY_URL", "http://od-checker/notify")
LOCAL_API_HOST = cfg.get("ROOT_DB_API", "http://10.0.240.4:5232")
MAX_WORKERS = 16
logging.basicConfig(level=logging.INFO, format="[%(asctime)s] %(message)s")
log = logging.getLogger("bag")
# ---------- 类型 ----------
BagRow = Tuple[
int, str, int, bool, int, int
] # id, name, project_id, is_decoded, od_old, ld_old
# ---------- 工具 ----------
def pg_conn():
conn = psycopg2.connect(**DB)
conn.autocommit = True
return conn, conn.cursor()
def cos_client() -> CosS3Client:
return CosS3Client(COS_CFG)
def cos_exists(cli: CosS3Client, prefix: str) -> bool:
try:
return bool(
cli.list_objects(Bucket=BUCKET, Prefix=prefix, MaxKeys=1).get("Contents")
)
except Exception:
return False
# ---------- HTTP ----------
def _get_all_fst_bags() -> List[str]:
url = f"{LOCAL_API_HOST}/api/fst/baglist"
try:
r = requests.get(url, timeout=10)
r.raise_for_status()
return r.json()
except Exception as e:
log.error("获取 fst bag list 失败: %s", e)
return []
def _resolve_paths(bags: List[str]) -> Dict[str, str]:
url = f"{LOCAL_API_HOST}/api/bags/pangu/detail"
if not bags:
return {}
try:
r = requests.post(url, json=bags, timeout=30)
r.raise_for_status()
return {
name: info["data_path"].split("/", 1)[-1]
for name, info in r.json().items()
if info.get("data_path")
}
except Exception as e:
log.error("解析 bag path 失败: %s", e)
return {}
# ---------- 推送 ----------
MISSING_LD: List[str] = []
MISSING_OD: List[str] = []
def push_list(ready_url: str, notify_url: str, bag_list: List[str], kind: str) -> None:
if not bag_list:
return
path_map = _resolve_paths(bag_list)
if not path_map:
log.warning("无可推送的 %s 路径", kind)
return
paths = [path_map[name] for name in bag_list if name in path_map]
log.info("%s: %d bags → pushing %d paths", kind, len(bag_list), len(paths))
while True:
try:
if requests.get(ready_url, timeout=5).json().get("ready"):
break
except Exception as e:
log.warning("ready check failed: %s", e)
time.sleep(2)
try:
resp = requests.post(notify_url, json=paths, timeout=30)
resp.raise_for_status()
log.info("%s pushed %d paths, status=%s", kind, len(paths), resp.status_code)
except Exception as e:
log.error("%s push failed: %s", kind, e)
# ---------- 核心处理 ----------
def process_row(row: BagRow, collect_missing: bool = False) -> None:
bag_id, name, project_id, _, od_old, ld_old = row
# 提前检查是否已“完美”
conn, cur = pg_conn()
try:
cur.execute("SELECT reserved_str FROM main_pangu WHERE name=%s", (name,))
reserved_str = (cur.fetchone() or ("",))[0]
has_md5 = reserved_str.startswith("md5:")
if ld_old == 3 and od_old == 3 and has_md5:
log.debug("[%s] 已完整,跳过", name)
return
finally:
conn.close()
# 需要真正检查 COS
conn, cur = pg_conn()
cli = cos_client()
try:
base = "mb" if project_id == 1 else "mmt"
derived = f"{base}_raw_rosbag_decode_dirs/{name}.dir/derived/"
ld_auto = cos_exists(cli, derived + "LDGT/")
ld_man = cos_exists(cli, derived + "LD_manual/")
od_auto = cos_exists(cli, derived + "object_auto_labeling/")
od_man = cos_exists(cli, derived + "object_manual_labeling/")
def calc_flag(a: bool, m: bool) -> int:
return 3 if a and m else 2 if m else 1 if a else 0
ld_new = calc_flag(ld_auto, ld_man)
od_new = calc_flag(od_auto, od_man)
# 更新状态
for field, old, new in (
("ld_annotated", ld_old, ld_new),
("od_annotated", od_old, od_new),
):
if old != new:
cur.execute(
f"UPDATE bag_list SET {field}=%s WHERE id=%s", (new, bag_id)
)
log.info("[%s] %s %s%s", name, field, old, new)
# 处理 md5
if not has_md5:
cur.execute("SELECT bag_path FROM main_pangu WHERE name=%s", (name,))
bag_path = (cur.fetchone() or ("",))[0]
if bag_path:
cos_key = bag_path.lstrip("/")
tmp_path = f"/tmp/{name}.bag"
try:
cli.download_file(Bucket=BUCKET, Key=cos_key, DestFilePath=tmp_path)
md5 = subprocess.check_output(
["md5sum", tmp_path], text=True
).split()[0]
os.remove(tmp_path)
except Exception as e:
log.warning("[%s] 下载/计算 md5 失败: %s", name, e)
md5 = None
if md5:
cur.execute(
"UPDATE main_pangu SET reserved_str='md5:'||%s WHERE name=%s",
(md5, name),
)
cur.execute(
"SELECT reserved_json FROM main_pangu WHERE name=%s", (name,)
)
reserved_json: Dict[str, Any] = json.loads(
cur.fetchone()[0] or "{}"
)
reserved_json.update({
"bag_meta": f"https://cla-dev.ca4ad.com/cdi/data/{md5}",
"bag_player": f"https://mviz-dev.ca4ad.com/player/v4/?bag_md5={md5}",
})
cur.execute(
"UPDATE main_pangu SET reserved_json=COALESCE(reserved_json,'{}'::jsonb)||%s WHERE name=%s",
(json.dumps(reserved_json, ensure_ascii=False), name),
)
log.info("[%s] 写入 md5:%s", name, md5)
# 收集缺失
if collect_missing:
if not ld_auto:
MISSING_LD.append(name)
if not od_auto:
MISSING_OD.append(name)
finally:
conn.close()
# ---------- 数据获取 ----------
def fetch_rows_from_file(path: Path) -> List[BagRow]:
with open(path, encoding="utf-8") as f:
names = [l.strip() for l in f if l.strip()]
if not names:
return []
conn, cur = pg_conn()
cur.execute("SELECT * FROM bag_list WHERE name = ANY(%s)", (names,))
rows = cur.fetchall()
conn.close()
return rows
def fetch_rows_from_all() -> List[BagRow]:
conn, cur = pg_conn()
cur.execute("SELECT * FROM bag_list")
rows = cur.fetchall()
conn.close()
return rows
def fetch_rows_from_fst() -> List[BagRow]:
bag_names = _get_all_fst_bags()
if not bag_names:
return []
conn, cur = pg_conn()
cur.execute("SELECT * FROM bag_list WHERE name = ANY(%s)", (bag_names,))
rows = cur.fetchall()
conn.close()
log.info("fst_bag 关联 %d", len(rows))
return rows
# ---------- 模式封装 ----------
def _run(rows: List[BagRow], collect_missing: bool):
MISSING_LD.clear()
MISSING_OD.clear()
with ThreadPoolExecutor(max_workers=MAX_WORKERS) as ex:
list(ex.map(lambda r: process_row(r, collect_missing), rows))
if collect_missing:
push_list(LD_READY, LD_NOTIFY, MISSING_LD, "LD")
push_list(OD_READY, OD_NOTIFY, MISSING_OD, "OD")
def run_all():
rows = fetch_rows_from_all()
_run(rows, collect_missing=False)
log.info("全量扫描完成")
def run_file(path: Path):
rows = fetch_rows_from_file(path)
_run(rows, collect_missing=False)
log.info("file 模式完成")
def run_service(host: str, port: int):
app = Flask(__name__)
@app.route("/ready")
def ready():
return jsonify(ready=True)
def handle_notify(bags: List[str]):
rows = []
if bags:
conn, cur = pg_conn()
cur.execute("SELECT * FROM bag_list WHERE name = ANY(%s)", (bags,))
rows = cur.fetchall()
conn.close()
for r in rows:
process_row(r, collect_missing=False)
@app.route("/notify_ld", methods=["POST"])
def notify_ld():
bags = request.get_json(force=True)
if not isinstance(bags, list):
return jsonify(error="need list"), 400
handle_notify(bags)
return jsonify(status="accepted"), 202
@app.route("/notify_od", methods=["POST"])
def notify_od():
bags = request.get_json(force=True)
if not isinstance(bags, list):
return jsonify(error="need list"), 400
handle_notify(bags)
return jsonify(status="accepted"), 202
app.run(host=host, port=port, threaded=True)
def run_check():
rows = fetch_rows_from_fst()
_run(rows, collect_missing=True)
log.info("check 模式完成")
# ---------- CLI ----------
if __name__ == "__main__":
parser = argparse.ArgumentParser()
sub = parser.add_subparsers(dest="mode", help="运行模式")
parser.set_defaults(mode="all")
f = sub.add_parser("file", help="读取 bag 列表文件")
f.add_argument("--tasks-file", required=True)
s = sub.add_parser("service", help="启动 HTTP 服务")
s.add_argument("--host", default="0.0.0.0")
s.add_argument("--port", type=int, default=8000)
sub.add_parser("check", help="仅扫描 fst_bag 关联 bag 并推送缺失")
args = parser.parse_args()
if args.mode == "file":
run_file(Path(args.tasks_file))
elif args.mode == "service":
run_service(args.host, args.port)
elif args.mode == "check":
run_check()
else:
run_all()

View File

@@ -0,0 +1,221 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
-----------------------------------------------------------------------------------
功能概述:
该脚本用于同步腾讯云 COS(对象存储)上的原始 .bag 文件与 PostgreSQL 中的 bag_list 表。
同时根据 COS 上解码输出目录 b/ 下的 .dir 文件, 更新数据库中对应记录的解码状态。
支持在数据库同步完成后, 根据 --notify 参数轮询后端服务是否就绪(通过 /ready 接口),
并在就绪后将未解码文件列表 POST 到 /notify 接口, 驱动后续解码或处理流程。
主要流程:
1. 连接腾讯云 COS, 列出 a/ 前缀下所有 .bag 文件, 插入数据库(跳过已存在)。
2. 查询数据库中未解码且未删除的记录。
3. 列出 b/ 前缀下的 .dir 文件, 推断已解码的 .bag 文件, 批量更新数据库标记 is_decoded=TRUE。
4. 如果指定 --notify:
- 轮询 /ready 接口直到返回 { "ready": true }
- 再将剩余未解码文件列表(带前缀)POST 到 /notify 接口。
依赖环境:
- Python 3.6+
- psycopg2 (PostgreSQL 驱动)
- requests (HTTP 请求)
- qcloud_cos (腾讯云 COS Python SDK)
使用示例:
python3 cos_pg_sync.py
python3 cos_pg_sync.py --notify
配置说明:
可在脚本开头直接修改:
COS_SECRET_ID / COS_SECRET_KEY / COS_REGION / COS_BUCKET
PostgreSQL 连接信息 PG_HOST 等
GATE_URL / NOTIFY_URL 等服务端接口地址
-----------------------------------------------------------------------------------
"""
import os
import time
import argparse
import logging
import psycopg2
from psycopg2.extras import execute_values
import requests
from qcloud_cos import CosConfig, CosS3Client
from fst_data_pipeline.core.config_manager import ConfigManager
# ====== 配置区 ======
config = ConfigManager()
DB_HOST = config.get("ROOT_DB_HOST", "localhost")
DB_PORT = config.get_int("ROOT_DB_PORT", 5432)
DB_NAME = config.get("ROOT_DB_NAME", "default_dbname")
DB_USER = config.get("ROOT_DB_USER", "default_user")
DB_PASSWORD = config.get("ROOT_DB_PASSWD", "default_password")
COS_REGION = config.get("COS_CLIENT_REGION", "default_region")
COS_SECRET_ID = config.get("COS_CLIENT_SECRET_ID", "default_id")
COS_SECRET_KEY = config.get("COS_CLIENT_SECRET_KEY", "default_key")
BUCKET = config.get("COS_BUCKET", "b-perception-e2e-1318950322")
config.require("COS_PREFIX_DECODE")
DECODED_PREFIX = config.get("COS_PREFIX_DECODE", "mb_raw_rosbag_decode_dirs")
RAW_BAG_PREFIX = config.get("COS_PREFIX_BAG", "mb_cuct_data_collection")
DOMAIN = f"https://{BUCKET}.cos.{COS_REGION}.myqcloud.com"
# Service 模式的 HTTP 接口
GATE_URL = "https://example.com/ready"
NOTIFY_URL = "https://example.com/notify"
LOG_FILE = "sync.log"
def init_logger():
fmt = "%(asctime)s [%(levelname)s] %(message)s"
logging.basicConfig(
level=logging.INFO,
format=fmt,
handlers=[
logging.StreamHandler(),
logging.FileHandler(LOG_FILE, encoding="utf-8"),
],
)
def get_cos_client():
cfg = CosConfig(Region=COS_REGION, SecretId=COS_SECRET_ID, SecretKey=COS_SECRET_KEY)
return CosS3Client(cfg)
def list_objs(cos, prefix):
"""分页列出 COS 上指定前缀下的所有对象 key"""
keys = []
token = None
while True:
params = {
"Bucket": BUCKET,
"Prefix": prefix,
"MaxKeys": 1000,
}
if token:
params["ContinuationToken"] = token
resp = cos.list_objects_v2(**params)
for obj in resp.get("Contents", []):
keys.append(obj["Key"])
if resp.get("IsTruncated"):
token = resp["NextContinuationToken"]
else:
break
return keys
def wait_for_ready():
"""轮询 /ready, 直到返回 JSON { "ready": true }"""
logging.info(f"开始轮询 Gate URL: {GATE_URL}")
while True:
try:
r = requests.get(GATE_URL, timeout=10)
if r.status_code == 200:
data = r.json()
if isinstance(data, dict) and data.get("ready") is True:
logging.info("Gate 就绪, 开始通知流程")
return
logging.info(
"Gate 未就绪(HTTP %d 或 ready=false), 10 分钟后重试", r.status_code
)
except Exception as e:
logging.warning(f"轮询 Gate 时发生异常: {e}")
time.sleep(600) # 10 分钟后重试
def main():
init_logger()
p = argparse.ArgumentParser(
description="同步 COS 与 bag_list, 并可选通知未解码列表"
)
p.add_argument(
"--notify",
"-n",
action="store_true",
help="轮询 /ready, 再 POST 剩余未解码列表(带前缀)",
)
args = p.parse_args()
# 1. 初始化 COS 客户端和数据库连接
cos = get_cos_client()
conn = psycopg2.connect(
host=DB_HOST, port=DB_PORT, user=DB_USER, password=DB_PASSWORD, dbname=DB_NAME
)
cur = conn.cursor()
# —— STEP1同步 a/ 下的 .bag 到 bag_list(name, project_id) ——
keys_a = list_objs(cos, RAW_BAG_PREFIX)
# 提取所有 .bag 的名字(不含前缀)
names = [os.path.basename(k) for k in keys_a if k.lower().endswith(".bag")]
if names:
# 默认 project_id=1
records = [(n, 1) for n in names]
execute_values(
cur,
"INSERT INTO bag_list(name, project_id) VALUES %s "
"ON CONFLICT(name) DO NOTHING",
records,
)
logging.info(f"STEP1: 尝试插入 {len(records)} 条记录, 冲突则跳过")
else:
logging.info("STEP1: 未发现 .bag 文件, 无需插入")
# —— STEP2查询所有 is_decoded=FALSE 且 is_deleted=FALSE 的 name ——
cur.execute("SELECT name FROM bag_list WHERE is_decoded=FALSE AND is_deleted=FALSE")
undecoded = [row[0] for row in cur.fetchall()]
logging.info(f"STEP2: 当前未解码记录数:{len(undecoded)}")
# —— STEP3扫描 b/ 下 .dir, 并批量更新 is_decoded ——
keys_b = list_objs(cos, DECODED_PREFIX)
# available 是所有已解码的 bag 名字(带 .bag 后缀)
available = {
os.path.basename(k)[:-4] + ".bag" for k in keys_b if k.lower().endswith(".dir")
}
to_update = list(set(undecoded) & available)
if to_update:
cur.execute(
"UPDATE bag_list "
"SET is_decoded=TRUE, update_time=CURRENT_TIMESTAMP "
"WHERE name = ANY(%s)",
(to_update,),
)
logging.info(f"STEP3: 标记已解码 {len(to_update)} 条记录")
else:
logging.info("STEP3: 无需更新解码状态")
conn.commit()
# —— 可选通知流程 ——
if args.notify:
# 1) 等待服务端就绪
wait_for_ready()
# 2) 再次查询剩余未解码名称
cur.execute(
"SELECT name FROM bag_list WHERE is_decoded=FALSE AND is_deleted=FALSE"
)
remaining = [row[0] for row in cur.fetchall()]
logging.info(f"通知前剩余未解码 {len(remaining)} 条: {remaining}")
# 3) 补上前缀, POST 给 /notify
paths = [RAW_BAG_PREFIX + name for name in remaining]
try:
resp = requests.post(NOTIFY_URL, json=paths, timeout=10)
logging.info(f"POST /notify 返回 HTTP {resp.status_code}: {resp.text}")
except Exception as e:
logging.error(f"通知失败: {e}")
# 清理资源
cur.close()
conn.close()
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,216 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
import os
import re
import shutil
import subprocess
import logging
from pathlib import Path
from concurrent.futures import ProcessPoolExecutor, as_completed
import requests
import psycopg2
# =========================================================
# 日志
# =========================================================
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s | %(levelname)s | %(message)s",
handlers=[logging.FileHandler("bag_merge.log"), logging.StreamHandler()],
)
log = logging.getLogger(__name__)
# =========================================================
# 配置(全部有默认值)
# =========================================================
# API
API_URL = os.getenv("API_URL", "http://127.0.0.1:8080/api/bag/mapping")
API_TIMEOUT = int(os.getenv("API_TIMEOUT", "30"))
# PostgreSQL
PG_DSN = os.getenv(
"PG_DSN",
"host=127.0.0.1 port=5432 dbname=test user=test password=test",
)
# 本地临时目录
TEMP_ROOT = Path(os.getenv("TEMP_ROOT", "/tmp/bag_merge"))
# coscmd
COSCMD_BIN = os.getenv("COSCMD_BIN", "coscmd")
COSCMD_TIMEOUT = int(os.getenv("COSCMD_TIMEOUT", "3600"))
# 并发数
MAX_WORKERS = int(os.getenv("MAX_WORKERS", "4"))
# 用于拼可访问 URL仅用于存 DB不影响 coscmd
COS_ENDPOINT = os.getenv("COS_ENDPOINT", "")
COS_BUCKET = os.getenv("COS_BUCKET", "")
COS_REGION = os.getenv("COS_REGION", "")
# =========================================================
# 工具函数
# =========================================================
def safe_name(key: str) -> str:
"""
将 COS key 转成本地安全文件名
"""
k = (key or "").strip().replace("\\", "/")
k = re.sub(r"/+", "/", k)
k = k.replace("..", "__")
k = k.replace("/", "__")
return k or "empty"
def run_cmd(cmd: list[str], *, timeout: int | None = None):
log.debug("CMD: %s", " ".join(cmd))
subprocess.run(cmd, check=True, timeout=timeout)
def make_cos_url(key: str) -> str:
"""
生成一个用于入库的 URL
"""
k = key.lstrip("/")
if COS_ENDPOINT:
return f"https://{COS_ENDPOINT.rstrip('/')}/{k}"
if COS_BUCKET and COS_REGION:
return f"https://{COS_BUCKET}.cos.{COS_REGION}.myqcloud.com/{k}"
return key
# =========================================================
# 业务函数
# =========================================================
def fetch_mapping() -> dict:
"""
从 API 获取 parent -> children 映射
"""
log.info("POST %s", API_URL)
resp = requests.post(
API_URL,
json={"bag_names": ["*"]},
headers={"Content-Type": "application/json"},
timeout=API_TIMEOUT,
)
resp.raise_for_status()
data = resp.json()
if not isinstance(data, dict):
raise RuntimeError("mapping response is not a dict")
return data
def download_file(key: str, local: Path):
"""
coscmd download /xxx local
"""
local.parent.mkdir(parents=True, exist_ok=True)
cos_path = "/" + key.lstrip("/")
log.info("↓ download %s -> %s", cos_path, local)
run_cmd(
[COSCMD_BIN, "download", cos_path, str(local)],
timeout=COSCMD_TIMEOUT,
)
def upload_file(local: Path, key: str) -> str:
"""
coscmd upload local /xxx
"""
cos_path = "/" + key.lstrip("/")
log.info("↑ upload %s -> %s", local, cos_path)
run_cmd(
[COSCMD_BIN, "upload", str(local), cos_path],
timeout=COSCMD_TIMEOUT,
)
return make_cos_url(key)
def merge_bags(inputs: list[Path], output: Path):
"""
调用 rosbag-merge
"""
output.parent.mkdir(parents=True, exist_ok=True)
subprocess.check_call(
["rosbag-merge", "-o", str(output)] + [str(p) for p in inputs]
)
def update_db(parent: str, cos_url: str):
"""
更新数据库
"""
sql = "UPDATE bag_task SET tos_path = %s WHERE parent_bag = %s"
with psycopg2.connect(PG_DSN) as conn:
with conn.cursor() as cur:
cur.execute(sql, (cos_url, parent))
conn.commit()
log.info("[DB] %s -> %s", parent, cos_url)
def work_one(parent: str, children: list[str]) -> str:
"""
单个 parent 的完整处理流程
"""
log.info("start parent=%s children=%d", parent, len(children))
wd = TEMP_ROOT / safe_name(parent)
wd.mkdir(parents=True, exist_ok=True)
try:
subs: list[Path] = []
for c in children:
lp = wd / safe_name(c)
download_file(c, lp)
subs.append(lp)
out = wd / safe_name(parent)
merge_bags(subs, out)
url = upload_file(out, parent)
update_db(parent, url)
log.info("finish parent=%s", parent)
return url
finally:
shutil.rmtree(wd, ignore_errors=True)
# =========================================================
# 主入口
# =========================================================
def main():
TEMP_ROOT.mkdir(parents=True, exist_ok=True)
# 可执行文件检查(有问题会在启动阶段直接失败)
run_cmd([COSCMD_BIN, "--version"], timeout=10)
run_cmd(["rosbag-merge", "-h"], timeout=10)
mapping = fetch_mapping()
if not mapping:
log.warning("mapping empty, exit")
return
with ProcessPoolExecutor(max_workers=MAX_WORKERS) as pool:
futures = {
pool.submit(work_one, parent, children): parent
for parent, children in mapping.items()
}
for fu in as_completed(futures):
parent = futures[fu]
try:
url = fu.result()
log.info("done %s -> %s", parent, url)
except Exception as e:
log.exception("failed %s: %s", parent, e)
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,7 @@
# auto copy MMT CLA rosbag to MB dev bucket
This is Tencent cloud function to filter the bag name and copy the bag from MMT CLA CDI storage bucket to MB dev bucket.
2 cloud functions are used in auto_copy phase:
* filter_bag_to_kafka.py: filter the rosbag with name rule, send bag list to kafka
* copy_cdi_cos_bag.py: listen to kafka queue and perform cloud bucket data copy

View File

@@ -0,0 +1,81 @@
# -*- coding=utf-8
#####----------------------------------------------------------------#####
##### #####
##### 使用教程/readme: #####
##### https://cloud.tencent.com/document/product/583/30722 #####
##### #####
#####----------------------------------------------------------------#####
import os
import logging
from qcloud_cos import CosConfig
from qcloud_cos import CosS3Client
from qcloud_cos import CosServiceError
import ast
# Setting user properties, including secret_id, secret_key, region, bucket
appid = "1318950322" # Please replace with your APPID. 请替换为您的 APPID
region = "ap-shanghai-adc" # Please replace with your region. 替换为用户的region
secret_id = os.environ.get("TENCENTCLOUD_SECRETID")
secret_key = os.environ.get("TENCENTCLOUD_SECRETKEY")
token = os.environ.get("TENCENTCLOUD_SESSIONTOKEN")
bucket = "b-perception-e2e-1318950322" # Please replace with your COS bucket. 替换为需要写入的COS Bucket
folder = "mb_cuct_data_collection/"
# Getting configuration object. 获取配置对象
config = CosConfig(
Region=region, Secret_id=secret_id, Secret_key=secret_key, Token=token
)
client = CosS3Client(config)
logger = logging.getLogger()
def copy_file(bucket_upload, key, src_key):
response1 = client.object_exists(Bucket=bucket_upload, Key=key)
if response1:
try:
destKey = folder + src_key
response2 = client.copy(
Bucket=bucket,
Key=destKey,
CopySource={"Bucket": bucket_upload, "Key": key, "Region": region},
CopyStatus="Replaced",
PartSize=50,
MAXThread=10,
)
return True
except CosServiceError as e:
print("e is", e)
return False
else:
print("resource file does not exist")
return False
def main_handler(event, context):
bucket_upload = ""
logger.info("start main handler")
for record in event["Records"]:
if "Ckafka" not in record.keys():
print("event: no ckafka")
continue
value = record["Ckafka"]["msgBody"]
try:
data = ast.literal_eval(value)
if "cos" not in data.keys():
print("event: no cos")
continue
bucket_upload = data["cos"]["cosBucket"]["name"]
key = record["Ckafka"]["msgKey"]
src_key = key.split("/")[-1]
print("file name is ", src_key)
except:
print("msgBody:", value)
return "message error"
if copy_file(bucket_upload, key, src_key):
print("copy success")
else:
print("copy fail")

View File

@@ -0,0 +1,94 @@
#!/usr/bin/env python
# -*- coding=utf-
import logging
import time
import json
from kafka import KafkaProducer
from kafka.errors import KafkaError
logger = logging.getLogger("COSToKafka")
logger.setLevel(logging.INFO)
# file index which need copy
file_index = "PL162802"
skip_keywords = ["recording_lpnp_recording", "recording_ddloc_recording"]
class CosToKafka(object):
def __init__(self, host, **kwargs):
self.host = host
self.producer = KafkaProducer(
bootstrap_servers=[self.host],
# retries = 10,
# max_in_flight_requests_per_connection = 1,
# request_timeout_ms = 30000,
# max_block_ms = 60000,
**kwargs,
)
def send(self, topic, event):
global count
count = 0
def on_send_success(record_metadata):
global count
count = count + 1
def on_send_error(excp):
logger.error("failed to send message", exc_info=excp)
s_time = time.time()
eventList = None
try:
if "Records" in event:
eventList = event["Records"]
for data in eventList:
if "cos" in data:
key_list = data["cos"]["cosObject"]["key"].split("/")
file_name = str(key_list[-1])
if file_name.startswith(file_index):
if any(kw in file_name for kw in skip_keywords):
print("skip by blacklist:", file_name)
continue
if file_name.endswith(".bag"):
today_str = time.strftime("%Y%m%d", time.localtime())
with open(
f"{today_str}.txt", "a", encoding="utf-8"
) as f:
f.write(file_name + "\n")
_key = "/".join(key_list[3:])
# "cos": {"cosBucket": {"appid": "1324295915", "cosRegion": "ap-shanghai", "name": "dis-source-cfdi-test", "region": "sh", "s3Region": "ap-shanghai"}
# 返回的value中, name并不是真的bucket的name, 用这个那么去访问会出问题的
# 真实的Bucket名字是 $name = $name-$appid
_name = data["cos"]["cosBucket"]["name"]
_appid = data["cos"]["cosBucket"]["appid"]
data["cos"]["cosBucket"]["name"] = _name + "-" + _appid
print("bucket:", _name)
print("ket:", _key)
key = _key.encode("utf-8")
value = json.dumps(data).encode("utf-8")
self.producer.send(
topic, key=key, value=value
).add_callback(on_send_success).add_errback(on_send_error)
else:
print("file index not match", str(key_list[-1]))
else:
print("message error")
# block until all async messages are sent
self.producer.flush()
except KafkaError as e:
return e
finally:
if self.producer is not None:
self.producer.close()
e_time = time.time()
return "{} messages delivered in {}s".format(count, e_time - s_time)

View File

@@ -0,0 +1,30 @@
#!/usr/bin/env python
# -*- coding=utf-8
import os
import logging
from cos_to_kafka import CosToKafka
logger = logging.getLogger("Index")
logger.setLevel(logging.INFO)
def main_handler(event, context):
logger.info("start main handler")
os.environ["KAFKA_ADDRESS"] = "10.0.210.45:9092"
os.environ["KAFKA_TOPIC_NAME"] = "synchronous_copy_release_1.0"
kafka_address = os.getenv("KAFKA_ADDRESS")
kafka_topic_name = os.getenv("KAFKA_TOPIC_NAME")
cos_to_kafka = CosToKafka(
kafka_address,
# security_protocol = "PLAINTEXT",
# sasl_mechanism = "PLAIN",
# sasl_plain_username = "ckafka-80o10xxx#lkoxx",
# sasl_plain_password = "kongllxxxx",
api_version=(1, 1, 1),
)
ret = cos_to_kafka.send(kafka_topic_name, event)
logger.info(ret)
return ret

View File

@@ -0,0 +1,439 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
import os
import sys
import time
import argparse
import threading
import subprocess
import shutil
import logging
from datetime import datetime
import queue
import re
from prometheus_client import start_http_server, Counter, Gauge, Summary, Histogram
from flask import Flask, request, jsonify
# —— 常量 & 路径 —— #
BASE = os.getcwd()
INPUT_ROOT = os.path.join(BASE, "input")
OUTPUT_ROOT = os.path.join(BASE, "output")
EMPTY_DIR = os.path.join(BASE, "empty")
LOG_DIR = os.path.join(BASE, "logs")
COS_BUCKET = "mb_raw_rosbag_decode_dirs"
DOCKER_IMAGE = (
"artifact.swfcn.i.mercedes-benz.com/swfcn_docker/perception-3d/mmtbag_decoder:v6.6"
)
DOCKER_CMD_TEMPLATE = [
"docker",
"run",
"--rm",
"-v",
"{in_dir}:/input",
"-v",
"{out_dir}:/output",
DOCKER_IMAGE,
"bash",
"-c",
"source /opt/ros/noetic/setup.bash && "
"/opt/perception-3d/scripts/tools/"
"mmt_bag_decoder_scripts/decoded-bag.sh /input /output 3 1",
]
BATCH_SIZE = 50
MAX_LOCAL = 100
MAX_RETRIES = 3
RETRY_DELAY_S = 2
METRICS_PORT = 8000
SENTINEL = (None, None)
# —— 日志配置 —— #
os.makedirs(LOG_DIR, exist_ok=True)
logger = logging.getLogger("pipeline")
logger.setLevel(logging.INFO)
h_info = logging.FileHandler(os.path.join(LOG_DIR, "pipeline.log"), encoding="utf-8")
h_err = logging.FileHandler(os.path.join(LOG_DIR, "error_tasks.log"), encoding="utf-8")
fmt = logging.Formatter("%(asctime)s %(levelname)s %(message)s")
h_info.setFormatter(fmt)
h_err.setFormatter(fmt)
h_err.setLevel(logging.ERROR)
logger.addHandler(h_info)
logger.addHandler(h_err)
# —— Prometheus 指标 —— #
DL_TOTAL = Counter("pipeline_download_total", "下载尝试总数")
DL_FAIL = Counter("pipeline_download_failures", "下载失败总数")
DL_RETRY = Counter("pipeline_download_retries", "下载重试总数")
PR_TOTAL = Counter("pipeline_process_total", "处理尝试总数")
PR_FAIL = Counter("pipeline_process_failures", "处理失败总数")
PR_RETRY = Counter("pipeline_process_retries", "处理重试总数")
UP_TOTAL = Counter("pipeline_upload_total", "上传尝试总数")
UP_FAIL = Counter("pipeline_upload_failures", "上传失败总数")
UP_RETRY = Counter("pipeline_upload_retries", "上传重试总数")
DL_DUR = Summary("pipeline_download_duration_seconds", "单批下载耗时秒")
PR_DUR = Summary("pipeline_process_duration_seconds", "单批处理耗时秒")
UP_DUR = Summary("pipeline_upload_duration_seconds", "单批上传耗时秒")
BATCH_SIZE_HIST = Histogram(
"pipeline_batch_size",
"单批任务中文件数量分布",
buckets=[1, 10, 20, 50, 100, 200, 500],
)
FILE_DL_DUR = Histogram("pipeline_file_download_duration_seconds", "单文件下载耗时分布")
BATCH_OUT_FILES = Gauge("pipeline_batch_output_file_count", "单批处理后输出文件数")
Q_BATCH = Gauge("pipeline_queue_batches", "待下载批次数")
Q_PROC = Gauge("pipeline_queue_processing", "待处理批次数")
Q_UP = Gauge("pipeline_queue_uploading", "待上传批次数")
LOCAL_FILES = Gauge("pipeline_local_file_count", "本地 input 文件总数")
# —— 全局队列 & inflight 计数 —— #
batch_q = queue.Queue()
proc_q = queue.Queue()
up_q = queue.Queue()
inflight = 0
inflight_lock = threading.Lock()
# —— 辅助函数 —— #
def count_local_files():
return sum(len(files) for _, _, files in os.walk(INPUT_ROOT))
def run(cmd, timeout=None):
logger.info("CMD: %s", " ".join(cmd))
try:
p = subprocess.Popen(
cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True
)
start = time.time()
out_lines = []
timed_out = False
for line in p.stdout:
out_lines.append(line)
if timeout and (time.time() - start) > timeout:
p.kill()
timed_out = True
break
code = p.wait()
return code, timed_out, "".join(out_lines)
except Exception:
logger.exception("CMD 执行异常")
return -1, False, ""
def with_retry(tag, fn, *args):
for i in range(1, MAX_RETRIES + 1):
code, timed_out, _ = fn(*args)
if code == 0:
return True
if timed_out:
logger.error("%s 阶段超时,不再重试", tag)
break
# 计数重试
if tag.startswith("DL["):
DL_RETRY.inc()
if tag.startswith("PR["):
PR_RETRY.inc()
if tag.startswith("UP["):
UP_RETRY.inc()
logger.warning("%s 重试 %d/%d", tag, i, MAX_RETRIES)
time.sleep(RETRY_DELAY_S)
logger.error("%s 最终失败", tag)
return False
# —— 下载 —— #
@DL_DUR.time()
def do_download(batch_id, paths, batch_timeout):
if batch_id is None:
proc_q.put(SENTINEL)
return
DL_TOTAL.inc()
start = time.time()
in_dir = os.path.join(INPUT_ROOT, batch_id)
os.makedirs(in_dir, exist_ok=True)
# 限制本地文件数
while count_local_files() >= MAX_LOCAL:
logger.warning("本地文件过多暂停下载5分钟")
time.sleep(300)
for p in paths:
if time.time() - start > batch_timeout:
logger.error("DL[%s] 下载阶段超时,跳过剩余", batch_id)
DL_FAIL.inc()
break
dst = os.path.join(in_dir, os.path.basename(p))
f_start = time.time()
ok = with_retry(
f"DL[{batch_id}]",
lambda s, d: run(
["coscmd", "-s", "download", s, d],
timeout=batch_timeout - (time.time() - start),
),
p,
dst,
)
FILE_DL_DUR.observe(time.time() - f_start)
if not ok:
DL_FAIL.inc()
proc_q.put((batch_id, in_dir))
# —— 处理 —— #
@PR_DUR.time()
def do_process(batch_id, in_dir, batch_timeout):
if batch_id is None:
up_q.put(SENTINEL)
return
PR_TOTAL.inc()
start = time.time()
out_dir = os.path.join(OUTPUT_ROOT, batch_id)
os.makedirs(out_dir, exist_ok=True)
cmd = [c.format(in_dir=in_dir, out_dir=out_dir) for c in DOCKER_CMD_TEMPLATE]
def run_pr(command):
p = subprocess.Popen(
command, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True
)
for line in p.stdout:
logger.info("[PR %s] %s", batch_id, line.rstrip())
if time.time() - start > batch_timeout:
p.kill()
return p.wait(), True, ""
return p.wait(), False, ""
ok = with_retry(f"PR[{batch_id}]", run_pr, cmd)
if not ok:
PR_FAIL.inc()
# 统计输出文件
files = []
for r, _, fs in os.walk(out_dir):
for fn in fs:
files.append(os.path.relpath(os.path.join(r, fn), out_dir))
BATCH_OUT_FILES.set(len(files))
shutil.rmtree(in_dir, ignore_errors=True)
up_q.put((batch_id, out_dir))
# —— 上传 —— #
@UP_DUR.time()
def do_upload(batch_id, out_dir, batch_timeout):
global inflight
if batch_id is None:
return
try:
UP_TOTAL.inc()
ok = with_retry(
f"UP[{batch_id}]",
lambda d: run(
["coscmd", "-s", "upload", "-r", d, COS_BUCKET], timeout=batch_timeout
),
out_dir,
)
if not ok:
UP_FAIL.inc()
return
# 删除目录结构
for cmd in [
["sudo", "rsync", "-av", "--delete", f"{EMPTY_DIR}/", f"{out_dir}/"],
["sudo", "rm", "-rf", out_dir],
]:
run(cmd, timeout=60)
logger.info("UP[%s] 完成", batch_id)
finally:
# 无论成功失败任务算完成inflight-1
with inflight_lock:
inflight -= 1
# —— Worker 模板 —— #
def worker(q, fn, timeout):
while True:
bid, data = q.get()
fn(bid, data, timeout)
q.task_done()
if bid is None:
break
# —— Service HTTP —— #
app = Flask(__name__)
@app.route("/ready", methods=["GET"])
def api_ready():
with inflight_lock:
busy = inflight > 0
return jsonify(ready=not busy)
@app.route("/notify", methods=["POST"])
def api_notify():
global inflight
data = request.get_json(force=True)
if not isinstance(data, list):
return jsonify(error="Expect JSON list"), 400
# 兼容 bag-checker只发 name 时补前缀
paths = []
for item in data:
if not isinstance(item, str):
continue
if item.startswith("mb_cuct_data_collection/"):
paths.append(item)
else:
paths.append("mb_cuct_data_collection/" + item)
TIME_RE = re.compile(r"_(\d{8})-(\d{6})_") # 匹配 20230803-160828
def extract_ts(p: str) -> datetime:
m = TIME_RE.search(os.path.basename(p))
if not m:
return datetime.min # 无法解析的放最后
date_part, time_part = m.groups()
ts_str = f"{date_part}{time_part}"
return datetime.strptime(ts_str, "%Y%m%d%H%M%S")
paths.sort(key=extract_ts, reverse=True)
with inflight_lock:
inflight += 1
for idx in range(0, len(paths), BATCH_SIZE):
blk = paths[idx : idx + BATCH_SIZE]
BATCH_SIZE_HIST.observe(len(blk))
bid = datetime.utcnow().strftime("%Y%m%dT%H%M%SZ") + f"_{idx // BATCH_SIZE + 1}"
batch_q.put((bid, blk))
batch_q.put(SENTINEL)
# batch_q.put((bid, paths))
return jsonify(status="accepted", batch_size=BATCH_SIZE), 202
def start_metric_updater():
def loop():
while True:
Q_BATCH.set(batch_q.qsize())
Q_PROC.set(proc_q.qsize())
Q_UP.set(up_q.qsize())
LOCAL_FILES.set(count_local_files())
time.sleep(1)
t = threading.Thread(target=loop, daemon=True)
t.start()
# —— 两种模式的入口 —— #
def file_mode(args):
# 读 tasks-file分批入队放入 sentinel然后启动处理
lines = [
line.strip() for line in open(args.tasks_file, encoding="utf-8") if line.strip()
]
for idx in range(0, len(lines), args.batch_size):
blk = lines[idx : idx + args.batch_size]
BATCH_SIZE_HIST.observe(len(blk))
bid = (
datetime.utcnow().strftime("%Y%m%dT%H%M%SZ")
+ f"_{idx // args.batch_size + 1}"
)
batch_q.put((bid, blk))
batch_q.put(SENTINEL)
start_http_server(METRICS_PORT)
logger.info("Metrics HTTP 启动,端口 %d", METRICS_PORT)
threads = [
threading.Thread(
target=worker, args=(batch_q, do_download, args.batch_timeout)
),
threading.Thread(target=worker, args=(proc_q, do_process, args.batch_timeout)),
threading.Thread(target=worker, args=(up_q, do_upload, args.batch_timeout)),
]
for t in threads:
t.start()
# 更新指标 & 等待完成
while batch_q.unfinished_tasks or proc_q.unfinished_tasks or up_q.unfinished_tasks:
Q_BATCH.set(batch_q.qsize())
Q_PROC.set(proc_q.qsize())
Q_UP.set(up_q.qsize())
LOCAL_FILES.set(count_local_files())
time.sleep(1)
for t in threads:
t.join()
logger.info("文件模式处理完成,退出。")
sys.exit(0)
def service_mode(args):
# 确保目录存在
for d in (INPUT_ROOT, OUTPUT_ROOT, EMPTY_DIR, LOG_DIR):
os.makedirs(d, exist_ok=True)
# 启动 Prometheus 和指标更新
start_http_server(METRICS_PORT)
logger.info("Metrics HTTP 启动,端口 %d", METRICS_PORT)
start_metric_updater()
# 启动后台 worker
threads = [
threading.Thread(
target=worker, args=(batch_q, do_download, args.batch_timeout), daemon=True
),
threading.Thread(
target=worker, args=(proc_q, do_process, args.batch_timeout), daemon=True
),
threading.Thread(
target=worker, args=(up_q, do_upload, args.batch_timeout), daemon=True
),
]
for t in threads:
t.start()
# 启动 Flask
logger.info("Decode Service 启动 HTTP on %s:%d", args.host, args.port)
app.run(host=args.host, port=args.port, threaded=True)
def main():
p = argparse.ArgumentParser()
sub = p.add_subparsers(dest="mode", required=True)
f = sub.add_parser("file", help="文件模式:--tasks-file")
f.add_argument("--tasks-file", required=True)
f.add_argument("--batch-size", type=int, default=BATCH_SIZE)
f.add_argument("--batch-timeout", type=int, default=3600)
s = sub.add_parser("service", help="服务模式:启动 HTTP ready/notify")
s.add_argument("--batch-timeout", type=int, default=3600)
s.add_argument("--host", default="0.0.0.0")
s.add_argument("--port", type=int, default=5000)
args = p.parse_args()
if args.mode == "file":
file_mode(args)
else:
service_mode(args)
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,133 @@
#!/usr/bin/env python3
"""
python pipeline.py --check
python pipeline.py --tasks-file task.txt
"""
import argparse
import logging
import subprocess
import sys
from pathlib import Path
from typing import List
import requests
from tqdm import tqdm
from fst_data_pipeline.pipelines.tencent.bag_operation.bag_scanner import cfg
BASE = Path.cwd()
BAG, OSM, OUT, LOG = BASE / "bags", BASE / "osm", BASE / "result", BASE / "logs"
for d in (BAG, OSM, OUT, LOG):
d.mkdir(exist_ok=True)
# ---------- 单一日志 ----------
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s %(message)s",
handlers=[
logging.FileHandler(LOG / "run.log"),
logging.StreamHandler(sys.stdout),
],
)
log = logging.getLogger("pipeline")
# ---------- 常量 ----------
LOCAL_API_HOST = cfg.get("ROOT_DB_API", "http://10.0.240.4:5232")
IMAGE = "eval-mmt:latest"
DOCKER_CMD = [
"sudo",
"docker",
"run",
"--rm",
"-v",
f"{BAG}:/bag_data",
"-v",
f"{OSM}:/osm_data",
"-v",
f"{OUT}:/output_folder",
IMAGE,
"bash",
"-c",
"source ~/.bashrc && python3 /root/tools/eval_tool/eval_mmt_data.py "
"--data_folder /bag_data/ --osm_folder /osm_data/ --result_folder /output_folder",
]
# ---------- 工具 ----------
def runcmd(cmd: List[str]) -> None:
subprocess.run(cmd, check=True, stdout=subprocess.DEVNULL, stderr=subprocess.STDOUT)
def fetch_baglist() -> List[str]:
url = f"{LOCAL_API_HOST}/api/fst/baglist"
resp = requests.get(url, timeout=60)
resp.raise_for_status()
return resp.json()
def fetch_detail(names: List[str]) -> dict:
url = f"{LOCAL_API_HOST}/api/bags/pangu/detail"
resp = requests.post(url, json=names, timeout=30)
resp.raise_for_status()
return resp.json()
def download_one(bag: str, info: dict) -> bool:
try:
cos_path = info["data_path"].split("/", 1)[1]
dst_dir = BAG / f"{bag}.dir"
if not dst_dir.exists():
runcmd(["coscmd", "-s", "download", "-r", cos_path, str(dst_dir)])
derived = (
requests.get(info["reserved_json"], timeout=30)
.json()["derived_dir"]
.rstrip("/")
)
osm_cos_path = f"{derived}/LD_manual/{bag.replace('.bag', '.osm')}"
dst_osm = OSM / f"{bag.replace('.bag', '.osm')}"
if not dst_osm.exists():
runcmd(["coscmd", "-s", "download", osm_cos_path, str(dst_osm)])
return True
except Exception as e:
log.error("下载 %s 失败: %s", bag, e)
return False
def download_all(names: List[str]) -> None:
infos = fetch_detail(names)
for bag in tqdm(names, desc="Download"):
download_one(bag, infos[bag])
log.info("全部下载完成")
def eval_all() -> None:
log.info("开始 Docker 评估")
runcmd(DOCKER_CMD)
log.info("评估完成")
# ---------- CLI ----------
def main():
parser = argparse.ArgumentParser()
group = parser.add_mutually_exclusive_group(required=True)
group.add_argument("--tasks-file", type=Path)
group.add_argument("--check", action="store_true")
args = parser.parse_args()
names = (
fetch_baglist()
if args.check
else [l.strip() for l in args.tasks_file.open() if l.strip()]
)
if not names:
log.info("列表为空")
return
download_all(names)
eval_all()
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,424 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
import argparse
import logging
import os
import queue
import shutil
import subprocess
import sys
import threading
import time
from datetime import datetime
from flask import Flask, request, jsonify
from prometheus_client import start_http_server, Counter, Gauge, Summary, Histogram
# ============ 常量 & 全局队列 ============
BASE = os.getcwd()
INPUT_ROOT = os.path.join(BASE, "input")
OUTPUT_ROOT = os.path.join(BASE, "output")
EMPTY_DIR = os.path.join(BASE, "empty")
LOG_DIR = os.path.join(BASE, "logs")
DOWNLOAD_ENTRIES = [
"lidar_gt_pandar128",
"camera_front_wide",
"calibration",
"raw_gnss.csv",
"raw_imu.csv",
"vehicle_wheel.csv",
]
MAX_PL_DIRS = 10 # 本地 PL* 子目录限流阈值
BATCH_CHUNK = 5 # 文件模式每批任务数
DOCKER_IMAGE = "ldgt_cu11_1_devel"
MAX_RETRIES = 3
RETRY_DELAY_S = 2
METRICS_PORT = 8002
# sentinel 用于优雅停止
# 格式:(batch_id, remote_paths, run_slam, visualize)
SENTINEL = (None, None, False, False)
batch_q = queue.Queue()
proc_q = queue.Queue()
up_q = queue.Queue()
# —— 日志配置 —— #
os.makedirs(LOG_DIR, exist_ok=True)
logger = logging.getLogger("pipeline")
logger.setLevel(logging.INFO)
fh = logging.FileHandler(os.path.join(LOG_DIR, "pipeline.log"), encoding="utf-8")
fh.setFormatter(logging.Formatter("%(asctime)s %(levelname)s %(message)s"))
logger.addHandler(fh)
ch = logging.StreamHandler(sys.stdout)
ch.setLevel(logging.INFO)
ch.setFormatter(logging.Formatter("%(asctime)s %(levelname)s %(message)s"))
logger.addHandler(ch)
# —— 全局自增 batch_id 计数 —— #
batch_counter = 0
counter_lock = threading.Lock()
# ============ Prometheus 指标 ============ #
DL_TOTAL = Counter("pipeline_download_total", "下载尝试总数")
DL_FAIL = Counter("pipeline_download_failures", "下载失败总数")
DL_RETRY = Counter("pipeline_download_retries", "下载重试总数")
PR_TOTAL = Counter("pipeline_process_total", "处理尝试总数")
PR_FAIL = Counter("pipeline_process_failures", "处理失败总数")
PR_RETRY = Counter("pipeline_process_retries", "处理重试总数")
UP_TOTAL = Counter("pipeline_upload_total", "上传尝试总数")
UP_FAIL = Counter("pipeline_upload_failures", "上传失败总数")
UP_RETRY = Counter("pipeline_upload_retries", "上传重试总数")
DL_DUR = Summary("pipeline_download_duration_seconds", "下载耗时秒")
PR_DUR = Summary("pipeline_process_duration_seconds", "处理耗时秒")
UP_DUR = Summary("pipeline_upload_duration_seconds", "上传耗时秒")
BATCH_SIZE_HIST = Histogram(
"pipeline_batch_size", "批次大小分布", buckets=[1, 5, 10, 20, 50, 100]
)
FILE_DL_DUR = Histogram("pipeline_file_download_duration_seconds", "单文件下载耗时分布")
BATCH_OUT_FILES = Gauge("pipeline_batch_output_file_count", "输出文件数")
Q_BATCH = Gauge("pipeline_queue_batches", "待下载批次数")
Q_PROC = Gauge("pipeline_queue_processing", "待处理批次数")
Q_UP = Gauge("pipeline_queue_uploading", "待上传批次数")
LOCAL_PL = Gauge("pipeline_local_pl_dirs", "本地 PL* 目录数")
# ============ 辅助函数 ============ #
def count_pl_dirs():
# return sum(len(files) for _, _, files in os.walk(INPUT_ROOT))
try:
return sum(
1
for d in os.listdir(INPUT_ROOT)
if d.startswith("PL") and os.path.isdir(os.path.join(INPUT_ROOT, d))
)
except FileNotFoundError:
return 0
def run(cmd, timeout=None):
logger.info("RUN %s", " ".join(cmd))
try:
p = subprocess.Popen(
cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True
)
start = time.time()
out = []
while True:
line = p.stdout.readline()
if not line:
break
out.append(line)
print(line, end="", flush=True)
if timeout and time.time() - start > timeout:
p.kill()
return p.wait(), True, "".join(out)
code = p.wait()
return code, False, "".join(out)
except Exception:
logger.exception("RUN 异常")
return -1, False, ""
def with_retry(tag, fn, *args):
for i in range(1, MAX_RETRIES + 1):
code, timed_out, _ = fn(*args)
if code == 0:
return True
if timed_out:
logger.error("%s 超时,停止重试", tag)
break
if tag.startswith("DL["):
DL_RETRY.inc()
if tag.startswith("PR["):
PR_RETRY.inc()
if tag.startswith("UP["):
UP_RETRY.inc()
logger.warning("%s 重试 %d/%d", tag, i, MAX_RETRIES)
time.sleep(RETRY_DELAY_S)
logger.error("%s 最终失败", tag)
return False
# ============ 1) 下载阶段 ============ #
@DL_DUR.time()
def do_download(batch_id, remote_paths, run_slam, visualize, timeout):
if batch_id is None:
proc_q.put(SENTINEL)
return
DL_TOTAL.inc()
in_dir = os.path.join(INPUT_ROOT, batch_id)
os.makedirs(in_dir, exist_ok=True)
logger.info("DL[%s] start, %d paths", batch_id, len(remote_paths))
for remote in remote_paths:
bag = os.path.basename(remote.rstrip("/"))
dst = os.path.join(in_dir, bag)
os.makedirs(dst, exist_ok=True)
for ent in DOWNLOAD_ENTRIES:
src = remote.rstrip("/") + "/" + ent
if ent.endswith(".csv"):
cmd = ["coscmd", "-s", "download", src, os.path.join(dst, ent)]
else:
cmd = ["coscmd", "-s", "download", "-r", src, os.path.join(dst, ent)]
t0 = time.time()
ok = with_retry(f"DL[{batch_id}-{bag}-{ent}]", run, cmd, timeout)
FILE_DL_DUR.observe(time.time() - t0)
if not ok:
DL_FAIL.inc()
proc_q.put((batch_id, remote_paths, run_slam, visualize))
# ============ 2) 处理阶段 ============ #
@PR_DUR.time()
def do_process(batch_id, remote_paths, run_slam, visualize, timeout):
if batch_id is None:
up_q.put(SENTINEL)
return
PR_TOTAL.inc()
out_dir = os.path.join(OUTPUT_ROOT, batch_id)
in_dir = os.path.join(INPUT_ROOT, batch_id)
os.makedirs(out_dir, exist_ok=True)
logger.info("PR[%s] input in {%s},output to {%s}", batch_id, in_dir, out_dir)
# 构造 Docker 脚本
parts = []
if run_slam:
parts.append("cd /root/slam_scripts && ./batch_run_map.sh /input_data")
parts.append(
"cd /root/latest/perception_dnn/projects/ldgt_net && "
"python tools/custom/generate_mb_lidar_map_with_intensity.py --data_dir /input_data --output /output_data/bev_image && "
"./tools/dist_produce_mk_gt.sh /output_data/bev_image/image_data "
"~/ckpt/epoch_500.pth /output_data/pipeline_out_box 8 && "
"./tools/dist_produce_gt.sh /output_data/bev_image/image_data ~/ckpt/epoch_50.pth /output_data/pipeline_out_line 8 && "
"python tools/custom/merge_osm.py --line_osm_folder /output_data/pipeline_out_line/default_batch/osm_out --marking_osm_folder /output_data/pipeline_out_box/default_batch/osm_out --output_folder /output_data/osm_out"
)
if visualize:
parts.append(
"python tools/vis_tool/split_lane_info.py --bag_data /input_data --pipeline_result /output_data/ &&"
" python tools/vis_tool/projection.py --bag_data /input_data --pipeline_result /output_data/ "
)
parts.append("chmod -R 777 /output_data")
script = " && ".join(parts)
cmd = [
"docker",
"run",
"--rm",
"--gpus",
"all",
"-v",
f"{in_dir}:/input_data",
"-v",
f"{out_dir}:/output_data",
DOCKER_IMAGE,
"bash",
"-i",
"-c",
script,
]
ok = with_retry(f"PR[{batch_id}-{in_dir}]", run, cmd, timeout)
if not ok:
PR_FAIL.inc()
# shutil.rmtree(in_dir, ignore_errors=True)
# shutil.rmtree(work, ignore_errors=True)
# 输出文件数
cnt = sum(
len(files) for _, _, files in os.walk(os.path.join(OUTPUT_ROOT, batch_id))
)
BATCH_OUT_FILES.set(cnt)
up_q.put((
batch_id,
remote_paths,
run_slam,
visualize,
))
# ============ 3) 上传阶段 ============ #
@UP_DUR.time()
def do_upload(batch_id, remote_paths, run_slam, visualize, timeout):
if batch_id is None:
return
UP_TOTAL.inc()
logger.info("UP[%s] start", batch_id)
base = os.path.join(OUTPUT_ROOT, batch_id)
input_dir = os.path.join(INPUT_ROOT, batch_id)
for remote in remote_paths:
bag = os.path.basename(remote.rstrip("/"))
short = bag.replace(".bag.dir", "")
target = remote.rstrip("/") + "/derived/LDGT"
# 1) osm
f1 = os.path.join(base, "osm_out", f"{short}.osm")
if os.path.isfile(f1):
cmd = ["coscmd", "-s", "upload", f1, target + f"/{short}.osm"]
with_retry(f"UP[{batch_id}-{short}-osm]", run, cmd, timeout)
# 2) split_json
d2 = os.path.join(base, "split_json", short)
if os.path.isdir(d2):
cmd = ["coscmd", "-s", "upload", "-r", d2, target + "/split_json/"]
with_retry(f"UP[{batch_id}-{short}-json]", run, cmd, timeout)
# 3) jpg
f3 = os.path.join(base, "bev_image", "image_data", f"{short}.jpg")
if os.path.isfile(f3):
cmd = ["coscmd", "-s", "upload", f3, target + "/bev_image.jpg"]
with_retry(f"UP[{batch_id}-{short}-jpg]", run, cmd, timeout)
f4 = os.path.join(input_dir, bag, "slam_lidar_ground")
if os.path.isdir(f4):
cmd = ["coscmd", "-s", "upload", "-r", f4, target]
with_retry(f"UP[{batch_id}-{short}-slam_lidar_ground]", run, cmd, timeout)
f5 = os.path.join(input_dir, bag, "slam_lidar_none_ground")
if os.path.isdir(f5):
cmd = ["coscmd", "-s", "upload", "-r", f5, target]
with_retry(
f"UP[{batch_id}-{short}-slam_lidar_none_ground]", run, cmd, timeout
)
f6 = os.path.join(input_dir, bag, "ego_motion_slam_lidar.csv")
if os.path.isfile(f6):
cmd = ["coscmd", "-s", "upload", f6, target + "/ego_motion_slam_lidar.csv"]
with_retry(f"UP[{batch_id}-{short}-csv]", run, cmd, timeout)
shutil.rmtree(base, ignore_errors=True)
logger.info("UP[%s] done", batch_id)
# ============ Worker 模板 ============ #
def worker(q, fn, timeout):
while True:
batch_id, paths, run_slam, visualize = q.get()
try:
fn(batch_id, paths, run_slam, visualize, timeout)
except Exception:
logger.exception("Stage %s 失败", batch_id)
finally:
q.task_done()
if batch_id is None:
break
# ============ Prometheus 指标更新 ============ #
def start_metric_updater():
def loop():
while True:
Q_BATCH.set(batch_q.qsize())
Q_PROC.set(proc_q.qsize())
Q_UP.set(up_q.qsize())
LOCAL_PL.set(count_pl_dirs())
time.sleep(1)
threading.Thread(target=loop, daemon=True).start()
# ============ Flask Service ============ #
app = Flask(__name__)
@app.route("/ready", methods=["GET"])
def api_ready():
busy = any([
batch_q.unfinished_tasks,
proc_q.unfinished_tasks,
up_q.unfinished_tasks,
])
return jsonify(ready=not busy)
@app.route("/notify", methods=["POST"])
def api_notify():
global batch_counter
data = request.get_json(force=True)
if isinstance(data, list):
paths, run_slam, visualize = data, True, True
elif isinstance(data, dict):
paths = data.get("paths", [])
run_slam = data.get("run_slam", True)
visualize = data.get("visualize", True)
else:
return jsonify(error="Unsupported JSON"), 400
if not all(isinstance(p, str) for p in paths):
return jsonify(error="paths must be strings"), 400
with counter_lock:
batch_counter += 1
bid = str(batch_counter)
batch_q.put((bid, paths, run_slam, visualize))
return jsonify(status="accepted", batch_id=bid), 202
# ============ 主入口 ============ #
def main():
global batch_counter
p = argparse.ArgumentParser()
sub = p.add_subparsers(dest="mode", required=True)
f = sub.add_parser("file", help="文件模式")
f.add_argument("--tasks-file", required=True, help="每行一个 COS 路径")
f.add_argument("--batch-timeout", type=int, default=3600)
f.set_defaults(run_slam=True, visualize=True)
s = sub.add_parser("service", help="服务模式")
s.add_argument("--batch-timeout", type=int, default=3600)
s.add_argument("--host", default="0.0.0.0")
s.add_argument("--port", type=int, default=5600)
args = p.parse_args()
# 启动 Prometheus HTTP & 指标更新
start_http_server(METRICS_PORT)
start_metric_updater()
# 启动三个阶段 worker
for q, fn in ((batch_q, do_download), (proc_q, do_process), (up_q, do_upload)):
t = threading.Thread(
target=worker, args=(q, fn, args.batch_timeout), daemon=True
)
t.start()
if args.mode == "file":
lines = [
l.strip() for l in open(args.tasks_file, encoding="utf-8") if l.strip()
]
for i in range(0, len(lines), BATCH_CHUNK):
while count_pl_dirs() >= MAX_PL_DIRS:
logger.warning(
"本地 PL* %d%d,暂停入队", count_pl_dirs(), MAX_PL_DIRS
)
time.sleep(60)
blk = lines[i : i + BATCH_CHUNK]
with counter_lock:
batch_counter += 1
bid = (
datetime.utcnow().strftime("%Y%m%dT%H%M%SZ")
+ f"_{i // BATCH_CHUNK + 1}"
)
BATCH_SIZE_HIST.observe(len(blk))
batch_q.put((bid, blk, args.run_slam, args.visualize))
batch_q.put(SENTINEL)
batch_q.join()
proc_q.join()
up_q.join()
logger.info("File mode done.")
sys.exit(0)
else:
for d in (INPUT_ROOT, OUTPUT_ROOT, EMPTY_DIR, LOG_DIR):
os.makedirs(d, exist_ok=True)
logger.info("Starting service on %s:%d", args.host, args.port)
app.run(host=args.host, port=args.port, threaded=True)
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,68 @@
# Trajectory Processing and Visualization System
## Overview
This system consists of two core scripts for processing GNSS trajectory data, identifying spatially overlapping trajectories, generating tile maps, and creating visualizations:
1. **tile_generate_to_db.py** - Processes raw trajectory data, computes spatial relationships, and stores results in MongoDB
2. **tile_visualization_from_db.py** - Reads data from database and generates interactive trajectory visualizations
## Key Features
- **Trajectory Gridding**: Divides GNSS points into 3D grids (10m planar grid + 5m height strata)
- **Overlap Detection**: Identifies trajectories sharing more than threshold number of grids
- **Tile Mapping**: Maps trajectories to map tile system
- **Spatial Indexing**: Stores trajectories, tiles and overlap relationships in MongoDB
- **Visualization**: Generates interactive maps with trajectory paths and direction arrows
## config.yml
-you should set the db param and tile_server_url in config.yml
-for example
mongodb:
uri: "mongodb://admin:admin@IP:port/"
db_name: "your db_name"
tile_server:
url: "http://IP:port/styles/maptiler-basic/512/{z}/{x}/{y}.png"
## Usage
1. Data Processing
```
python3 tile_generate_to_db.py <data_root> [zoom_level]
```
- Required Directory Structure:
- data_root/
- ├── bag_1/
- │ ├── raw_gnss.csv
- ├── bag_2/
- │ ├── raw_gnss.csv
2. Visualization Generation
```
python3 tile_visualization_from_db.py <data_root>
```
- Output will be saved in:
```
<data_root>/tile_visualizations/
```
# Output Files
1. Database Collections:
- tile_db: Tile-to-trajectory mappings
- processed_bags: Processed trajectory data
- overlap_db: Trajectory overlap sets
- processed_bags: Processed bag records
2. Local Files:
- <bag_name>_overlap.json: Overlap information per trajectory
- processed_tiles.txt: List of processed tile IDs
- tile_visualizations/: Folder containing HTML visualizations

View File

@@ -0,0 +1,7 @@
# config.yml
mongodb:
uri: "mongodb uri"
db_name: "your db_name"
tile_server:
url: "tiler_server_url"

View File

@@ -0,0 +1,605 @@
import os
import yaml
import json
import math
import pandas as pd
import utm
import hashlib
from collections import defaultdict
import logging
import sys
from typing import Dict, List, Tuple, Set, Any
from pymongo import MongoClient
from bson import ObjectId
# 配置日志
logging.basicConfig(level=logging.INFO, format="[%(levelname)s] %(message)s")
logger = logging.getLogger(__name__)
# 全局参数
MIN_SHARED_GRIDS = 6 # 最小共享3D网格数阈值
MIN_GRID_FOR_ONE_TRAJECTORY = 0 # 单条轨迹最少网格数
GRID_SIZE = 10.0 # 平面网格大小(米)
HEIGHT_STRATUM = 5.0 # 高度分层阈值(米)
def load_config(config_path="config.yml"):
"""加载配置文件"""
try:
with open(config_path) as f:
return yaml.safe_load(f)
except Exception as e:
logger.error(f"加载配置文件失败: {e}")
raise
config = load_config()
def get_mongo_client():
"""获取MongoDB客户端"""
return MongoClient(config["mongodb"]["uri"])
def get_db_collections():
"""获取数据库集合"""
client = get_mongo_client()
db = client[config["mongodb"]["db_name"]]
return {
"overlap_db": db.overlap_db,
"tile_db": db.tile_db,
"processed_bags": db.processed_bags,
}
def lon2tilex(lon: float, zoom: int) -> int:
"""将经度转换为瓦片x坐标"""
return int((lon + 180) / 360 * (1 << zoom))
def lat2tiley(lat: float, zoom: int) -> int:
"""将纬度转换为瓦片y坐标"""
return int(
(
1
- math.log(math.tan(math.radians(lat)) + 1 / math.cos(math.radians(lat)))
/ math.pi
)
/ 2
* (1 << zoom)
)
def get_tile_neighbors(tileid: str) -> List[str]:
"""获取九宫格相邻瓦片ID"""
zoom, x, y = map(int, tileid.split("/"))
neighbors = []
for dx in [-1, 0, 1]:
for dy in [-1, 0, 1]:
if dx == 0 and dy == 0:
continue # 跳过自身
new_x, new_y = x + dx, y + dy
# 检查是否超出瓦片范围
max_tile = 1 << zoom
if 0 <= new_x < max_tile and 0 <= new_y < max_tile:
neighbors.append(f"{zoom}/{new_x}/{new_y}")
return neighbors
def extract_xyz_from_gnss(csv_path: str) -> List[Tuple[float, float, float]]:
"""提取包含高度信息的轨迹点"""
try:
df = pd.read_csv(csv_path)
except Exception as e:
logger.error(f"读取 {csv_path} 失败: {e}")
return []
# 自动适配列名
lat_col = next((col for col in df.columns if "lat" in col.lower()), None)
lon_col = next((col for col in df.columns if "lon" in col.lower()), None)
alt_col = next(
(col for col in df.columns if "alt" in col.lower() or "height" in col.lower()),
None,
)
if not lat_col or not lon_col:
logger.warning(f"{csv_path} 中缺少经纬度列")
return []
# 如果没有高度列使用默认高度0
if not alt_col:
logger.warning(f"{csv_path} 中未找到高度列将使用默认高度0")
df["altitude"] = 0
alt_col = "altitude"
# 过滤无效数据
df = df.dropna(subset=[lat_col, lon_col, alt_col])
df = df[(df[lat_col].between(-90, 90)) & (df[lon_col].between(-180, 180))]
try:
points = []
for lat, lon, alt in zip(df[lat_col], df[lon_col], df[alt_col]):
easting, northing, zone_num, zone_letter = utm.from_latlon(lat, lon)
points.append((easting, northing, alt, lat, lon))
return points
except Exception as e:
logger.error(f"{csv_path} UTM 转换失败: {e}")
return []
def get_3d_grid(x: float, y: float, z: float) -> Tuple[int, int, int]:
"""三维网格划分"""
grid_x = int(x // GRID_SIZE)
grid_y = int(y // GRID_SIZE)
grid_z = int(z // HEIGHT_STRATUM) # 高度分层
return (grid_x, grid_y, grid_z)
def get_tile_and_grids_from_path(
bag_path: str, zoom: int = 14
) -> Tuple[Set[str], List[Tuple[float, float, float]], Set[Tuple[int, int, int]]]:
"""获取瓦片ID集合和三维网格集合"""
csv_path = os.path.join(bag_path, "raw_gnss.csv")
if not os.path.exists(csv_path):
logger.warning(f"{bag_path} 不存在 raw_gnss.csv")
return None
xyz_points = extract_xyz_from_gnss(csv_path)
if not xyz_points:
logger.warning(f"{bag_path} 无有效 GNSS 数据")
return None
# 获取所有瓦片ID
tile_ids = set()
for x, y, z, lat, lon in xyz_points:
tile_id = f"{zoom}/{lon2tilex(lon, zoom)}/{lat2tiley(lat, zoom)}"
tile_ids.add(tile_id)
# 获取所有3D网格
grids = set(get_3d_grid(x, y, z) for x, y, z, _, _ in xyz_points)
logger.info(
f"{bag_path} 共有 {len(xyz_points)} 点,映射到 {len(tile_ids)} 个瓦片和 {len(grids)} 个3D网格"
)
return tile_ids, [(x, y, z) for x, y, z, _, _ in xyz_points], grids
def calculate_overlap(grids1, grids2):
"""计算两个轨迹的3D网格重叠数"""
set1 = set(tuple(grid) for grid in grids1) if isinstance(grids1, list) else grids1
set2 = set(tuple(grid) for grid in grids2) if isinstance(grids2, list) else grids2
return len(set1 & set2)
def update_tile_db(
tile_db_collection,
tile_id: str,
bag_name: str,
lat: float = None,
lon: float = None,
alt: float = None,
):
"""更新瓦片数据库"""
tile_data = tile_db_collection.find_one({"tileid": tile_id})
if tile_data:
if bag_name not in tile_data["trajectories"]:
tile_db_collection.update_one(
{"tileid": tile_id}, {"$addToSet": {"trajectories": bag_name}}
)
if lat is not None and lon is not None and "gps" not in tile_data:
tile_db_collection.update_one(
{"tileid": tile_id},
{
"$set": {
"gps": {
"latitude": lat,
"longitude": lon,
"altitude": alt or 0.0,
}
}
},
)
else:
new_tile = {"tileid": tile_id, "trajectories": [bag_name]}
if lat is not None and lon is not None:
new_tile["gps"] = {
"latitude": lat,
"longitude": lon,
"altitude": alt or 0.0,
}
tile_db_collection.insert_one(new_tile)
def initialize_overlap_db(collections):
"""初始化或加载现有的overlap数据库"""
overlap_db_data = collections["overlap_db"].find_one({"_id": "overlap_db"})
if overlap_db_data:
# 转换数据结构
overlap_db = {
"next_id": overlap_db_data.get("next_id", 1),
"sets": {
k: {"bags": set(v["bags"]), "source_bags": set(v["source_bags"])}
for k, v in overlap_db_data.get("sets", {}).items()
},
"hash_to_id": overlap_db_data.get("hash_to_id", {}),
"bag_to_set": overlap_db_data.get("bag_to_set", {}),
"overlap_sets": {
k: set(v) for k, v in overlap_db_data.get("overlap_sets", {}).items()
},
}
logger.info(
f"已加载现有的overlap数据库当前最大ID: {overlap_db['next_id'] - 1}"
)
else:
overlap_db = {
"next_id": 1,
"sets": {},
"hash_to_id": {},
"bag_to_set": {},
"overlap_sets": {},
}
collections["overlap_db"].insert_one({
"_id": "overlap_db",
"next_id": 1,
"sets": {},
"hash_to_id": {},
"bag_to_set": {},
"overlap_sets": {},
})
logger.info("创建新的overlap数据库")
return overlap_db
def update_overlap_sets(
collections, overlap_db: Dict, bag_name: str, overlap_set: Set[str]
):
"""更新overlap_sets容器不分配set_id"""
# 保存当前包的overlap集合
if not isinstance(overlap_set, set):
overlap_set = set(overlap_set)
overlap_db["overlap_sets"][bag_name] = overlap_set
# 更新所有相关包的overlap集合
for other_bag in overlap_set:
if other_bag == bag_name:
continue
if other_bag not in overlap_db["overlap_sets"]:
overlap_db["overlap_sets"][other_bag] = {other_bag}
# 合并两个包的overlap集合
overlap_db["overlap_sets"][other_bag].add(bag_name)
def assign_set_ids(collections, overlap_db: Dict):
"""为所有overlap集合分配set_id"""
# 收集所有唯一的overlap集合
unique_sets = {}
set_hashes = set()
# 首先处理已经存在的集合保持它们的set_id不变
existing_sets = {}
for set_id, set_data in overlap_db["sets"].items():
set_hash = hashlib.md5(",".join(sorted(set_data["bags"])).encode()).hexdigest()
existing_sets[set_hash] = int(set_id)
# 找出所有唯一的overlap集合
for bag_name, overlap_set in overlap_db["overlap_sets"].items():
if len(overlap_set) <= 1:
continue # 跳过单独的包
sorted_bags = sorted(overlap_set)
set_hash = hashlib.md5(",".join(sorted_bags).encode()).hexdigest()
if set_hash not in unique_sets:
unique_sets[set_hash] = {"bags": overlap_set, "source_bags": set()}
# 记录来源包
unique_sets[set_hash]["source_bags"].add(bag_name)
# 分配set_id
new_sets = {}
for set_hash, set_data in unique_sets.items():
if set_hash in existing_sets:
# 使用现有的set_id
set_id = existing_sets[set_hash]
new_sets[set_id] = set_data
else:
# 分配新的set_id
set_id = overlap_db["next_id"]
overlap_db["next_id"] += 1
new_sets[set_id] = set_data
# 更新overlap_db
overlap_db["sets"] = {
str(k): {"bags": sorted(v["bags"]), "source_bags": sorted(v["source_bags"])}
for k, v in new_sets.items()
}
# 更新hash_to_id映射
overlap_db["hash_to_id"] = {
hashlib.md5(",".join(sorted(v["bags"])).encode()).hexdigest(): int(k)
for k, v in new_sets.items()
}
# 更新bag_to_set映射
overlap_db["bag_to_set"] = {}
for set_id, set_data in new_sets.items():
for bag in set_data["source_bags"]:
overlap_db["bag_to_set"][bag] = set_id
def save_overlap_db(collections, overlap_db: Dict):
"""保存overlap数据库"""
# 在保存前确保所有set_id已分配
assign_set_ids(collections, overlap_db)
serializable_db = {
"next_id": overlap_db["next_id"],
"sets": {
str(set_id): {
"bags": sorted(data["bags"]),
"source_bags": sorted(data["source_bags"]),
}
for set_id, data in overlap_db["sets"].items()
},
"hash_to_id": overlap_db["hash_to_id"],
"bag_to_set": overlap_db["bag_to_set"],
"overlap_sets": {k: sorted(v) for k, v in overlap_db["overlap_sets"].items()},
}
collections["overlap_db"].replace_one(
{"_id": "overlap_db"}, serializable_db, upsert=True
)
logger.info(f"已保存重叠集合数据库到 MongoDB (共 {len(overlap_db['sets'])} 个集合)")
def update_json_files_with_set_ids(data_root: str, overlap_db: Dict, collections):
"""更新所有包的JSON文件添加set_id信息并同步更新MongoDB"""
for bag_name, set_id in overlap_db["bag_to_set"].items():
bag_path = os.path.join(data_root, bag_name)
overlap_json_path = os.path.join(bag_path, f"{bag_name}_overlap.json")
# 从MongoDB获取当前包数据
bag_data = collections["processed_bags"].find_one({"_id": bag_name})
if not bag_data:
logger.warning(f"在MongoDB中未找到包 {bag_name} 的数据")
continue
# 更新重叠信息
updated_data = {
"set_id": set_id,
"trajectories": sorted(overlap_db["sets"][str(set_id)]["bags"]),
"count": len(overlap_db["sets"][str(set_id)]["bags"]),
}
# 更新MongoDB
collections["processed_bags"].update_one(
{"_id": bag_name}, {"$set": {"overlaps_bags": updated_data}}
)
# 更新本地JSON文件可选
if os.path.exists(overlap_json_path):
with open(overlap_json_path, "r+") as f:
try:
data = json.load(f)
data["overlap_bags"] = updated_data
f.seek(0)
json.dump(data, f, indent=2)
f.truncate()
except Exception as e:
logger.error(f"更新 {overlap_json_path} 失败: {e}")
logger.info("已更新所有包的set_id信息到MongoDB和本地JSON文件")
def process_bag(bag_path: str, collections, overlap_db: Dict, zoom: int = 14):
"""处理单个数据包"""
bag_name = os.path.basename(bag_path)
overlap_json_path = os.path.join(bag_path, f"{bag_name}_overlap.json")
if not os.path.exists(overlap_json_path):
with open(overlap_json_path, "w") as f:
json.dump(
{
"overlap_bags": {
"trajectories": [bag_name],
"count": 1,
"set_id": None,
}
},
f,
)
# 获取瓦片ID和3D网格
result = get_tile_and_grids_from_path(bag_path, zoom)
if result is None:
return
tile_ids, xyz_points, grids = result
if len(grids) <= MIN_GRID_FOR_ONE_TRAJECTORY:
logger.info(f"{bag_name} (仅 {len(grids)} 个3D网格跳过)")
return
processed_tiles_file = os.path.join(
os.path.dirname(data_root), "processed_tiles.txt"
)
existing_tiles = set()
if os.path.exists(processed_tiles_file):
with open(processed_tiles_file, "r") as f:
existing_tiles = set(line.strip() for line in f if line.strip())
new_tiles = set(tile_ids) - existing_tiles
if new_tiles:
with open(processed_tiles_file, "a") as f:
for tile_id in new_tiles:
f.write(f"{tile_id}\n")
print("write tile_id:")
print(tile_id)
csv_path = os.path.join(bag_path, "raw_gnss.csv")
df = pd.read_csv(csv_path)
lat_col = next((col for col in df.columns if "lat" in col.lower()), None)
lon_col = next((col for col in df.columns if "lon" in col.lower()), None)
alt_col = next(
(col for col in df.columns if "alt" in col.lower() or "height" in col.lower()),
None,
)
# 过滤无效数据
if lat_col and lon_col:
df = df.dropna(subset=[lat_col, lon_col, alt_col])
df = df[(df[lat_col].between(-90, 90)) & (df[lon_col].between(-180, 180))]
latlon_points = list(
zip(df[lat_col], df[lon_col], df[alt_col] if alt_col else [0] * len(df))
)
else:
latlon_points = []
# 初始化当前包的overlap集合
overlap_set = {bag_name}
# 查找所有相关瓦片(包括相邻瓦片)
all_related_tiles = set(tile_ids)
for tile_id in tile_ids:
all_related_tiles.update(get_tile_neighbors(tile_id))
# 查找所有可能重叠的包
candidate_bags = set()
for tile in collections["tile_db"].find({
"tileid": {"$in": list(all_related_tiles)}
}):
candidate_bags.update(tile["trajectories"])
candidate_bags.discard(bag_name)
# 检查与每个候选包的重叠情况
for other_bag in candidate_bags:
other_data = collections["processed_bags"].find_one({"_id": other_bag})
if not other_data:
continue
other_grids = set(tuple(grid) for grid in other_data.get("grids", []))
current_grids_set = set(tuple(grid) for grid in grids)
overlap_count = calculate_overlap(current_grids_set, other_grids)
if overlap_count >= MIN_SHARED_GRIDS:
overlap_set.add(other_bag)
# 更新overlap_sets容器
update_overlap_sets(collections, overlap_db, bag_name, overlap_set)
# 更新每个包的overlap JSON文件不包含set_id
for b in overlap_set:
if b == bag_name:
continue
other_path = os.path.join(os.path.dirname(bag_path), b)
other_json_path = os.path.join(other_path, f"{b}_overlap.json")
if os.path.exists(other_json_path):
with open(other_json_path, "r+") as f:
try:
other_data = json.load(f)
if bag_name not in other_data["overlap_bags"]["trajectories"]:
other_data["overlap_bags"]["trajectories"].append(bag_name)
other_data["overlap_bags"]["count"] = len(
overlap_db["overlap_sets"][b]
)
f.seek(0)
json.dump(other_data, f, indent=2)
f.truncate()
except Exception as e:
logger.error(f"更新 {other_json_path} 失败: {e}")
# 保存当前包的overlap信息到JSON文件不包含set_id
with open(overlap_json_path, "w") as f:
json.dump(
{
"overlap_bags": {
"trajectories": sorted(overlap_set),
"count": len(overlap_set),
"set_id": None, # 将在最后阶段分配
}
},
f,
indent=2,
)
# 保存到processed_bags集合
processed_data = {
"_id": bag_name,
"tile_ids": list(tile_ids),
"grids": list(grids),
"xyz_points": xyz_points,
"latlon_points": latlon_points,
"overlaps_bags": {
"trajectories": sorted(overlap_set),
"count": len(overlap_set),
"set_id": None, # 将在最后阶段分配
},
}
collections["processed_bags"].replace_one(
{"_id": bag_name}, processed_data, upsert=True
)
# 更新瓦片数据库
for tile_id in tile_ids:
lat, lon = None, None
for x, y, z, pt_lat, pt_lon in extract_xyz_from_gnss(
os.path.join(bag_path, "raw_gnss.csv")
):
if f"{zoom}/{lon2tilex(pt_lon, zoom)}/{lat2tiley(pt_lat, zoom)}" == tile_id:
lat, lon = pt_lat, pt_lon
break
update_tile_db(collections["tile_db"], tile_id, bag_name, lat, lon)
logger.info(f"{bag_name}{len(overlap_set) - 1} 个包重叠")
def main(data_root: str, zoom: int = 14):
"""主函数"""
if not os.path.isdir(data_root):
logger.error(f"数据目录不存在: {data_root}")
return
# 初始化数据库连接
collections = get_db_collections()
# 初始化数据库
overlap_db = initialize_overlap_db(collections)
# 处理所有数据包
bag_dirs = [
os.path.join(data_root, d)
for d in os.listdir(data_root)
if os.path.isdir(os.path.join(data_root, d))
]
if not bag_dirs:
logger.error(f"目录中没有子文件夹: {data_root}")
return
logger.info(f"开始处理 {len(bag_dirs)} 个数据包...")
for bag_path in bag_dirs:
process_bag(bag_path, collections, overlap_db, zoom)
# 保存所有结果到MongoDB
save_overlap_db(collections, overlap_db)
# 更新set_id信息这会同时更新MongoDB和本地JSON
update_json_files_with_set_ids(data_root, overlap_db, collections)
logger.info("处理完成")
if __name__ == "__main__":
if len(sys.argv) < 2:
logger.error("用法: python3 tile_generate_to_db.py <data_root> [zoom_level]")
sys.exit(1)
data_root = sys.argv[1]
zoom_level = int(sys.argv[2]) if len(sys.argv) > 2 else 14
main(data_root, zoom_level)

View File

@@ -0,0 +1,467 @@
import os
import yaml
import json
import math
import pandas as pd
import plotly.graph_objects as go
import plotly.express as px
from typing import List, Dict, Set, Tuple
import utm
import sys
import logging
import numpy as np
from plotly.subplots import make_subplots
import requests
from io import BytesIO
from PIL import Image
from pymongo import MongoClient
import hashlib
# 配置日志
logging.basicConfig(level=logging.INFO, format="[%(levelname)s] %(message)s")
logger = logging.getLogger(__name__)
def load_config(config_path="config.yml"):
"""加载配置文件"""
try:
with open(config_path) as f:
return yaml.safe_load(f)
except Exception as e:
logger.error(f"加载配置文件失败: {e}")
raise
config = load_config()
def calculate_bearing(x1, y1, x2, y2):
"""计算两点之间的朝向角(弧度)"""
dx = x2 - x1
dy = y2 - y1
return math.atan2(dy, dx)
def calculate_cluster_center(bag_grid_map, cluster):
"""计算单个聚类的中心点"""
all_x = []
all_y = []
all_z = []
for bag_name in cluster:
if bag_name in bag_grid_map:
xyz_points = bag_grid_map[bag_name]
all_x.extend([x for x, y, z in xyz_points])
all_y.extend([y for x, y, z in xyz_points])
all_z.extend([z for x, y, z in xyz_points])
if not all_x:
return 0, 0, 0 # 默认值
return np.mean(all_x), np.mean(all_y), np.mean(all_z)
def create_direction_arrows(xy_points, spacing=5):
"""创建方向箭头数据"""
arrows = []
for i in range(0, len(xy_points) - 1, spacing):
if i + 1 >= len(xy_points):
continue
x1, y1 = xy_points[i]
x2, y2 = xy_points[i + 1]
angle = calculate_bearing(x1, y1, x2, y2)
# 箭头中间点
mid_x = (x1 + x2) / 2
mid_y = (y1 + y2) / 2
arrows.append({
"x": mid_x,
"y": mid_y,
"angle": -angle,
"angle_deg": 90 - rad_to_deg(angle),
})
return arrows
def rad_to_deg(rad):
"""将弧度转换为角度"""
return rad * 180 / math.pi
def get_3d_grids_from_db(collections, bag_name):
"""从数据库获取三维网格集合"""
bag_data = collections["processed_bags"].find_one({"_id": bag_name})
if not bag_data:
logger.warning(f"未找到包 {bag_name} 的数据")
return None
if "xyz_points" in bag_data:
return bag_data["xyz_points"]
elif "grids" in bag_data:
# 如果只有网格数据,没有原始点,则返回网格中心点
grid_points = []
for grid in bag_data["grids"]:
grid_x, grid_y, grid_z = grid
center_x = (grid_x + 0.5) * 10.0
center_y = (grid_y + 0.5) * 10.0
center_z = (grid_z + 0.5) * 5.0
grid_points.append((center_x, center_y, center_z))
return grid_points
else:
logger.warning(f"{bag_name} 没有轨迹点数据")
return None
def utm_to_latlon(easting, northing, zone_number, zone_letter):
"""将UTM坐标转换为经纬度"""
try:
lat, lon = utm.to_latlon(easting, northing, zone_number, zone_letter)
return lat, lon
except Exception as e:
logger.error(f"UTM转换失败: {e}")
return None, None
class TileVisualizer:
def __init__(self, output_dir: str = "tile_visualizations"):
self.mongo_uri = config["mongodb"]["uri"]
self.db_name = config["mongodb"]["db_name"]
self.tile_server_url = config["tile_server"]["url"]
self.output_dir = output_dir
self.client = MongoClient(self.mongo_uri)
self.db = self.client[self.db_name]
self.collections = {
"tile_db": self.db.tile_db,
"processed_bags": self.db.processed_bags,
}
self.tile_db = self._load_tile_db()
self.color_cycle = px.colors.qualitative.Plotly
def close(self):
if self.client is not None:
self.client.close()
self.client = None
self.db = None
def _load_tile_db(self) -> Dict:
"""从MongoDB加载tileDB数据"""
try:
tiles = list(self.collections["tile_db"].find({}, {"_id": 0}))
return {"TileDB": tiles}
except Exception as e:
logger.error(f"从MongoDB加载tileDB失败: {e}")
return {"TileDB": []}
def _get_trajectory_points_from_db(
self, bag_name: str
) -> List[Tuple[float, float, float]]:
"""从数据库获取轨迹点(经纬度+高度)"""
bag_data = self.collections["processed_bags"].find_one({"_id": bag_name})
if not bag_data:
logger.warning(f"未找到包 {bag_name} 的数据")
return []
# 尝试直接获取经纬度点
if "latlon_points" in bag_data:
return [
(point[0], point[1], point[2]) for point in bag_data["latlon_points"]
]
# 如果有UTM坐标则转换为经纬度
if "xyz_points" in bag_data:
# 注意这里需要假设所有点在同一个UTM区域
# 实际上,每个点可能有不同的区域,但为了简化,我们使用第一个点的区域
if not bag_data["xyz_points"]:
return []
# 尝试获取UTM区域信息如果存储了
zone_num = bag_data.get("utm_zone_num", 50) # 默认值
zone_letter = bag_data.get("utm_zone_letter", "N") # 默认值
latlon_points = []
for point in bag_data["xyz_points"]:
easting, northing, alt = point
lat, lon = utm_to_latlon(easting, northing, zone_num, zone_letter)
if lat is not None and lon is not None:
latlon_points.append((lat, lon, alt))
return latlon_points
logger.warning(f"{bag_name} 没有可用的轨迹点数据")
return []
def _calculate_bearing(
self, lat1: float, lon1: float, lat2: float, lon2: float
) -> float:
"""计算两点之间的朝向角(度数)"""
lat1, lon1, lat2, lon2 = map(math.radians, [lat1, lon1, lat2, lon2])
dlon = lon2 - lon1
x = math.sin(dlon) * math.cos(lat2)
y = math.cos(lat1) * math.sin(lat2) - math.sin(lat1) * math.cos(
lat2
) * math.cos(dlon)
return math.degrees(math.atan2(x, y))
def _create_arrows(
self, lats: List[float], lons: List[float], spacing: int = 5
) -> List[Dict]:
"""创建方向箭头数据"""
arrows = []
for i in range(0, len(lats) - 1, spacing):
if i + 1 >= len(lats):
continue
angle = self._calculate_bearing(lats[i], lons[i], lats[i + 1], lons[i + 1])
arrows.append({
"lat": (lats[i] + lats[i + 1]) / 2,
"lon": (lons[i] + lons[i + 1]) / 2,
"angle": -angle,
"angle_deg": 90 - angle,
})
return arrows
def visualize_tile(self, tile_id: str):
"""可视化单个瓦片的轨迹"""
tile_data = next(
(t for t in self.tile_db["TileDB"] if t["tileid"] == tile_id), None
)
if not tile_data:
logger.warning(f"未找到瓦片 {tile_id} 的数据")
return
trajectories = tile_data.get("trajectories", [])
if not trajectories:
logger.info(f"瓦片 {tile_id} 没有轨迹数据")
return
bag_grid_map = dict()
for bag_name in trajectories:
xyz_points = get_3d_grids_from_db(self.collections, bag_name)
if xyz_points:
bag_grid_map[bag_name] = xyz_points
fig = go.Figure()
colors = px.colors.qualitative.Plotly
all_lats, all_lons = [], []
# 计算聚类中心
center_x, center_y, center_z = calculate_cluster_center(
bag_grid_map, trajectories
)
logger.info(
f"聚类 {tile_id} 中心点: ({center_x:.1f}, {center_y:.1f}, {center_z:.1f})"
)
tz, tx, ty = map(int, tile_id.split("/"))
tile_images = {}
# 下载周边瓦片
for dx in [-1, 0, 1]:
for dy in [-1, 0, 1]:
ttx = tx + dx
tty = ty + dy
tile_url = self.tile_server_url.format(z=tz, x=ttx, y=tty)
try:
response = requests.get(tile_url, timeout=5)
if response.status_code == 200:
tile_images[(dx, dy)] = Image.open(BytesIO(response.content))
else:
logger.warning(
f"无法获取瓦片 {tz}/{ttx}/{tty},使用空白图片代替"
)
tile_images[(dx, dy)] = Image.new(
"RGB", (512, 512), (255, 255, 255)
)
except Exception as e:
logger.error(f"下载瓦片 {tz}/{ttx}/{tty} 失败: {e}")
tile_images[(dx, dy)] = Image.new(
"RGB", (512, 512), (255, 255, 255)
)
# 拼接瓦片
width = 512 * 3
height = 512 * 3
combined_image = Image.new("RGB", (width, height))
for dx in [-1, 0, 1]:
for dy in [-1, 0, 1]:
ttx = (dx + 1) * 512
tty = (dy + 1) * 512
combined_image.paste(tile_images[(dx, dy)], (ttx, tty))
# 添加底图
fig.add_layout_image(
dict(
source=combined_image,
xref="x",
yref="y",
x=0,
y=height,
sizex=width,
sizey=height,
sizing="stretch",
opacity=0.8,
layer="below",
)
)
def latlon_to_tile_pixel(lat, lon, tile_x, tile_y, zoom):
"""将经纬度转换为瓦片像素坐标"""
n = 2**zoom
xtile = (lon + 180) / 360 * n
ytile = (
(
1
- math.log(
math.tan(math.radians(lat)) + 1 / math.cos(math.radians(lat))
)
/ math.pi
)
/ 2
* n
)
x_pixel = (xtile - tile_x - 0.386) * 512
y_pixel = (ytile - tile_y - 0.23) * 512
return x_pixel, y_pixel
for i, bag_name in enumerate(trajectories):
# 从数据库获取轨迹点
points = self._get_trajectory_points_from_db(bag_name)
if not points:
logger.warning(f"{bag_name} 没有轨迹点数据")
continue
x = []
y = []
for point in points:
lat, lon, _ = point
pt_x, pt_y = latlon_to_tile_pixel(lat, lon, tx, ty, tz)
# 转换为大图的坐标
pt_x += 512 # 中心瓦片x偏移
pt_y += 512 # 中心瓦片y偏移
# 调整坐标系(原点在左上角)
pt_y = height - pt_y
if 0 <= pt_x <= width and 0 <= pt_y <= height:
x.append(pt_x)
y.append(pt_y)
if not x:
logger.warning(f"{bag_name} 没有在瓦片范围内的点")
continue
# 轨迹线
fig.add_trace(
go.Scatter(
x=x,
y=y,
mode="lines",
name=bag_name,
line=dict(width=4, color=colors[i % len(colors)]),
showlegend=True,
legendgroup=bag_name,
hoverinfo="text",
text=[f"{bag_name}<br>点 {j + 1}/{len(x)}" for j in range(len(x))],
)
)
# 方向箭头2D投影
if len(x) > 1:
xy_points = list(zip(x, y))
arrows = create_direction_arrows(xy_points)
if arrows:
arrow_x = [arrow["x"] for arrow in arrows]
arrow_y = [arrow["y"] for arrow in arrows]
arrow_text = [
f"{bag_name}<br>朝向: {arrow['angle_deg']:.1f}°"
for arrow in arrows
]
fig.add_trace(
go.Scatter(
x=arrow_x,
y=arrow_y,
mode="markers",
name=f"{bag_name} 方向",
marker=dict(
symbol="arrow-up",
size=8,
color=colors[i % len(colors)],
angle=[arrow["angle_deg"] for arrow in arrows],
line=dict(width=1, color="black"),
),
hoverinfo="text",
text=arrow_text,
showlegend=False,
legendgroup=bag_name,
)
)
fig.update_layout(
title=f"瓦片 {tile_id} 轨迹可视化 (共{len(trajectories)}条轨迹)",
xaxis=dict(
range=[0, width],
scaleanchor="y",
title="X px",
showgrid=False,
zeroline=False,
constrain="domain",
),
yaxis=dict(
range=[0, height],
scaleanchor="x",
title="Y px",
showgrid=False,
zeroline=False,
),
showlegend=True,
legend=dict(
orientation="h", yanchor="bottom", y=1.02, xanchor="right", x=1
),
height=800,
width=800,
)
# 保存HTML文件
os.makedirs(self.output_dir, exist_ok=True)
safe_tile_id = tile_id.replace("/", "_")
output_path = os.path.join(self.output_dir, f"tile_{safe_tile_id}.html")
fig.write_html(output_path)
logger.info(f"已保存瓦片 {tile_id} 可视化结果到 {output_path}")
def visualize_all_tiles(self):
"""可视化所有瓦片"""
for tile in self.tile_db["TileDB"]:
self.visualize_tile(tile["tileid"])
def visualize_from_txt(self, txt_path: str):
try:
with open(txt_path, "r") as f:
tile_ids = set(line.strip() for line in f if line.strip())
except FileNotFoundError:
logger.error(f"Not Found txt")
return
tiles_data = list(
self.collections["tile_db"].find({"tileid": {"$in": list(tile_ids)}})
)
if not tiles_data:
logger.warning("no match tile")
return
for tile_data in tiles_data:
self.visualize_tile(tile_data["tileid"])
if __name__ == "__main__":
if len(sys.argv) < 2:
logger.error("用法: python3 tile_visualization_from_db.py <data_root>")
sys.exit(1)
data_root = sys.argv[1]
# 创建可视化器并运行
visualizer = TileVisualizer(
output_dir=os.path.join(data_root, "tile_visualizations")
)
txt_path = os.path.join(data_root, "processed_tiles.txt")
visualizer.visualize_from_txt(txt_path)
visualizer.close()