第一次
This commit is contained in:
2
fst_data_pipeline/pipelines/README.md
Normal file
2
fst_data_pipeline/pipelines/README.md
Normal file
@@ -0,0 +1,2 @@
|
||||
# fst data production automation pipeline
|
||||
all automation scripts to enable fst data production line.
|
||||
0
fst_data_pipeline/pipelines/__init__.py
Normal file
0
fst_data_pipeline/pipelines/__init__.py
Normal file
439
fst_data_pipeline/pipelines/tencent/4dod_prod.py
Normal file
439
fst_data_pipeline/pipelines/tencent/4dod_prod.py
Normal file
@@ -0,0 +1,439 @@
|
||||
#!/usr/bin/env python3
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
import argparse
|
||||
import threading
|
||||
import queue
|
||||
import subprocess
|
||||
import shutil
|
||||
import logging
|
||||
from datetime import datetime
|
||||
|
||||
from prometheus_client import start_http_server, Counter, Gauge, Summary, Histogram
|
||||
|
||||
# —— 常量 & 路径 —— #
|
||||
BASE = os.getcwd()
|
||||
INPUT_ROOT = os.path.join(BASE, "input")
|
||||
OUTPUT_ROOT = os.path.join(BASE, "output")
|
||||
EMPTY_DIR = os.path.join(BASE, "empty")
|
||||
LOG_DIR = os.path.join(BASE, "logs")
|
||||
|
||||
EXCLUDE_SUBDIRS = ["struct_infos/*"]
|
||||
DOCKER_IMAGE = (
|
||||
"artifact.swfcn.i.mercedes-benz.com/swfcn_docker/perception-dnn/od-net:prod_v0.2"
|
||||
)
|
||||
|
||||
BATCH_SIZE = 10
|
||||
MAX_LOCAL = 20
|
||||
MAX_RETRIES = 3
|
||||
RETRY_DELAY_S = 2
|
||||
METRICS_PORT = 8001
|
||||
SENTINEL = (None, None)
|
||||
|
||||
# —— 日志配置 —— #
|
||||
os.makedirs(LOG_DIR, exist_ok=True)
|
||||
logger = logging.getLogger("pipeline")
|
||||
logger.setLevel(logging.INFO)
|
||||
h_info = logging.FileHandler(os.path.join(LOG_DIR, "pipeline.log"), encoding="utf-8")
|
||||
h_err = logging.FileHandler(os.path.join(LOG_DIR, "error_tasks.log"), encoding="utf-8")
|
||||
fmt = logging.Formatter("%(asctime)s %(levelname)s [%(threadName)s] %(message)s")
|
||||
h_info.setFormatter(fmt)
|
||||
h_err.setFormatter(fmt)
|
||||
h_err.setLevel(logging.ERROR)
|
||||
logger.addHandler(h_info)
|
||||
logger.addHandler(h_err)
|
||||
|
||||
# —— Prometheus 指标 —— #
|
||||
DL_TOTAL = Counter("pipeline_download_total", "下载尝试总数")
|
||||
DL_FAIL = Counter("pipeline_download_failures", "下载失败总数")
|
||||
DL_RETRY = Counter("pipeline_download_retries", "下载重试总数")
|
||||
PR_TOTAL = Counter("pipeline_process_total", "处理尝试总数")
|
||||
PR_FAIL = Counter("pipeline_process_failures", "处理失败总数")
|
||||
PR_RETRY = Counter("pipeline_process_retries", "处理重试总数")
|
||||
UP_TOTAL = Counter("pipeline_upload_total", "上传尝试总数")
|
||||
UP_FAIL = Counter("pipeline_upload_failures", "上传失败总数")
|
||||
UP_RETRY = Counter("pipeline_upload_retries", "上传重试总数")
|
||||
|
||||
DL_DUR = Summary("pipeline_download_duration_seconds", "单批下载耗时秒")
|
||||
PR_DUR = Summary("pipeline_process_duration_seconds", "单批处理耗时秒")
|
||||
UP_DUR = Summary("pipeline_upload_duration_seconds", "单批上传耗时秒")
|
||||
|
||||
BATCH_SIZE_HIST = Histogram(
|
||||
"pipeline_batch_size",
|
||||
"单批任务中文件夹数量分布",
|
||||
buckets=[1, 10, 20, 50, 100, 200, 500],
|
||||
)
|
||||
FILE_DL_DUR = Histogram(
|
||||
"pipeline_file_download_duration_seconds", "单文件/单目录下载耗时分布"
|
||||
)
|
||||
BATCH_OUT_FILES = Gauge(
|
||||
"pipeline_batch_output_subfolder_count", "单批处理后 output 下指定子文件夹数"
|
||||
)
|
||||
|
||||
Q_BATCH = Gauge("pipeline_queue_batches", "待下载批次数")
|
||||
Q_PROC = Gauge("pipeline_queue_processing", "待处理批次数")
|
||||
Q_UP = Gauge("pipeline_queue_uploading", "待上传批次数")
|
||||
LOCAL_COUNT = Gauge("pipeline_local_subdir_count", "当前本地 input 子文件夹总数")
|
||||
|
||||
# —— 队列 & 控制 —— #
|
||||
batch_q = queue.Queue()
|
||||
proc_q = queue.Queue()
|
||||
up_q = queue.Queue()
|
||||
|
||||
# —— 全局计数 & 锁,用于减少磁盘扫描 —— #
|
||||
_local_counter = 0
|
||||
_counter_lock = threading.Lock()
|
||||
_downloaded_per_batch = {} # batch_id -> 成功下载的子目录数量
|
||||
|
||||
|
||||
def incr_local(n=1, batch_id=None):
|
||||
global _local_counter
|
||||
with _counter_lock:
|
||||
_local_counter += n
|
||||
if batch_id:
|
||||
_downloaded_per_batch.setdefault(batch_id, 0)
|
||||
_downloaded_per_batch[batch_id] += n
|
||||
|
||||
|
||||
def decr_local_batch(batch_id):
|
||||
global _local_counter
|
||||
with _counter_lock:
|
||||
n = _downloaded_per_batch.pop(batch_id, 0)
|
||||
_local_counter -= n
|
||||
if _local_counter < 0:
|
||||
_local_counter = 0
|
||||
|
||||
|
||||
def get_local_count():
|
||||
with _counter_lock:
|
||||
return _local_counter
|
||||
|
||||
|
||||
# —— 子目录限流 —— #
|
||||
def count_local_subdirs_in_batch(batch_dir):
|
||||
# 只统计当前 batch 下直接子目录
|
||||
if not os.path.isdir(batch_dir):
|
||||
return 0
|
||||
return sum(
|
||||
1 for e in os.listdir(batch_dir) if os.path.isdir(os.path.join(batch_dir, e))
|
||||
)
|
||||
|
||||
|
||||
# —— 统一 subprocess + 重试 —— #
|
||||
def run(cmd, timeout=None):
|
||||
"""
|
||||
统一调用 subprocess,记录日志,并返回 (code, timed_out, output).
|
||||
"""
|
||||
logger.info("RUN: %s", " ".join(cmd))
|
||||
try:
|
||||
p = subprocess.Popen(
|
||||
cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True
|
||||
)
|
||||
start = time.time()
|
||||
output = []
|
||||
timed_out = False
|
||||
for line in p.stdout:
|
||||
output.append(line)
|
||||
if timeout and (time.time() - start) > timeout:
|
||||
p.kill()
|
||||
timed_out = True
|
||||
break
|
||||
code = p.wait()
|
||||
out_str = "".join(output)
|
||||
if code != 0:
|
||||
logger.warning("CMD 返回非零(%d): %s", code, out_str.strip())
|
||||
return code, timed_out, out_str
|
||||
except Exception as e:
|
||||
logger.exception("RUN 异常: %s", e)
|
||||
return -1, False, ""
|
||||
|
||||
|
||||
def with_retry(tag, func, *args):
|
||||
"""
|
||||
重试包装,func 必须返回 (code, timed_out, output)
|
||||
"""
|
||||
for i in range(1, MAX_RETRIES + 1):
|
||||
code, timed_out, out = func(*args)
|
||||
if code == 0:
|
||||
return True
|
||||
if timed_out:
|
||||
logger.error("%s 超时,不再重试", tag)
|
||||
break
|
||||
# 统计重试
|
||||
if tag.startswith("DL["):
|
||||
DL_RETRY.inc()
|
||||
if tag.startswith("PR["):
|
||||
PR_RETRY.inc()
|
||||
if tag.startswith("UP["):
|
||||
UP_RETRY.inc()
|
||||
logger.warning("%s 重试 %d/%d, output: %s", tag, i, MAX_RETRIES, out.strip())
|
||||
time.sleep(RETRY_DELAY_S)
|
||||
logger.error("%s 最终失败", tag)
|
||||
return False
|
||||
|
||||
|
||||
# —— 删除软连接 —— #
|
||||
def delete_symlinks(root_path):
|
||||
"""
|
||||
调用 find 一次性删除所有软链接,效率更高。
|
||||
"""
|
||||
logger.info("删除软连接: %s", root_path)
|
||||
code, _, out = run(
|
||||
["sudo", "find", root_path, "-type", "l", "-delete"], timeout=120
|
||||
)
|
||||
if code != 0:
|
||||
logger.warning("删除软连接失败: %s", out.strip())
|
||||
|
||||
|
||||
# —— 下载阶段 —— #
|
||||
@DL_DUR.time()
|
||||
def do_download(batch_id, remote_paths, batch_timeout):
|
||||
if batch_id is None:
|
||||
proc_q.put(SENTINEL)
|
||||
return
|
||||
|
||||
DL_TOTAL.inc()
|
||||
start = time.time()
|
||||
in_dir = os.path.join(INPUT_ROOT, batch_id)
|
||||
os.makedirs(in_dir, exist_ok=True)
|
||||
logger.info("DL[%s] 开始, targets=%s", batch_id, remote_paths)
|
||||
|
||||
# 本地同 batch 子目录限流
|
||||
while _local_counter >= MAX_LOCAL:
|
||||
logger.warning("DL[%s] 本地子目录≥%d,sleep 5min", batch_id, MAX_LOCAL)
|
||||
time.sleep(300)
|
||||
|
||||
exclude_flags = []
|
||||
for sub in EXCLUDE_SUBDIRS:
|
||||
exclude_flags += ["--ignore", f"'{sub}'"]
|
||||
|
||||
success_count = 0
|
||||
for remote in remote_paths:
|
||||
# 检查 lidar_gt_pandar128
|
||||
lidar_sub = remote.rstrip("/") + "/lidar_gt_pandar128"
|
||||
code, _, out = run(["coscmd", "list", lidar_sub])
|
||||
if code != 0 or not out.strip():
|
||||
logger.info("DL[%s] 跳过无数据: %s", batch_id, lidar_sub)
|
||||
continue
|
||||
|
||||
elapsed = time.time() - start
|
||||
if elapsed > batch_timeout:
|
||||
logger.error("DL[%s] 超时,停止下载", batch_id)
|
||||
DL_FAIL.inc()
|
||||
break
|
||||
|
||||
f_start = time.time()
|
||||
local_bag = os.path.join(in_dir, os.path.basename(remote))
|
||||
cmd = ["coscmd", "-s", "download", "-r", remote, local_bag] + exclude_flags
|
||||
ok = with_retry(
|
||||
f"DL[{batch_id}]", lambda c: run(c, timeout=batch_timeout - elapsed), cmd
|
||||
)
|
||||
FILE_DL_DUR.observe(time.time() - f_start)
|
||||
if ok:
|
||||
success_count += 1
|
||||
incr_local(1, batch_id) # 成功下载一个子目录
|
||||
else:
|
||||
DL_FAIL.inc()
|
||||
|
||||
logger.info("DL[%s] 完成,成功 %d", batch_id, success_count)
|
||||
proc_q.put((batch_id, (in_dir, remote_paths)))
|
||||
|
||||
|
||||
# —— 处理阶段 —— #
|
||||
@PR_DUR.time()
|
||||
def do_process(batch_id, data, batch_timeout):
|
||||
if batch_id is None:
|
||||
up_q.put(SENTINEL)
|
||||
return
|
||||
|
||||
in_dir, remote_paths = data
|
||||
PR_TOTAL.inc()
|
||||
start = time.time()
|
||||
out_dir = os.path.join(OUTPUT_ROOT, batch_id)
|
||||
os.makedirs(out_dir, exist_ok=True)
|
||||
logger.info("PR[%s] 开始", batch_id)
|
||||
|
||||
# 统一用 run() 运行 Docker
|
||||
docker_cmd = [
|
||||
"docker",
|
||||
"run",
|
||||
"--rm",
|
||||
"--gpus",
|
||||
"all",
|
||||
"--shm-size=16g",
|
||||
"-v",
|
||||
f"{in_dir}:/input",
|
||||
"-v",
|
||||
f"{out_dir}:/output",
|
||||
DOCKER_IMAGE,
|
||||
"bash",
|
||||
"-c",
|
||||
"cd /code/projects/od_net/lidarnet/ && "
|
||||
"source ../.venv/bin/activate && "
|
||||
"python pipeline.py --input_dir=/input --output_dir=/output",
|
||||
]
|
||||
ok = with_retry(
|
||||
f"PR[{batch_id}]", lambda c: run(c, timeout=batch_timeout), docker_cmd
|
||||
)
|
||||
elapsed = time.time() - start
|
||||
if ok:
|
||||
logger.info("PR[%s] 成功 耗时 %.1fs", batch_id, elapsed)
|
||||
else:
|
||||
PR_FAIL.inc()
|
||||
logger.error("PR[%s] 失败 耗时 %.1fs", batch_id, elapsed)
|
||||
|
||||
# 统计并清理输出子目录
|
||||
existing = []
|
||||
for name in os.listdir(out_dir):
|
||||
path = os.path.join(out_dir, name)
|
||||
if os.path.isdir(path) and name.endswith(".dir"):
|
||||
existing.append(path)
|
||||
BATCH_OUT_FILES.set(len(existing))
|
||||
logger.info("PR[%s] 输出子目录: %s", batch_id, existing)
|
||||
|
||||
# for folder in existing:
|
||||
# for item in os.listdir(folder):
|
||||
# p = os.path.join(folder, item)
|
||||
# if item != "object_tracking":
|
||||
# # 非 object_tracking 全删
|
||||
# if os.path.isdir(p):
|
||||
# shutil.rmtree(p, ignore_errors=True)
|
||||
# else:
|
||||
# os.remove(p)
|
||||
# else:
|
||||
# # 重命名 object_tracking
|
||||
# newp = os.path.join(folder, "object_auto_labeling")
|
||||
# os.rename(p, newp)
|
||||
|
||||
shutil.rmtree(in_dir, ignore_errors=True)
|
||||
up_q.put((batch_id, (out_dir, remote_paths)))
|
||||
|
||||
|
||||
# —— 上传阶段 —— #
|
||||
@UP_DUR.time()
|
||||
def do_upload(batch_id, data, batch_timeout):
|
||||
if batch_id is None:
|
||||
return
|
||||
|
||||
out_dir, remote_paths = data
|
||||
UP_TOTAL.inc()
|
||||
start = time.time()
|
||||
logger.info("UP[%s] 开始", batch_id)
|
||||
|
||||
# 删除中间产物
|
||||
for sub in ("logs", "intermedia_products"):
|
||||
shutil.rmtree(os.path.join(out_dir, sub), ignore_errors=True)
|
||||
|
||||
# 删除所有软连接
|
||||
delete_symlinks(out_dir)
|
||||
|
||||
# 只上传 object_tracking -> remote/aaa
|
||||
for remote in remote_paths:
|
||||
base = remote.rstrip("/")
|
||||
dest = f"{base}/derived/object_auto_labeling"
|
||||
for root, dirs, _ in os.walk(out_dir):
|
||||
if "object_tracking" in dirs:
|
||||
src = os.path.join(root, "object_tracking")
|
||||
elapsed = time.time() - start
|
||||
if elapsed > batch_timeout:
|
||||
logger.error("UP[%s] 超时", batch_id)
|
||||
UP_FAIL.inc()
|
||||
return
|
||||
ok = with_retry(
|
||||
f"UP[{batch_id}]",
|
||||
lambda c: run(c, timeout=batch_timeout - elapsed),
|
||||
["coscmd", "-s", "upload", "-r", src, dest],
|
||||
)
|
||||
if not ok:
|
||||
UP_FAIL.inc()
|
||||
logger.info("UP[%s] 上传完成", batch_id)
|
||||
|
||||
# 清理本地输出 & 输入目录
|
||||
for cmd in [
|
||||
["sudo", "rsync", "-a", "--delete", f"{EMPTY_DIR}/", f"{out_dir}/"],
|
||||
]:
|
||||
code, _, _ = run(cmd, timeout=60)
|
||||
if code != 0:
|
||||
logger.warning("UP[%s] rsync 失败: %s", batch_id, cmd)
|
||||
shutil.rmtree(out_dir, ignore_errors=True)
|
||||
# 新增:删除 input 下对应 batch,释放磁盘
|
||||
in_dir = os.path.join(INPUT_ROOT, batch_id)
|
||||
shutil.rmtree(in_dir, ignore_errors=True)
|
||||
decr_local_batch(batch_id)
|
||||
|
||||
logger.info("UP[%s] 本地清理完成", batch_id)
|
||||
|
||||
|
||||
# —— Worker & 主流程 —— #
|
||||
def worker(q, fn, timeout):
|
||||
while True:
|
||||
bid, data = q.get()
|
||||
try:
|
||||
fn(bid, data, timeout)
|
||||
except Exception:
|
||||
logger.exception("阶段异常,batch=%s", bid)
|
||||
finally:
|
||||
q.task_done()
|
||||
if bid is None:
|
||||
break
|
||||
|
||||
|
||||
def main():
|
||||
p = argparse.ArgumentParser()
|
||||
p.add_argument("--tasks-file", required=True, help="每行一个 COS 目录路径")
|
||||
p.add_argument("--batch-size", type=int, default=BATCH_SIZE)
|
||||
p.add_argument("--batch-timeout", type=int, default=3600)
|
||||
args = p.parse_args()
|
||||
T = args.batch_timeout
|
||||
|
||||
# 确保基础目录
|
||||
for d in (INPUT_ROOT, OUTPUT_ROOT, EMPTY_DIR, LOG_DIR):
|
||||
os.makedirs(d, exist_ok=True)
|
||||
|
||||
# 读取并分批
|
||||
lines = [
|
||||
line.strip() for line in open(args.tasks_file, encoding="utf-8") if line.strip()
|
||||
]
|
||||
for idx in range(0, len(lines), args.batch_size):
|
||||
blk = lines[idx : idx + args.batch_size]
|
||||
BATCH_SIZE_HIST.observe(len(blk))
|
||||
bid = (
|
||||
datetime.utcnow().strftime("%Y%m%dT%H%M%SZ")
|
||||
+ f"_{idx // args.batch_size + 1}"
|
||||
)
|
||||
batch_q.put((bid, blk))
|
||||
batch_q.put(SENTINEL)
|
||||
|
||||
# 启动监控
|
||||
start_http_server(METRICS_PORT)
|
||||
logger.info("Metrics 服务启动,端口 %d", METRICS_PORT)
|
||||
|
||||
# 启动线程
|
||||
threads = [
|
||||
threading.Thread(
|
||||
target=worker, args=(batch_q, do_download, T), name="DL-Worker"
|
||||
),
|
||||
threading.Thread(target=worker, args=(proc_q, do_process, T), name="PR-Worker"),
|
||||
threading.Thread(target=worker, args=(up_q, do_upload, T), name="UP-Worker"),
|
||||
]
|
||||
for t in threads:
|
||||
t.start()
|
||||
|
||||
# 主线程仅更新队列长度和本地目录计数,无全盘扫描
|
||||
while batch_q.unfinished_tasks or proc_q.unfinished_tasks or up_q.unfinished_tasks:
|
||||
Q_BATCH.set(batch_q.qsize())
|
||||
Q_PROC.set(proc_q.qsize())
|
||||
Q_UP.set(up_q.qsize())
|
||||
LOCAL_COUNT.set(get_local_count())
|
||||
time.sleep(1)
|
||||
|
||||
for t in threads:
|
||||
t.join()
|
||||
|
||||
logger.info("所有批次完成")
|
||||
sys.exit(0)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
0
fst_data_pipeline/pipelines/tencent/__init__.py
Normal file
0
fst_data_pipeline/pipelines/tencent/__init__.py
Normal file
@@ -0,0 +1,229 @@
|
||||
#!/usr/bin/env python3
|
||||
# cleanup_bags_psql.py
|
||||
"""
|
||||
批量清理 bag_list 中 fst_indexed=FALSE 且日期早于 N 个月的记录:
|
||||
1. 删除 COS 目录
|
||||
2. 删除关联 6 张表
|
||||
3. 标记 bag_list.is_deleted=TRUE
|
||||
4. 失败自动重试 3 次,失败明细写 delete_failed.log
|
||||
5. 最终生成 CSV 并打印统计
|
||||
"""
|
||||
|
||||
import csv
|
||||
import logging
|
||||
import re
|
||||
import threading
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
import psycopg2
|
||||
import requests
|
||||
from qcloud_cos import CosConfig, CosS3Client
|
||||
from tenacity import (
|
||||
retry,
|
||||
stop_after_attempt,
|
||||
wait_exponential,
|
||||
retry_if_exception_type,
|
||||
)
|
||||
|
||||
# 如果原来用 ConfigManager,保持不动;这里为了单文件可跑,直接用 os.environ 兜底
|
||||
import os
|
||||
|
||||
# ===================== 配置 =====================
|
||||
DB_HOST = os.getenv("ROOT_DB_HOST", "localhost")
|
||||
DB_PORT = int(os.getenv("ROOT_DB_PORT", 5432))
|
||||
DB_NAME = os.getenv("ROOT_DB_NAME", "default_dbname")
|
||||
DB_USER = os.getenv("ROOT_DB_USER", "default_user")
|
||||
DB_PASSWORD = os.getenv("ROOT_DB_PASSWD", "default_password")
|
||||
|
||||
COS_REGION = os.getenv("COS_CLIENT_REGION", "ap-guangzhou")
|
||||
COS_SECRET_ID = os.getenv("COS_CLIENT_SECRET_ID", "default_id")
|
||||
COS_SECRET_KEY = os.getenv("COS_CLIENT_SECRET_KEY", "default_key")
|
||||
COS_BUCKET = os.getenv("COS_BUCKET", "b-perception-e2e-1318950322")
|
||||
|
||||
BASE_URL = os.getenv("ROOT_DB_API", "http://localhost")
|
||||
PROJECT_IDS = (1, 2)
|
||||
MONTHS = 6
|
||||
MAX_WORKERS = 10
|
||||
RETRY_TIMES = 3
|
||||
|
||||
CSV_FILE = f"deleted_bags_{datetime.now():%Y%m%d_%H%M%S}.csv"
|
||||
|
||||
DB_CONF = dict(
|
||||
host=DB_HOST, port=DB_PORT, user=DB_USER, password=DB_PASSWORD, dbname=DB_NAME
|
||||
)
|
||||
|
||||
# ===================== 日志 =====================
|
||||
logging.basicConfig(
|
||||
filename=f"cleanup_{datetime.now():%Y%m%d}.log",
|
||||
level=logging.INFO,
|
||||
format="%(asctime)s [%(levelname)s] %(message)s",
|
||||
)
|
||||
|
||||
# 失败明细日志
|
||||
failure_handler = logging.FileHandler("delete_failed.log", encoding="utf-8")
|
||||
failure_handler.setFormatter(logging.Formatter("%(asctime)s,%(message)s"))
|
||||
logger_failure = logging.getLogger("failure")
|
||||
logger_failure.setLevel(logging.INFO)
|
||||
logger_failure.addHandler(failure_handler)
|
||||
|
||||
# ===================== COS 客户端 =====================
|
||||
cos = CosS3Client(
|
||||
CosConfig(Region=COS_REGION, SecretId=COS_SECRET_ID, SecretKey=COS_SECRET_KEY)
|
||||
)
|
||||
|
||||
# ===================== 统计 =====================
|
||||
stats_lock = threading.Lock()
|
||||
stats = {"total": 0, "success": 0, "fail": 0}
|
||||
|
||||
|
||||
def inc_stat(key: str):
|
||||
with stats_lock:
|
||||
stats[key] += 1
|
||||
|
||||
|
||||
# ===================== 数据库工具 =====================
|
||||
def get_conn():
|
||||
return psycopg2.connect(**DB_CONF)
|
||||
|
||||
|
||||
def fetch_candidates() -> list[tuple[int, str]]:
|
||||
"""返回 [(id, name), ...]"""
|
||||
sql = """
|
||||
SELECT id, name
|
||||
FROM bag_list
|
||||
WHERE project_id = ANY(%s)
|
||||
AND fst_indexed = FALSE
|
||||
AND is_deleted = FALSE
|
||||
"""
|
||||
with get_conn() as conn:
|
||||
with conn.cursor() as cur:
|
||||
cur.execute(sql, (PROJECT_IDS,))
|
||||
return cur.fetchall()
|
||||
|
||||
|
||||
# ===================== 日期过滤 =====================
|
||||
DATE_RE = re.compile(r"_(\d{4})(\d{2})(\d{2})-(\d{2})(\d{2})(\d{2})_")
|
||||
|
||||
|
||||
def need_delete(name: str) -> bool:
|
||||
days = 30 * MONTHS
|
||||
m = DATE_RE.search(name)
|
||||
if not m:
|
||||
return False
|
||||
dt = datetime.strptime("".join(m.groups()), "%Y%m%d%H%M%S")
|
||||
return dt < datetime.utcnow() - timedelta(days=days)
|
||||
|
||||
|
||||
# ===================== 调接口拿 COS 路径 =====================
|
||||
def get_pangu_detail(bag_name: str):
|
||||
url = f"{BASE_URL}/api/bags/pangu/detail"
|
||||
resp = requests.get(url, params={"bagName": bag_name}, timeout=10)
|
||||
resp.raise_for_status()
|
||||
data = resp.json()
|
||||
return data["dataPath"], data["rawPath"]
|
||||
|
||||
|
||||
# ===================== COS 批量删前缀 =====================
|
||||
def cos_delete_prefix(prefix: str):
|
||||
if not prefix:
|
||||
return
|
||||
marker = ""
|
||||
while True:
|
||||
resp = cos.list_objects(Bucket=COS_BUCKET, Prefix=prefix, Marker=marker)
|
||||
contents = resp.get("Contents", [])
|
||||
if not contents:
|
||||
break
|
||||
keys = [{"Key": obj["Key"]} for obj in contents]
|
||||
cos.delete_objects(Bucket=COS_BUCKET, Delete={"Objects": keys})
|
||||
if resp.get("IsTruncated") == "false":
|
||||
break
|
||||
marker = resp["NextMarker"]
|
||||
|
||||
|
||||
# ===================== 重试包装 =====================
|
||||
@retry(
|
||||
stop=stop_after_attempt(RETRY_TIMES),
|
||||
wait=wait_exponential(multiplier=2, min=4, max=30),
|
||||
retry=retry_if_exception_type((
|
||||
requests.exceptions.RequestException,
|
||||
psycopg2.OperationalError,
|
||||
)),
|
||||
reraise=True,
|
||||
)
|
||||
def _do_cleanup(bag_id: int, bag_name: str):
|
||||
# 1. 拿路径并删 COS
|
||||
data_path, raw_path = get_pangu_detail(bag_name)
|
||||
cos_delete_prefix(data_path)
|
||||
cos_delete_prefix(raw_path)
|
||||
|
||||
# 2. 一个事务删关联表
|
||||
with get_conn() as conn:
|
||||
with conn.cursor() as cur:
|
||||
cur.execute(
|
||||
"""
|
||||
DELETE FROM bag_lifecycle WHERE bag_name = %s;
|
||||
DELETE FROM secondary_pangu WHERE bag_id = %s;
|
||||
DELETE FROM secondary_minerva WHERE bag_id = %s;
|
||||
DELETE FROM bag_topic WHERE bag_id = %s;
|
||||
DELETE FROM bag_reserved_tag WHERE bag_id = %s;
|
||||
DELETE FROM fst_bag WHERE bag_id = %s;
|
||||
UPDATE bag_list SET is_deleted = TRUE WHERE id = %s;
|
||||
""",
|
||||
(bag_name, bag_id, bag_id, bag_id, bag_id, bag_id, bag_id),
|
||||
)
|
||||
conn.commit()
|
||||
|
||||
|
||||
def process_one(bag_id: int, bag_name: str) -> tuple[bool, int, str]:
|
||||
"""外层统一入口:负责统计、重试、失败落盘"""
|
||||
inc_stat("total")
|
||||
try:
|
||||
_do_cleanup(bag_id, bag_name)
|
||||
inc_stat("success")
|
||||
logging.info("Cleaned bag_id=%d name=%s", bag_id, bag_name)
|
||||
return True, bag_id, bag_name
|
||||
except Exception as e:
|
||||
inc_stat("fail")
|
||||
# 失败明细落盘
|
||||
logger_failure.info("%d,%s,%s", bag_id, bag_name, str(e).replace(",", ";"))
|
||||
logging.error("Failed bag_id=%d name=%s error=%s", bag_id, bag_name, e)
|
||||
return False, bag_id, bag_name
|
||||
|
||||
|
||||
# ===================== 主流程 =====================
|
||||
def main():
|
||||
candidates = fetch_candidates()
|
||||
logging.info("Fetched %d candidate bags", len(candidates))
|
||||
|
||||
to_delete = [(bid, nm) for bid, nm in candidates if need_delete(nm)]
|
||||
logging.info("Need delete %d bags", len(to_delete))
|
||||
if not to_delete:
|
||||
logging.info("Nothing to delete, exit.")
|
||||
return
|
||||
|
||||
results = []
|
||||
with ThreadPoolExecutor(max_workers=MAX_WORKERS) as pool:
|
||||
futures = [pool.submit(process_one, bid, nm) for bid, nm in to_delete]
|
||||
for f in as_completed(futures):
|
||||
results.append(f.result())
|
||||
|
||||
# 写 CSV
|
||||
with open(CSV_FILE, "w", newline="", encoding="utf-8") as f:
|
||||
writer = csv.writer(f)
|
||||
writer.writerow(["id", "name", "success"])
|
||||
for ok, bid, nm in results:
|
||||
writer.writerow([bid, nm, ok])
|
||||
|
||||
# 最终汇总
|
||||
logging.info(
|
||||
"=== Done: Total=%d Success=%d Fail=%d. CSV=%s FailureDetail=delete_failed.log ===",
|
||||
stats["total"],
|
||||
stats["success"],
|
||||
stats["fail"],
|
||||
CSV_FILE,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -0,0 +1,270 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
import csv
|
||||
import sys
|
||||
import shutil
|
||||
import logging
|
||||
from logging.handlers import TimedRotatingFileHandler
|
||||
from pathlib import Path
|
||||
import subprocess
|
||||
import argparse
|
||||
from concurrent.futures import ProcessPoolExecutor, as_completed
|
||||
from tenacity import retry, stop_after_attempt, wait_fixed, retry_if_exception_type
|
||||
import psycopg2
|
||||
from psycopg2 import OperationalError
|
||||
from subprocess import CalledProcessError
|
||||
from tqdm import tqdm
|
||||
from fst_data_pipeline.core.config_manager import ConfigManager
|
||||
|
||||
# ==== 1. 全局目录管理 ====
|
||||
ROOT_DIR = Path(__file__).parent.resolve()
|
||||
LOG_DIR = ROOT_DIR / "logs"
|
||||
DOWNLOAD_DIR = ROOT_DIR / "downloads"
|
||||
VIDEO_DIR = ROOT_DIR / "videos"
|
||||
for d in (LOG_DIR, DOWNLOAD_DIR, VIDEO_DIR):
|
||||
d.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# ==== 2. 日志配置 ====
|
||||
logger = logging.getLogger("bag_pipeline")
|
||||
logger.setLevel(logging.DEBUG)
|
||||
ch = logging.StreamHandler(sys.stdout)
|
||||
ch.setLevel(logging.INFO)
|
||||
ch.setFormatter(logging.Formatter("[%(levelname)s] %(message)s"))
|
||||
fh = TimedRotatingFileHandler(
|
||||
LOG_DIR / "pipeline.log", when="midnight", backupCount=7, encoding="utf-8"
|
||||
)
|
||||
fh.setLevel(logging.DEBUG)
|
||||
fh.setFormatter(
|
||||
logging.Formatter("%(asctime)s %(levelname)s [%(processName)s] %(message)s")
|
||||
)
|
||||
logger.addHandler(ch)
|
||||
logger.addHandler(fh)
|
||||
|
||||
# ==== 3. 数据库连接工厂 ====
|
||||
config = ConfigManager()
|
||||
DB_HOST = config.get("ROOT_DB_HOST", "localhost")
|
||||
DB_PORT = config.get_int("ROOT_DB_PORT", 5432)
|
||||
DB_NAME = config.get("ROOT_DB_NAME", "default_dbname")
|
||||
DB_USER = config.get("ROOT_DB_USER", "default_user")
|
||||
DB_PASSWORD = config.get("ROOT_DB_PASSWD", "default_password")
|
||||
|
||||
DB_CFG = dict(
|
||||
host=DB_HOST, port=DB_PORT, dbname=DB_NAME, user=DB_USER, password=DB_PASSWORD
|
||||
)
|
||||
|
||||
|
||||
def get_conn():
|
||||
return psycopg2.connect(**DB_CFG)
|
||||
|
||||
|
||||
def read_and_validate_csv(path):
|
||||
required = {"bag name", "FST 1st node", "FST 2nd node", "FST 3rd node"}
|
||||
records = []
|
||||
seen_bags = set()
|
||||
with open(path, encoding="utf-8", newline="") as f:
|
||||
reader = csv.DictReader(f)
|
||||
headers = set(reader.fieldnames or [])
|
||||
if not required.issubset(headers):
|
||||
missing = required - headers
|
||||
raise ValueError(f"CSV 缺少必要列: {missing}")
|
||||
for i, row in enumerate(reader, start=2):
|
||||
bag = row["bag name"].strip()
|
||||
if not bag:
|
||||
raise ValueError(f"第{i}行 bag name 为空")
|
||||
if bag in seen_bags:
|
||||
raise ValueError(f'第{i}行 bag 重复: "{bag}"')
|
||||
seen_bags.add(bag)
|
||||
nodes = []
|
||||
for col in ("FST 1st node", "FST 2nd node", "FST 3rd node"):
|
||||
v = row.get(col, "").strip()
|
||||
if v:
|
||||
nodes.append(v)
|
||||
if not nodes:
|
||||
raise ValueError(f"第{i}行没有任何 FST 节点")
|
||||
records.append({"bag": bag, "nodes": nodes, "first": nodes[0]})
|
||||
return records
|
||||
|
||||
|
||||
@retry(
|
||||
stop=stop_after_attempt(3),
|
||||
wait=wait_fixed(5),
|
||||
retry=retry_if_exception_type((OperationalError, CalledProcessError)),
|
||||
)
|
||||
def process_record(rec) -> bool:
|
||||
bag, first = rec["bag"], rec["first"]
|
||||
logger.info(f"[{bag}] 开始处理")
|
||||
|
||||
# 5.1 查 data_path
|
||||
try:
|
||||
conn = get_conn()
|
||||
with conn.cursor() as cur:
|
||||
cur.execute(
|
||||
"SELECT data_path FROM main_pangu WHERE bag_path LIKE %s LIMIT 1",
|
||||
(f"%{bag}%",),
|
||||
)
|
||||
row = cur.fetchone()
|
||||
conn.close()
|
||||
except Exception as e:
|
||||
logger.error(f"[{bag}] 查询 data_path 出错: {e}")
|
||||
raise
|
||||
if not row:
|
||||
logger.error(f"[{bag}] main_pangu 未找到记录,跳过")
|
||||
return False
|
||||
data_path = row[0]
|
||||
|
||||
# 5.2 下载图片
|
||||
dst_dir = DOWNLOAD_DIR / bag / "camera_front_wide"
|
||||
dst_dir.mkdir(parents=True, exist_ok=True)
|
||||
src_path = data_path.partition("/")[2] + "/camera_front_wide/"
|
||||
cmd_dl = ["coscmd", "download", "-r", src_path, str(dst_dir)]
|
||||
try:
|
||||
subprocess.run(cmd_dl, check=True, stdout=subprocess.DEVNULL)
|
||||
except CalledProcessError as e:
|
||||
logger.error(f"[{bag}] 下载失败: {e}")
|
||||
raise
|
||||
|
||||
# 5.3 生成 list.txt
|
||||
imgs = sorted(
|
||||
p.name for p in dst_dir.iterdir() if p.suffix.lower() in (".jpg", ".png")
|
||||
)
|
||||
sampled = imgs[::3]
|
||||
if not imgs:
|
||||
logger.error(f"[{bag}] 未找到任何图片,跳过")
|
||||
shutil.rmtree(DOWNLOAD_DIR / bag, ignore_errors=True)
|
||||
return False
|
||||
(dst_dir / "list.txt").write_text(
|
||||
"\n".join(f'file "{fn}"' for fn in sampled), encoding="utf-8"
|
||||
)
|
||||
|
||||
# 5.4 ffmpeg 合成
|
||||
out_dir = VIDEO_DIR / first
|
||||
out_dir.mkdir(parents=True, exist_ok=True)
|
||||
out_mp4 = out_dir / f"{bag}.mp4"
|
||||
cmd_ff = [
|
||||
"ffmpeg",
|
||||
"-y",
|
||||
"-f",
|
||||
"concat",
|
||||
"-safe",
|
||||
"0",
|
||||
"-i",
|
||||
"list.txt",
|
||||
"-vf",
|
||||
"scale=640:360",
|
||||
"-c:v",
|
||||
"libx264",
|
||||
"-pix_fmt",
|
||||
"yuv420p",
|
||||
str(out_mp4),
|
||||
]
|
||||
try:
|
||||
subprocess.run(cmd_ff, cwd=dst_dir, check=True)
|
||||
except CalledProcessError as e:
|
||||
logger.error(f"[{bag}] 合成失败: {e}")
|
||||
raise
|
||||
|
||||
# 5.5 更新数据库
|
||||
try:
|
||||
conn = get_conn()
|
||||
with conn:
|
||||
with conn.cursor() as cur:
|
||||
# 5.5.1 bag_list
|
||||
cur.execute(
|
||||
"SELECT id,fst_indexed FROM bag_list "
|
||||
"WHERE name=%s AND is_deleted=FALSE AND is_active_data=TRUE",
|
||||
(bag,),
|
||||
)
|
||||
r = cur.fetchone()
|
||||
if not r:
|
||||
logger.error(f"[{bag}] bag_list 未找到,跳过更新")
|
||||
conn.rollback()
|
||||
return False
|
||||
bag_id, fst_indexed = r
|
||||
if not fst_indexed:
|
||||
cur.execute(
|
||||
"UPDATE bag_list SET fst_indexed=TRUE WHERE id=%s", (bag_id,)
|
||||
)
|
||||
# 5.5.2 逐级关联 fst
|
||||
parent = None
|
||||
for name in rec["nodes"]:
|
||||
cur.execute("SELECT id,parent_id FROM fst WHERE name=%s", (name,))
|
||||
fr = cur.fetchone()
|
||||
if not fr:
|
||||
logger.error(f"[{bag}] FST 节点不存在 '{name}',跳过")
|
||||
conn.rollback()
|
||||
return False
|
||||
fid, actual_parent = fr
|
||||
if parent is not None and actual_parent != parent:
|
||||
logger.error(
|
||||
f"[{bag}] 节点 '{name}' 父级校验失败 "
|
||||
f"(actual={actual_parent}!=expect={parent}),跳过"
|
||||
)
|
||||
conn.rollback()
|
||||
return False
|
||||
|
||||
# ---- 确保没插入过才 update bag_sum ----
|
||||
cur.execute(
|
||||
"SELECT 1 FROM fst_bag WHERE bag_id=%s AND fst_node_id=%s",
|
||||
(bag_id, fid),
|
||||
)
|
||||
if not cur.fetchone():
|
||||
cur.execute(
|
||||
"INSERT INTO fst_bag(bag_id,fst_node_id) VALUES(%s,%s)",
|
||||
(bag_id, fid),
|
||||
)
|
||||
cur.execute(
|
||||
"UPDATE fst SET bag_sum = bag_sum + 1 WHERE id=%s", (fid,)
|
||||
)
|
||||
parent = fid
|
||||
except OperationalError as e:
|
||||
logger.error(f"[{bag}] 更新DB连接错误: {e}")
|
||||
raise
|
||||
except Exception:
|
||||
logger.exception(f"[{bag}] 更新DB时未知错误,跳过")
|
||||
return False
|
||||
finally:
|
||||
if conn:
|
||||
conn.close()
|
||||
|
||||
# 5.6 清理 & 成功日志
|
||||
shutil.rmtree(DOWNLOAD_DIR / bag, ignore_errors=True)
|
||||
logger.info(f"[{bag}] 处理成功")
|
||||
return True
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="批量下载→合成→更新DB")
|
||||
parser.add_argument("csv_path", help="CSV 文件路径")
|
||||
parser.add_argument("-w", "--workers", type=int, default=4, help="并行 worker 数")
|
||||
args = parser.parse_args()
|
||||
|
||||
try:
|
||||
records = read_and_validate_csv(args.csv_path)
|
||||
except Exception as e:
|
||||
logger.error(f"CSV 校验失败: {e}")
|
||||
sys.exit(1)
|
||||
|
||||
total = len(records)
|
||||
logger.info(f"共 {total} 条,启动 {args.workers} 并行 worker")
|
||||
|
||||
success = skipped = failed = 0
|
||||
with ProcessPoolExecutor(max_workers=args.workers) as exe:
|
||||
fut2bag = {exe.submit(process_record, rec): rec["bag"] for rec in records}
|
||||
for fut in tqdm(as_completed(fut2bag), total=total, desc="进度"):
|
||||
bag = fut2bag[fut]
|
||||
try:
|
||||
ok = fut.result()
|
||||
if ok:
|
||||
success += 1
|
||||
else:
|
||||
skipped += 1
|
||||
except Exception as e:
|
||||
logger.error(f"[{bag}] 最终失败: {e}")
|
||||
failed += 1
|
||||
|
||||
logger.info(f"结束:{success} 成功,{skipped} 跳过,{failed} 失败,共 {total}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -0,0 +1,317 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
"""
|
||||
-----------------------------------------------------------------------------------
|
||||
功能概述:
|
||||
把数据库里的rosbag的经纬度降采样, 提取, 画图生成每周采集覆盖度在SD map上的分布情况
|
||||
|
||||
使用示例:
|
||||
|
||||
|
||||
配置说明:
|
||||
-----------------------------------------------------------------------------------
|
||||
"""
|
||||
|
||||
import os
|
||||
import re
|
||||
import csv
|
||||
import time
|
||||
import math
|
||||
import logging
|
||||
import threading
|
||||
import argparse
|
||||
from io import StringIO
|
||||
from http.server import HTTPServer, SimpleHTTPRequestHandler
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
|
||||
import numpy as np
|
||||
import psycopg2
|
||||
from shapely import wkb
|
||||
import folium
|
||||
from folium import TileLayer, FeatureGroup, LayerControl
|
||||
from branca.element import Element
|
||||
from branca.colormap import linear
|
||||
from qcloud_cos import CosConfig, CosS3Client
|
||||
from tqdm import tqdm
|
||||
from fst_data_pipeline.core.config_manager import ConfigManager
|
||||
|
||||
|
||||
# ====== 配置区 ======
|
||||
config = ConfigManager()
|
||||
DB_HOST = config.get("ROOT_DB_HOST", "localhost")
|
||||
DB_PORT = config.get_int("ROOT_DB_PORT", 5432)
|
||||
DB_NAME = config.get("ROOT_DB_NAME", "default_dbname")
|
||||
DB_USER = config.get("ROOT_DB_USER", "default_user")
|
||||
DB_PASSWORD = config.get("ROOT_DB_PASSWD", "default_password")
|
||||
|
||||
COS_REGION = config.get("COS_CLIENT_REGION", "default_region")
|
||||
COS_SECRET_ID = config.get("COS_CLIENT_SECRET_ID", "default_id")
|
||||
COS_SECRET_KEY = config.get("COS_CLIENT_SECRET_KEY", "default_key")
|
||||
|
||||
# ===== 日志配置 =====
|
||||
logging.basicConfig(
|
||||
level=logging.DEBUG,
|
||||
format="%(asctime)s %(levelname)s [%(threadName)s] %(message)s",
|
||||
datefmt="%Y-%m-%dT%H:%M:%S",
|
||||
)
|
||||
logger = logging.getLogger("gnss_map")
|
||||
|
||||
|
||||
# ===== 静态文件服务 (带 CORS) =====
|
||||
class CORSHandler(SimpleHTTPRequestHandler):
|
||||
def end_headers(self):
|
||||
self.send_header("Access-Control-Allow-Origin", "*")
|
||||
super().end_headers()
|
||||
|
||||
|
||||
def serve_map(port: int, web_dir: str):
|
||||
os.chdir(web_dir)
|
||||
logger.info("Serving static files from %s", web_dir)
|
||||
HTTPServer(("0.0.0.0", port), CORSHandler).serve_forever()
|
||||
|
||||
|
||||
# ===== 工具:haversine 计算距离 =====
|
||||
def haversine(lat1, lon1, lat2, lon2):
|
||||
R = 6371000.0
|
||||
φ1, φ2 = math.radians(lat1), math.radians(lat2)
|
||||
dφ = math.radians(lat2 - lat1)
|
||||
dλ = math.radians(lon2 - lon1)
|
||||
a = math.sin(dφ / 2) ** 2 + math.cos(φ1) * math.cos(φ2) * math.sin(dλ / 2) ** 2
|
||||
return 2 * R * math.asin(math.sqrt(a))
|
||||
|
||||
|
||||
# ===== 单 bag 处理:下载→去重→下采样 → 返回 (name, pts_wkt) =====
|
||||
def process_bag(name, cos, bucket, prefix):
|
||||
logger.debug("Start processing bag %s", name)
|
||||
key = f"{prefix}/{name}.dir/raw_gnss.csv"
|
||||
text = None
|
||||
for attempt in range(1, 4):
|
||||
try:
|
||||
resp = cos.get_object(Bucket=bucket, Key=key)
|
||||
text = resp["Body"].get_raw_stream().read().decode("utf-8")
|
||||
logger.debug("[%s] downloaded %d bytes", name, len(text))
|
||||
break
|
||||
except Exception as e:
|
||||
logger.warning("[%s] download failed (attempt %d): %s", name, attempt, e)
|
||||
time.sleep(1)
|
||||
if not text:
|
||||
logger.error("[%s] all download attempts failed → skip", name)
|
||||
return None
|
||||
|
||||
rows = list(csv.DictReader(StringIO(text)))
|
||||
unique = []
|
||||
for r in rows:
|
||||
try:
|
||||
lat, lon, alt = float(r["lat"]), float(r["lon"]), float(r.get("alt", 0))
|
||||
except AttributeError:
|
||||
continue
|
||||
if not any(haversine(lat, lon, plat, plon) < 1.0 for plat, plon, _ in unique):
|
||||
unique.append((lat, lon, alt))
|
||||
logger.debug("[%s] unique points: %d", name, len(unique))
|
||||
if not unique:
|
||||
return None
|
||||
|
||||
# 下采样最多30点
|
||||
if len(unique) > 30:
|
||||
idx = np.linspace(0, len(unique) - 1, 30, dtype=int)
|
||||
sampled = [unique[i] for i in idx]
|
||||
else:
|
||||
sampled = unique
|
||||
logger.info("[%s] sampled %d points", name, len(sampled))
|
||||
|
||||
# 仅拼接 ST_MakePoint(...) 逗号列表,留给 SQL 拼 ARRAY[]
|
||||
pts_wkt = ",".join(f"ST_MakePoint({lon},{lat},{alt})" for lat, lon, alt in sampled)
|
||||
return name, pts_wkt
|
||||
|
||||
|
||||
# ===== 主流程 =====
|
||||
def main():
|
||||
p = argparse.ArgumentParser()
|
||||
p.add_argument("-p", "--map-port", type=int, default=8000)
|
||||
p.add_argument("-u", "--tile-url", required=True)
|
||||
p.add_argument("-w", "--workers", type=int, default=8)
|
||||
args = p.parse_args()
|
||||
|
||||
# COS client
|
||||
cos = CosS3Client(
|
||||
CosConfig(
|
||||
Region=COS_REGION,
|
||||
SecretId=COS_SECRET_ID,
|
||||
SecretKey=COS_SECRET_KEY,
|
||||
)
|
||||
)
|
||||
BUCKET = config.get("bucket", "b-perception-e2e-1318950322")
|
||||
PREFIX = config.get("prefix", "mb_raw_rosbag_decode_dirs")
|
||||
|
||||
# DB 配置
|
||||
DB_CONF = dict(
|
||||
host=DB_HOST, port=DB_PORT, dbname=DB_NAME, user=DB_USER, password=DB_PASSWORD
|
||||
)
|
||||
|
||||
# 1. 取待处理 bag
|
||||
logger.info("Fetch unprocessed bags from DB")
|
||||
with psycopg2.connect(**DB_CONF) as conn:
|
||||
with conn.cursor() as cur:
|
||||
cur.execute(
|
||||
"""
|
||||
SELECT bl.name
|
||||
FROM bag_list bl
|
||||
LEFT JOIN geometry_info gi ON gi.rosbag_name=bl.name
|
||||
WHERE bl.is_decoded=TRUE
|
||||
AND gi.rosbag_name IS NULL
|
||||
"""
|
||||
)
|
||||
names = [r[0] for r in cur.fetchall()]
|
||||
logger.info("To process: %d bags", len(names))
|
||||
|
||||
# 2. 并行下载 & 处理
|
||||
results = []
|
||||
with ThreadPoolExecutor(max_workers=args.workers) as ex:
|
||||
futures = {ex.submit(process_bag, n, cos, BUCKET, PREFIX): n for n in names}
|
||||
for f in tqdm(as_completed(futures), total=len(names), desc="Processing"):
|
||||
n = futures[f]
|
||||
try:
|
||||
r = f.result()
|
||||
if r:
|
||||
results.append(r)
|
||||
except Exception as e:
|
||||
logger.error("[%s] error: %s", n, e)
|
||||
logger.info("Valid processed bags: %d", len(results))
|
||||
if not results:
|
||||
logger.warning("No new data → skip insertion, proceed to visualization")
|
||||
|
||||
# 3. 批量插入 geometry_info
|
||||
if results:
|
||||
logger.info("Batch inserting into geometry_info")
|
||||
with psycopg2.connect(**DB_CONF) as conn:
|
||||
with conn.cursor() as cur:
|
||||
batch = 50
|
||||
for i in range(0, len(results), batch):
|
||||
chunk = results[i : i + batch]
|
||||
# 手工拼 VALUES 列表
|
||||
vals = []
|
||||
for name, pts in chunk:
|
||||
esc = name.replace("'", "''")
|
||||
vals.append(f"('{esc}', ARRAY[{pts}]::geometry(PointZ)[])")
|
||||
sql = f"""
|
||||
INSERT INTO geometry_info(rosbag_name,gnss_downsampled_points)
|
||||
VALUES {",".join(vals)}
|
||||
ON CONFLICT(rosbag_name) DO NOTHING
|
||||
"""
|
||||
cur.execute(sql)
|
||||
conn.commit()
|
||||
logger.info(
|
||||
"Inserted batch %d~%d", i + 1, min(i + batch, len(results))
|
||||
)
|
||||
logger.info("Batch insert done, total: %d", len(results))
|
||||
|
||||
# 4. 从 DB 读取所有 geometry
|
||||
logger.info("Load all geometries from DB")
|
||||
with psycopg2.connect(**DB_CONF) as conn:
|
||||
with conn.cursor() as cur:
|
||||
cur.execute(
|
||||
"SELECT rosbag_name, gnss_downsampled_points FROM geometry_info"
|
||||
)
|
||||
rows = cur.fetchall()
|
||||
|
||||
by_bag, all_pts = {}, []
|
||||
for name, arr in tqdm(rows, desc="Loading geometry"):
|
||||
pts = []
|
||||
for hw in arr.strip("{}").split(","):
|
||||
h = re.sub(r"[^0-9A-Fa-f]", "", hw)
|
||||
p = wkb.loads(bytes.fromhex(h))
|
||||
pts.append((p.y, p.x))
|
||||
all_pts.append((p.y, p.x))
|
||||
if pts:
|
||||
by_bag[name] = pts
|
||||
|
||||
if all_pts:
|
||||
lats = [p[0] for p in all_pts]
|
||||
lons = [p[1] for p in all_pts]
|
||||
center = (sum(lats) / len(lats), sum(lons) / len(lons))
|
||||
else:
|
||||
logger.warning("No points in DB → default center (0,0)")
|
||||
center = (0.0, 0.0)
|
||||
logger.info("Map center: %s", center)
|
||||
|
||||
# 5. Folium 地图 & 两个图层(视图)
|
||||
m = folium.Map(
|
||||
location=center,
|
||||
zoom_start=6,
|
||||
tiles=None,
|
||||
prefer_canvas=True,
|
||||
control_scale=True,
|
||||
zoom_control=True,
|
||||
)
|
||||
TileLayer(tiles=args.tile_url, overlay=True, control=False, attr=" ").add_to(m)
|
||||
|
||||
header = m.get_root().header
|
||||
for tag in [
|
||||
'<link rel="stylesheet" href="leaflet.css"/>',
|
||||
'<link rel="stylesheet" href="MarkerCluster.css"/>',
|
||||
'<link rel="stylesheet" href="MarkerCluster.Default.css"/>',
|
||||
'<link rel="stylesheet" href="leaflet.awesome-markers.css"/>',
|
||||
'<link rel="stylesheet" href="leaflet.awesome.rotate.min.css"/>',
|
||||
'<link rel="stylesheet" href="bootstrap.min.css"/>',
|
||||
'<link rel="stylesheet" href="all.min.css"/>',
|
||||
'<script src="jquery-1.12.4.min.js"></script>',
|
||||
'<script src="bootstrap.bundle.min.js"></script>',
|
||||
'<script src="leaflet.js"></script>',
|
||||
'<script src="leaflet.markercluster.js"></script>',
|
||||
'<script src="leaflet.awesome-markers.js"></script>',
|
||||
]:
|
||||
header.add_child(Element(tag))
|
||||
|
||||
# 轨迹视图
|
||||
fg1 = FeatureGroup(name="Trajectory", overlay=False, show=True)
|
||||
for pts in by_bag.values():
|
||||
folium.PolyLine(pts, color="#3388ff", weight=1, opacity=0.7).add_to(fg1)
|
||||
m.add_child(fg1)
|
||||
|
||||
# 聚类点视图
|
||||
cnt = {}
|
||||
for lat, lon in all_pts:
|
||||
cnt[(lat, lon)] = cnt.get((lat, lon), 0) + 1
|
||||
if cnt:
|
||||
mi, ma = min(cnt.values()), max(cnt.values())
|
||||
cmap = linear.YlOrRd_09.scale(mi, ma)
|
||||
cmap.caption = "Point Count"
|
||||
else:
|
||||
cmap = None
|
||||
fg2 = FeatureGroup(name="Cluster", overlay=False, show=False)
|
||||
for (lat, lon), c in cnt.items():
|
||||
color = cmap(c) if cmap else "#000000"
|
||||
folium.CircleMarker(
|
||||
location=(lat, lon),
|
||||
radius=1 + min(c, 4),
|
||||
color=color,
|
||||
fill=True,
|
||||
fill_color=color,
|
||||
fill_opacity=0.7,
|
||||
popup=f"count: {c}",
|
||||
).add_to(fg2)
|
||||
m.add_child(fg2)
|
||||
if cmap:
|
||||
m.add_child(cmap)
|
||||
|
||||
LayerControl(collapsed=False).add_to(m)
|
||||
|
||||
# 6. 保存并启动静态服务
|
||||
script_dir = os.path.dirname(__file__)
|
||||
html = os.path.join(script_dir, "map.html")
|
||||
m.save(html)
|
||||
logger.info("Saved map → %s", html)
|
||||
|
||||
threading.Thread(
|
||||
target=serve_map,
|
||||
args=(args.map_port, script_dir),
|
||||
daemon=True,
|
||||
name="HTTP-Server",
|
||||
).start()
|
||||
logger.info("Server running at http://127.0.0.1:%d/map.html", args.map_port)
|
||||
threading.Event().wait()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -0,0 +1,293 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
"""_summary_
|
||||
✅ 批量读取 ROS bag 文件列表(一个 txt 文件里每行是文件路径)
|
||||
✅ 根据文件名解析出车辆、时间、bag 名称
|
||||
✅ 自动:
|
||||
在数据库里插入 / 更新 bag_list(记录名称与 project_id)
|
||||
插入 / 更新 main_pangu(记录 bag 的元信息)
|
||||
插入 / 更新 secondary_pangu(记录 COS 上的文件、目录存在性及链接)
|
||||
读取 meta_info.txt 文件里的 topic 列表, 更新 topic_list 与 bag_topic 关联表
|
||||
✅ 使用腾讯云 COS (对象存储) API 检查文件、目录是否存在, 并提供文件访问链接
|
||||
✅ 支持多线程并发处理, 加快大量 bag 的同步速度
|
||||
✅ 自动维护数据库连接池 (psycopg2 ThreadedConnectionPool)
|
||||
✅ 日志文件记录所有操作, 包括成功、跳过、失败, 便于后续排查
|
||||
"""
|
||||
|
||||
import os
|
||||
import re
|
||||
import tempfile
|
||||
import logging
|
||||
import argparse
|
||||
from datetime import datetime
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
from psycopg2.pool import ThreadedConnectionPool
|
||||
from tqdm import tqdm
|
||||
from qcloud_cos import CosConfig, CosS3Client
|
||||
from qcloud_cos.cos_exception import CosServiceError
|
||||
from fst_data_pipeline.core.config_manager import ConfigManager
|
||||
|
||||
# ====== 配置区 ======
|
||||
config = ConfigManager()
|
||||
DB_HOST = config.get("ROOT_DB_HOST", "localhost")
|
||||
DB_PORT = config.get_int("ROOT_DB_PORT", 5432)
|
||||
DB_NAME = config.get("ROOT_DB_NAME", "default_dbname")
|
||||
DB_USER = config.get("ROOT_DB_USER", "default_user")
|
||||
DB_PASSWORD = config.get("ROOT_DB_PASSWD", "default_password")
|
||||
|
||||
COS_REGION = config.get("COS_CLIENT_REGION", "default_region")
|
||||
COS_SECRET_ID = config.get("COS_CLIENT_SECRET_ID", "default_id")
|
||||
COS_SECRET_KEY = config.get("COS_CLIENT_SECRET_KEY", "default_key")
|
||||
|
||||
BUCKET = config.get("COS_BUCKET", "b-perception-e2e-1318950322")
|
||||
config.require("COS_PREFIX_DECODE")
|
||||
PREFIX = config.get("COS_PREFIX_DECODE", "mb_raw_rosbag_decode_dirs")
|
||||
DOMAIN = f"https://{BUCKET}.cos.{COS_REGION}.myqcloud.com"
|
||||
|
||||
MIN_CONN, MAX_CONN = 1, 10
|
||||
MAX_WORKERS = 8 # 并发线程数, 可调
|
||||
|
||||
# ====== 日志 ======
|
||||
LOGFILE = f"{datetime.now():%Y-%m-%d}.log"
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format="%(asctime)s %(levelname)s %(message)s",
|
||||
handlers=[logging.StreamHandler(), logging.FileHandler(LOGFILE, mode="a")],
|
||||
)
|
||||
logger = logging.getLogger()
|
||||
|
||||
# ====== COS 客户端 ======
|
||||
cos = CosS3Client(
|
||||
CosConfig(Region=COS_REGION, SecretId=COS_SECRET_ID, SecretKey=COS_SECRET_KEY)
|
||||
)
|
||||
|
||||
# ====== 全局 DB 连接池 ======
|
||||
pool = ThreadedConnectionPool(
|
||||
MIN_CONN, MAX_CONN, host=DB_HOST, dbname=DB_NAME, user=DB_USER, password=DB_PASSWORD
|
||||
)
|
||||
|
||||
# ====== 工具函数 ======
|
||||
|
||||
|
||||
def get_db():
|
||||
"""从池里取一个 conn, 设置 autocommit=True"""
|
||||
conn = pool.getconn()
|
||||
conn.autocommit = True
|
||||
return conn, conn.cursor()
|
||||
|
||||
|
||||
def release_db(conn, cur):
|
||||
cur.close()
|
||||
pool.putconn(conn)
|
||||
|
||||
|
||||
def parse_rosbag_line(line):
|
||||
raw = line.strip()
|
||||
name = os.path.basename(raw).split("/")[0]
|
||||
if not name.endswith(".bag"):
|
||||
name += ".bag"
|
||||
m = re.search(r"_(\d{4})(\d{2})(\d{2})-(\d{2})(\d{2})(\d{2})_", name)
|
||||
if not m:
|
||||
raise ValueError(f"无法解析 rosbag 名称/时间: {name}")
|
||||
yy, mm, dd, HH, MN, SS = m.groups()
|
||||
v = name[:8]
|
||||
dt = datetime.strptime(f"{yy}-{mm}-{dd} {HH}:{MN}:{SS}", "%Y-%m-%d %H:%M:%S")
|
||||
return name, {
|
||||
"name": name,
|
||||
"vehicle": v,
|
||||
"datetime": dt,
|
||||
"path": f"{DOMAIN}/mb_cuct_data_collection/{name}",
|
||||
"data_path": f"{DOMAIN}/{PREFIX}{name}.dir",
|
||||
}
|
||||
|
||||
|
||||
def download_txt(name):
|
||||
key = f"{name}.meta_info.txt"
|
||||
try:
|
||||
obj = cos.get_object(Bucket=BUCKET, Key=key)
|
||||
except CosServiceError:
|
||||
return None
|
||||
fd, tmp = tempfile.mkstemp(prefix=name + "_", suffix=".txt")
|
||||
os.close(fd)
|
||||
with open(tmp, "wb") as f:
|
||||
for chunk in obj["Body"].get_raw_stream():
|
||||
f.write(chunk)
|
||||
return tmp
|
||||
|
||||
|
||||
def parse_rosbag_info(fp):
|
||||
tops = []
|
||||
on = False
|
||||
with open(fp, "r", encoding="utf-8") as f:
|
||||
for L in f:
|
||||
if L.startswith("topics:"):
|
||||
on = True
|
||||
continue
|
||||
if on:
|
||||
s = L.strip()
|
||||
if not s:
|
||||
break
|
||||
p = s.split()
|
||||
if len(p) > 3:
|
||||
tops.append((p[0], " ".join(p[3:]).rstrip(":")))
|
||||
return tops
|
||||
|
||||
|
||||
# ==== 读取已成功处理的 bag_name 列表 ====
|
||||
|
||||
|
||||
def load_done(logfile):
|
||||
done = set()
|
||||
if not os.path.exists(logfile):
|
||||
return done
|
||||
pat = re.compile(r"SUCCESS\s+(.+\.bag)\b")
|
||||
with open(logfile, "r", encoding="utf-8") as f:
|
||||
for L in f:
|
||||
m = pat.search(L)
|
||||
if m:
|
||||
done.add(m.group(1))
|
||||
return done
|
||||
|
||||
|
||||
# ==== 单条处理逻辑 ====
|
||||
|
||||
|
||||
def process_one(line):
|
||||
try:
|
||||
bag_name, ent = parse_rosbag_line(line)
|
||||
except ValueError as e:
|
||||
logger.error(e)
|
||||
return
|
||||
|
||||
conn, cur = get_db()
|
||||
try:
|
||||
# 1) bag_list
|
||||
cur.execute(
|
||||
"INSERT INTO bag_list(name,project_id) VALUES(%s,%s) "
|
||||
"ON CONFLICT(name) DO UPDATE SET update_time=NOW() "
|
||||
"RETURNING id",
|
||||
(bag_name, 1),
|
||||
)
|
||||
bag_id = cur.fetchone()[0]
|
||||
|
||||
# 2) main_pangu
|
||||
cur.execute(
|
||||
"INSERT INTO main_pangu(bag_id,name,vehicle,datetime,path,data_path) "
|
||||
"VALUES(%s,%s,%s,%s,%s,%s) "
|
||||
"ON CONFLICT(bag_id) DO UPDATE SET "
|
||||
"name=COALESCE(main_pangu.name,EXCLUDED.name),"
|
||||
"vehicle=COALESCE(main_pangu.vehicle,EXCLUDED.vehicle),"
|
||||
"datetime=COALESCE(main_pangu.datetime,EXCLUDED.datetime),"
|
||||
"path=COALESCE(main_pangu.path,EXCLUDED.path),"
|
||||
"data_path=COALESCE(main_pangu.data_path,EXCLUDED.data_path)",
|
||||
(
|
||||
bag_id,
|
||||
ent["name"],
|
||||
ent["vehicle"],
|
||||
ent["datetime"],
|
||||
ent["path"],
|
||||
ent["data_path"],
|
||||
),
|
||||
)
|
||||
|
||||
# 3) secondary_pangu
|
||||
file_fs = ["raw_gnss.csv", "raw_imu.csv", "ego_motion.csv", "vehicle_wheel.csv"]
|
||||
dir_fs = [
|
||||
"camera_front_wide",
|
||||
"lidar_gt_pandar128",
|
||||
"object_lidar_gt_pandar128_manual",
|
||||
"lidar_fd_multi_scan_raw",
|
||||
"camera_fisheye_left",
|
||||
"camera_fisheye_right",
|
||||
"calibration",
|
||||
]
|
||||
data = {}
|
||||
for f in file_fs:
|
||||
key = f"{PREFIX}{bag_name}.dir/{f}"
|
||||
try:
|
||||
cos.head_object(Bucket=BUCKET, Key=key)
|
||||
data[f] = f"{DOMAIN}/{key}"
|
||||
except CosServiceError:
|
||||
data[f] = None
|
||||
for d in dir_fs:
|
||||
pfx = f"{PREFIX}{bag_name}.dir/{d}/"
|
||||
try:
|
||||
r = cos.list_objects_v2(Bucket=BUCKET, Prefix=pfx, MaxKeys=1)
|
||||
data[d] = f"{DOMAIN}/{pfx}" if r.get("Contents") else None
|
||||
except CosServiceError:
|
||||
data[d] = None
|
||||
data["opt"] = {}
|
||||
cols = ",".join(["bag_id"] + file_fs + dir_fs + ["opt"])
|
||||
vals = [bag_id] + [data[x] for x in file_fs + dir_fs] + [data["opt"]]
|
||||
upd = (
|
||||
",".join(
|
||||
f"{x}=COALESCE(secondary_pangu.{x},EXCLUDED.{x})"
|
||||
for x in file_fs + dir_fs
|
||||
)
|
||||
+ ",opt=EXCLUDED.opt"
|
||||
)
|
||||
cur.execute(
|
||||
f"INSERT INTO secondary_pangu({cols}) VALUES({','.join(['%s'] * len(vals))}) "
|
||||
f"ON CONFLICT(bag_id) DO UPDATE SET {upd}",
|
||||
vals,
|
||||
)
|
||||
|
||||
# 4) topics + 关联
|
||||
txt = download_txt(bag_name)
|
||||
if txt:
|
||||
tops = parse_rosbag_info(txt)
|
||||
os.remove(txt)
|
||||
for nm, ty in tops:
|
||||
cur.execute(
|
||||
"INSERT INTO topic_list(name,type) VALUES(%s,%s) "
|
||||
"ON CONFLICT(name) DO UPDATE "
|
||||
"SET type=COALESCE(topic_list.type,EXCLUDED.type),update_time=NOW() "
|
||||
"RETURNING id",
|
||||
(nm, ty),
|
||||
)
|
||||
tid = cur.fetchone()[0]
|
||||
cur.execute(
|
||||
"INSERT INTO bag_topic(bag_id,topic_id) VALUES(%s,%s) "
|
||||
"ON CONFLICT DO NOTHING",
|
||||
(bag_id, tid),
|
||||
)
|
||||
|
||||
logger.info("SUCCESS %s", bag_name)
|
||||
|
||||
except Exception:
|
||||
logger.exception("FAIL %s", bag_name)
|
||||
finally:
|
||||
release_db(conn, cur)
|
||||
|
||||
|
||||
# ==== 主入口 ====
|
||||
|
||||
|
||||
def main(txtfile, threads):
|
||||
done = load_done(LOGFILE)
|
||||
lines = [L for L in open(txtfile, "r") if L.strip()]
|
||||
tasks = []
|
||||
with ThreadPoolExecutor(max_workers=threads) as ex:
|
||||
for L in lines:
|
||||
try:
|
||||
nm = os.path.basename(L.strip())
|
||||
if not nm.endswith(".bag"):
|
||||
nm += ".bag"
|
||||
except FileExistsError:
|
||||
continue
|
||||
if nm in done:
|
||||
logger.info("SKIP %s", nm)
|
||||
continue
|
||||
tasks.append(ex.submit(process_one, L))
|
||||
|
||||
for _ in tqdm(as_completed(tasks), total=len(tasks), desc="Bags"):
|
||||
pass
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
p = argparse.ArgumentParser()
|
||||
p.add_argument("file", help="ROS bag 列表 txt")
|
||||
p.add_argument("--threads", type=int, default=MAX_WORKERS)
|
||||
args = p.parse_args()
|
||||
main(args.file, args.threads)
|
||||
@@ -0,0 +1,84 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
|
||||
import psycopg2
|
||||
from qcloud_cos import CosConfig, CosS3Client
|
||||
from fst_data_pipeline.core.config_manager import ConfigManager
|
||||
|
||||
# ====== 配置区 ======
|
||||
|
||||
config = ConfigManager()
|
||||
DB_HOST = config.get("ROOT_DB_HOST", "localhost")
|
||||
DB_PORT = config.get_int("ROOT_DB_PORT", 5432)
|
||||
DB_NAME = config.get("ROOT_DB_NAME", "default_dbname")
|
||||
DB_USER = config.get("ROOT_DB_USER", "default_user")
|
||||
DB_PASSWORD = config.get("ROOT_DB_PASSWD", "default_password")
|
||||
|
||||
COS_REGION = config.get("COS_CLIENT_REGION", "default_region")
|
||||
COS_SECRET_ID = config.get("COS_CLIENT_SECRET_ID", "default_id")
|
||||
COS_SECRET_KEY = config.get("COS_CLIENT_SECRET_KEY", "default_key")
|
||||
|
||||
BUCKET = config.get("COS_BUCKET", "b-perception-e2e-1318950322")
|
||||
config.require("COS_PREFIX_DECODE")
|
||||
PREFIX = config.get("COS_PREFIX_DECODE", "mb_raw_rosbag_decode_dirs")
|
||||
DOMAIN = f"https://{BUCKET}.cos.{COS_REGION}.myqcloud.com"
|
||||
|
||||
|
||||
# 输出缺失列表的文件
|
||||
MISSING_FILE = "missing.txt"
|
||||
|
||||
|
||||
def fetch_paths():
|
||||
conn = psycopg2.connect(
|
||||
host=DB_HOST, port=DB_PORT, dbname=DB_NAME, user=DB_USER, password=DB_PASSWORD
|
||||
)
|
||||
with conn.cursor() as cur:
|
||||
cur.execute("SELECT data_path FROM bag_list;")
|
||||
paths = [row[0] for row in cur.fetchall()]
|
||||
conn.close()
|
||||
return paths
|
||||
|
||||
|
||||
def check_dirs(paths):
|
||||
cfg = CosConfig(
|
||||
Secret_id=COS_SECRET_ID, Secret_key=COS_SECRET_KEY, Region=COS_REGION
|
||||
)
|
||||
client = CosS3Client(cfg)
|
||||
exist, missing = [], []
|
||||
for p in paths:
|
||||
prefix = p.rstrip("/") + "/lidar_gt_pandar128/"
|
||||
resp = client.list_objects(Bucket=BUCKET, Prefix=prefix, MaxKeys=1)
|
||||
if resp.get("Contents"):
|
||||
exist.append(p)
|
||||
else:
|
||||
missing.append(p)
|
||||
return exist, missing
|
||||
|
||||
|
||||
def write_missing(missing, filepath):
|
||||
with open(filepath, "w", encoding="utf-8") as f:
|
||||
for p in missing:
|
||||
f.write(p + "\n")
|
||||
|
||||
|
||||
def main():
|
||||
paths = fetch_paths()
|
||||
exist, missing = check_dirs(paths)
|
||||
|
||||
print(f"总路径数: {len(paths)}")
|
||||
print(f"存在 lidar_gt_pandar128 的: {len(exist)}")
|
||||
for p in exist:
|
||||
print(" [OK] ", p)
|
||||
|
||||
print(f"\n缺失 lidar_gt_pandar128 的: {len(missing)}")
|
||||
for p in missing:
|
||||
print(" [MISS]", p)
|
||||
|
||||
# 将缺失列表写入文件
|
||||
write_missing(missing, MISSING_FILE)
|
||||
print(f"\n已将 {len(missing)} 条缺失记录写入 `{MISSING_FILE}`")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
354
fst_data_pipeline/pipelines/tencent/bag_operation/bag_scanner.py
Normal file
354
fst_data_pipeline/pipelines/tencent/bag_operation/bag_scanner.py
Normal file
@@ -0,0 +1,354 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
四种互斥模式:
|
||||
(默认) —— 全量扫描
|
||||
file --tasks-file—— 指定文件
|
||||
service —— HTTP 服务
|
||||
check —— fst_bag 关联并推送缺失
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import subprocess
|
||||
import time
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from pathlib import Path
|
||||
from typing import List, Tuple, Dict, Any
|
||||
|
||||
import psycopg2
|
||||
import requests
|
||||
from flask import Flask, jsonify, request
|
||||
from qcloud_cos import CosConfig, CosS3Client
|
||||
|
||||
from fst_data_pipeline.core.config_manager import ConfigManager
|
||||
|
||||
# ---------- 配置 ----------
|
||||
cfg = ConfigManager()
|
||||
|
||||
DB = {
|
||||
"host": cfg.get("ROOT_DB_HOST"),
|
||||
"port": cfg.get_int("ROOT_DB_PORT"),
|
||||
"dbname": cfg.get("ROOT_DB_NAME"),
|
||||
"user": cfg.get("ROOT_DB_USER"),
|
||||
"password": cfg.get("ROOT_DB_PASSWD"),
|
||||
}
|
||||
|
||||
COS_CFG = CosConfig(
|
||||
Region=cfg.get("COS_CLIENT_REGION"),
|
||||
SecretId=cfg.get("COS_CLIENT_SECRET"),
|
||||
SecretKey=cfg.get("COS_CLIENT_KEY"),
|
||||
)
|
||||
BUCKET = cfg.get("COS_BUCKET")
|
||||
|
||||
LD_READY = cfg.get("LD_READY_URL", "http://ld-checker/ready")
|
||||
OD_READY = cfg.get("OD_READY_URL", "http://od-checker/ready")
|
||||
LD_NOTIFY = cfg.get("LD_NOTIFY_URL", "http://ld-checker/notify")
|
||||
OD_NOTIFY = cfg.get("OD_NOTIFY_URL", "http://od-checker/notify")
|
||||
|
||||
LOCAL_API_HOST = cfg.get("ROOT_DB_API", "http://10.0.240.4:5232")
|
||||
|
||||
MAX_WORKERS = 16
|
||||
logging.basicConfig(level=logging.INFO, format="[%(asctime)s] %(message)s")
|
||||
log = logging.getLogger("bag")
|
||||
|
||||
# ---------- 类型 ----------
|
||||
BagRow = Tuple[
|
||||
int, str, int, bool, int, int
|
||||
] # id, name, project_id, is_decoded, od_old, ld_old
|
||||
|
||||
|
||||
# ---------- 工具 ----------
|
||||
def pg_conn():
|
||||
conn = psycopg2.connect(**DB)
|
||||
conn.autocommit = True
|
||||
return conn, conn.cursor()
|
||||
|
||||
|
||||
def cos_client() -> CosS3Client:
|
||||
return CosS3Client(COS_CFG)
|
||||
|
||||
|
||||
def cos_exists(cli: CosS3Client, prefix: str) -> bool:
|
||||
try:
|
||||
return bool(
|
||||
cli.list_objects(Bucket=BUCKET, Prefix=prefix, MaxKeys=1).get("Contents")
|
||||
)
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
|
||||
# ---------- HTTP ----------
|
||||
def _get_all_fst_bags() -> List[str]:
|
||||
url = f"{LOCAL_API_HOST}/api/fst/baglist"
|
||||
try:
|
||||
r = requests.get(url, timeout=10)
|
||||
r.raise_for_status()
|
||||
return r.json()
|
||||
except Exception as e:
|
||||
log.error("获取 fst bag list 失败: %s", e)
|
||||
return []
|
||||
|
||||
|
||||
def _resolve_paths(bags: List[str]) -> Dict[str, str]:
|
||||
url = f"{LOCAL_API_HOST}/api/bags/pangu/detail"
|
||||
if not bags:
|
||||
return {}
|
||||
try:
|
||||
r = requests.post(url, json=bags, timeout=30)
|
||||
r.raise_for_status()
|
||||
return {
|
||||
name: info["data_path"].split("/", 1)[-1]
|
||||
for name, info in r.json().items()
|
||||
if info.get("data_path")
|
||||
}
|
||||
except Exception as e:
|
||||
log.error("解析 bag path 失败: %s", e)
|
||||
return {}
|
||||
|
||||
|
||||
# ---------- 推送 ----------
|
||||
MISSING_LD: List[str] = []
|
||||
MISSING_OD: List[str] = []
|
||||
|
||||
|
||||
def push_list(ready_url: str, notify_url: str, bag_list: List[str], kind: str) -> None:
|
||||
if not bag_list:
|
||||
return
|
||||
path_map = _resolve_paths(bag_list)
|
||||
if not path_map:
|
||||
log.warning("无可推送的 %s 路径", kind)
|
||||
return
|
||||
paths = [path_map[name] for name in bag_list if name in path_map]
|
||||
|
||||
log.info("%s: %d bags → pushing %d paths", kind, len(bag_list), len(paths))
|
||||
while True:
|
||||
try:
|
||||
if requests.get(ready_url, timeout=5).json().get("ready"):
|
||||
break
|
||||
except Exception as e:
|
||||
log.warning("ready check failed: %s", e)
|
||||
time.sleep(2)
|
||||
try:
|
||||
resp = requests.post(notify_url, json=paths, timeout=30)
|
||||
resp.raise_for_status()
|
||||
log.info("%s pushed %d paths, status=%s", kind, len(paths), resp.status_code)
|
||||
except Exception as e:
|
||||
log.error("%s push failed: %s", kind, e)
|
||||
|
||||
|
||||
# ---------- 核心处理 ----------
|
||||
def process_row(row: BagRow, collect_missing: bool = False) -> None:
|
||||
bag_id, name, project_id, _, od_old, ld_old = row
|
||||
|
||||
# 提前检查是否已“完美”
|
||||
conn, cur = pg_conn()
|
||||
try:
|
||||
cur.execute("SELECT reserved_str FROM main_pangu WHERE name=%s", (name,))
|
||||
reserved_str = (cur.fetchone() or ("",))[0]
|
||||
has_md5 = reserved_str.startswith("md5:")
|
||||
if ld_old == 3 and od_old == 3 and has_md5:
|
||||
log.debug("[%s] 已完整,跳过", name)
|
||||
return
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
# 需要真正检查 COS
|
||||
conn, cur = pg_conn()
|
||||
cli = cos_client()
|
||||
try:
|
||||
base = "mb" if project_id == 1 else "mmt"
|
||||
derived = f"{base}_raw_rosbag_decode_dirs/{name}.dir/derived/"
|
||||
|
||||
ld_auto = cos_exists(cli, derived + "LDGT/")
|
||||
ld_man = cos_exists(cli, derived + "LD_manual/")
|
||||
od_auto = cos_exists(cli, derived + "object_auto_labeling/")
|
||||
od_man = cos_exists(cli, derived + "object_manual_labeling/")
|
||||
|
||||
def calc_flag(a: bool, m: bool) -> int:
|
||||
return 3 if a and m else 2 if m else 1 if a else 0
|
||||
|
||||
ld_new = calc_flag(ld_auto, ld_man)
|
||||
od_new = calc_flag(od_auto, od_man)
|
||||
|
||||
# 更新状态
|
||||
for field, old, new in (
|
||||
("ld_annotated", ld_old, ld_new),
|
||||
("od_annotated", od_old, od_new),
|
||||
):
|
||||
if old != new:
|
||||
cur.execute(
|
||||
f"UPDATE bag_list SET {field}=%s WHERE id=%s", (new, bag_id)
|
||||
)
|
||||
log.info("[%s] %s %s→%s", name, field, old, new)
|
||||
|
||||
# 处理 md5
|
||||
if not has_md5:
|
||||
cur.execute("SELECT bag_path FROM main_pangu WHERE name=%s", (name,))
|
||||
bag_path = (cur.fetchone() or ("",))[0]
|
||||
if bag_path:
|
||||
cos_key = bag_path.lstrip("/")
|
||||
tmp_path = f"/tmp/{name}.bag"
|
||||
try:
|
||||
cli.download_file(Bucket=BUCKET, Key=cos_key, DestFilePath=tmp_path)
|
||||
md5 = subprocess.check_output(
|
||||
["md5sum", tmp_path], text=True
|
||||
).split()[0]
|
||||
os.remove(tmp_path)
|
||||
except Exception as e:
|
||||
log.warning("[%s] 下载/计算 md5 失败: %s", name, e)
|
||||
md5 = None
|
||||
|
||||
if md5:
|
||||
cur.execute(
|
||||
"UPDATE main_pangu SET reserved_str='md5:'||%s WHERE name=%s",
|
||||
(md5, name),
|
||||
)
|
||||
cur.execute(
|
||||
"SELECT reserved_json FROM main_pangu WHERE name=%s", (name,)
|
||||
)
|
||||
reserved_json: Dict[str, Any] = json.loads(
|
||||
cur.fetchone()[0] or "{}"
|
||||
)
|
||||
reserved_json.update({
|
||||
"bag_meta": f"https://cla-dev.ca4ad.com/cdi/data/{md5}",
|
||||
"bag_player": f"https://mviz-dev.ca4ad.com/player/v4/?bag_md5={md5}",
|
||||
})
|
||||
cur.execute(
|
||||
"UPDATE main_pangu SET reserved_json=COALESCE(reserved_json,'{}'::jsonb)||%s WHERE name=%s",
|
||||
(json.dumps(reserved_json, ensure_ascii=False), name),
|
||||
)
|
||||
log.info("[%s] 写入 md5:%s", name, md5)
|
||||
|
||||
# 收集缺失
|
||||
if collect_missing:
|
||||
if not ld_auto:
|
||||
MISSING_LD.append(name)
|
||||
if not od_auto:
|
||||
MISSING_OD.append(name)
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
|
||||
# ---------- 数据获取 ----------
|
||||
def fetch_rows_from_file(path: Path) -> List[BagRow]:
|
||||
with open(path, encoding="utf-8") as f:
|
||||
names = [l.strip() for l in f if l.strip()]
|
||||
if not names:
|
||||
return []
|
||||
conn, cur = pg_conn()
|
||||
cur.execute("SELECT * FROM bag_list WHERE name = ANY(%s)", (names,))
|
||||
rows = cur.fetchall()
|
||||
conn.close()
|
||||
return rows
|
||||
|
||||
|
||||
def fetch_rows_from_all() -> List[BagRow]:
|
||||
conn, cur = pg_conn()
|
||||
cur.execute("SELECT * FROM bag_list")
|
||||
rows = cur.fetchall()
|
||||
conn.close()
|
||||
return rows
|
||||
|
||||
|
||||
def fetch_rows_from_fst() -> List[BagRow]:
|
||||
bag_names = _get_all_fst_bags()
|
||||
if not bag_names:
|
||||
return []
|
||||
conn, cur = pg_conn()
|
||||
cur.execute("SELECT * FROM bag_list WHERE name = ANY(%s)", (bag_names,))
|
||||
rows = cur.fetchall()
|
||||
conn.close()
|
||||
log.info("fst_bag 关联 %d 条", len(rows))
|
||||
return rows
|
||||
|
||||
|
||||
# ---------- 模式封装 ----------
|
||||
def _run(rows: List[BagRow], collect_missing: bool):
|
||||
MISSING_LD.clear()
|
||||
MISSING_OD.clear()
|
||||
with ThreadPoolExecutor(max_workers=MAX_WORKERS) as ex:
|
||||
list(ex.map(lambda r: process_row(r, collect_missing), rows))
|
||||
if collect_missing:
|
||||
push_list(LD_READY, LD_NOTIFY, MISSING_LD, "LD")
|
||||
push_list(OD_READY, OD_NOTIFY, MISSING_OD, "OD")
|
||||
|
||||
|
||||
def run_all():
|
||||
rows = fetch_rows_from_all()
|
||||
_run(rows, collect_missing=False)
|
||||
log.info("全量扫描完成")
|
||||
|
||||
|
||||
def run_file(path: Path):
|
||||
rows = fetch_rows_from_file(path)
|
||||
_run(rows, collect_missing=False)
|
||||
log.info("file 模式完成")
|
||||
|
||||
|
||||
def run_service(host: str, port: int):
|
||||
app = Flask(__name__)
|
||||
|
||||
@app.route("/ready")
|
||||
def ready():
|
||||
return jsonify(ready=True)
|
||||
|
||||
def handle_notify(bags: List[str]):
|
||||
rows = []
|
||||
if bags:
|
||||
conn, cur = pg_conn()
|
||||
cur.execute("SELECT * FROM bag_list WHERE name = ANY(%s)", (bags,))
|
||||
rows = cur.fetchall()
|
||||
conn.close()
|
||||
for r in rows:
|
||||
process_row(r, collect_missing=False)
|
||||
|
||||
@app.route("/notify_ld", methods=["POST"])
|
||||
def notify_ld():
|
||||
bags = request.get_json(force=True)
|
||||
if not isinstance(bags, list):
|
||||
return jsonify(error="need list"), 400
|
||||
handle_notify(bags)
|
||||
return jsonify(status="accepted"), 202
|
||||
|
||||
@app.route("/notify_od", methods=["POST"])
|
||||
def notify_od():
|
||||
bags = request.get_json(force=True)
|
||||
if not isinstance(bags, list):
|
||||
return jsonify(error="need list"), 400
|
||||
handle_notify(bags)
|
||||
return jsonify(status="accepted"), 202
|
||||
|
||||
app.run(host=host, port=port, threaded=True)
|
||||
|
||||
|
||||
def run_check():
|
||||
rows = fetch_rows_from_fst()
|
||||
_run(rows, collect_missing=True)
|
||||
log.info("check 模式完成")
|
||||
|
||||
|
||||
# ---------- CLI ----------
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
sub = parser.add_subparsers(dest="mode", help="运行模式")
|
||||
parser.set_defaults(mode="all")
|
||||
f = sub.add_parser("file", help="读取 bag 列表文件")
|
||||
f.add_argument("--tasks-file", required=True)
|
||||
s = sub.add_parser("service", help="启动 HTTP 服务")
|
||||
s.add_argument("--host", default="0.0.0.0")
|
||||
s.add_argument("--port", type=int, default=8000)
|
||||
sub.add_parser("check", help="仅扫描 fst_bag 关联 bag 并推送缺失")
|
||||
|
||||
args = parser.parse_args()
|
||||
if args.mode == "file":
|
||||
run_file(Path(args.tasks_file))
|
||||
elif args.mode == "service":
|
||||
run_service(args.host, args.port)
|
||||
elif args.mode == "check":
|
||||
run_check()
|
||||
else:
|
||||
run_all()
|
||||
@@ -0,0 +1,221 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
"""
|
||||
-----------------------------------------------------------------------------------
|
||||
功能概述:
|
||||
该脚本用于同步腾讯云 COS(对象存储)上的原始 .bag 文件与 PostgreSQL 中的 bag_list 表。
|
||||
同时根据 COS 上解码输出目录 b/ 下的 .dir 文件, 更新数据库中对应记录的解码状态。
|
||||
支持在数据库同步完成后, 根据 --notify 参数轮询后端服务是否就绪(通过 /ready 接口),
|
||||
并在就绪后将未解码文件列表 POST 到 /notify 接口, 驱动后续解码或处理流程。
|
||||
|
||||
主要流程:
|
||||
1. 连接腾讯云 COS, 列出 a/ 前缀下所有 .bag 文件, 插入数据库(跳过已存在)。
|
||||
2. 查询数据库中未解码且未删除的记录。
|
||||
3. 列出 b/ 前缀下的 .dir 文件, 推断已解码的 .bag 文件, 批量更新数据库标记 is_decoded=TRUE。
|
||||
4. 如果指定 --notify:
|
||||
- 轮询 /ready 接口直到返回 { "ready": true }
|
||||
- 再将剩余未解码文件列表(带前缀)POST 到 /notify 接口。
|
||||
|
||||
依赖环境:
|
||||
- Python 3.6+
|
||||
- psycopg2 (PostgreSQL 驱动)
|
||||
- requests (HTTP 请求)
|
||||
- qcloud_cos (腾讯云 COS Python SDK)
|
||||
|
||||
使用示例:
|
||||
python3 cos_pg_sync.py
|
||||
python3 cos_pg_sync.py --notify
|
||||
|
||||
配置说明:
|
||||
可在脚本开头直接修改:
|
||||
COS_SECRET_ID / COS_SECRET_KEY / COS_REGION / COS_BUCKET
|
||||
PostgreSQL 连接信息 PG_HOST 等
|
||||
GATE_URL / NOTIFY_URL 等服务端接口地址
|
||||
-----------------------------------------------------------------------------------
|
||||
"""
|
||||
|
||||
import os
|
||||
import time
|
||||
import argparse
|
||||
import logging
|
||||
import psycopg2
|
||||
from psycopg2.extras import execute_values
|
||||
import requests
|
||||
from qcloud_cos import CosConfig, CosS3Client
|
||||
from fst_data_pipeline.core.config_manager import ConfigManager
|
||||
|
||||
# ====== 配置区 ======
|
||||
config = ConfigManager()
|
||||
DB_HOST = config.get("ROOT_DB_HOST", "localhost")
|
||||
DB_PORT = config.get_int("ROOT_DB_PORT", 5432)
|
||||
DB_NAME = config.get("ROOT_DB_NAME", "default_dbname")
|
||||
DB_USER = config.get("ROOT_DB_USER", "default_user")
|
||||
DB_PASSWORD = config.get("ROOT_DB_PASSWD", "default_password")
|
||||
|
||||
COS_REGION = config.get("COS_CLIENT_REGION", "default_region")
|
||||
COS_SECRET_ID = config.get("COS_CLIENT_SECRET_ID", "default_id")
|
||||
COS_SECRET_KEY = config.get("COS_CLIENT_SECRET_KEY", "default_key")
|
||||
BUCKET = config.get("COS_BUCKET", "b-perception-e2e-1318950322")
|
||||
|
||||
config.require("COS_PREFIX_DECODE")
|
||||
DECODED_PREFIX = config.get("COS_PREFIX_DECODE", "mb_raw_rosbag_decode_dirs")
|
||||
RAW_BAG_PREFIX = config.get("COS_PREFIX_BAG", "mb_cuct_data_collection")
|
||||
|
||||
DOMAIN = f"https://{BUCKET}.cos.{COS_REGION}.myqcloud.com"
|
||||
|
||||
|
||||
# Service 模式的 HTTP 接口
|
||||
GATE_URL = "https://example.com/ready"
|
||||
NOTIFY_URL = "https://example.com/notify"
|
||||
LOG_FILE = "sync.log"
|
||||
|
||||
|
||||
def init_logger():
|
||||
fmt = "%(asctime)s [%(levelname)s] %(message)s"
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format=fmt,
|
||||
handlers=[
|
||||
logging.StreamHandler(),
|
||||
logging.FileHandler(LOG_FILE, encoding="utf-8"),
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
def get_cos_client():
|
||||
cfg = CosConfig(Region=COS_REGION, SecretId=COS_SECRET_ID, SecretKey=COS_SECRET_KEY)
|
||||
return CosS3Client(cfg)
|
||||
|
||||
|
||||
def list_objs(cos, prefix):
|
||||
"""分页列出 COS 上指定前缀下的所有对象 key"""
|
||||
keys = []
|
||||
token = None
|
||||
while True:
|
||||
params = {
|
||||
"Bucket": BUCKET,
|
||||
"Prefix": prefix,
|
||||
"MaxKeys": 1000,
|
||||
}
|
||||
if token:
|
||||
params["ContinuationToken"] = token
|
||||
resp = cos.list_objects_v2(**params)
|
||||
for obj in resp.get("Contents", []):
|
||||
keys.append(obj["Key"])
|
||||
if resp.get("IsTruncated"):
|
||||
token = resp["NextContinuationToken"]
|
||||
else:
|
||||
break
|
||||
return keys
|
||||
|
||||
|
||||
def wait_for_ready():
|
||||
"""轮询 /ready, 直到返回 JSON { "ready": true }"""
|
||||
logging.info(f"开始轮询 Gate URL: {GATE_URL}")
|
||||
while True:
|
||||
try:
|
||||
r = requests.get(GATE_URL, timeout=10)
|
||||
if r.status_code == 200:
|
||||
data = r.json()
|
||||
if isinstance(data, dict) and data.get("ready") is True:
|
||||
logging.info("Gate 就绪, 开始通知流程")
|
||||
return
|
||||
logging.info(
|
||||
"Gate 未就绪(HTTP %d 或 ready=false), 10 分钟后重试", r.status_code
|
||||
)
|
||||
except Exception as e:
|
||||
logging.warning(f"轮询 Gate 时发生异常: {e}")
|
||||
time.sleep(600) # 10 分钟后重试
|
||||
|
||||
|
||||
def main():
|
||||
init_logger()
|
||||
|
||||
p = argparse.ArgumentParser(
|
||||
description="同步 COS 与 bag_list, 并可选通知未解码列表"
|
||||
)
|
||||
p.add_argument(
|
||||
"--notify",
|
||||
"-n",
|
||||
action="store_true",
|
||||
help="轮询 /ready, 再 POST 剩余未解码列表(带前缀)",
|
||||
)
|
||||
args = p.parse_args()
|
||||
|
||||
# 1. 初始化 COS 客户端和数据库连接
|
||||
cos = get_cos_client()
|
||||
conn = psycopg2.connect(
|
||||
host=DB_HOST, port=DB_PORT, user=DB_USER, password=DB_PASSWORD, dbname=DB_NAME
|
||||
)
|
||||
cur = conn.cursor()
|
||||
|
||||
# —— STEP1:同步 a/ 下的 .bag 到 bag_list(name, project_id) ——
|
||||
keys_a = list_objs(cos, RAW_BAG_PREFIX)
|
||||
# 提取所有 .bag 的名字(不含前缀)
|
||||
names = [os.path.basename(k) for k in keys_a if k.lower().endswith(".bag")]
|
||||
if names:
|
||||
# 默认 project_id=1
|
||||
records = [(n, 1) for n in names]
|
||||
execute_values(
|
||||
cur,
|
||||
"INSERT INTO bag_list(name, project_id) VALUES %s "
|
||||
"ON CONFLICT(name) DO NOTHING",
|
||||
records,
|
||||
)
|
||||
logging.info(f"STEP1: 尝试插入 {len(records)} 条记录, 冲突则跳过")
|
||||
else:
|
||||
logging.info("STEP1: 未发现 .bag 文件, 无需插入")
|
||||
|
||||
# —— STEP2:查询所有 is_decoded=FALSE 且 is_deleted=FALSE 的 name ——
|
||||
cur.execute("SELECT name FROM bag_list WHERE is_decoded=FALSE AND is_deleted=FALSE")
|
||||
undecoded = [row[0] for row in cur.fetchall()]
|
||||
logging.info(f"STEP2: 当前未解码记录数:{len(undecoded)}")
|
||||
|
||||
# —— STEP3:扫描 b/ 下 .dir, 并批量更新 is_decoded ——
|
||||
keys_b = list_objs(cos, DECODED_PREFIX)
|
||||
# available 是所有已解码的 bag 名字(带 .bag 后缀)
|
||||
available = {
|
||||
os.path.basename(k)[:-4] + ".bag" for k in keys_b if k.lower().endswith(".dir")
|
||||
}
|
||||
to_update = list(set(undecoded) & available)
|
||||
if to_update:
|
||||
cur.execute(
|
||||
"UPDATE bag_list "
|
||||
"SET is_decoded=TRUE, update_time=CURRENT_TIMESTAMP "
|
||||
"WHERE name = ANY(%s)",
|
||||
(to_update,),
|
||||
)
|
||||
logging.info(f"STEP3: 标记已解码 {len(to_update)} 条记录")
|
||||
else:
|
||||
logging.info("STEP3: 无需更新解码状态")
|
||||
|
||||
conn.commit()
|
||||
|
||||
# —— 可选通知流程 ——
|
||||
if args.notify:
|
||||
# 1) 等待服务端就绪
|
||||
wait_for_ready()
|
||||
|
||||
# 2) 再次查询剩余未解码名称
|
||||
cur.execute(
|
||||
"SELECT name FROM bag_list WHERE is_decoded=FALSE AND is_deleted=FALSE"
|
||||
)
|
||||
remaining = [row[0] for row in cur.fetchall()]
|
||||
logging.info(f"通知前剩余未解码 {len(remaining)} 条: {remaining}")
|
||||
|
||||
# 3) 补上前缀, POST 给 /notify
|
||||
paths = [RAW_BAG_PREFIX + name for name in remaining]
|
||||
try:
|
||||
resp = requests.post(NOTIFY_URL, json=paths, timeout=10)
|
||||
logging.info(f"POST /notify 返回 HTTP {resp.status_code}: {resp.text}")
|
||||
except Exception as e:
|
||||
logging.error(f"通知失败: {e}")
|
||||
|
||||
# 清理资源
|
||||
cur.close()
|
||||
conn.close()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -0,0 +1,216 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
import os
|
||||
import re
|
||||
import shutil
|
||||
import subprocess
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from concurrent.futures import ProcessPoolExecutor, as_completed
|
||||
|
||||
import requests
|
||||
import psycopg2
|
||||
|
||||
# =========================================================
|
||||
# 日志
|
||||
# =========================================================
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format="%(asctime)s | %(levelname)s | %(message)s",
|
||||
handlers=[logging.FileHandler("bag_merge.log"), logging.StreamHandler()],
|
||||
)
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
# =========================================================
|
||||
# 配置(全部有默认值)
|
||||
# =========================================================
|
||||
|
||||
# API
|
||||
API_URL = os.getenv("API_URL", "http://127.0.0.1:8080/api/bag/mapping")
|
||||
API_TIMEOUT = int(os.getenv("API_TIMEOUT", "30"))
|
||||
|
||||
# PostgreSQL
|
||||
PG_DSN = os.getenv(
|
||||
"PG_DSN",
|
||||
"host=127.0.0.1 port=5432 dbname=test user=test password=test",
|
||||
)
|
||||
|
||||
# 本地临时目录
|
||||
TEMP_ROOT = Path(os.getenv("TEMP_ROOT", "/tmp/bag_merge"))
|
||||
|
||||
# coscmd
|
||||
COSCMD_BIN = os.getenv("COSCMD_BIN", "coscmd")
|
||||
COSCMD_TIMEOUT = int(os.getenv("COSCMD_TIMEOUT", "3600"))
|
||||
|
||||
# 并发数
|
||||
MAX_WORKERS = int(os.getenv("MAX_WORKERS", "4"))
|
||||
|
||||
# 用于拼可访问 URL(仅用于存 DB,不影响 coscmd)
|
||||
COS_ENDPOINT = os.getenv("COS_ENDPOINT", "")
|
||||
COS_BUCKET = os.getenv("COS_BUCKET", "")
|
||||
COS_REGION = os.getenv("COS_REGION", "")
|
||||
|
||||
|
||||
# =========================================================
|
||||
# 工具函数
|
||||
# =========================================================
|
||||
def safe_name(key: str) -> str:
|
||||
"""
|
||||
将 COS key 转成本地安全文件名
|
||||
"""
|
||||
k = (key or "").strip().replace("\\", "/")
|
||||
k = re.sub(r"/+", "/", k)
|
||||
k = k.replace("..", "__")
|
||||
k = k.replace("/", "__")
|
||||
return k or "empty"
|
||||
|
||||
|
||||
def run_cmd(cmd: list[str], *, timeout: int | None = None):
|
||||
log.debug("CMD: %s", " ".join(cmd))
|
||||
subprocess.run(cmd, check=True, timeout=timeout)
|
||||
|
||||
|
||||
def make_cos_url(key: str) -> str:
|
||||
"""
|
||||
生成一个用于入库的 URL
|
||||
"""
|
||||
k = key.lstrip("/")
|
||||
if COS_ENDPOINT:
|
||||
return f"https://{COS_ENDPOINT.rstrip('/')}/{k}"
|
||||
if COS_BUCKET and COS_REGION:
|
||||
return f"https://{COS_BUCKET}.cos.{COS_REGION}.myqcloud.com/{k}"
|
||||
return key
|
||||
|
||||
|
||||
# =========================================================
|
||||
# 业务函数
|
||||
# =========================================================
|
||||
def fetch_mapping() -> dict:
|
||||
"""
|
||||
从 API 获取 parent -> children 映射
|
||||
"""
|
||||
log.info("POST %s", API_URL)
|
||||
resp = requests.post(
|
||||
API_URL,
|
||||
json={"bag_names": ["*"]},
|
||||
headers={"Content-Type": "application/json"},
|
||||
timeout=API_TIMEOUT,
|
||||
)
|
||||
resp.raise_for_status()
|
||||
data = resp.json()
|
||||
if not isinstance(data, dict):
|
||||
raise RuntimeError("mapping response is not a dict")
|
||||
return data
|
||||
|
||||
|
||||
def download_file(key: str, local: Path):
|
||||
"""
|
||||
coscmd download /xxx local
|
||||
"""
|
||||
local.parent.mkdir(parents=True, exist_ok=True)
|
||||
cos_path = "/" + key.lstrip("/")
|
||||
|
||||
log.info("↓ download %s -> %s", cos_path, local)
|
||||
run_cmd(
|
||||
[COSCMD_BIN, "download", cos_path, str(local)],
|
||||
timeout=COSCMD_TIMEOUT,
|
||||
)
|
||||
|
||||
|
||||
def upload_file(local: Path, key: str) -> str:
|
||||
"""
|
||||
coscmd upload local /xxx
|
||||
"""
|
||||
cos_path = "/" + key.lstrip("/")
|
||||
|
||||
log.info("↑ upload %s -> %s", local, cos_path)
|
||||
run_cmd(
|
||||
[COSCMD_BIN, "upload", str(local), cos_path],
|
||||
timeout=COSCMD_TIMEOUT,
|
||||
)
|
||||
return make_cos_url(key)
|
||||
|
||||
|
||||
def merge_bags(inputs: list[Path], output: Path):
|
||||
"""
|
||||
调用 rosbag-merge
|
||||
"""
|
||||
output.parent.mkdir(parents=True, exist_ok=True)
|
||||
subprocess.check_call(
|
||||
["rosbag-merge", "-o", str(output)] + [str(p) for p in inputs]
|
||||
)
|
||||
|
||||
|
||||
def update_db(parent: str, cos_url: str):
|
||||
"""
|
||||
更新数据库
|
||||
"""
|
||||
sql = "UPDATE bag_task SET tos_path = %s WHERE parent_bag = %s"
|
||||
with psycopg2.connect(PG_DSN) as conn:
|
||||
with conn.cursor() as cur:
|
||||
cur.execute(sql, (cos_url, parent))
|
||||
conn.commit()
|
||||
log.info("[DB] %s -> %s", parent, cos_url)
|
||||
|
||||
|
||||
def work_one(parent: str, children: list[str]) -> str:
|
||||
"""
|
||||
单个 parent 的完整处理流程
|
||||
"""
|
||||
log.info("start parent=%s children=%d", parent, len(children))
|
||||
|
||||
wd = TEMP_ROOT / safe_name(parent)
|
||||
wd.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
try:
|
||||
subs: list[Path] = []
|
||||
for c in children:
|
||||
lp = wd / safe_name(c)
|
||||
download_file(c, lp)
|
||||
subs.append(lp)
|
||||
|
||||
out = wd / safe_name(parent)
|
||||
merge_bags(subs, out)
|
||||
|
||||
url = upload_file(out, parent)
|
||||
update_db(parent, url)
|
||||
|
||||
log.info("finish parent=%s", parent)
|
||||
return url
|
||||
finally:
|
||||
shutil.rmtree(wd, ignore_errors=True)
|
||||
|
||||
|
||||
# =========================================================
|
||||
# 主入口
|
||||
# =========================================================
|
||||
def main():
|
||||
TEMP_ROOT.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# 可执行文件检查(有问题会在启动阶段直接失败)
|
||||
run_cmd([COSCMD_BIN, "--version"], timeout=10)
|
||||
run_cmd(["rosbag-merge", "-h"], timeout=10)
|
||||
|
||||
mapping = fetch_mapping()
|
||||
if not mapping:
|
||||
log.warning("mapping empty, exit")
|
||||
return
|
||||
|
||||
with ProcessPoolExecutor(max_workers=MAX_WORKERS) as pool:
|
||||
futures = {
|
||||
pool.submit(work_one, parent, children): parent
|
||||
for parent, children in mapping.items()
|
||||
}
|
||||
|
||||
for fu in as_completed(futures):
|
||||
parent = futures[fu]
|
||||
try:
|
||||
url = fu.result()
|
||||
log.info("done %s -> %s", parent, url)
|
||||
except Exception as e:
|
||||
log.exception("failed %s: %s", parent, e)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -0,0 +1,7 @@
|
||||
# auto copy MMT CLA rosbag to MB dev bucket
|
||||
|
||||
This is Tencent cloud function to filter the bag name and copy the bag from MMT CLA CDI storage bucket to MB dev bucket.
|
||||
|
||||
2 cloud functions are used in auto_copy phase:
|
||||
* filter_bag_to_kafka.py: filter the rosbag with name rule, send bag list to kafka
|
||||
* copy_cdi_cos_bag.py: listen to kafka queue and perform cloud bucket data copy
|
||||
@@ -0,0 +1,81 @@
|
||||
# -*- coding=utf-8
|
||||
|
||||
|
||||
#####----------------------------------------------------------------#####
|
||||
##### #####
|
||||
##### 使用教程/readme: #####
|
||||
##### https://cloud.tencent.com/document/product/583/30722 #####
|
||||
##### #####
|
||||
#####----------------------------------------------------------------#####
|
||||
|
||||
import os
|
||||
import logging
|
||||
from qcloud_cos import CosConfig
|
||||
from qcloud_cos import CosS3Client
|
||||
from qcloud_cos import CosServiceError
|
||||
import ast
|
||||
|
||||
# Setting user properties, including secret_id, secret_key, region, bucket
|
||||
appid = "1318950322" # Please replace with your APPID. 请替换为您的 APPID
|
||||
region = "ap-shanghai-adc" # Please replace with your region. 替换为用户的region
|
||||
|
||||
secret_id = os.environ.get("TENCENTCLOUD_SECRETID")
|
||||
secret_key = os.environ.get("TENCENTCLOUD_SECRETKEY")
|
||||
token = os.environ.get("TENCENTCLOUD_SESSIONTOKEN")
|
||||
|
||||
bucket = "b-perception-e2e-1318950322" # Please replace with your COS bucket. 替换为需要写入的COS Bucket
|
||||
folder = "mb_cuct_data_collection/"
|
||||
# Getting configuration object. 获取配置对象
|
||||
config = CosConfig(
|
||||
Region=region, Secret_id=secret_id, Secret_key=secret_key, Token=token
|
||||
)
|
||||
client = CosS3Client(config)
|
||||
logger = logging.getLogger()
|
||||
|
||||
|
||||
def copy_file(bucket_upload, key, src_key):
|
||||
response1 = client.object_exists(Bucket=bucket_upload, Key=key)
|
||||
if response1:
|
||||
try:
|
||||
destKey = folder + src_key
|
||||
response2 = client.copy(
|
||||
Bucket=bucket,
|
||||
Key=destKey,
|
||||
CopySource={"Bucket": bucket_upload, "Key": key, "Region": region},
|
||||
CopyStatus="Replaced",
|
||||
PartSize=50,
|
||||
MAXThread=10,
|
||||
)
|
||||
return True
|
||||
except CosServiceError as e:
|
||||
print("e is", e)
|
||||
return False
|
||||
else:
|
||||
print("resource file does not exist")
|
||||
return False
|
||||
|
||||
|
||||
def main_handler(event, context):
|
||||
bucket_upload = ""
|
||||
logger.info("start main handler")
|
||||
for record in event["Records"]:
|
||||
if "Ckafka" not in record.keys():
|
||||
print("event: no ckafka")
|
||||
continue
|
||||
value = record["Ckafka"]["msgBody"]
|
||||
try:
|
||||
data = ast.literal_eval(value)
|
||||
if "cos" not in data.keys():
|
||||
print("event: no cos")
|
||||
continue
|
||||
bucket_upload = data["cos"]["cosBucket"]["name"]
|
||||
key = record["Ckafka"]["msgKey"]
|
||||
src_key = key.split("/")[-1]
|
||||
print("file name is ", src_key)
|
||||
except:
|
||||
print("msgBody:", value)
|
||||
return "message error"
|
||||
if copy_file(bucket_upload, key, src_key):
|
||||
print("copy success")
|
||||
else:
|
||||
print("copy fail")
|
||||
@@ -0,0 +1,94 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- coding=utf-
|
||||
|
||||
import logging
|
||||
import time
|
||||
import json
|
||||
from kafka import KafkaProducer
|
||||
from kafka.errors import KafkaError
|
||||
|
||||
logger = logging.getLogger("COSToKafka")
|
||||
logger.setLevel(logging.INFO)
|
||||
|
||||
# file index which need copy
|
||||
file_index = "PL162802"
|
||||
skip_keywords = ["recording_lpnp_recording", "recording_ddloc_recording"]
|
||||
|
||||
|
||||
class CosToKafka(object):
|
||||
def __init__(self, host, **kwargs):
|
||||
self.host = host
|
||||
|
||||
self.producer = KafkaProducer(
|
||||
bootstrap_servers=[self.host],
|
||||
# retries = 10,
|
||||
# max_in_flight_requests_per_connection = 1,
|
||||
# request_timeout_ms = 30000,
|
||||
# max_block_ms = 60000,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def send(self, topic, event):
|
||||
global count
|
||||
count = 0
|
||||
|
||||
def on_send_success(record_metadata):
|
||||
global count
|
||||
count = count + 1
|
||||
|
||||
def on_send_error(excp):
|
||||
logger.error("failed to send message", exc_info=excp)
|
||||
|
||||
s_time = time.time()
|
||||
|
||||
eventList = None
|
||||
try:
|
||||
if "Records" in event:
|
||||
eventList = event["Records"]
|
||||
for data in eventList:
|
||||
if "cos" in data:
|
||||
key_list = data["cos"]["cosObject"]["key"].split("/")
|
||||
file_name = str(key_list[-1])
|
||||
if file_name.startswith(file_index):
|
||||
if any(kw in file_name for kw in skip_keywords):
|
||||
print("skip by blacklist:", file_name)
|
||||
continue
|
||||
|
||||
if file_name.endswith(".bag"):
|
||||
today_str = time.strftime("%Y%m%d", time.localtime())
|
||||
with open(
|
||||
f"{today_str}.txt", "a", encoding="utf-8"
|
||||
) as f:
|
||||
f.write(file_name + "\n")
|
||||
|
||||
_key = "/".join(key_list[3:])
|
||||
# "cos": {"cosBucket": {"appid": "1324295915", "cosRegion": "ap-shanghai", "name": "dis-source-cfdi-test", "region": "sh", "s3Region": "ap-shanghai"}
|
||||
# 返回的value中, name并不是真的bucket的name, 用这个那么去访问会出问题的
|
||||
# 真实的Bucket名字是 $name = $name-$appid
|
||||
_name = data["cos"]["cosBucket"]["name"]
|
||||
_appid = data["cos"]["cosBucket"]["appid"]
|
||||
data["cos"]["cosBucket"]["name"] = _name + "-" + _appid
|
||||
|
||||
print("bucket:", _name)
|
||||
print("ket:", _key)
|
||||
key = _key.encode("utf-8")
|
||||
value = json.dumps(data).encode("utf-8")
|
||||
|
||||
self.producer.send(
|
||||
topic, key=key, value=value
|
||||
).add_callback(on_send_success).add_errback(on_send_error)
|
||||
else:
|
||||
print("file index not match", str(key_list[-1]))
|
||||
else:
|
||||
print("message error")
|
||||
# block until all async messages are sent
|
||||
self.producer.flush()
|
||||
except KafkaError as e:
|
||||
return e
|
||||
finally:
|
||||
if self.producer is not None:
|
||||
self.producer.close()
|
||||
|
||||
e_time = time.time()
|
||||
|
||||
return "{} messages delivered in {}s".format(count, e_time - s_time)
|
||||
@@ -0,0 +1,30 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- coding=utf-8
|
||||
|
||||
import os
|
||||
import logging
|
||||
from cos_to_kafka import CosToKafka
|
||||
|
||||
logger = logging.getLogger("Index")
|
||||
logger.setLevel(logging.INFO)
|
||||
|
||||
|
||||
def main_handler(event, context):
|
||||
logger.info("start main handler")
|
||||
os.environ["KAFKA_ADDRESS"] = "10.0.210.45:9092"
|
||||
os.environ["KAFKA_TOPIC_NAME"] = "synchronous_copy_release_1.0"
|
||||
kafka_address = os.getenv("KAFKA_ADDRESS")
|
||||
kafka_topic_name = os.getenv("KAFKA_TOPIC_NAME")
|
||||
|
||||
cos_to_kafka = CosToKafka(
|
||||
kafka_address,
|
||||
# security_protocol = "PLAINTEXT",
|
||||
# sasl_mechanism = "PLAIN",
|
||||
# sasl_plain_username = "ckafka-80o10xxx#lkoxx",
|
||||
# sasl_plain_password = "kongllxxxx",
|
||||
api_version=(1, 1, 1),
|
||||
)
|
||||
|
||||
ret = cos_to_kafka.send(kafka_topic_name, event)
|
||||
logger.info(ret)
|
||||
return ret
|
||||
439
fst_data_pipeline/pipelines/tencent/decoder.py
Normal file
439
fst_data_pipeline/pipelines/tencent/decoder.py
Normal file
@@ -0,0 +1,439 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
import argparse
|
||||
import threading
|
||||
import subprocess
|
||||
import shutil
|
||||
import logging
|
||||
from datetime import datetime
|
||||
import queue
|
||||
import re
|
||||
|
||||
from prometheus_client import start_http_server, Counter, Gauge, Summary, Histogram
|
||||
from flask import Flask, request, jsonify
|
||||
|
||||
# —— 常量 & 路径 —— #
|
||||
BASE = os.getcwd()
|
||||
INPUT_ROOT = os.path.join(BASE, "input")
|
||||
OUTPUT_ROOT = os.path.join(BASE, "output")
|
||||
EMPTY_DIR = os.path.join(BASE, "empty")
|
||||
LOG_DIR = os.path.join(BASE, "logs")
|
||||
|
||||
COS_BUCKET = "mb_raw_rosbag_decode_dirs"
|
||||
DOCKER_IMAGE = (
|
||||
"artifact.swfcn.i.mercedes-benz.com/swfcn_docker/perception-3d/mmtbag_decoder:v6.6"
|
||||
)
|
||||
DOCKER_CMD_TEMPLATE = [
|
||||
"docker",
|
||||
"run",
|
||||
"--rm",
|
||||
"-v",
|
||||
"{in_dir}:/input",
|
||||
"-v",
|
||||
"{out_dir}:/output",
|
||||
DOCKER_IMAGE,
|
||||
"bash",
|
||||
"-c",
|
||||
"source /opt/ros/noetic/setup.bash && "
|
||||
"/opt/perception-3d/scripts/tools/"
|
||||
"mmt_bag_decoder_scripts/decoded-bag.sh /input /output 3 1",
|
||||
]
|
||||
|
||||
BATCH_SIZE = 50
|
||||
MAX_LOCAL = 100
|
||||
MAX_RETRIES = 3
|
||||
RETRY_DELAY_S = 2
|
||||
METRICS_PORT = 8000
|
||||
|
||||
SENTINEL = (None, None)
|
||||
|
||||
# —— 日志配置 —— #
|
||||
os.makedirs(LOG_DIR, exist_ok=True)
|
||||
logger = logging.getLogger("pipeline")
|
||||
logger.setLevel(logging.INFO)
|
||||
h_info = logging.FileHandler(os.path.join(LOG_DIR, "pipeline.log"), encoding="utf-8")
|
||||
h_err = logging.FileHandler(os.path.join(LOG_DIR, "error_tasks.log"), encoding="utf-8")
|
||||
fmt = logging.Formatter("%(asctime)s %(levelname)s %(message)s")
|
||||
h_info.setFormatter(fmt)
|
||||
h_err.setFormatter(fmt)
|
||||
h_err.setLevel(logging.ERROR)
|
||||
logger.addHandler(h_info)
|
||||
logger.addHandler(h_err)
|
||||
|
||||
# —— Prometheus 指标 —— #
|
||||
DL_TOTAL = Counter("pipeline_download_total", "下载尝试总数")
|
||||
DL_FAIL = Counter("pipeline_download_failures", "下载失败总数")
|
||||
DL_RETRY = Counter("pipeline_download_retries", "下载重试总数")
|
||||
PR_TOTAL = Counter("pipeline_process_total", "处理尝试总数")
|
||||
PR_FAIL = Counter("pipeline_process_failures", "处理失败总数")
|
||||
PR_RETRY = Counter("pipeline_process_retries", "处理重试总数")
|
||||
UP_TOTAL = Counter("pipeline_upload_total", "上传尝试总数")
|
||||
UP_FAIL = Counter("pipeline_upload_failures", "上传失败总数")
|
||||
UP_RETRY = Counter("pipeline_upload_retries", "上传重试总数")
|
||||
|
||||
DL_DUR = Summary("pipeline_download_duration_seconds", "单批下载耗时秒")
|
||||
PR_DUR = Summary("pipeline_process_duration_seconds", "单批处理耗时秒")
|
||||
UP_DUR = Summary("pipeline_upload_duration_seconds", "单批上传耗时秒")
|
||||
|
||||
BATCH_SIZE_HIST = Histogram(
|
||||
"pipeline_batch_size",
|
||||
"单批任务中文件数量分布",
|
||||
buckets=[1, 10, 20, 50, 100, 200, 500],
|
||||
)
|
||||
FILE_DL_DUR = Histogram("pipeline_file_download_duration_seconds", "单文件下载耗时分布")
|
||||
BATCH_OUT_FILES = Gauge("pipeline_batch_output_file_count", "单批处理后输出文件数")
|
||||
|
||||
Q_BATCH = Gauge("pipeline_queue_batches", "待下载批次数")
|
||||
Q_PROC = Gauge("pipeline_queue_processing", "待处理批次数")
|
||||
Q_UP = Gauge("pipeline_queue_uploading", "待上传批次数")
|
||||
LOCAL_FILES = Gauge("pipeline_local_file_count", "本地 input 文件总数")
|
||||
|
||||
# —— 全局队列 & inflight 计数 —— #
|
||||
batch_q = queue.Queue()
|
||||
proc_q = queue.Queue()
|
||||
up_q = queue.Queue()
|
||||
|
||||
inflight = 0
|
||||
inflight_lock = threading.Lock()
|
||||
|
||||
|
||||
# —— 辅助函数 —— #
|
||||
def count_local_files():
|
||||
return sum(len(files) for _, _, files in os.walk(INPUT_ROOT))
|
||||
|
||||
|
||||
def run(cmd, timeout=None):
|
||||
logger.info("CMD: %s", " ".join(cmd))
|
||||
try:
|
||||
p = subprocess.Popen(
|
||||
cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True
|
||||
)
|
||||
start = time.time()
|
||||
out_lines = []
|
||||
timed_out = False
|
||||
for line in p.stdout:
|
||||
out_lines.append(line)
|
||||
if timeout and (time.time() - start) > timeout:
|
||||
p.kill()
|
||||
timed_out = True
|
||||
break
|
||||
code = p.wait()
|
||||
return code, timed_out, "".join(out_lines)
|
||||
except Exception:
|
||||
logger.exception("CMD 执行异常")
|
||||
return -1, False, ""
|
||||
|
||||
|
||||
def with_retry(tag, fn, *args):
|
||||
for i in range(1, MAX_RETRIES + 1):
|
||||
code, timed_out, _ = fn(*args)
|
||||
if code == 0:
|
||||
return True
|
||||
if timed_out:
|
||||
logger.error("%s 阶段超时,不再重试", tag)
|
||||
break
|
||||
# 计数重试
|
||||
if tag.startswith("DL["):
|
||||
DL_RETRY.inc()
|
||||
if tag.startswith("PR["):
|
||||
PR_RETRY.inc()
|
||||
if tag.startswith("UP["):
|
||||
UP_RETRY.inc()
|
||||
logger.warning("%s 重试 %d/%d", tag, i, MAX_RETRIES)
|
||||
time.sleep(RETRY_DELAY_S)
|
||||
logger.error("%s 最终失败", tag)
|
||||
return False
|
||||
|
||||
|
||||
# —— 下载 —— #
|
||||
@DL_DUR.time()
|
||||
def do_download(batch_id, paths, batch_timeout):
|
||||
if batch_id is None:
|
||||
proc_q.put(SENTINEL)
|
||||
return
|
||||
|
||||
DL_TOTAL.inc()
|
||||
start = time.time()
|
||||
in_dir = os.path.join(INPUT_ROOT, batch_id)
|
||||
os.makedirs(in_dir, exist_ok=True)
|
||||
|
||||
# 限制本地文件数
|
||||
while count_local_files() >= MAX_LOCAL:
|
||||
logger.warning("本地文件过多,暂停下载5分钟")
|
||||
time.sleep(300)
|
||||
|
||||
for p in paths:
|
||||
if time.time() - start > batch_timeout:
|
||||
logger.error("DL[%s] 下载阶段超时,跳过剩余", batch_id)
|
||||
DL_FAIL.inc()
|
||||
break
|
||||
dst = os.path.join(in_dir, os.path.basename(p))
|
||||
f_start = time.time()
|
||||
ok = with_retry(
|
||||
f"DL[{batch_id}]",
|
||||
lambda s, d: run(
|
||||
["coscmd", "-s", "download", s, d],
|
||||
timeout=batch_timeout - (time.time() - start),
|
||||
),
|
||||
p,
|
||||
dst,
|
||||
)
|
||||
FILE_DL_DUR.observe(time.time() - f_start)
|
||||
if not ok:
|
||||
DL_FAIL.inc()
|
||||
|
||||
proc_q.put((batch_id, in_dir))
|
||||
|
||||
|
||||
# —— 处理 —— #
|
||||
@PR_DUR.time()
|
||||
def do_process(batch_id, in_dir, batch_timeout):
|
||||
if batch_id is None:
|
||||
up_q.put(SENTINEL)
|
||||
return
|
||||
|
||||
PR_TOTAL.inc()
|
||||
start = time.time()
|
||||
out_dir = os.path.join(OUTPUT_ROOT, batch_id)
|
||||
os.makedirs(out_dir, exist_ok=True)
|
||||
|
||||
cmd = [c.format(in_dir=in_dir, out_dir=out_dir) for c in DOCKER_CMD_TEMPLATE]
|
||||
|
||||
def run_pr(command):
|
||||
p = subprocess.Popen(
|
||||
command, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True
|
||||
)
|
||||
for line in p.stdout:
|
||||
logger.info("[PR %s] %s", batch_id, line.rstrip())
|
||||
if time.time() - start > batch_timeout:
|
||||
p.kill()
|
||||
return p.wait(), True, ""
|
||||
return p.wait(), False, ""
|
||||
|
||||
ok = with_retry(f"PR[{batch_id}]", run_pr, cmd)
|
||||
if not ok:
|
||||
PR_FAIL.inc()
|
||||
|
||||
# 统计输出文件
|
||||
files = []
|
||||
for r, _, fs in os.walk(out_dir):
|
||||
for fn in fs:
|
||||
files.append(os.path.relpath(os.path.join(r, fn), out_dir))
|
||||
BATCH_OUT_FILES.set(len(files))
|
||||
|
||||
shutil.rmtree(in_dir, ignore_errors=True)
|
||||
up_q.put((batch_id, out_dir))
|
||||
|
||||
|
||||
# —— 上传 —— #
|
||||
@UP_DUR.time()
|
||||
def do_upload(batch_id, out_dir, batch_timeout):
|
||||
global inflight
|
||||
if batch_id is None:
|
||||
return
|
||||
|
||||
try:
|
||||
UP_TOTAL.inc()
|
||||
ok = with_retry(
|
||||
f"UP[{batch_id}]",
|
||||
lambda d: run(
|
||||
["coscmd", "-s", "upload", "-r", d, COS_BUCKET], timeout=batch_timeout
|
||||
),
|
||||
out_dir,
|
||||
)
|
||||
if not ok:
|
||||
UP_FAIL.inc()
|
||||
return
|
||||
|
||||
# 删除目录结构
|
||||
for cmd in [
|
||||
["sudo", "rsync", "-av", "--delete", f"{EMPTY_DIR}/", f"{out_dir}/"],
|
||||
["sudo", "rm", "-rf", out_dir],
|
||||
]:
|
||||
run(cmd, timeout=60)
|
||||
|
||||
logger.info("UP[%s] 完成", batch_id)
|
||||
finally:
|
||||
# 无论成功失败,任务算完成,inflight-1
|
||||
with inflight_lock:
|
||||
inflight -= 1
|
||||
|
||||
|
||||
# —— Worker 模板 —— #
|
||||
def worker(q, fn, timeout):
|
||||
while True:
|
||||
bid, data = q.get()
|
||||
fn(bid, data, timeout)
|
||||
q.task_done()
|
||||
if bid is None:
|
||||
break
|
||||
|
||||
|
||||
# —— Service HTTP —— #
|
||||
app = Flask(__name__)
|
||||
|
||||
|
||||
@app.route("/ready", methods=["GET"])
|
||||
def api_ready():
|
||||
with inflight_lock:
|
||||
busy = inflight > 0
|
||||
return jsonify(ready=not busy)
|
||||
|
||||
|
||||
@app.route("/notify", methods=["POST"])
|
||||
def api_notify():
|
||||
global inflight
|
||||
data = request.get_json(force=True)
|
||||
if not isinstance(data, list):
|
||||
return jsonify(error="Expect JSON list"), 400
|
||||
|
||||
# 兼容 bag-checker,只发 name 时补前缀
|
||||
paths = []
|
||||
for item in data:
|
||||
if not isinstance(item, str):
|
||||
continue
|
||||
if item.startswith("mb_cuct_data_collection/"):
|
||||
paths.append(item)
|
||||
else:
|
||||
paths.append("mb_cuct_data_collection/" + item)
|
||||
|
||||
TIME_RE = re.compile(r"_(\d{8})-(\d{6})_") # 匹配 20230803-160828
|
||||
|
||||
def extract_ts(p: str) -> datetime:
|
||||
m = TIME_RE.search(os.path.basename(p))
|
||||
if not m:
|
||||
return datetime.min # 无法解析的放最后
|
||||
date_part, time_part = m.groups()
|
||||
ts_str = f"{date_part}{time_part}"
|
||||
return datetime.strptime(ts_str, "%Y%m%d%H%M%S")
|
||||
|
||||
paths.sort(key=extract_ts, reverse=True)
|
||||
|
||||
with inflight_lock:
|
||||
inflight += 1
|
||||
|
||||
for idx in range(0, len(paths), BATCH_SIZE):
|
||||
blk = paths[idx : idx + BATCH_SIZE]
|
||||
BATCH_SIZE_HIST.observe(len(blk))
|
||||
bid = datetime.utcnow().strftime("%Y%m%dT%H%M%SZ") + f"_{idx // BATCH_SIZE + 1}"
|
||||
batch_q.put((bid, blk))
|
||||
batch_q.put(SENTINEL)
|
||||
|
||||
# batch_q.put((bid, paths))
|
||||
return jsonify(status="accepted", batch_size=BATCH_SIZE), 202
|
||||
|
||||
|
||||
def start_metric_updater():
|
||||
def loop():
|
||||
while True:
|
||||
Q_BATCH.set(batch_q.qsize())
|
||||
Q_PROC.set(proc_q.qsize())
|
||||
Q_UP.set(up_q.qsize())
|
||||
LOCAL_FILES.set(count_local_files())
|
||||
time.sleep(1)
|
||||
|
||||
t = threading.Thread(target=loop, daemon=True)
|
||||
t.start()
|
||||
|
||||
|
||||
# —— 两种模式的入口 —— #
|
||||
def file_mode(args):
|
||||
# 读 tasks-file,分批入队,放入 sentinel,然后启动处理
|
||||
lines = [
|
||||
line.strip() for line in open(args.tasks_file, encoding="utf-8") if line.strip()
|
||||
]
|
||||
for idx in range(0, len(lines), args.batch_size):
|
||||
blk = lines[idx : idx + args.batch_size]
|
||||
BATCH_SIZE_HIST.observe(len(blk))
|
||||
bid = (
|
||||
datetime.utcnow().strftime("%Y%m%dT%H%M%SZ")
|
||||
+ f"_{idx // args.batch_size + 1}"
|
||||
)
|
||||
batch_q.put((bid, blk))
|
||||
batch_q.put(SENTINEL)
|
||||
|
||||
start_http_server(METRICS_PORT)
|
||||
logger.info("Metrics HTTP 启动,端口 %d", METRICS_PORT)
|
||||
|
||||
threads = [
|
||||
threading.Thread(
|
||||
target=worker, args=(batch_q, do_download, args.batch_timeout)
|
||||
),
|
||||
threading.Thread(target=worker, args=(proc_q, do_process, args.batch_timeout)),
|
||||
threading.Thread(target=worker, args=(up_q, do_upload, args.batch_timeout)),
|
||||
]
|
||||
for t in threads:
|
||||
t.start()
|
||||
|
||||
# 更新指标 & 等待完成
|
||||
while batch_q.unfinished_tasks or proc_q.unfinished_tasks or up_q.unfinished_tasks:
|
||||
Q_BATCH.set(batch_q.qsize())
|
||||
Q_PROC.set(proc_q.qsize())
|
||||
Q_UP.set(up_q.qsize())
|
||||
LOCAL_FILES.set(count_local_files())
|
||||
time.sleep(1)
|
||||
|
||||
for t in threads:
|
||||
t.join()
|
||||
|
||||
logger.info("文件模式处理完成,退出。")
|
||||
sys.exit(0)
|
||||
|
||||
|
||||
def service_mode(args):
|
||||
# 确保目录存在
|
||||
for d in (INPUT_ROOT, OUTPUT_ROOT, EMPTY_DIR, LOG_DIR):
|
||||
os.makedirs(d, exist_ok=True)
|
||||
|
||||
# 启动 Prometheus 和指标更新
|
||||
start_http_server(METRICS_PORT)
|
||||
logger.info("Metrics HTTP 启动,端口 %d", METRICS_PORT)
|
||||
start_metric_updater()
|
||||
|
||||
# 启动后台 worker
|
||||
threads = [
|
||||
threading.Thread(
|
||||
target=worker, args=(batch_q, do_download, args.batch_timeout), daemon=True
|
||||
),
|
||||
threading.Thread(
|
||||
target=worker, args=(proc_q, do_process, args.batch_timeout), daemon=True
|
||||
),
|
||||
threading.Thread(
|
||||
target=worker, args=(up_q, do_upload, args.batch_timeout), daemon=True
|
||||
),
|
||||
]
|
||||
for t in threads:
|
||||
t.start()
|
||||
|
||||
# 启动 Flask
|
||||
logger.info("Decode Service 启动 HTTP on %s:%d", args.host, args.port)
|
||||
app.run(host=args.host, port=args.port, threaded=True)
|
||||
|
||||
|
||||
def main():
|
||||
p = argparse.ArgumentParser()
|
||||
sub = p.add_subparsers(dest="mode", required=True)
|
||||
|
||||
f = sub.add_parser("file", help="文件模式:--tasks-file")
|
||||
f.add_argument("--tasks-file", required=True)
|
||||
f.add_argument("--batch-size", type=int, default=BATCH_SIZE)
|
||||
f.add_argument("--batch-timeout", type=int, default=3600)
|
||||
|
||||
s = sub.add_parser("service", help="服务模式:启动 HTTP ready/notify")
|
||||
s.add_argument("--batch-timeout", type=int, default=3600)
|
||||
s.add_argument("--host", default="0.0.0.0")
|
||||
s.add_argument("--port", type=int, default=5000)
|
||||
|
||||
args = p.parse_args()
|
||||
if args.mode == "file":
|
||||
file_mode(args)
|
||||
else:
|
||||
service_mode(args)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
133
fst_data_pipeline/pipelines/tencent/eval_rosbag.py
Normal file
133
fst_data_pipeline/pipelines/tencent/eval_rosbag.py
Normal file
@@ -0,0 +1,133 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
python pipeline.py --check
|
||||
python pipeline.py --tasks-file task.txt
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import subprocess
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import List
|
||||
|
||||
import requests
|
||||
from tqdm import tqdm
|
||||
|
||||
from fst_data_pipeline.pipelines.tencent.bag_operation.bag_scanner import cfg
|
||||
|
||||
BASE = Path.cwd()
|
||||
BAG, OSM, OUT, LOG = BASE / "bags", BASE / "osm", BASE / "result", BASE / "logs"
|
||||
for d in (BAG, OSM, OUT, LOG):
|
||||
d.mkdir(exist_ok=True)
|
||||
|
||||
# ---------- 单一日志 ----------
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format="%(asctime)s %(message)s",
|
||||
handlers=[
|
||||
logging.FileHandler(LOG / "run.log"),
|
||||
logging.StreamHandler(sys.stdout),
|
||||
],
|
||||
)
|
||||
log = logging.getLogger("pipeline")
|
||||
|
||||
# ---------- 常量 ----------
|
||||
LOCAL_API_HOST = cfg.get("ROOT_DB_API", "http://10.0.240.4:5232")
|
||||
IMAGE = "eval-mmt:latest"
|
||||
DOCKER_CMD = [
|
||||
"sudo",
|
||||
"docker",
|
||||
"run",
|
||||
"--rm",
|
||||
"-v",
|
||||
f"{BAG}:/bag_data",
|
||||
"-v",
|
||||
f"{OSM}:/osm_data",
|
||||
"-v",
|
||||
f"{OUT}:/output_folder",
|
||||
IMAGE,
|
||||
"bash",
|
||||
"-c",
|
||||
"source ~/.bashrc && python3 /root/tools/eval_tool/eval_mmt_data.py "
|
||||
"--data_folder /bag_data/ --osm_folder /osm_data/ --result_folder /output_folder",
|
||||
]
|
||||
|
||||
|
||||
# ---------- 工具 ----------
|
||||
def runcmd(cmd: List[str]) -> None:
|
||||
subprocess.run(cmd, check=True, stdout=subprocess.DEVNULL, stderr=subprocess.STDOUT)
|
||||
|
||||
|
||||
def fetch_baglist() -> List[str]:
|
||||
url = f"{LOCAL_API_HOST}/api/fst/baglist"
|
||||
resp = requests.get(url, timeout=60)
|
||||
resp.raise_for_status()
|
||||
return resp.json()
|
||||
|
||||
|
||||
def fetch_detail(names: List[str]) -> dict:
|
||||
url = f"{LOCAL_API_HOST}/api/bags/pangu/detail"
|
||||
resp = requests.post(url, json=names, timeout=30)
|
||||
resp.raise_for_status()
|
||||
return resp.json()
|
||||
|
||||
|
||||
def download_one(bag: str, info: dict) -> bool:
|
||||
try:
|
||||
cos_path = info["data_path"].split("/", 1)[1]
|
||||
dst_dir = BAG / f"{bag}.dir"
|
||||
if not dst_dir.exists():
|
||||
runcmd(["coscmd", "-s", "download", "-r", cos_path, str(dst_dir)])
|
||||
|
||||
derived = (
|
||||
requests.get(info["reserved_json"], timeout=30)
|
||||
.json()["derived_dir"]
|
||||
.rstrip("/")
|
||||
)
|
||||
osm_cos_path = f"{derived}/LD_manual/{bag.replace('.bag', '.osm')}"
|
||||
dst_osm = OSM / f"{bag.replace('.bag', '.osm')}"
|
||||
if not dst_osm.exists():
|
||||
runcmd(["coscmd", "-s", "download", osm_cos_path, str(dst_osm)])
|
||||
return True
|
||||
except Exception as e:
|
||||
log.error("下载 %s 失败: %s", bag, e)
|
||||
return False
|
||||
|
||||
|
||||
def download_all(names: List[str]) -> None:
|
||||
infos = fetch_detail(names)
|
||||
for bag in tqdm(names, desc="Download"):
|
||||
download_one(bag, infos[bag])
|
||||
log.info("全部下载完成")
|
||||
|
||||
|
||||
def eval_all() -> None:
|
||||
log.info("开始 Docker 评估")
|
||||
runcmd(DOCKER_CMD)
|
||||
log.info("评估完成")
|
||||
|
||||
|
||||
# ---------- CLI ----------
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
group = parser.add_mutually_exclusive_group(required=True)
|
||||
group.add_argument("--tasks-file", type=Path)
|
||||
group.add_argument("--check", action="store_true")
|
||||
args = parser.parse_args()
|
||||
|
||||
names = (
|
||||
fetch_baglist()
|
||||
if args.check
|
||||
else [l.strip() for l in args.tasks_file.open() if l.strip()]
|
||||
)
|
||||
if not names:
|
||||
log.info("列表为空")
|
||||
return
|
||||
|
||||
download_all(names)
|
||||
eval_all()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
424
fst_data_pipeline/pipelines/tencent/ldgt_prod.py
Normal file
424
fst_data_pipeline/pipelines/tencent/ldgt_prod.py
Normal file
@@ -0,0 +1,424 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import os
|
||||
import queue
|
||||
import shutil
|
||||
import subprocess
|
||||
import sys
|
||||
import threading
|
||||
import time
|
||||
from datetime import datetime
|
||||
|
||||
from flask import Flask, request, jsonify
|
||||
from prometheus_client import start_http_server, Counter, Gauge, Summary, Histogram
|
||||
|
||||
# ============ 常量 & 全局队列 ============
|
||||
BASE = os.getcwd()
|
||||
INPUT_ROOT = os.path.join(BASE, "input")
|
||||
OUTPUT_ROOT = os.path.join(BASE, "output")
|
||||
EMPTY_DIR = os.path.join(BASE, "empty")
|
||||
LOG_DIR = os.path.join(BASE, "logs")
|
||||
|
||||
DOWNLOAD_ENTRIES = [
|
||||
"lidar_gt_pandar128",
|
||||
"camera_front_wide",
|
||||
"calibration",
|
||||
"raw_gnss.csv",
|
||||
"raw_imu.csv",
|
||||
"vehicle_wheel.csv",
|
||||
]
|
||||
|
||||
MAX_PL_DIRS = 10 # 本地 PL* 子目录限流阈值
|
||||
BATCH_CHUNK = 5 # 文件模式每批任务数
|
||||
DOCKER_IMAGE = "ldgt_cu11_1_devel"
|
||||
MAX_RETRIES = 3
|
||||
RETRY_DELAY_S = 2
|
||||
METRICS_PORT = 8002
|
||||
|
||||
# sentinel 用于优雅停止
|
||||
# 格式:(batch_id, remote_paths, run_slam, visualize)
|
||||
SENTINEL = (None, None, False, False)
|
||||
|
||||
batch_q = queue.Queue()
|
||||
proc_q = queue.Queue()
|
||||
up_q = queue.Queue()
|
||||
|
||||
# —— 日志配置 —— #
|
||||
os.makedirs(LOG_DIR, exist_ok=True)
|
||||
logger = logging.getLogger("pipeline")
|
||||
logger.setLevel(logging.INFO)
|
||||
|
||||
fh = logging.FileHandler(os.path.join(LOG_DIR, "pipeline.log"), encoding="utf-8")
|
||||
fh.setFormatter(logging.Formatter("%(asctime)s %(levelname)s %(message)s"))
|
||||
logger.addHandler(fh)
|
||||
|
||||
ch = logging.StreamHandler(sys.stdout)
|
||||
ch.setLevel(logging.INFO)
|
||||
ch.setFormatter(logging.Formatter("%(asctime)s %(levelname)s %(message)s"))
|
||||
logger.addHandler(ch)
|
||||
|
||||
# —— 全局自增 batch_id 计数 —— #
|
||||
batch_counter = 0
|
||||
counter_lock = threading.Lock()
|
||||
|
||||
# ============ Prometheus 指标 ============ #
|
||||
DL_TOTAL = Counter("pipeline_download_total", "下载尝试总数")
|
||||
DL_FAIL = Counter("pipeline_download_failures", "下载失败总数")
|
||||
DL_RETRY = Counter("pipeline_download_retries", "下载重试总数")
|
||||
PR_TOTAL = Counter("pipeline_process_total", "处理尝试总数")
|
||||
PR_FAIL = Counter("pipeline_process_failures", "处理失败总数")
|
||||
PR_RETRY = Counter("pipeline_process_retries", "处理重试总数")
|
||||
UP_TOTAL = Counter("pipeline_upload_total", "上传尝试总数")
|
||||
UP_FAIL = Counter("pipeline_upload_failures", "上传失败总数")
|
||||
UP_RETRY = Counter("pipeline_upload_retries", "上传重试总数")
|
||||
|
||||
DL_DUR = Summary("pipeline_download_duration_seconds", "下载耗时秒")
|
||||
PR_DUR = Summary("pipeline_process_duration_seconds", "处理耗时秒")
|
||||
UP_DUR = Summary("pipeline_upload_duration_seconds", "上传耗时秒")
|
||||
|
||||
BATCH_SIZE_HIST = Histogram(
|
||||
"pipeline_batch_size", "批次大小分布", buckets=[1, 5, 10, 20, 50, 100]
|
||||
)
|
||||
FILE_DL_DUR = Histogram("pipeline_file_download_duration_seconds", "单文件下载耗时分布")
|
||||
BATCH_OUT_FILES = Gauge("pipeline_batch_output_file_count", "输出文件数")
|
||||
|
||||
Q_BATCH = Gauge("pipeline_queue_batches", "待下载批次数")
|
||||
Q_PROC = Gauge("pipeline_queue_processing", "待处理批次数")
|
||||
Q_UP = Gauge("pipeline_queue_uploading", "待上传批次数")
|
||||
LOCAL_PL = Gauge("pipeline_local_pl_dirs", "本地 PL* 目录数")
|
||||
|
||||
|
||||
# ============ 辅助函数 ============ #
|
||||
def count_pl_dirs():
|
||||
# return sum(len(files) for _, _, files in os.walk(INPUT_ROOT))
|
||||
try:
|
||||
return sum(
|
||||
1
|
||||
for d in os.listdir(INPUT_ROOT)
|
||||
if d.startswith("PL") and os.path.isdir(os.path.join(INPUT_ROOT, d))
|
||||
)
|
||||
except FileNotFoundError:
|
||||
return 0
|
||||
|
||||
|
||||
def run(cmd, timeout=None):
|
||||
logger.info("RUN %s", " ".join(cmd))
|
||||
try:
|
||||
p = subprocess.Popen(
|
||||
cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True
|
||||
)
|
||||
start = time.time()
|
||||
out = []
|
||||
while True:
|
||||
line = p.stdout.readline()
|
||||
if not line:
|
||||
break
|
||||
out.append(line)
|
||||
print(line, end="", flush=True)
|
||||
if timeout and time.time() - start > timeout:
|
||||
p.kill()
|
||||
return p.wait(), True, "".join(out)
|
||||
code = p.wait()
|
||||
return code, False, "".join(out)
|
||||
except Exception:
|
||||
logger.exception("RUN 异常")
|
||||
return -1, False, ""
|
||||
|
||||
|
||||
def with_retry(tag, fn, *args):
|
||||
for i in range(1, MAX_RETRIES + 1):
|
||||
code, timed_out, _ = fn(*args)
|
||||
if code == 0:
|
||||
return True
|
||||
if timed_out:
|
||||
logger.error("%s 超时,停止重试", tag)
|
||||
break
|
||||
if tag.startswith("DL["):
|
||||
DL_RETRY.inc()
|
||||
if tag.startswith("PR["):
|
||||
PR_RETRY.inc()
|
||||
if tag.startswith("UP["):
|
||||
UP_RETRY.inc()
|
||||
logger.warning("%s 重试 %d/%d", tag, i, MAX_RETRIES)
|
||||
time.sleep(RETRY_DELAY_S)
|
||||
logger.error("%s 最终失败", tag)
|
||||
return False
|
||||
|
||||
|
||||
# ============ 1) 下载阶段 ============ #
|
||||
@DL_DUR.time()
|
||||
def do_download(batch_id, remote_paths, run_slam, visualize, timeout):
|
||||
if batch_id is None:
|
||||
proc_q.put(SENTINEL)
|
||||
return
|
||||
DL_TOTAL.inc()
|
||||
in_dir = os.path.join(INPUT_ROOT, batch_id)
|
||||
os.makedirs(in_dir, exist_ok=True)
|
||||
logger.info("DL[%s] start, %d paths", batch_id, len(remote_paths))
|
||||
|
||||
for remote in remote_paths:
|
||||
bag = os.path.basename(remote.rstrip("/"))
|
||||
dst = os.path.join(in_dir, bag)
|
||||
os.makedirs(dst, exist_ok=True)
|
||||
for ent in DOWNLOAD_ENTRIES:
|
||||
src = remote.rstrip("/") + "/" + ent
|
||||
if ent.endswith(".csv"):
|
||||
cmd = ["coscmd", "-s", "download", src, os.path.join(dst, ent)]
|
||||
else:
|
||||
cmd = ["coscmd", "-s", "download", "-r", src, os.path.join(dst, ent)]
|
||||
t0 = time.time()
|
||||
ok = with_retry(f"DL[{batch_id}-{bag}-{ent}]", run, cmd, timeout)
|
||||
FILE_DL_DUR.observe(time.time() - t0)
|
||||
if not ok:
|
||||
DL_FAIL.inc()
|
||||
proc_q.put((batch_id, remote_paths, run_slam, visualize))
|
||||
|
||||
|
||||
# ============ 2) 处理阶段 ============ #
|
||||
@PR_DUR.time()
|
||||
def do_process(batch_id, remote_paths, run_slam, visualize, timeout):
|
||||
if batch_id is None:
|
||||
up_q.put(SENTINEL)
|
||||
return
|
||||
PR_TOTAL.inc()
|
||||
|
||||
out_dir = os.path.join(OUTPUT_ROOT, batch_id)
|
||||
in_dir = os.path.join(INPUT_ROOT, batch_id)
|
||||
os.makedirs(out_dir, exist_ok=True)
|
||||
logger.info("PR[%s] input in {%s},output to {%s}", batch_id, in_dir, out_dir)
|
||||
# 构造 Docker 脚本
|
||||
parts = []
|
||||
if run_slam:
|
||||
parts.append("cd /root/slam_scripts && ./batch_run_map.sh /input_data")
|
||||
parts.append(
|
||||
"cd /root/latest/perception_dnn/projects/ldgt_net && "
|
||||
"python tools/custom/generate_mb_lidar_map_with_intensity.py --data_dir /input_data --output /output_data/bev_image && "
|
||||
"./tools/dist_produce_mk_gt.sh /output_data/bev_image/image_data "
|
||||
"~/ckpt/epoch_500.pth /output_data/pipeline_out_box 8 && "
|
||||
"./tools/dist_produce_gt.sh /output_data/bev_image/image_data ~/ckpt/epoch_50.pth /output_data/pipeline_out_line 8 && "
|
||||
"python tools/custom/merge_osm.py --line_osm_folder /output_data/pipeline_out_line/default_batch/osm_out --marking_osm_folder /output_data/pipeline_out_box/default_batch/osm_out --output_folder /output_data/osm_out"
|
||||
)
|
||||
if visualize:
|
||||
parts.append(
|
||||
"python tools/vis_tool/split_lane_info.py --bag_data /input_data --pipeline_result /output_data/ &&"
|
||||
" python tools/vis_tool/projection.py --bag_data /input_data --pipeline_result /output_data/ "
|
||||
)
|
||||
parts.append("chmod -R 777 /output_data")
|
||||
script = " && ".join(parts)
|
||||
cmd = [
|
||||
"docker",
|
||||
"run",
|
||||
"--rm",
|
||||
"--gpus",
|
||||
"all",
|
||||
"-v",
|
||||
f"{in_dir}:/input_data",
|
||||
"-v",
|
||||
f"{out_dir}:/output_data",
|
||||
DOCKER_IMAGE,
|
||||
"bash",
|
||||
"-i",
|
||||
"-c",
|
||||
script,
|
||||
]
|
||||
ok = with_retry(f"PR[{batch_id}-{in_dir}]", run, cmd, timeout)
|
||||
if not ok:
|
||||
PR_FAIL.inc()
|
||||
|
||||
# shutil.rmtree(in_dir, ignore_errors=True)
|
||||
# shutil.rmtree(work, ignore_errors=True)
|
||||
# 输出文件数
|
||||
cnt = sum(
|
||||
len(files) for _, _, files in os.walk(os.path.join(OUTPUT_ROOT, batch_id))
|
||||
)
|
||||
BATCH_OUT_FILES.set(cnt)
|
||||
up_q.put((
|
||||
batch_id,
|
||||
remote_paths,
|
||||
run_slam,
|
||||
visualize,
|
||||
))
|
||||
|
||||
|
||||
# ============ 3) 上传阶段 ============ #
|
||||
@UP_DUR.time()
|
||||
def do_upload(batch_id, remote_paths, run_slam, visualize, timeout):
|
||||
if batch_id is None:
|
||||
return
|
||||
UP_TOTAL.inc()
|
||||
logger.info("UP[%s] start", batch_id)
|
||||
base = os.path.join(OUTPUT_ROOT, batch_id)
|
||||
input_dir = os.path.join(INPUT_ROOT, batch_id)
|
||||
|
||||
for remote in remote_paths:
|
||||
bag = os.path.basename(remote.rstrip("/"))
|
||||
short = bag.replace(".bag.dir", "")
|
||||
target = remote.rstrip("/") + "/derived/LDGT"
|
||||
# 1) osm
|
||||
f1 = os.path.join(base, "osm_out", f"{short}.osm")
|
||||
if os.path.isfile(f1):
|
||||
cmd = ["coscmd", "-s", "upload", f1, target + f"/{short}.osm"]
|
||||
with_retry(f"UP[{batch_id}-{short}-osm]", run, cmd, timeout)
|
||||
# 2) split_json
|
||||
d2 = os.path.join(base, "split_json", short)
|
||||
if os.path.isdir(d2):
|
||||
cmd = ["coscmd", "-s", "upload", "-r", d2, target + "/split_json/"]
|
||||
with_retry(f"UP[{batch_id}-{short}-json]", run, cmd, timeout)
|
||||
# 3) jpg
|
||||
f3 = os.path.join(base, "bev_image", "image_data", f"{short}.jpg")
|
||||
if os.path.isfile(f3):
|
||||
cmd = ["coscmd", "-s", "upload", f3, target + "/bev_image.jpg"]
|
||||
with_retry(f"UP[{batch_id}-{short}-jpg]", run, cmd, timeout)
|
||||
|
||||
f4 = os.path.join(input_dir, bag, "slam_lidar_ground")
|
||||
if os.path.isdir(f4):
|
||||
cmd = ["coscmd", "-s", "upload", "-r", f4, target]
|
||||
with_retry(f"UP[{batch_id}-{short}-slam_lidar_ground]", run, cmd, timeout)
|
||||
|
||||
f5 = os.path.join(input_dir, bag, "slam_lidar_none_ground")
|
||||
if os.path.isdir(f5):
|
||||
cmd = ["coscmd", "-s", "upload", "-r", f5, target]
|
||||
with_retry(
|
||||
f"UP[{batch_id}-{short}-slam_lidar_none_ground]", run, cmd, timeout
|
||||
)
|
||||
|
||||
f6 = os.path.join(input_dir, bag, "ego_motion_slam_lidar.csv")
|
||||
if os.path.isfile(f6):
|
||||
cmd = ["coscmd", "-s", "upload", f6, target + "/ego_motion_slam_lidar.csv"]
|
||||
with_retry(f"UP[{batch_id}-{short}-csv]", run, cmd, timeout)
|
||||
|
||||
shutil.rmtree(base, ignore_errors=True)
|
||||
logger.info("UP[%s] done", batch_id)
|
||||
|
||||
|
||||
# ============ Worker 模板 ============ #
|
||||
def worker(q, fn, timeout):
|
||||
while True:
|
||||
batch_id, paths, run_slam, visualize = q.get()
|
||||
try:
|
||||
fn(batch_id, paths, run_slam, visualize, timeout)
|
||||
except Exception:
|
||||
logger.exception("Stage %s 失败", batch_id)
|
||||
finally:
|
||||
q.task_done()
|
||||
if batch_id is None:
|
||||
break
|
||||
|
||||
|
||||
# ============ Prometheus 指标更新 ============ #
|
||||
def start_metric_updater():
|
||||
def loop():
|
||||
while True:
|
||||
Q_BATCH.set(batch_q.qsize())
|
||||
Q_PROC.set(proc_q.qsize())
|
||||
Q_UP.set(up_q.qsize())
|
||||
LOCAL_PL.set(count_pl_dirs())
|
||||
time.sleep(1)
|
||||
|
||||
threading.Thread(target=loop, daemon=True).start()
|
||||
|
||||
|
||||
# ============ Flask Service ============ #
|
||||
app = Flask(__name__)
|
||||
|
||||
|
||||
@app.route("/ready", methods=["GET"])
|
||||
def api_ready():
|
||||
busy = any([
|
||||
batch_q.unfinished_tasks,
|
||||
proc_q.unfinished_tasks,
|
||||
up_q.unfinished_tasks,
|
||||
])
|
||||
return jsonify(ready=not busy)
|
||||
|
||||
|
||||
@app.route("/notify", methods=["POST"])
|
||||
def api_notify():
|
||||
global batch_counter
|
||||
data = request.get_json(force=True)
|
||||
if isinstance(data, list):
|
||||
paths, run_slam, visualize = data, True, True
|
||||
elif isinstance(data, dict):
|
||||
paths = data.get("paths", [])
|
||||
run_slam = data.get("run_slam", True)
|
||||
visualize = data.get("visualize", True)
|
||||
else:
|
||||
return jsonify(error="Unsupported JSON"), 400
|
||||
|
||||
if not all(isinstance(p, str) for p in paths):
|
||||
return jsonify(error="paths must be strings"), 400
|
||||
|
||||
with counter_lock:
|
||||
batch_counter += 1
|
||||
bid = str(batch_counter)
|
||||
|
||||
batch_q.put((bid, paths, run_slam, visualize))
|
||||
return jsonify(status="accepted", batch_id=bid), 202
|
||||
|
||||
|
||||
# ============ 主入口 ============ #
|
||||
def main():
|
||||
global batch_counter
|
||||
p = argparse.ArgumentParser()
|
||||
sub = p.add_subparsers(dest="mode", required=True)
|
||||
|
||||
f = sub.add_parser("file", help="文件模式")
|
||||
f.add_argument("--tasks-file", required=True, help="每行一个 COS 路径")
|
||||
f.add_argument("--batch-timeout", type=int, default=3600)
|
||||
f.set_defaults(run_slam=True, visualize=True)
|
||||
|
||||
s = sub.add_parser("service", help="服务模式")
|
||||
s.add_argument("--batch-timeout", type=int, default=3600)
|
||||
s.add_argument("--host", default="0.0.0.0")
|
||||
s.add_argument("--port", type=int, default=5600)
|
||||
|
||||
args = p.parse_args()
|
||||
|
||||
# 启动 Prometheus HTTP & 指标更新
|
||||
start_http_server(METRICS_PORT)
|
||||
start_metric_updater()
|
||||
|
||||
# 启动三个阶段 worker
|
||||
for q, fn in ((batch_q, do_download), (proc_q, do_process), (up_q, do_upload)):
|
||||
t = threading.Thread(
|
||||
target=worker, args=(q, fn, args.batch_timeout), daemon=True
|
||||
)
|
||||
t.start()
|
||||
|
||||
if args.mode == "file":
|
||||
lines = [
|
||||
l.strip() for l in open(args.tasks_file, encoding="utf-8") if l.strip()
|
||||
]
|
||||
for i in range(0, len(lines), BATCH_CHUNK):
|
||||
while count_pl_dirs() >= MAX_PL_DIRS:
|
||||
logger.warning(
|
||||
"本地 PL* %d ≥ %d,暂停入队", count_pl_dirs(), MAX_PL_DIRS
|
||||
)
|
||||
time.sleep(60)
|
||||
blk = lines[i : i + BATCH_CHUNK]
|
||||
with counter_lock:
|
||||
batch_counter += 1
|
||||
bid = (
|
||||
datetime.utcnow().strftime("%Y%m%dT%H%M%SZ")
|
||||
+ f"_{i // BATCH_CHUNK + 1}"
|
||||
)
|
||||
BATCH_SIZE_HIST.observe(len(blk))
|
||||
batch_q.put((bid, blk, args.run_slam, args.visualize))
|
||||
batch_q.put(SENTINEL)
|
||||
batch_q.join()
|
||||
proc_q.join()
|
||||
up_q.join()
|
||||
logger.info("File mode done.")
|
||||
sys.exit(0)
|
||||
else:
|
||||
for d in (INPUT_ROOT, OUTPUT_ROOT, EMPTY_DIR, LOG_DIR):
|
||||
os.makedirs(d, exist_ok=True)
|
||||
logger.info("Starting service on %s:%d", args.host, args.port)
|
||||
app.run(host=args.host, port=args.port, threaded=True)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -0,0 +1,68 @@
|
||||
# Trajectory Processing and Visualization System
|
||||
|
||||
## Overview
|
||||
|
||||
This system consists of two core scripts for processing GNSS trajectory data, identifying spatially overlapping trajectories, generating tile maps, and creating visualizations:
|
||||
|
||||
1. **tile_generate_to_db.py** - Processes raw trajectory data, computes spatial relationships, and stores results in MongoDB
|
||||
2. **tile_visualization_from_db.py** - Reads data from database and generates interactive trajectory visualizations
|
||||
|
||||
## Key Features
|
||||
|
||||
- **Trajectory Gridding**: Divides GNSS points into 3D grids (10m planar grid + 5m height strata)
|
||||
- **Overlap Detection**: Identifies trajectories sharing more than threshold number of grids
|
||||
- **Tile Mapping**: Maps trajectories to map tile system
|
||||
- **Spatial Indexing**: Stores trajectories, tiles and overlap relationships in MongoDB
|
||||
- **Visualization**: Generates interactive maps with trajectory paths and direction arrows
|
||||
|
||||
## config.yml
|
||||
-you should set the db param and tile_server_url in config.yml
|
||||
-for example
|
||||
mongodb:
|
||||
uri: "mongodb://admin:admin@IP:port/"
|
||||
db_name: "your db_name"
|
||||
|
||||
tile_server:
|
||||
url: "http://IP:port/styles/maptiler-basic/512/{z}/{x}/{y}.png"
|
||||
|
||||
|
||||
## Usage
|
||||
1. Data Processing
|
||||
```
|
||||
python3 tile_generate_to_db.py <data_root> [zoom_level]
|
||||
```
|
||||
- Required Directory Structure:
|
||||
- data_root/
|
||||
- ├── bag_1/
|
||||
- │ ├── raw_gnss.csv
|
||||
- ├── bag_2/
|
||||
- │ ├── raw_gnss.csv
|
||||
|
||||
2. Visualization Generation
|
||||
```
|
||||
python3 tile_visualization_from_db.py <data_root>
|
||||
```
|
||||
- Output will be saved in:
|
||||
```
|
||||
<data_root>/tile_visualizations/
|
||||
```
|
||||
|
||||
# Output Files
|
||||
|
||||
1. Database Collections:
|
||||
|
||||
- tile_db: Tile-to-trajectory mappings
|
||||
|
||||
- processed_bags: Processed trajectory data
|
||||
|
||||
- overlap_db: Trajectory overlap sets
|
||||
|
||||
- processed_bags: Processed bag records
|
||||
|
||||
2. Local Files:
|
||||
|
||||
- <bag_name>_overlap.json: Overlap information per trajectory
|
||||
|
||||
- processed_tiles.txt: List of processed tile IDs
|
||||
|
||||
- tile_visualizations/: Folder containing HTML visualizations
|
||||
@@ -0,0 +1,7 @@
|
||||
# config.yml
|
||||
mongodb:
|
||||
uri: "mongodb uri"
|
||||
db_name: "your db_name"
|
||||
|
||||
tile_server:
|
||||
url: "tiler_server_url"
|
||||
@@ -0,0 +1,605 @@
|
||||
import os
|
||||
import yaml
|
||||
import json
|
||||
import math
|
||||
import pandas as pd
|
||||
import utm
|
||||
import hashlib
|
||||
from collections import defaultdict
|
||||
import logging
|
||||
import sys
|
||||
from typing import Dict, List, Tuple, Set, Any
|
||||
from pymongo import MongoClient
|
||||
from bson import ObjectId
|
||||
|
||||
# 配置日志
|
||||
logging.basicConfig(level=logging.INFO, format="[%(levelname)s] %(message)s")
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# 全局参数
|
||||
MIN_SHARED_GRIDS = 6 # 最小共享3D网格数阈值
|
||||
MIN_GRID_FOR_ONE_TRAJECTORY = 0 # 单条轨迹最少网格数
|
||||
GRID_SIZE = 10.0 # 平面网格大小(米)
|
||||
HEIGHT_STRATUM = 5.0 # 高度分层阈值(米)
|
||||
|
||||
|
||||
def load_config(config_path="config.yml"):
|
||||
"""加载配置文件"""
|
||||
try:
|
||||
with open(config_path) as f:
|
||||
return yaml.safe_load(f)
|
||||
except Exception as e:
|
||||
logger.error(f"加载配置文件失败: {e}")
|
||||
raise
|
||||
|
||||
|
||||
config = load_config()
|
||||
|
||||
|
||||
def get_mongo_client():
|
||||
"""获取MongoDB客户端"""
|
||||
return MongoClient(config["mongodb"]["uri"])
|
||||
|
||||
|
||||
def get_db_collections():
|
||||
"""获取数据库集合"""
|
||||
client = get_mongo_client()
|
||||
db = client[config["mongodb"]["db_name"]]
|
||||
return {
|
||||
"overlap_db": db.overlap_db,
|
||||
"tile_db": db.tile_db,
|
||||
"processed_bags": db.processed_bags,
|
||||
}
|
||||
|
||||
|
||||
def lon2tilex(lon: float, zoom: int) -> int:
|
||||
"""将经度转换为瓦片x坐标"""
|
||||
return int((lon + 180) / 360 * (1 << zoom))
|
||||
|
||||
|
||||
def lat2tiley(lat: float, zoom: int) -> int:
|
||||
"""将纬度转换为瓦片y坐标"""
|
||||
return int(
|
||||
(
|
||||
1
|
||||
- math.log(math.tan(math.radians(lat)) + 1 / math.cos(math.radians(lat)))
|
||||
/ math.pi
|
||||
)
|
||||
/ 2
|
||||
* (1 << zoom)
|
||||
)
|
||||
|
||||
|
||||
def get_tile_neighbors(tileid: str) -> List[str]:
|
||||
"""获取九宫格相邻瓦片ID"""
|
||||
zoom, x, y = map(int, tileid.split("/"))
|
||||
neighbors = []
|
||||
|
||||
for dx in [-1, 0, 1]:
|
||||
for dy in [-1, 0, 1]:
|
||||
if dx == 0 and dy == 0:
|
||||
continue # 跳过自身
|
||||
new_x, new_y = x + dx, y + dy
|
||||
# 检查是否超出瓦片范围
|
||||
max_tile = 1 << zoom
|
||||
if 0 <= new_x < max_tile and 0 <= new_y < max_tile:
|
||||
neighbors.append(f"{zoom}/{new_x}/{new_y}")
|
||||
return neighbors
|
||||
|
||||
|
||||
def extract_xyz_from_gnss(csv_path: str) -> List[Tuple[float, float, float]]:
|
||||
"""提取包含高度信息的轨迹点"""
|
||||
try:
|
||||
df = pd.read_csv(csv_path)
|
||||
except Exception as e:
|
||||
logger.error(f"读取 {csv_path} 失败: {e}")
|
||||
return []
|
||||
|
||||
# 自动适配列名
|
||||
lat_col = next((col for col in df.columns if "lat" in col.lower()), None)
|
||||
lon_col = next((col for col in df.columns if "lon" in col.lower()), None)
|
||||
alt_col = next(
|
||||
(col for col in df.columns if "alt" in col.lower() or "height" in col.lower()),
|
||||
None,
|
||||
)
|
||||
|
||||
if not lat_col or not lon_col:
|
||||
logger.warning(f"{csv_path} 中缺少经纬度列")
|
||||
return []
|
||||
|
||||
# 如果没有高度列,使用默认高度0
|
||||
if not alt_col:
|
||||
logger.warning(f"{csv_path} 中未找到高度列,将使用默认高度0")
|
||||
df["altitude"] = 0
|
||||
alt_col = "altitude"
|
||||
|
||||
# 过滤无效数据
|
||||
df = df.dropna(subset=[lat_col, lon_col, alt_col])
|
||||
df = df[(df[lat_col].between(-90, 90)) & (df[lon_col].between(-180, 180))]
|
||||
|
||||
try:
|
||||
points = []
|
||||
for lat, lon, alt in zip(df[lat_col], df[lon_col], df[alt_col]):
|
||||
easting, northing, zone_num, zone_letter = utm.from_latlon(lat, lon)
|
||||
points.append((easting, northing, alt, lat, lon))
|
||||
return points
|
||||
except Exception as e:
|
||||
logger.error(f"{csv_path} UTM 转换失败: {e}")
|
||||
return []
|
||||
|
||||
|
||||
def get_3d_grid(x: float, y: float, z: float) -> Tuple[int, int, int]:
|
||||
"""三维网格划分"""
|
||||
grid_x = int(x // GRID_SIZE)
|
||||
grid_y = int(y // GRID_SIZE)
|
||||
grid_z = int(z // HEIGHT_STRATUM) # 高度分层
|
||||
return (grid_x, grid_y, grid_z)
|
||||
|
||||
|
||||
def get_tile_and_grids_from_path(
|
||||
bag_path: str, zoom: int = 14
|
||||
) -> Tuple[Set[str], List[Tuple[float, float, float]], Set[Tuple[int, int, int]]]:
|
||||
"""获取瓦片ID集合和三维网格集合"""
|
||||
csv_path = os.path.join(bag_path, "raw_gnss.csv")
|
||||
if not os.path.exists(csv_path):
|
||||
logger.warning(f"{bag_path} 不存在 raw_gnss.csv")
|
||||
return None
|
||||
|
||||
xyz_points = extract_xyz_from_gnss(csv_path)
|
||||
if not xyz_points:
|
||||
logger.warning(f"{bag_path} 无有效 GNSS 数据")
|
||||
return None
|
||||
|
||||
# 获取所有瓦片ID
|
||||
tile_ids = set()
|
||||
for x, y, z, lat, lon in xyz_points:
|
||||
tile_id = f"{zoom}/{lon2tilex(lon, zoom)}/{lat2tiley(lat, zoom)}"
|
||||
tile_ids.add(tile_id)
|
||||
|
||||
# 获取所有3D网格
|
||||
grids = set(get_3d_grid(x, y, z) for x, y, z, _, _ in xyz_points)
|
||||
|
||||
logger.info(
|
||||
f"{bag_path} 共有 {len(xyz_points)} 点,映射到 {len(tile_ids)} 个瓦片和 {len(grids)} 个3D网格"
|
||||
)
|
||||
return tile_ids, [(x, y, z) for x, y, z, _, _ in xyz_points], grids
|
||||
|
||||
|
||||
def calculate_overlap(grids1, grids2):
|
||||
"""计算两个轨迹的3D网格重叠数"""
|
||||
set1 = set(tuple(grid) for grid in grids1) if isinstance(grids1, list) else grids1
|
||||
set2 = set(tuple(grid) for grid in grids2) if isinstance(grids2, list) else grids2
|
||||
return len(set1 & set2)
|
||||
|
||||
|
||||
def update_tile_db(
|
||||
tile_db_collection,
|
||||
tile_id: str,
|
||||
bag_name: str,
|
||||
lat: float = None,
|
||||
lon: float = None,
|
||||
alt: float = None,
|
||||
):
|
||||
"""更新瓦片数据库"""
|
||||
tile_data = tile_db_collection.find_one({"tileid": tile_id})
|
||||
|
||||
if tile_data:
|
||||
if bag_name not in tile_data["trajectories"]:
|
||||
tile_db_collection.update_one(
|
||||
{"tileid": tile_id}, {"$addToSet": {"trajectories": bag_name}}
|
||||
)
|
||||
if lat is not None and lon is not None and "gps" not in tile_data:
|
||||
tile_db_collection.update_one(
|
||||
{"tileid": tile_id},
|
||||
{
|
||||
"$set": {
|
||||
"gps": {
|
||||
"latitude": lat,
|
||||
"longitude": lon,
|
||||
"altitude": alt or 0.0,
|
||||
}
|
||||
}
|
||||
},
|
||||
)
|
||||
else:
|
||||
new_tile = {"tileid": tile_id, "trajectories": [bag_name]}
|
||||
if lat is not None and lon is not None:
|
||||
new_tile["gps"] = {
|
||||
"latitude": lat,
|
||||
"longitude": lon,
|
||||
"altitude": alt or 0.0,
|
||||
}
|
||||
tile_db_collection.insert_one(new_tile)
|
||||
|
||||
|
||||
def initialize_overlap_db(collections):
|
||||
"""初始化或加载现有的overlap数据库"""
|
||||
overlap_db_data = collections["overlap_db"].find_one({"_id": "overlap_db"})
|
||||
|
||||
if overlap_db_data:
|
||||
# 转换数据结构
|
||||
overlap_db = {
|
||||
"next_id": overlap_db_data.get("next_id", 1),
|
||||
"sets": {
|
||||
k: {"bags": set(v["bags"]), "source_bags": set(v["source_bags"])}
|
||||
for k, v in overlap_db_data.get("sets", {}).items()
|
||||
},
|
||||
"hash_to_id": overlap_db_data.get("hash_to_id", {}),
|
||||
"bag_to_set": overlap_db_data.get("bag_to_set", {}),
|
||||
"overlap_sets": {
|
||||
k: set(v) for k, v in overlap_db_data.get("overlap_sets", {}).items()
|
||||
},
|
||||
}
|
||||
logger.info(
|
||||
f"已加载现有的overlap数据库,当前最大ID: {overlap_db['next_id'] - 1}"
|
||||
)
|
||||
else:
|
||||
overlap_db = {
|
||||
"next_id": 1,
|
||||
"sets": {},
|
||||
"hash_to_id": {},
|
||||
"bag_to_set": {},
|
||||
"overlap_sets": {},
|
||||
}
|
||||
collections["overlap_db"].insert_one({
|
||||
"_id": "overlap_db",
|
||||
"next_id": 1,
|
||||
"sets": {},
|
||||
"hash_to_id": {},
|
||||
"bag_to_set": {},
|
||||
"overlap_sets": {},
|
||||
})
|
||||
logger.info("创建新的overlap数据库")
|
||||
return overlap_db
|
||||
|
||||
|
||||
def update_overlap_sets(
|
||||
collections, overlap_db: Dict, bag_name: str, overlap_set: Set[str]
|
||||
):
|
||||
"""更新overlap_sets容器,不分配set_id"""
|
||||
# 保存当前包的overlap集合
|
||||
if not isinstance(overlap_set, set):
|
||||
overlap_set = set(overlap_set)
|
||||
overlap_db["overlap_sets"][bag_name] = overlap_set
|
||||
|
||||
# 更新所有相关包的overlap集合
|
||||
for other_bag in overlap_set:
|
||||
if other_bag == bag_name:
|
||||
continue
|
||||
|
||||
if other_bag not in overlap_db["overlap_sets"]:
|
||||
overlap_db["overlap_sets"][other_bag] = {other_bag}
|
||||
|
||||
# 合并两个包的overlap集合
|
||||
overlap_db["overlap_sets"][other_bag].add(bag_name)
|
||||
|
||||
|
||||
def assign_set_ids(collections, overlap_db: Dict):
|
||||
"""为所有overlap集合分配set_id"""
|
||||
# 收集所有唯一的overlap集合
|
||||
unique_sets = {}
|
||||
set_hashes = set()
|
||||
|
||||
# 首先处理已经存在的集合,保持它们的set_id不变
|
||||
existing_sets = {}
|
||||
for set_id, set_data in overlap_db["sets"].items():
|
||||
set_hash = hashlib.md5(",".join(sorted(set_data["bags"])).encode()).hexdigest()
|
||||
existing_sets[set_hash] = int(set_id)
|
||||
|
||||
# 找出所有唯一的overlap集合
|
||||
for bag_name, overlap_set in overlap_db["overlap_sets"].items():
|
||||
if len(overlap_set) <= 1:
|
||||
continue # 跳过单独的包
|
||||
|
||||
sorted_bags = sorted(overlap_set)
|
||||
set_hash = hashlib.md5(",".join(sorted_bags).encode()).hexdigest()
|
||||
|
||||
if set_hash not in unique_sets:
|
||||
unique_sets[set_hash] = {"bags": overlap_set, "source_bags": set()}
|
||||
|
||||
# 记录来源包
|
||||
unique_sets[set_hash]["source_bags"].add(bag_name)
|
||||
|
||||
# 分配set_id
|
||||
new_sets = {}
|
||||
for set_hash, set_data in unique_sets.items():
|
||||
if set_hash in existing_sets:
|
||||
# 使用现有的set_id
|
||||
set_id = existing_sets[set_hash]
|
||||
new_sets[set_id] = set_data
|
||||
else:
|
||||
# 分配新的set_id
|
||||
set_id = overlap_db["next_id"]
|
||||
overlap_db["next_id"] += 1
|
||||
new_sets[set_id] = set_data
|
||||
|
||||
# 更新overlap_db
|
||||
overlap_db["sets"] = {
|
||||
str(k): {"bags": sorted(v["bags"]), "source_bags": sorted(v["source_bags"])}
|
||||
for k, v in new_sets.items()
|
||||
}
|
||||
|
||||
# 更新hash_to_id映射
|
||||
overlap_db["hash_to_id"] = {
|
||||
hashlib.md5(",".join(sorted(v["bags"])).encode()).hexdigest(): int(k)
|
||||
for k, v in new_sets.items()
|
||||
}
|
||||
|
||||
# 更新bag_to_set映射
|
||||
overlap_db["bag_to_set"] = {}
|
||||
for set_id, set_data in new_sets.items():
|
||||
for bag in set_data["source_bags"]:
|
||||
overlap_db["bag_to_set"][bag] = set_id
|
||||
|
||||
|
||||
def save_overlap_db(collections, overlap_db: Dict):
|
||||
"""保存overlap数据库"""
|
||||
# 在保存前确保所有set_id已分配
|
||||
assign_set_ids(collections, overlap_db)
|
||||
|
||||
serializable_db = {
|
||||
"next_id": overlap_db["next_id"],
|
||||
"sets": {
|
||||
str(set_id): {
|
||||
"bags": sorted(data["bags"]),
|
||||
"source_bags": sorted(data["source_bags"]),
|
||||
}
|
||||
for set_id, data in overlap_db["sets"].items()
|
||||
},
|
||||
"hash_to_id": overlap_db["hash_to_id"],
|
||||
"bag_to_set": overlap_db["bag_to_set"],
|
||||
"overlap_sets": {k: sorted(v) for k, v in overlap_db["overlap_sets"].items()},
|
||||
}
|
||||
|
||||
collections["overlap_db"].replace_one(
|
||||
{"_id": "overlap_db"}, serializable_db, upsert=True
|
||||
)
|
||||
logger.info(f"已保存重叠集合数据库到 MongoDB (共 {len(overlap_db['sets'])} 个集合)")
|
||||
|
||||
|
||||
def update_json_files_with_set_ids(data_root: str, overlap_db: Dict, collections):
|
||||
"""更新所有包的JSON文件,添加set_id信息,并同步更新MongoDB"""
|
||||
for bag_name, set_id in overlap_db["bag_to_set"].items():
|
||||
bag_path = os.path.join(data_root, bag_name)
|
||||
overlap_json_path = os.path.join(bag_path, f"{bag_name}_overlap.json")
|
||||
|
||||
# 从MongoDB获取当前包数据
|
||||
bag_data = collections["processed_bags"].find_one({"_id": bag_name})
|
||||
if not bag_data:
|
||||
logger.warning(f"在MongoDB中未找到包 {bag_name} 的数据")
|
||||
continue
|
||||
|
||||
# 更新重叠信息
|
||||
updated_data = {
|
||||
"set_id": set_id,
|
||||
"trajectories": sorted(overlap_db["sets"][str(set_id)]["bags"]),
|
||||
"count": len(overlap_db["sets"][str(set_id)]["bags"]),
|
||||
}
|
||||
|
||||
# 更新MongoDB
|
||||
collections["processed_bags"].update_one(
|
||||
{"_id": bag_name}, {"$set": {"overlaps_bags": updated_data}}
|
||||
)
|
||||
|
||||
# 更新本地JSON文件(可选)
|
||||
if os.path.exists(overlap_json_path):
|
||||
with open(overlap_json_path, "r+") as f:
|
||||
try:
|
||||
data = json.load(f)
|
||||
data["overlap_bags"] = updated_data
|
||||
f.seek(0)
|
||||
json.dump(data, f, indent=2)
|
||||
f.truncate()
|
||||
except Exception as e:
|
||||
logger.error(f"更新 {overlap_json_path} 失败: {e}")
|
||||
|
||||
logger.info("已更新所有包的set_id信息到MongoDB和本地JSON文件")
|
||||
|
||||
|
||||
def process_bag(bag_path: str, collections, overlap_db: Dict, zoom: int = 14):
|
||||
"""处理单个数据包"""
|
||||
bag_name = os.path.basename(bag_path)
|
||||
|
||||
overlap_json_path = os.path.join(bag_path, f"{bag_name}_overlap.json")
|
||||
if not os.path.exists(overlap_json_path):
|
||||
with open(overlap_json_path, "w") as f:
|
||||
json.dump(
|
||||
{
|
||||
"overlap_bags": {
|
||||
"trajectories": [bag_name],
|
||||
"count": 1,
|
||||
"set_id": None,
|
||||
}
|
||||
},
|
||||
f,
|
||||
)
|
||||
|
||||
# 获取瓦片ID和3D网格
|
||||
result = get_tile_and_grids_from_path(bag_path, zoom)
|
||||
if result is None:
|
||||
return
|
||||
|
||||
tile_ids, xyz_points, grids = result
|
||||
if len(grids) <= MIN_GRID_FOR_ONE_TRAJECTORY:
|
||||
logger.info(f"{bag_name} (仅 {len(grids)} 个3D网格,跳过)")
|
||||
return
|
||||
|
||||
processed_tiles_file = os.path.join(
|
||||
os.path.dirname(data_root), "processed_tiles.txt"
|
||||
)
|
||||
|
||||
existing_tiles = set()
|
||||
if os.path.exists(processed_tiles_file):
|
||||
with open(processed_tiles_file, "r") as f:
|
||||
existing_tiles = set(line.strip() for line in f if line.strip())
|
||||
|
||||
new_tiles = set(tile_ids) - existing_tiles
|
||||
if new_tiles:
|
||||
with open(processed_tiles_file, "a") as f:
|
||||
for tile_id in new_tiles:
|
||||
f.write(f"{tile_id}\n")
|
||||
print("write tile_id:")
|
||||
print(tile_id)
|
||||
|
||||
csv_path = os.path.join(bag_path, "raw_gnss.csv")
|
||||
df = pd.read_csv(csv_path)
|
||||
lat_col = next((col for col in df.columns if "lat" in col.lower()), None)
|
||||
lon_col = next((col for col in df.columns if "lon" in col.lower()), None)
|
||||
alt_col = next(
|
||||
(col for col in df.columns if "alt" in col.lower() or "height" in col.lower()),
|
||||
None,
|
||||
)
|
||||
|
||||
# 过滤无效数据
|
||||
if lat_col and lon_col:
|
||||
df = df.dropna(subset=[lat_col, lon_col, alt_col])
|
||||
df = df[(df[lat_col].between(-90, 90)) & (df[lon_col].between(-180, 180))]
|
||||
latlon_points = list(
|
||||
zip(df[lat_col], df[lon_col], df[alt_col] if alt_col else [0] * len(df))
|
||||
)
|
||||
else:
|
||||
latlon_points = []
|
||||
# 初始化当前包的overlap集合
|
||||
overlap_set = {bag_name}
|
||||
|
||||
# 查找所有相关瓦片(包括相邻瓦片)
|
||||
all_related_tiles = set(tile_ids)
|
||||
for tile_id in tile_ids:
|
||||
all_related_tiles.update(get_tile_neighbors(tile_id))
|
||||
|
||||
# 查找所有可能重叠的包
|
||||
candidate_bags = set()
|
||||
for tile in collections["tile_db"].find({
|
||||
"tileid": {"$in": list(all_related_tiles)}
|
||||
}):
|
||||
candidate_bags.update(tile["trajectories"])
|
||||
|
||||
candidate_bags.discard(bag_name)
|
||||
|
||||
# 检查与每个候选包的重叠情况
|
||||
for other_bag in candidate_bags:
|
||||
other_data = collections["processed_bags"].find_one({"_id": other_bag})
|
||||
if not other_data:
|
||||
continue
|
||||
|
||||
other_grids = set(tuple(grid) for grid in other_data.get("grids", []))
|
||||
current_grids_set = set(tuple(grid) for grid in grids)
|
||||
overlap_count = calculate_overlap(current_grids_set, other_grids)
|
||||
|
||||
if overlap_count >= MIN_SHARED_GRIDS:
|
||||
overlap_set.add(other_bag)
|
||||
|
||||
# 更新overlap_sets容器
|
||||
update_overlap_sets(collections, overlap_db, bag_name, overlap_set)
|
||||
|
||||
# 更新每个包的overlap JSON文件(不包含set_id)
|
||||
for b in overlap_set:
|
||||
if b == bag_name:
|
||||
continue
|
||||
|
||||
other_path = os.path.join(os.path.dirname(bag_path), b)
|
||||
other_json_path = os.path.join(other_path, f"{b}_overlap.json")
|
||||
|
||||
if os.path.exists(other_json_path):
|
||||
with open(other_json_path, "r+") as f:
|
||||
try:
|
||||
other_data = json.load(f)
|
||||
if bag_name not in other_data["overlap_bags"]["trajectories"]:
|
||||
other_data["overlap_bags"]["trajectories"].append(bag_name)
|
||||
other_data["overlap_bags"]["count"] = len(
|
||||
overlap_db["overlap_sets"][b]
|
||||
)
|
||||
f.seek(0)
|
||||
json.dump(other_data, f, indent=2)
|
||||
f.truncate()
|
||||
except Exception as e:
|
||||
logger.error(f"更新 {other_json_path} 失败: {e}")
|
||||
|
||||
# 保存当前包的overlap信息到JSON文件(不包含set_id)
|
||||
with open(overlap_json_path, "w") as f:
|
||||
json.dump(
|
||||
{
|
||||
"overlap_bags": {
|
||||
"trajectories": sorted(overlap_set),
|
||||
"count": len(overlap_set),
|
||||
"set_id": None, # 将在最后阶段分配
|
||||
}
|
||||
},
|
||||
f,
|
||||
indent=2,
|
||||
)
|
||||
|
||||
# 保存到processed_bags集合
|
||||
processed_data = {
|
||||
"_id": bag_name,
|
||||
"tile_ids": list(tile_ids),
|
||||
"grids": list(grids),
|
||||
"xyz_points": xyz_points,
|
||||
"latlon_points": latlon_points,
|
||||
"overlaps_bags": {
|
||||
"trajectories": sorted(overlap_set),
|
||||
"count": len(overlap_set),
|
||||
"set_id": None, # 将在最后阶段分配
|
||||
},
|
||||
}
|
||||
collections["processed_bags"].replace_one(
|
||||
{"_id": bag_name}, processed_data, upsert=True
|
||||
)
|
||||
|
||||
# 更新瓦片数据库
|
||||
for tile_id in tile_ids:
|
||||
lat, lon = None, None
|
||||
for x, y, z, pt_lat, pt_lon in extract_xyz_from_gnss(
|
||||
os.path.join(bag_path, "raw_gnss.csv")
|
||||
):
|
||||
if f"{zoom}/{lon2tilex(pt_lon, zoom)}/{lat2tiley(pt_lat, zoom)}" == tile_id:
|
||||
lat, lon = pt_lat, pt_lon
|
||||
break
|
||||
update_tile_db(collections["tile_db"], tile_id, bag_name, lat, lon)
|
||||
|
||||
logger.info(f"{bag_name} 与 {len(overlap_set) - 1} 个包重叠")
|
||||
|
||||
|
||||
def main(data_root: str, zoom: int = 14):
|
||||
"""主函数"""
|
||||
if not os.path.isdir(data_root):
|
||||
logger.error(f"数据目录不存在: {data_root}")
|
||||
return
|
||||
|
||||
# 初始化数据库连接
|
||||
collections = get_db_collections()
|
||||
|
||||
# 初始化数据库
|
||||
overlap_db = initialize_overlap_db(collections)
|
||||
|
||||
# 处理所有数据包
|
||||
bag_dirs = [
|
||||
os.path.join(data_root, d)
|
||||
for d in os.listdir(data_root)
|
||||
if os.path.isdir(os.path.join(data_root, d))
|
||||
]
|
||||
if not bag_dirs:
|
||||
logger.error(f"目录中没有子文件夹: {data_root}")
|
||||
return
|
||||
|
||||
logger.info(f"开始处理 {len(bag_dirs)} 个数据包...")
|
||||
for bag_path in bag_dirs:
|
||||
process_bag(bag_path, collections, overlap_db, zoom)
|
||||
|
||||
# 保存所有结果到MongoDB
|
||||
save_overlap_db(collections, overlap_db)
|
||||
|
||||
# 更新set_id信息(这会同时更新MongoDB和本地JSON)
|
||||
update_json_files_with_set_ids(data_root, overlap_db, collections)
|
||||
|
||||
logger.info("处理完成")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
if len(sys.argv) < 2:
|
||||
logger.error("用法: python3 tile_generate_to_db.py <data_root> [zoom_level]")
|
||||
sys.exit(1)
|
||||
|
||||
data_root = sys.argv[1]
|
||||
zoom_level = int(sys.argv[2]) if len(sys.argv) > 2 else 14
|
||||
main(data_root, zoom_level)
|
||||
@@ -0,0 +1,467 @@
|
||||
import os
|
||||
import yaml
|
||||
import json
|
||||
import math
|
||||
import pandas as pd
|
||||
import plotly.graph_objects as go
|
||||
import plotly.express as px
|
||||
from typing import List, Dict, Set, Tuple
|
||||
import utm
|
||||
import sys
|
||||
import logging
|
||||
import numpy as np
|
||||
from plotly.subplots import make_subplots
|
||||
import requests
|
||||
from io import BytesIO
|
||||
from PIL import Image
|
||||
from pymongo import MongoClient
|
||||
import hashlib
|
||||
|
||||
# 配置日志
|
||||
logging.basicConfig(level=logging.INFO, format="[%(levelname)s] %(message)s")
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def load_config(config_path="config.yml"):
|
||||
"""加载配置文件"""
|
||||
try:
|
||||
with open(config_path) as f:
|
||||
return yaml.safe_load(f)
|
||||
except Exception as e:
|
||||
logger.error(f"加载配置文件失败: {e}")
|
||||
raise
|
||||
|
||||
|
||||
config = load_config()
|
||||
|
||||
|
||||
def calculate_bearing(x1, y1, x2, y2):
|
||||
"""计算两点之间的朝向角(弧度)"""
|
||||
dx = x2 - x1
|
||||
dy = y2 - y1
|
||||
return math.atan2(dy, dx)
|
||||
|
||||
|
||||
def calculate_cluster_center(bag_grid_map, cluster):
|
||||
"""计算单个聚类的中心点"""
|
||||
all_x = []
|
||||
all_y = []
|
||||
all_z = []
|
||||
for bag_name in cluster:
|
||||
if bag_name in bag_grid_map:
|
||||
xyz_points = bag_grid_map[bag_name]
|
||||
all_x.extend([x for x, y, z in xyz_points])
|
||||
all_y.extend([y for x, y, z in xyz_points])
|
||||
all_z.extend([z for x, y, z in xyz_points])
|
||||
if not all_x:
|
||||
return 0, 0, 0 # 默认值
|
||||
return np.mean(all_x), np.mean(all_y), np.mean(all_z)
|
||||
|
||||
|
||||
def create_direction_arrows(xy_points, spacing=5):
|
||||
"""创建方向箭头数据"""
|
||||
arrows = []
|
||||
for i in range(0, len(xy_points) - 1, spacing):
|
||||
if i + 1 >= len(xy_points):
|
||||
continue
|
||||
x1, y1 = xy_points[i]
|
||||
x2, y2 = xy_points[i + 1]
|
||||
angle = calculate_bearing(x1, y1, x2, y2)
|
||||
|
||||
# 箭头中间点
|
||||
mid_x = (x1 + x2) / 2
|
||||
mid_y = (y1 + y2) / 2
|
||||
|
||||
arrows.append({
|
||||
"x": mid_x,
|
||||
"y": mid_y,
|
||||
"angle": -angle,
|
||||
"angle_deg": 90 - rad_to_deg(angle),
|
||||
})
|
||||
return arrows
|
||||
|
||||
|
||||
def rad_to_deg(rad):
|
||||
"""将弧度转换为角度"""
|
||||
return rad * 180 / math.pi
|
||||
|
||||
|
||||
def get_3d_grids_from_db(collections, bag_name):
|
||||
"""从数据库获取三维网格集合"""
|
||||
bag_data = collections["processed_bags"].find_one({"_id": bag_name})
|
||||
if not bag_data:
|
||||
logger.warning(f"未找到包 {bag_name} 的数据")
|
||||
return None
|
||||
|
||||
if "xyz_points" in bag_data:
|
||||
return bag_data["xyz_points"]
|
||||
elif "grids" in bag_data:
|
||||
# 如果只有网格数据,没有原始点,则返回网格中心点
|
||||
grid_points = []
|
||||
for grid in bag_data["grids"]:
|
||||
grid_x, grid_y, grid_z = grid
|
||||
center_x = (grid_x + 0.5) * 10.0
|
||||
center_y = (grid_y + 0.5) * 10.0
|
||||
center_z = (grid_z + 0.5) * 5.0
|
||||
grid_points.append((center_x, center_y, center_z))
|
||||
return grid_points
|
||||
else:
|
||||
logger.warning(f"包 {bag_name} 没有轨迹点数据")
|
||||
return None
|
||||
|
||||
|
||||
def utm_to_latlon(easting, northing, zone_number, zone_letter):
|
||||
"""将UTM坐标转换为经纬度"""
|
||||
try:
|
||||
lat, lon = utm.to_latlon(easting, northing, zone_number, zone_letter)
|
||||
return lat, lon
|
||||
except Exception as e:
|
||||
logger.error(f"UTM转换失败: {e}")
|
||||
return None, None
|
||||
|
||||
|
||||
class TileVisualizer:
|
||||
def __init__(self, output_dir: str = "tile_visualizations"):
|
||||
self.mongo_uri = config["mongodb"]["uri"]
|
||||
self.db_name = config["mongodb"]["db_name"]
|
||||
self.tile_server_url = config["tile_server"]["url"]
|
||||
self.output_dir = output_dir
|
||||
self.client = MongoClient(self.mongo_uri)
|
||||
self.db = self.client[self.db_name]
|
||||
self.collections = {
|
||||
"tile_db": self.db.tile_db,
|
||||
"processed_bags": self.db.processed_bags,
|
||||
}
|
||||
self.tile_db = self._load_tile_db()
|
||||
self.color_cycle = px.colors.qualitative.Plotly
|
||||
|
||||
def close(self):
|
||||
if self.client is not None:
|
||||
self.client.close()
|
||||
self.client = None
|
||||
self.db = None
|
||||
|
||||
def _load_tile_db(self) -> Dict:
|
||||
"""从MongoDB加载tileDB数据"""
|
||||
try:
|
||||
tiles = list(self.collections["tile_db"].find({}, {"_id": 0}))
|
||||
return {"TileDB": tiles}
|
||||
except Exception as e:
|
||||
logger.error(f"从MongoDB加载tileDB失败: {e}")
|
||||
return {"TileDB": []}
|
||||
|
||||
def _get_trajectory_points_from_db(
|
||||
self, bag_name: str
|
||||
) -> List[Tuple[float, float, float]]:
|
||||
"""从数据库获取轨迹点(经纬度+高度)"""
|
||||
bag_data = self.collections["processed_bags"].find_one({"_id": bag_name})
|
||||
if not bag_data:
|
||||
logger.warning(f"未找到包 {bag_name} 的数据")
|
||||
return []
|
||||
|
||||
# 尝试直接获取经纬度点
|
||||
if "latlon_points" in bag_data:
|
||||
return [
|
||||
(point[0], point[1], point[2]) for point in bag_data["latlon_points"]
|
||||
]
|
||||
|
||||
# 如果有UTM坐标,则转换为经纬度
|
||||
if "xyz_points" in bag_data:
|
||||
# 注意:这里需要假设所有点在同一个UTM区域
|
||||
# 实际上,每个点可能有不同的区域,但为了简化,我们使用第一个点的区域
|
||||
if not bag_data["xyz_points"]:
|
||||
return []
|
||||
|
||||
# 尝试获取UTM区域信息(如果存储了)
|
||||
zone_num = bag_data.get("utm_zone_num", 50) # 默认值
|
||||
zone_letter = bag_data.get("utm_zone_letter", "N") # 默认值
|
||||
|
||||
latlon_points = []
|
||||
for point in bag_data["xyz_points"]:
|
||||
easting, northing, alt = point
|
||||
lat, lon = utm_to_latlon(easting, northing, zone_num, zone_letter)
|
||||
if lat is not None and lon is not None:
|
||||
latlon_points.append((lat, lon, alt))
|
||||
return latlon_points
|
||||
|
||||
logger.warning(f"包 {bag_name} 没有可用的轨迹点数据")
|
||||
return []
|
||||
|
||||
def _calculate_bearing(
|
||||
self, lat1: float, lon1: float, lat2: float, lon2: float
|
||||
) -> float:
|
||||
"""计算两点之间的朝向角(度数)"""
|
||||
lat1, lon1, lat2, lon2 = map(math.radians, [lat1, lon1, lat2, lon2])
|
||||
dlon = lon2 - lon1
|
||||
x = math.sin(dlon) * math.cos(lat2)
|
||||
y = math.cos(lat1) * math.sin(lat2) - math.sin(lat1) * math.cos(
|
||||
lat2
|
||||
) * math.cos(dlon)
|
||||
return math.degrees(math.atan2(x, y))
|
||||
|
||||
def _create_arrows(
|
||||
self, lats: List[float], lons: List[float], spacing: int = 5
|
||||
) -> List[Dict]:
|
||||
"""创建方向箭头数据"""
|
||||
arrows = []
|
||||
for i in range(0, len(lats) - 1, spacing):
|
||||
if i + 1 >= len(lats):
|
||||
continue
|
||||
angle = self._calculate_bearing(lats[i], lons[i], lats[i + 1], lons[i + 1])
|
||||
arrows.append({
|
||||
"lat": (lats[i] + lats[i + 1]) / 2,
|
||||
"lon": (lons[i] + lons[i + 1]) / 2,
|
||||
"angle": -angle,
|
||||
"angle_deg": 90 - angle,
|
||||
})
|
||||
return arrows
|
||||
|
||||
def visualize_tile(self, tile_id: str):
|
||||
"""可视化单个瓦片的轨迹"""
|
||||
tile_data = next(
|
||||
(t for t in self.tile_db["TileDB"] if t["tileid"] == tile_id), None
|
||||
)
|
||||
if not tile_data:
|
||||
logger.warning(f"未找到瓦片 {tile_id} 的数据")
|
||||
return
|
||||
|
||||
trajectories = tile_data.get("trajectories", [])
|
||||
if not trajectories:
|
||||
logger.info(f"瓦片 {tile_id} 没有轨迹数据")
|
||||
return
|
||||
|
||||
bag_grid_map = dict()
|
||||
for bag_name in trajectories:
|
||||
xyz_points = get_3d_grids_from_db(self.collections, bag_name)
|
||||
if xyz_points:
|
||||
bag_grid_map[bag_name] = xyz_points
|
||||
|
||||
fig = go.Figure()
|
||||
colors = px.colors.qualitative.Plotly
|
||||
all_lats, all_lons = [], []
|
||||
|
||||
# 计算聚类中心
|
||||
center_x, center_y, center_z = calculate_cluster_center(
|
||||
bag_grid_map, trajectories
|
||||
)
|
||||
logger.info(
|
||||
f"聚类 {tile_id} 中心点: ({center_x:.1f}, {center_y:.1f}, {center_z:.1f})"
|
||||
)
|
||||
|
||||
tz, tx, ty = map(int, tile_id.split("/"))
|
||||
tile_images = {}
|
||||
|
||||
# 下载周边瓦片
|
||||
for dx in [-1, 0, 1]:
|
||||
for dy in [-1, 0, 1]:
|
||||
ttx = tx + dx
|
||||
tty = ty + dy
|
||||
tile_url = self.tile_server_url.format(z=tz, x=ttx, y=tty)
|
||||
try:
|
||||
response = requests.get(tile_url, timeout=5)
|
||||
if response.status_code == 200:
|
||||
tile_images[(dx, dy)] = Image.open(BytesIO(response.content))
|
||||
else:
|
||||
logger.warning(
|
||||
f"无法获取瓦片 {tz}/{ttx}/{tty},使用空白图片代替"
|
||||
)
|
||||
tile_images[(dx, dy)] = Image.new(
|
||||
"RGB", (512, 512), (255, 255, 255)
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"下载瓦片 {tz}/{ttx}/{tty} 失败: {e}")
|
||||
tile_images[(dx, dy)] = Image.new(
|
||||
"RGB", (512, 512), (255, 255, 255)
|
||||
)
|
||||
|
||||
# 拼接瓦片
|
||||
width = 512 * 3
|
||||
height = 512 * 3
|
||||
combined_image = Image.new("RGB", (width, height))
|
||||
|
||||
for dx in [-1, 0, 1]:
|
||||
for dy in [-1, 0, 1]:
|
||||
ttx = (dx + 1) * 512
|
||||
tty = (dy + 1) * 512
|
||||
combined_image.paste(tile_images[(dx, dy)], (ttx, tty))
|
||||
|
||||
# 添加底图
|
||||
fig.add_layout_image(
|
||||
dict(
|
||||
source=combined_image,
|
||||
xref="x",
|
||||
yref="y",
|
||||
x=0,
|
||||
y=height,
|
||||
sizex=width,
|
||||
sizey=height,
|
||||
sizing="stretch",
|
||||
opacity=0.8,
|
||||
layer="below",
|
||||
)
|
||||
)
|
||||
|
||||
def latlon_to_tile_pixel(lat, lon, tile_x, tile_y, zoom):
|
||||
"""将经纬度转换为瓦片像素坐标"""
|
||||
n = 2**zoom
|
||||
xtile = (lon + 180) / 360 * n
|
||||
|
||||
ytile = (
|
||||
(
|
||||
1
|
||||
- math.log(
|
||||
math.tan(math.radians(lat)) + 1 / math.cos(math.radians(lat))
|
||||
)
|
||||
/ math.pi
|
||||
)
|
||||
/ 2
|
||||
* n
|
||||
)
|
||||
x_pixel = (xtile - tile_x - 0.386) * 512
|
||||
y_pixel = (ytile - tile_y - 0.23) * 512
|
||||
return x_pixel, y_pixel
|
||||
|
||||
for i, bag_name in enumerate(trajectories):
|
||||
# 从数据库获取轨迹点
|
||||
points = self._get_trajectory_points_from_db(bag_name)
|
||||
if not points:
|
||||
logger.warning(f"包 {bag_name} 没有轨迹点数据")
|
||||
continue
|
||||
|
||||
x = []
|
||||
y = []
|
||||
|
||||
for point in points:
|
||||
lat, lon, _ = point
|
||||
pt_x, pt_y = latlon_to_tile_pixel(lat, lon, tx, ty, tz)
|
||||
|
||||
# 转换为大图的坐标
|
||||
pt_x += 512 # 中心瓦片x偏移
|
||||
pt_y += 512 # 中心瓦片y偏移
|
||||
|
||||
# 调整坐标系(原点在左上角)
|
||||
pt_y = height - pt_y
|
||||
|
||||
if 0 <= pt_x <= width and 0 <= pt_y <= height:
|
||||
x.append(pt_x)
|
||||
y.append(pt_y)
|
||||
|
||||
if not x:
|
||||
logger.warning(f"包 {bag_name} 没有在瓦片范围内的点")
|
||||
continue
|
||||
|
||||
# 轨迹线
|
||||
fig.add_trace(
|
||||
go.Scatter(
|
||||
x=x,
|
||||
y=y,
|
||||
mode="lines",
|
||||
name=bag_name,
|
||||
line=dict(width=4, color=colors[i % len(colors)]),
|
||||
showlegend=True,
|
||||
legendgroup=bag_name,
|
||||
hoverinfo="text",
|
||||
text=[f"{bag_name}<br>点 {j + 1}/{len(x)}" for j in range(len(x))],
|
||||
)
|
||||
)
|
||||
|
||||
# 方向箭头(2D投影)
|
||||
if len(x) > 1:
|
||||
xy_points = list(zip(x, y))
|
||||
arrows = create_direction_arrows(xy_points)
|
||||
if arrows:
|
||||
arrow_x = [arrow["x"] for arrow in arrows]
|
||||
arrow_y = [arrow["y"] for arrow in arrows]
|
||||
arrow_text = [
|
||||
f"{bag_name}<br>朝向: {arrow['angle_deg']:.1f}°"
|
||||
for arrow in arrows
|
||||
]
|
||||
|
||||
fig.add_trace(
|
||||
go.Scatter(
|
||||
x=arrow_x,
|
||||
y=arrow_y,
|
||||
mode="markers",
|
||||
name=f"{bag_name} 方向",
|
||||
marker=dict(
|
||||
symbol="arrow-up",
|
||||
size=8,
|
||||
color=colors[i % len(colors)],
|
||||
angle=[arrow["angle_deg"] for arrow in arrows],
|
||||
line=dict(width=1, color="black"),
|
||||
),
|
||||
hoverinfo="text",
|
||||
text=arrow_text,
|
||||
showlegend=False,
|
||||
legendgroup=bag_name,
|
||||
)
|
||||
)
|
||||
|
||||
fig.update_layout(
|
||||
title=f"瓦片 {tile_id} 轨迹可视化 (共{len(trajectories)}条轨迹)",
|
||||
xaxis=dict(
|
||||
range=[0, width],
|
||||
scaleanchor="y",
|
||||
title="X px",
|
||||
showgrid=False,
|
||||
zeroline=False,
|
||||
constrain="domain",
|
||||
),
|
||||
yaxis=dict(
|
||||
range=[0, height],
|
||||
scaleanchor="x",
|
||||
title="Y px",
|
||||
showgrid=False,
|
||||
zeroline=False,
|
||||
),
|
||||
showlegend=True,
|
||||
legend=dict(
|
||||
orientation="h", yanchor="bottom", y=1.02, xanchor="right", x=1
|
||||
),
|
||||
height=800,
|
||||
width=800,
|
||||
)
|
||||
|
||||
# 保存HTML文件
|
||||
os.makedirs(self.output_dir, exist_ok=True)
|
||||
safe_tile_id = tile_id.replace("/", "_")
|
||||
output_path = os.path.join(self.output_dir, f"tile_{safe_tile_id}.html")
|
||||
fig.write_html(output_path)
|
||||
logger.info(f"已保存瓦片 {tile_id} 可视化结果到 {output_path}")
|
||||
|
||||
def visualize_all_tiles(self):
|
||||
"""可视化所有瓦片"""
|
||||
for tile in self.tile_db["TileDB"]:
|
||||
self.visualize_tile(tile["tileid"])
|
||||
|
||||
def visualize_from_txt(self, txt_path: str):
|
||||
try:
|
||||
with open(txt_path, "r") as f:
|
||||
tile_ids = set(line.strip() for line in f if line.strip())
|
||||
except FileNotFoundError:
|
||||
logger.error(f"Not Found txt")
|
||||
return
|
||||
tiles_data = list(
|
||||
self.collections["tile_db"].find({"tileid": {"$in": list(tile_ids)}})
|
||||
)
|
||||
|
||||
if not tiles_data:
|
||||
logger.warning("no match tile")
|
||||
return
|
||||
for tile_data in tiles_data:
|
||||
self.visualize_tile(tile_data["tileid"])
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
if len(sys.argv) < 2:
|
||||
logger.error("用法: python3 tile_visualization_from_db.py <data_root>")
|
||||
sys.exit(1)
|
||||
|
||||
data_root = sys.argv[1]
|
||||
# 创建可视化器并运行
|
||||
visualizer = TileVisualizer(
|
||||
output_dir=os.path.join(data_root, "tile_visualizations")
|
||||
)
|
||||
txt_path = os.path.join(data_root, "processed_tiles.txt")
|
||||
visualizer.visualize_from_txt(txt_path)
|
||||
visualizer.close()
|
||||
0
fst_data_pipeline/pipelines/volc/__init__.py
Normal file
0
fst_data_pipeline/pipelines/volc/__init__.py
Normal file
97
fst_data_pipeline/pipelines/volc/bag-copy.sh
Normal file
97
fst_data_pipeline/pipelines/volc/bag-copy.sh
Normal file
@@ -0,0 +1,97 @@
|
||||
#!/usr/bin/env bash
|
||||
set -euo pipefail
|
||||
# 外部必须导出:
|
||||
# INPUT_ROOT WORKFLOW_ID SUB_DIR OUTPUT_ROOT
|
||||
|
||||
INPUT_ROOT="${INPUT_ROOT}/${WORKFLOW_ID}"
|
||||
INPUT_DIR="${INPUT_ROOT}/${SUB_DIR}" # 实际指向 bagdir_splits 上一级
|
||||
OUTPUT_DIR="${OUTPUT_ROOT}" # 通常就是 /nas_perception/.../output/wfxxx
|
||||
|
||||
log(){
|
||||
echo "[$(date '+%F %T')] [$$] $*"
|
||||
}
|
||||
|
||||
#----------------------------------------------------------
|
||||
# 统一封装:目录同步 / 文件拷贝
|
||||
#----------------------------------------------------------
|
||||
sync_dir(){
|
||||
local srcDir="$1" dstDir="$2"
|
||||
if [[ -d ${srcDir} ]]; then
|
||||
log " + $(basename "${srcDir}")/"
|
||||
mkdir -p "${dstDir}"
|
||||
rsync -a --delete "${srcDir}/" "${dstDir}/"
|
||||
fi
|
||||
}
|
||||
|
||||
sync_file(){
|
||||
local srcFile="$1" dstDir="$2"
|
||||
if [[ -f ${srcFile} ]]; then
|
||||
log " + $(basename "${srcFile}")"
|
||||
mkdir -p "${dstDir}"
|
||||
cp -p "${srcFile}" "${dstDir}/"
|
||||
fi
|
||||
}
|
||||
|
||||
#----------------------------------------------------------
|
||||
# 主逻辑
|
||||
#----------------------------------------------------------
|
||||
log "Script started"
|
||||
log "INPUT_DIR = ${INPUT_DIR}"
|
||||
log "OUTPUT_DIR = ${OUTPUT_DIR}"
|
||||
|
||||
[[ -d ${INPUT_DIR} ]] || { log "ERROR: INPUT_DIR not found: ${INPUT_DIR}"; exit 1; }
|
||||
|
||||
while IFS= read -r -d '' src; do
|
||||
# 去掉 split_N 层级,得到纯 bag.dir 名
|
||||
rel="${src#${INPUT_DIR}/*/}"
|
||||
dest="${OUTPUT_DIR}/${rel}/derived/${SUB_DIR}"
|
||||
mkdir -p "${dest}"
|
||||
|
||||
basename_bag=$(basename "$src") # xxx.bag.dir
|
||||
pkgname="${basename_bag%.bag.dir}" # xxx
|
||||
split_name=$(basename "$(dirname "$src")") # split_0 / split_1 / ...
|
||||
truth_root="${INPUT_DIR}/${split_name}"
|
||||
|
||||
log "==================== Processing ${pkgname} ==================="
|
||||
|
||||
# 1. 老 object 重命名同步
|
||||
sync_dir "${src}/object_det_ep20" "${dest}/object_det_al"
|
||||
sync_dir "${src}/object_tracking" "${dest}/object_tracking_al"
|
||||
|
||||
# 2. 目录类:slam + lidar_gt + 新增 6 目录
|
||||
for item in slam_lidar_ground slam_lidar_none_ground \
|
||||
lidar_gt_pandar128_5f_front lidar_gt_pandar128_5f_rear \
|
||||
object_det_ep20_lrgt_front object_det_ep20_lrgt_rear \
|
||||
object_lrgt_filter object_postprocess; do
|
||||
sync_dir "${src}/${item}" "${dest}/${item}"
|
||||
done
|
||||
|
||||
# 3. 文件类
|
||||
for item in bev_image_ground.png ego_motion_slam_lidar.csv; do
|
||||
sync_file "${src}/${item}" "${dest}"
|
||||
done
|
||||
|
||||
# 4. osm & split_json(源在 OUTPUT_ROOT/SUB_DIR/split_N/)
|
||||
osm_src="${truth_root}/osm_out/${pkgname}.osm"
|
||||
if [[ -f ${osm_src} ]]; then
|
||||
log " + ${pkgname}.osm -> input bag.dir"
|
||||
cp -p "${osm_src}" "${src}/"
|
||||
log " + ${pkgname}.osm -> output"
|
||||
cp -p "${src}/${pkgname}.osm" "${dest}/"
|
||||
fi
|
||||
|
||||
split_src="${truth_root}/split_json/${pkgname}"
|
||||
if [[ -d ${split_src} ]]; then
|
||||
log " + split_json/ -> input bag.dir"
|
||||
rsync -a --delete "${split_src}/" "${src}/split_json/"
|
||||
log " + split_json/ -> output"
|
||||
sync_dir "${src}/split_json" "${dest}/split_json"
|
||||
fi
|
||||
|
||||
# 5. 2dseg and occ
|
||||
sync_dir "${src}/${SUB_DIR}" "${dest}"
|
||||
|
||||
done < <(find "${INPUT_DIR}" -mindepth 2 -maxdepth 2 -type d -name '*.bag.dir' -print0)
|
||||
|
||||
log "============================================================"
|
||||
log "All done, success!"
|
||||
139
fst_data_pipeline/pipelines/volc/bag_operation/bag_scanner.py
Normal file
139
fst_data_pipeline/pipelines/volc/bag_operation/bag_scanner.py
Normal file
@@ -0,0 +1,139 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
check_bags.py
|
||||
|
||||
环境变量:
|
||||
BAG_DIR 必填 bag 根目录
|
||||
GT_API_URL 可选 获取 pipeline 路径的接口,默认 http://10.204.22.135:30000/api/gt/types
|
||||
OUTPUT_PREFIX 可选 输出前缀(直接拼接)
|
||||
其余变量 真值控制,示例:
|
||||
OBJECT_DETECTION=true
|
||||
LANE_DETECTION=false
|
||||
SLAM_GROUND=true
|
||||
…
|
||||
仅当变量值为 true/false 时参与检查;
|
||||
true → 该 path 必须存在(若是目录则不能为空)
|
||||
false → 该 path 必须不存在
|
||||
其它值或缺失 → 忽略
|
||||
|
||||
结果同时输出到 stdout 和 list.txt(每行一条完整拼接路径)
|
||||
新增:
|
||||
- 扫描前检查 BAG_DIR 是否为空
|
||||
- 统计:总 bag 数、规则数、通过数、失败数
|
||||
- 目录存在时额外检查“非空”
|
||||
"""
|
||||
|
||||
import os
|
||||
import requests
|
||||
import logging
|
||||
import sys
|
||||
|
||||
# ---------- 日志 ----------
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format="[%(asctime)s][%(levelname)s] %(message)s",
|
||||
datefmt="%Y-%m-%d %H:%M:%S",
|
||||
)
|
||||
log = logging.getLogger("check_bags")
|
||||
|
||||
# ---------- 1. 基础目录 ----------
|
||||
BASE = os.environ.get("BAG_DIR")
|
||||
if not BASE or not os.path.isdir(BASE):
|
||||
log.error("BAG_DIR not set or not a directory")
|
||||
sys.exit(1)
|
||||
|
||||
PREFIX = os.environ.get("OUTPUT_PREFIX", "")
|
||||
GT_API_URL = os.environ.get(
|
||||
"GT_API_URL", "http://10.204.22.135:30000/api/gt/types"
|
||||
).rstrip()
|
||||
log.info("GT_API_URL = %s", GT_API_URL)
|
||||
|
||||
|
||||
# ---------- 2. 拉取 API ----------
|
||||
try:
|
||||
log.info("fetching pipeline list from %s", GT_API_URL)
|
||||
api = requests.get(GT_API_URL, timeout=10).json()
|
||||
log.info("got %d items from API", len(api))
|
||||
except Exception as e:
|
||||
log.error("API unreachable: %s", e)
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
# ---------- 3. 收集检查规则 ----------
|
||||
checks = []
|
||||
for item in api:
|
||||
if item.get("type") != "pipeline":
|
||||
continue
|
||||
name = item["name"]
|
||||
env_val = os.environ.get(name, "").lower()
|
||||
if env_val in ("true", "false"):
|
||||
path = item["path"].lstrip("/")
|
||||
must_exist = env_val == "true"
|
||||
checks.append((path, must_exist))
|
||||
log.info("check rule: %-30s must_exist=%-5s path=%s", name, must_exist, path)
|
||||
|
||||
if not checks:
|
||||
log.error("No pipeline paths enabled for check")
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
# ---------- 4. 遍历 bag + 统计 ----------
|
||||
def _empty_dir(p: str) -> bool:
|
||||
"""目录存在且为空返回 True"""
|
||||
return os.path.isdir(p) and not bool(os.listdir(p))
|
||||
|
||||
|
||||
valid_cnt = invalid_cnt = 0
|
||||
bag_dirs = [
|
||||
d
|
||||
for d in os.listdir(BASE)
|
||||
if d.endswith(".bag.dir") and os.path.isdir(os.path.join(BASE, d))
|
||||
]
|
||||
if not bag_dirs:
|
||||
log.error("No *.bag.dir found under BAG_DIR (%s), aborting", BASE)
|
||||
sys.exit(1)
|
||||
|
||||
total_bag = len(bag_dirs)
|
||||
log.info("start scanning %d bag(s) against %d rule(s)", total_bag, len(checks))
|
||||
|
||||
valid = []
|
||||
for bag in bag_dirs:
|
||||
bag_path = os.path.join(BASE, bag)
|
||||
ok = True
|
||||
for rel, must_exist in checks:
|
||||
full = os.path.join(bag_path, rel)
|
||||
exists = os.path.exists(full)
|
||||
# 关键:目录不能为空
|
||||
if must_exist and os.path.isdir(full) and _empty_dir(full):
|
||||
exists = False
|
||||
if exists != must_exist:
|
||||
log.debug(
|
||||
"bag %s failed: %s exists=%s required=%s",
|
||||
bag,
|
||||
rel,
|
||||
exists,
|
||||
must_exist,
|
||||
)
|
||||
ok = False
|
||||
break
|
||||
if ok:
|
||||
valid_cnt += 1
|
||||
valid.append(bag_path)
|
||||
log.info("valid bag: %s", bag)
|
||||
else:
|
||||
invalid_cnt += 1
|
||||
|
||||
|
||||
# ---------- 5. 输出结果 & 统计 ----------
|
||||
out_file = "list.txt"
|
||||
with open(out_file, "w") as f:
|
||||
for bag_path in valid:
|
||||
line = f"{PREFIX}{os.path.basename(bag_path)}"
|
||||
f.write(line + "\n")
|
||||
|
||||
log.info("==== summary ====")
|
||||
log.info("total bags : %d", total_bag)
|
||||
log.info("rules : %d", len(checks))
|
||||
log.info("passed : %d", valid_cnt)
|
||||
log.info("failed : %d", invalid_cnt)
|
||||
log.info("wrote %d bags to %s and stdout", len(valid), out_file)
|
||||
113
fst_data_pipeline/pipelines/volc/bag_operation/merge_rosbag.py
Normal file
113
fst_data_pipeline/pipelines/volc/bag_operation/merge_rosbag.py
Normal file
@@ -0,0 +1,113 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
import os
|
||||
import shutil
|
||||
import subprocess
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from concurrent.futures import ProcessPoolExecutor, as_completed
|
||||
|
||||
import requests
|
||||
import tos
|
||||
import psycopg2
|
||||
from tqdm import tqdm
|
||||
|
||||
# ---------- 日志 ----------
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format="%(asctime)s | %(levelname)s | %(message)s",
|
||||
handlers=[logging.FileHandler("bag_merge.log"), logging.StreamHandler()],
|
||||
)
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
# ---------- 环境变量 ----------
|
||||
API_URL = os.getenv("API_URL")
|
||||
TOS_ENDPOINT = os.getenv("TOS_ENDPOINT")
|
||||
TOS_REGION = os.getenv("TOS_REGION")
|
||||
TOS_BUCKET = os.getenv("TOS_BUCKET")
|
||||
TOS_AK = os.getenv("TOS_ACCESS_KEY")
|
||||
TOS_SK = os.getenv("TOS_SECRET_KEY")
|
||||
PG_DSN = os.getenv("PG_DSN")
|
||||
TEMP_ROOT = Path(os.getenv("TEMP_ROOT", "/tmp/bag_merge"))
|
||||
|
||||
# ---------- TOS 客户端 ----------
|
||||
tos_client = tos.TosClientV2(TOS_AK, TOS_SK, TOS_ENDPOINT, TOS_REGION)
|
||||
|
||||
|
||||
# ---------- 原子函数 ----------
|
||||
def fetch_mapping() -> dict:
|
||||
log.info("POST %s", API_URL)
|
||||
resp = requests.post(
|
||||
API_URL,
|
||||
json={"bag_names": ["*"]},
|
||||
headers={"Content-Type": "application/json"},
|
||||
timeout=30,
|
||||
)
|
||||
resp.raise_for_status()
|
||||
return resp.json()
|
||||
|
||||
|
||||
def download_file(key: str, local: Path):
|
||||
meta = tos_client.head_object(TOS_BUCKET, key)
|
||||
total = int(meta.content_length)
|
||||
with tqdm(total=total, unit="B", unit_scale=True, desc=f"↓ {key}") as bar:
|
||||
tos_client.get_object_to_file(
|
||||
TOS_BUCKET,
|
||||
key,
|
||||
str(local),
|
||||
progress_callback=lambda c, t: bar.update(t - c),
|
||||
)
|
||||
|
||||
|
||||
def upload_file(local: Path, key: str) -> str:
|
||||
tos_client.put_object_from_file(TOS_BUCKET, key, str(local))
|
||||
return f"https://{TOS_BUCKET}.{TOS_ENDPOINT}/{key}"
|
||||
|
||||
|
||||
def merge_bags(inputs: list[Path], output: Path):
|
||||
subprocess.check_call(
|
||||
["rosbag-merge", "-o", str(output)] + [str(p) for p in inputs]
|
||||
)
|
||||
|
||||
|
||||
def update_db(parent: str, tos_url: str):
|
||||
sql = "UPDATE bag_task SET tos_path = %s WHERE parent_bag = %s"
|
||||
with psycopg2.connect(PG_DSN) as conn:
|
||||
with conn.cursor() as cur:
|
||||
cur.execute(sql, (tos_url, parent))
|
||||
conn.commit()
|
||||
log.info("[DB] %s tos_path ⇢ %s", parent, tos_url)
|
||||
|
||||
|
||||
def work_one(parent: str, children: list[str]) -> str:
|
||||
log.info("start parent=%s children=%d", parent, len(children))
|
||||
wd = TEMP_ROOT / parent
|
||||
wd.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
subs = [wd / c for c in children]
|
||||
for c, s in zip(children, subs):
|
||||
download_file(c, s)
|
||||
|
||||
out = wd / parent
|
||||
merge_bags(subs, out)
|
||||
|
||||
url = upload_file(out, parent)
|
||||
update_db(parent, url)
|
||||
|
||||
shutil.rmtree(wd)
|
||||
log.info("finish parent=%s", parent)
|
||||
return url
|
||||
|
||||
|
||||
# ---------- 主入口 ----------
|
||||
def main():
|
||||
TEMP_ROOT.mkdir(parents=True, exist_ok=True)
|
||||
mapping = fetch_mapping()
|
||||
with ProcessPoolExecutor() as pool:
|
||||
futures = {pool.submit(work_one, p, c): p for p, c in mapping.items()}
|
||||
for fu in as_completed(futures):
|
||||
log.info("done %s -> %s", futures[fu], fu.result())
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user