AI

SmolVLA 模型

机器人

Posted by LXG on April 16, 2026

SmolVLA 模型

SmolVLA 是近年来在机器人 + 多模态领域里一个比较“轻量但很实用”的模型方向,名字可以拆成两部分来理解:

  • Smol:小模型(small model 的网络文化写法)
  • VLA(Vision-Language-Action):视觉 + 语言 + 动作 的统一模型

模型结构


        图像 (Camera)
             ↓
     Vision Encoder(轻量CNN/ViT)
             ↓
        特征向量
             ↓
文本指令 → Language Encoder(小型Transformer)
             ↓
        融合(Cross Attention)
             ↓
      Action Head(动作预测)
             ↓
        控制信号

lerobot/smolvla_base

smolvla_base-魔塔社区

  • 输入: 图像(多视角)、本体感知/状态、可选的语言指令
  • 输出: 连续动作
  • 训练目标: 流匹配(flow matching)
  • 动作表示: 连续
  • 预期用途: 作为基础模型,在您的特定用例上进行微调

smolvla_base$ tree
.
├── collage_small.gif
├── config.json
├── configuration.json
├── Finetune_SmolVLA_notebook.ipynb
├── model.safetensors
├── policy_postprocessor.json
├── policy_postprocessor_step_0_unnormalizer_processor.safetensors
├── policy_preprocessor.json
├── policy_preprocessor_step_5_normalizer_processor.safetensors
└── README.md

SmolVLM2-500M-Video-Instruct

SmolVLM2-500M-Video-Instruct-魔塔社区

SmolVLM2-500M-Video 是一个轻量级的多模态模型,设计用于分析视频内容。该模型处理视频、图像和文本输入以生成文本输出——无论是回答关于媒体文件的问题、比较视觉内容,还是从图像中转录文本。尽管其体积小巧,仅需 1.8GB 的 GPU 内存即可进行视频推理,但它在复杂的多模态任务上表现出色。这种效率使其特别适合于计算资源可能有限的设备端应用。


SmolVLM2-500M-Video-Instruct$ tree 
.
├── added_tokens.json
├── chat_template.json
├── config.json
├── configuration.json
├── generation_config.json
├── merges.txt
├── model.safetensors
├── onnx
│   ├── decoder_model_merged_bnb4.onnx
│   ├── decoder_model_merged_fp16.onnx
│   ├── decoder_model_merged_int8.onnx
│   ├── decoder_model_merged.onnx
│   ├── decoder_model_merged_q4f16.onnx
│   ├── decoder_model_merged_q4.onnx
│   ├── decoder_model_merged_quantized.onnx
│   ├── decoder_model_merged_uint8.onnx
│   ├── embed_tokens_bnb4.onnx
│   ├── embed_tokens_fp16.onnx
│   ├── embed_tokens_int8.onnx
│   ├── embed_tokens.onnx
│   ├── embed_tokens_q4f16.onnx
│   ├── embed_tokens_q4.onnx
│   ├── embed_tokens_quantized.onnx
│   ├── embed_tokens_uint8.onnx
│   ├── vision_encoder_bnb4.onnx
│   ├── vision_encoder_fp16.onnx
│   ├── vision_encoder_int8.onnx
│   ├── vision_encoder.onnx
│   ├── vision_encoder_q4f16.onnx
│   ├── vision_encoder_q4.onnx
│   ├── vision_encoder_quantized.onnx
│   └── vision_encoder_uint8.onnx
├── preprocessor_config.json
├── processor_config.json
├── README.md
├── special_tokens_map.json
├── tokenizer_config.json
├── tokenizer.json
└── vocab.json

项目 SmolVLM2 SmolVLA
输入 图像 + 文本 图像特征 + state
输出 语义 动作
是否控制机器人
是否需要 state

为什么要分成两个模型

分成两个模型,是为了:把“通用理解”和“具体控制”解耦,降低训练难度、提高复用性和泛化能力

1️⃣ 视觉模型(SmolVLM2)


image / video + task → 语义特征(embedding)

👉 这里不是输出一句话,而是:

一串向量(比如 1024维)

2️⃣ 融合输入


语义特征 + state + task → 拼在一起

例如:

[视觉embedding | state向量 | task embedding]

3️⃣ 策略模型(SmolVLA)


输入 → 动作(action)

例如:

→ [steering, speed]
或
→ 机械臂6D动作

使用场景

机械臂操作(最核心)


输入:
- camera1:桌面
- camera2:机械臂视角
- task:Pick up the red cup

输出:
- 机械臂6自由度动作

案例

机械臂看到桌面场景,执行:“把红色方块抓起来,放到蓝色区域”

输入是什么


camera1: [3,256,256]
camera2: [3,256,256]
camera3: [3,256,256]

相机 作用
camera1 正前方(看桌面整体)
camera2 侧视角(判断高度/深度)
camera3 机械臂腕部(精细抓取)

机器人状态(state)


"observation.state": [6]

语言任务(task)


"task": "Pick and place the object."

本质


视觉(图像) + 语言(任务) + 状态(机器人姿态)
            ↓
        动作(Action)

看到这个画面 + 当前姿态 + 任务要求 → 下一步该怎么动”

输出到底是什么?


[0.12, -0.05, 0.03, 0.01, 0.02, -0.01]

👉 这不是“语义动作”,而是:

👉 连续控制量(非常关键)

完整执行链路


摄像头拍照
   ↓
构造 observation
   ↓
SmolVLA 推理
   ↓
输出 action
   ↓
发送给机械臂控制器
   ↓
机械臂动一点
   ↓
再次拍照(下一帧)

环境准备

安装Conda



wget https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh
bash Miniconda3-latest-Linux-x86_64.sh -b -p $HOME/miniconda

# 初始化当前 shell
eval "$($HOME/miniconda/bin/conda shell.bash hook)"

# 写入配置文件,下次登录自动生效
$HOME/miniconda/bin/conda init

# 接受服务条款
conda tos accept --override-channels --channel https://repo.anaconda.com/pkgs/main
conda tos accept --override-channels --channel https://repo.anaconda.com/pkgs/r

配置环境

创建conda环境


conda create -n smolvla python=3.10 -y
conda activate smolvla

pip install torch torchvision
pip install lerobot
pip install opencv-python

下载模型 lerobot/smolvla_base


pip install distro six

pip install modelscope

modelscope download --model lerobot/smolvla_base  --local_dir ./model

下载SmolVLM2-500M-Video-Instruct模型


modelscope download --model HuggingFaceTB/SmolVLM2-500M-Video-Instruct --local_dir ./SmolVLM2-500M-Video-Instruct

模型运行依赖


pip install "lerobot[smolvla]" torch

# 使用清华镜像源安装
pip install "lerobot[smolvla]" torch -i https://pypi.tuna.tsinghua.edu.cn/simple

测试方案设计

机械臂看到桌面场景,执行:“把红色方块抓起来,放到蓝色区域”

smolvla_base_test

smolvla_base_test_2

smolvla_base_test_3

单次推理运行结果


(smolvla) xt@xt-2288H-V5:~/smolvla_test$ python test_smolvla_rtx4090.py --model-path ./smolvla_base --vlm-path ./SmolVLM2-500M-Video-Instruct --image-path smolvla_base_test.png --extra-image-paths smolvla_base_test_2.png smolvla_base_test_3.png --gpu-id 1 --load-on-cpu --weights-dtype fp16 --autocast fp16 --warmup 0 --iters 1 --verify-vlm --first-breakdown --report-json ./benchmark_results/rtx3090_3images_report.json
[信息] ===== SmolVLA RTX3090 推理测试开始 =====
[信息] 机器架构: x86_64 (amd64=True)
[信息] 使用GPU编号: 1 / 总GPU数: 2
[信息] GPU类型: NVIDIA GeForce RTX 3090
[信息] 模型路径: /home/xt/smolvla_test/smolvla_base
[信息] VLM路径: /home/xt/smolvla_test/SmolVLM2-500M-Video-Instruct
[信息] 测试图像: /home/xt/smolvla_test/smolvla_base_test.png
[信息] 额外图像数量: 2
[信息]   额外图像1: /home/xt/smolvla_test/smolvla_base_test_2.png
[信息]   额外图像2: /home/xt/smolvla_test/smolvla_base_test_3.png
[信息] 预热次数: 0
[信息] 统计次数: 1
[信息] 混精模式: fp16
[信息] 加载策略: load_on_cpu=True weights_dtype=fp16
[信息] 每帧强制reset: True
[信息] 输入模态: 视觉(3路相机) + 语言(任务) + 状态(6维姿态)
[信息] 本地化模型视图: /tmp/smolvla_rtx4090_j_js7ler
[信息] 正在加载 SmolVLA 模型...
[显存] 加载前: allocated=0.00GB reserved=0.00GB total=23.56GB
[时间] 模型加载耗时: 6.546173 秒
[显存] 加载后: allocated=0.86GB reserved=0.88GB total=23.56GB
[验证] 总参数量: 450046176
[验证] CUDA参数量: 450046176
[验证] VLM参数量(按名称匹配): 448411024
[验证] VLM CUDA参数量: 448411024
[时间] 预处理/后处理器构建耗时: 0.109936 秒
[信息] 观测构造完成,共3张图像
[信息] 任务文本: Pick the red block and place it in the blue region
[信息] 状态维度: (6,)
[时间] 首图预处理耗时: 0.068545 秒
[信息] 开始正式计时...
[信息] 多图模式已启用:每张图像测试一次
[结果] 推理测试完成
[结果] 单次推理耗时(最新一次): 601.595 ms
[结果] 单次推理耗时(唯一样本): 1581.034 ms
[结果] 平均推理耗时: 928.209 ms
[结果] 首次推理耗时: 1581.034 ms
[结果] 稳态推理均值(去掉首帧): 601.796 ms
[结果] P95 推理耗时: 1581.034 ms
[结果] 推理吞吐: 1.077 FPS
[结果] 动作输出(Action): [0.061508119106292725, -0.12993311882019043, 0.051314592361450195, -0.19741106033325195, 0.03245288133621216, 0.3086456060409546]
[结果] 每张图像单次耗时:
[结果]   /home/xt/smolvla_test/smolvla_base_test.png: preprocess=43.346 ms, inference=1581.034 ms, vlm_sub_calls=2407
[结果]   /home/xt/smolvla_test/smolvla_base_test_2.png: preprocess=43.767 ms, inference=601.997 ms, vlm_sub_calls=2407
[结果]   /home/xt/smolvla_test/smolvla_base_test_3.png: preprocess=41.844 ms, inference=601.595 ms, vlm_sub_calls=2407
[结果] 首次推理分段耗时:
[结果]   模型前向(select_action): 297.448 ms
[结果]   后处理(postprocess): 0.100 ms
[结果]   动作展平(flatten): 0.025 ms
[结果]   推理内合计(inside_infer): 297.574 ms
[结果]   迭代端到端总耗时: 1581.034 ms
[结果] 分阶段耗时不可用: hooks registered but no module forward events captured 
[验证] VLM前向调用次数: 0 (模块: model.vlm_with_expert)
[验证] VLM子模块前向调用总数: 2407
[验证] 单次推理峰值显存(allocated): 0.903 GB
[验证] 单次推理峰值显存(reserved): 0.938 GB
{
  "machine": "x86_64",
  "is_amd64": true,
  "gpu_name": "NVIDIA GeForce RTX 3090",
  "model_path": "/home/xt/smolvla_test/smolvla_base",
  "vlm_path": "/home/xt/smolvla_test/SmolVLM2-500M-Video-Instruct",
  "image_path": "/home/xt/smolvla_test/smolvla_base_test.png",
  "tested_images": [
    "/home/xt/smolvla_test/smolvla_base_test.png",
    "/home/xt/smolvla_test/smolvla_base_test_2.png",
    "/home/xt/smolvla_test/smolvla_base_test_3.png"
  ],
  "load_seconds": 6.546173128299415,
  "warmup": 0,
  "iterations": 3,
  "autocast": "fp16",
  "latency_ms": {
    "mean": 928.2086109742522,
    "first_call": 1581.0344032943249,
    "steady_mean": 601.7957148142159,
    "min": 601.5947433188558,
    "max": 1581.0344032943249,
    "p50": 601.996686309576,
    "p90": 1581.0344032943249,
    "p95": 1581.0344032943249,
    "p99": 1581.0344032943249,
    "sum": 2784.6258329227567
  },
  "throughput_fps": 1.0773440239370278,
  "action_output": {
    "last_action": [
      0.061508119106292725,
      -0.12993311882019043,
      0.051314592361450195,
      -0.19741106033325195,
      0.03245288133621216,
      0.3086456060409546
    ],
    "action_dim": 6
  },
  "phase_breakdown": {
    "enabled": false,
    "reason": "hooks registered but no module forward events captured"
  },
  "vlm_verification": {
    "enabled": true,
    "primary_vlm_module": "model.vlm_with_expert",
    "vlm_forward_calls": 0,
    "vlm_submodule_forward_calls": 2407,
    "vlm_top_called_submodules": [
      [
        "lm_expert.layers.0.input_layernorm",
        10
      ],
      [
        "lm_expert.layers.0.self_attn.q_proj",
        10
      ],
      [
        "lm_expert.layers.0.self_attn.k_proj",
        10
      ],
      [
        "lm_expert.layers.0.self_attn.v_proj",
        10
      ],
      [
        "lm_expert.layers.0.self_attn.o_proj",
        10
      ],
      [
        "lm_expert.layers.0.post_attention_layernorm",
        10
      ],
      [
        "lm_expert.layers.0.mlp",
        10
      ],
      [
        "lm_expert.layers.0.mlp.gate_proj",
        10
      ],
      [
        "lm_expert.layers.0.mlp.act_fn",
        10
      ],
      [
        "lm_expert.layers.0.mlp.up_proj",
        10
      ]
    ],
    "select_total_ms": 306.9929936900735,
    "memory_gb": {
      "before_allocated": 0.8698334693908691,
      "before_reserved": 0.9375,
      "after_allocated": 0.8698396682739258,
      "after_reserved": 0.9375,
      "peak_allocated": 0.9031515121459961,
      "peak_reserved": 0.9375
    }
  },
  "per_image_results": [
    {
      "image_path": "/home/xt/smolvla_test/smolvla_base_test.png",
      "preprocess_ms": 43.3460446074605,
      "latency_ms": 1581.0344032943249,
      "action": [
        0.005486726760864258,
        -0.07818222045898438,
        -0.02676069736480713,
        -0.050153493881225586,
        0.08656781911849976,
        0.4369911551475525
      ],
      "vlm_submodule_forward_calls": 2407
    },
    {
      "image_path": "/home/xt/smolvla_test/smolvla_base_test_2.png",
      "preprocess_ms": 43.76712907105684,
      "latency_ms": 601.996686309576,
      "action": [
        0.05755782127380371,
        -0.11391529440879822,
        -0.024209022521972656,
        -0.08604282140731812,
        0.10651946067810059,
        0.1984783411026001
      ],
      "vlm_submodule_forward_calls": 2407
    },
    {
      "image_path": "/home/xt/smolvla_test/smolvla_base_test_3.png",
      "preprocess_ms": 41.84360522776842,
      "latency_ms": 601.5947433188558,
      "action": [
        0.061508119106292725,
        -0.12993311882019043,
        0.051314592361450195,
        -0.19741106033325195,
        0.03245288133621216,
        0.3086456060409546
      ],
      "vlm_submodule_forward_calls": 2407
    }
  ],
  "first_call_breakdown": {
    "select_action_ms": 297.4484558105469,
    "postprocess_ms": 0.10041240602731705,
    "flatten_ms": 0.025027431547641754,
    "inside_infer_ms": 297.57389564812183,
    "iteration_total_ms": 1581.0344032943249
  }
}
[结果] 报告文件: /home/xt/smolvla_test/benchmark_results/rtx3090_3images_report.json
[信息] ===== SmolVLA RTX3090 推理测试结束 =====

工业场景

模式1:缓存 VLM(你现在这种)


第一次:
图像 → SmolVLM2 → embedding(很慢 ~800ms)

后续:
embedding → SmolVLA head → action(很快 ~3ms)

模式2:每帧重新视觉(真实机器人)


每一帧:
新图像 → SmolVLM2 → action

👉 适用于:

物体移动
抓取过程中位置变化
遮挡变化
视觉反馈控制(闭环)

👉 代价:

每帧 ≈ 800ms ❌ 太慢

真实机器人


摄像头采集        ~10ms
↓
图像预处理        ~40ms
↓
VLM编码(视觉理解) ~500–800ms
↓
策略头(动作)     ~3ms
↓
输出动作

现实中机械臂的工作流程


Camera1 ─┐
Camera2 ─┼──→ Frame Buffer(持续更新)
Camera3 ─┘

        ↓(每个控制周期采样一次)
   observation 构建(取最新帧)   SmolVLA 推理(policy inference)   action(6D/7D/Δpose)   机械臂控制器(50~200Hz)

测试代码


#!/usr/bin/env python3
"""SmolVLA benchmark script for amd64 desktop with CUDA GPU.

What this script does:
1) Verifies CPU architecture and CUDA device.
2) Loads local SmolVLA checkpoint with local VLM/tokenizer path.
3) Builds a fixed image observation.
4) Runs warmup + benchmark loops.
5) Reports latency percentiles and throughput.

Example:
  python test_smolvla_rtx4090.py \
    --model-path ./smolvla_base \
    --vlm-path ./SmolVLM2-500M-Video-Instruct \
    --image-path ./smolvla_base_test.png \
    --warmup 20 \
    --iters 200
"""

from __future__ import annotations

import argparse
import contextlib
import io
import importlib
import json
import os
import platform
import shutil
import statistics
import tempfile
import time
import warnings
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple

import torch
from PIL import Image


def load_json(path: Path) -> Dict[str, Any]:
    with path.open("r", encoding="utf-8") as f:
        return json.load(f)


def parse_image_shape(config: Dict[str, Any]) -> Tuple[int, int, int]:
    for _, feature in config.get("input_features", {}).items():
        if feature.get("type") == "VISUAL":
            shape = feature.get("shape", [3, 256, 256])
            if len(shape) != 3:
                raise ValueError(f"Unexpected VISUAL shape: {shape}")
            return int(shape[0]), int(shape[1]), int(shape[2])
    return 3, 256, 256


def parse_state_dim(config: Dict[str, Any]) -> int:
    for _, feature in config.get("input_features", {}).items():
        if feature.get("type") == "STATE":
            shape = feature.get("shape", [6])
            if len(shape) != 1:
                raise ValueError(f"Unexpected STATE shape: {shape}")
            return int(shape[0])
    return 6


def count_cameras(config: Dict[str, Any]) -> int:
    n = 0
    for _, feature in config.get("input_features", {}).items():
        if feature.get("type") == "VISUAL":
            n += 1
    return max(1, n)


def load_image_chw_uint8(image_path: Path, h: int, w: int) -> torch.Tensor:
    with Image.open(image_path) as img:
        rgb = img.convert("RGB")
        resized = rgb.resize((w, h), Image.BILINEAR)
        # Use writable bytearray buffer to avoid non-writable buffer warning.
        data = torch.ByteTensor(bytearray(resized.tobytes())).view(h, w, 3)
        return data.permute(2, 0, 1).contiguous()


def build_observation(config: Dict[str, Any], image_path: Path) -> Dict[str, Any]:
    c, h, w = parse_image_shape(config)
    if c != 3:
        raise ValueError(f"Only 3-channel image is supported now, got {c}")

    state_dim = parse_state_dim(config)
    cam_count = count_cameras(config)

    img = load_image_chw_uint8(image_path, h, w)

    obs: Dict[str, Any] = {
        "observation.state": torch.zeros(state_dim, dtype=torch.float32),
        "task": "Pick the red block and place it in the blue region",
    }
    for i in range(1, cam_count + 1):
        obs[f"observation.images.camera{i}"] = img.clone()
    return obs


def create_localized_model_view(model_path: Path, vlm_path: Path) -> Path:
    temp_dir = Path(tempfile.mkdtemp(prefix="smolvla_rtx4090_"))

    for item in model_path.iterdir():
        dst = temp_dir / item.name
        if item.is_file() and item.name.endswith(".json"):
            continue
        try:
            os.symlink(item, dst)
        except OSError:
            if item.is_dir():
                shutil.copytree(item, dst)
            else:
                shutil.copy2(item, dst)

    config = load_json(model_path / "config.json")
    config["vlm_model_name"] = str(vlm_path)
    with (temp_dir / "config.json").open("w", encoding="utf-8") as f:
        json.dump(config, f, ensure_ascii=False, indent=2)

    preproc_path = model_path / "policy_preprocessor.json"
    if preproc_path.exists():
        preproc = load_json(preproc_path)
        for step in preproc.get("steps", []):
            if step.get("registry_name") == "tokenizer_processor":
                step.setdefault("config", {})["tokenizer_name"] = str(vlm_path)
        with (temp_dir / "policy_preprocessor.json").open("w", encoding="utf-8") as f:
            json.dump(preproc, f, ensure_ascii=False, indent=2)

    for name in ["configuration.json", "policy_postprocessor.json", "README.md"]:
        src = model_path / name
        if src.exists() and not (temp_dir / name).exists():
            shutil.copy2(src, temp_dir / name)

    return temp_dir


def sanitize_batch_for_select_action(batch: Dict[str, Any], device: torch.device) -> Dict[str, Any]:
    fixed = dict(batch)
    for key, value in list(fixed.items()):
        if not isinstance(value, torch.Tensor):
            continue
        if not key.startswith(("observation.images.camera", "observation.image")):
            continue

        t = value.to(device)
        if t.ndim == 3:
            t = t.unsqueeze(0)
        if t.dtype == torch.uint8:
            t = t.float().div(255.0)
        elif not t.is_floating_point():
            t = t.float()
        fixed[key] = t

    return fixed


def flatten_action_to_list(action: Any) -> List[float]:
    if isinstance(action, dict):
        if "action" in action:
            action = action["action"]
        else:
            for value in action.values():
                if isinstance(value, torch.Tensor):
                    action = value
                    break
    if not isinstance(action, torch.Tensor):
        return []
    return [float(x) for x in action.detach().float().reshape(-1).cpu().tolist()]


def percentile(sorted_vals: List[float], q: float) -> float:
    if not sorted_vals:
        return 0.0
    if q <= 0:
        return sorted_vals[0]
    if q >= 100:
        return sorted_vals[-1]
    idx = int(round((q / 100.0) * (len(sorted_vals) - 1)))
    return sorted_vals[max(0, min(len(sorted_vals) - 1, idx))]


def _pick_timed_modules(policy: Any) -> List[Tuple[str, str, Any]]:
    """Pick coarse modules for phase timing.

    Returns tuples: (bucket, module_name, module)
    bucket in {"vlm", "smolvla"}
    """
    picks: List[Tuple[str, str, Any]] = []
    seen = set()

    for name, module in policy.named_modules():
        if not name:
            continue
        depth = name.count(".")
        if depth > 3:
            continue

        lower = name.lower()
        bucket = None
        if any(k in lower for k in ("vlm", "vision", "language", "text_model", "vision_model")):
            bucket = "vlm"
        elif any(k in lower for k in ("action", "expert", "decoder", "head", "state_proj")):
            bucket = "smolvla"

        if bucket is None:
            continue
        key = (bucket, name)
        if key in seen:
            continue
        seen.add(key)
        picks.append((bucket, name, module))

    # Keep only a small set of coarse modules to reduce nested double counting.
    picks.sort(key=lambda x: (x[0], x[1].count("."), len(x[1])))
    trimmed: List[Tuple[str, str, Any]] = []
    bucket_count = {"vlm": 0, "smolvla": 0}
    for bucket, name, module in picks:
        if bucket_count[bucket] >= 4:
            continue
        trimmed.append((bucket, name, module))
        bucket_count[bucket] += 1
    return trimmed


def _find_primary_vlm_module(policy: Any) -> Optional[Tuple[str, Any]]:
    # 优先找最靠近根节点的 VLM 子模块,便于稳定拆分主干耗时。
    candidates: List[Tuple[int, int, str, Any]] = []
    for name, module in policy.named_modules():
        if not name:
            continue
        lower = name.lower()
        if "vlm" in lower or "vision_language" in lower or "smolvlm" in lower:
            candidates.append((name.count("."), len(name), name, module))
    if not candidates:
        return None
    candidates.sort(key=lambda x: (x[0], x[1]))
    _, _, name, module = candidates[0]
    return name, module


def measure_phase_breakdown(
    policy: Any,
    batch: Dict[str, Any],
    device: torch.device,
    use_autocast: bool,
    autocast_dtype: torch.dtype,
) -> Dict[str, Any]:
    """Approximate phase timing breakdown using module hooks + CUDA events.

    Note: values are approximate and may include overlap due to nested modules.
    """
    if device.type != "cuda":
        return {
            "enabled": False,
            "reason": "phase breakdown requires CUDA events",
        }

    # Strategy A (preferred): directly time the primary VLM module forward,
    # then estimate SmolVLA head by subtraction from select_action total.
    primary_vlm = _find_primary_vlm_module(policy)
    if primary_vlm is not None:
        vlm_name, vlm_module = primary_vlm
        vlm_times: List[float] = []

        def _vlm_pre(_module, _inputs):
            start = torch.cuda.Event(enable_timing=True)
            start.record()
            _module.__dict__["_vlm_timer_start"] = start

        def _vlm_post(_module, _inputs, _output):
            start = _module.__dict__.pop("_vlm_timer_start", None)
            if start is None:
                return
            end = torch.cuda.Event(enable_timing=True)
            end.record()
            _module.__dict__["_vlm_timer_end"] = end
            _module.__dict__["_vlm_timer_pair"] = (start, end)

        h1 = vlm_module.register_forward_pre_hook(_vlm_pre)
        h2 = vlm_module.register_forward_hook(_vlm_post)

        t0 = time.perf_counter()
        with torch.inference_mode():
            if use_autocast:
                with torch.autocast(device_type="cuda", dtype=autocast_dtype):
                    _ = policy.select_action(batch)
            else:
                _ = policy.select_action(batch)
        torch.cuda.synchronize(device)
        t1 = time.perf_counter()

        pair = vlm_module.__dict__.pop("_vlm_timer_pair", None)
        h1.remove()
        h2.remove()

        if pair is not None:
            start, end = pair
            vlm_ms = float(start.elapsed_time(end))
            total_ms = (t1 - t0) * 1000.0
            head_ms = max(0.0, total_ms - vlm_ms)
            return {
                "enabled": True,
                "method": "primary_vlm_subtract",
                "note": "approximate, head is computed as total-vlm",
                "primary_vlm_module": vlm_name,
                "select_total_ms": total_ms,
                "smolvlm2_ms": vlm_ms,
                "smolvla_head_ms": head_ms,
                "timed_modules": [[vlm_name, vlm_ms]],
            }

    # Strategy B (fallback): coarse module hooks by name heuristics.
    picks = _pick_timed_modules(policy)
    if not picks:
        return {
            "enabled": False,
            "reason": "no matched modules for phase hooks",
        }

    records: List[Tuple[str, str, torch.cuda.Event, torch.cuda.Event]] = []
    handles = []

    def make_pre(bucket: str, name: str):
        def _pre(_module, _inputs):
            start = torch.cuda.Event(enable_timing=True)
            start.record()
            _module.__dict__["_phase_timer_start"] = (bucket, name, start)
        return _pre

    def _post(module, _inputs, _output):
        start_data = module.__dict__.pop("_phase_timer_start", None)
        if start_data is None:
            return
        bucket, name, start = start_data
        end = torch.cuda.Event(enable_timing=True)
        end.record()
        records.append((bucket, name, start, end))

    for bucket, name, mod in picks:
        handles.append(mod.register_forward_pre_hook(make_pre(bucket, name)))
        handles.append(mod.register_forward_hook(_post))

    t0 = time.perf_counter()
    with torch.inference_mode():
        if use_autocast:
            with torch.autocast(device_type="cuda", dtype=autocast_dtype):
                _ = policy.select_action(batch)
        else:
            _ = policy.select_action(batch)
    torch.cuda.synchronize(device)
    t1 = time.perf_counter()

    for h in handles:
        h.remove()

    if not records:
        return {
            "enabled": False,
            "reason": "hooks registered but no module forward events captured",
        }

    per_bucket = {"vlm": 0.0, "smolvla": 0.0}
    per_module: Dict[str, float] = {}
    for bucket, name, start, end in records:
        ms = float(start.elapsed_time(end))
        per_bucket[bucket] += ms
        per_module[name] = per_module.get(name, 0.0) + ms

    return {
        "enabled": True,
        "method": "coarse_hooks",
        "note": "approximate, may include nested overlap",
        "select_total_ms": (t1 - t0) * 1000.0,
        "smolvlm2_ms": per_bucket["vlm"],
        "smolvla_head_ms": per_bucket["smolvla"],
        "timed_modules": sorted(per_module.items(), key=lambda x: x[1], reverse=True)[:8],
    }


def parse_args() -> argparse.Namespace:
    parser = argparse.ArgumentParser(description="SmolVLA benchmark for amd64 + CUDA GPU")
    parser.add_argument("--model-path", type=Path, default=Path("./smolvla_base"), help="Local SmolVLA checkpoint dir")
    parser.add_argument("--vlm-path", type=Path, default=Path("./SmolVLM2-500M-Video-Instruct"), help="Local VLM/tokenizer dir")
    parser.add_argument("--image-path", type=Path, default=Path("./smolvla_base_test.png"), help="Input image path")
    parser.add_argument("--extra-image-paths", type=Path, nargs="*", default=[], help="额外测试图像路径,多图模式下每张图测试一次")
    parser.add_argument("--warmup", type=int, default=20, help="Warmup iterations")
    parser.add_argument("--iters", type=int, default=200, help="Benchmark iterations")
    parser.add_argument("--report-json", type=Path, default=Path("./benchmark_results/rtx4090_report.json"), help="Output benchmark report")
    parser.add_argument("--strict-arch", action="store_true", help="Fail if architecture is not amd64/x86_64")
    parser.add_argument("--gpu-id", type=int, default=0, help="CUDA GPU编号,如0/1")
    parser.add_argument("--autocast", type=str, default="fp16", choices=["none", "fp16", "bf16"], help="CUDA autocast mode")
    parser.add_argument("--load-on-cpu", action="store_true", help="先在CPU加载模型,再迁移到GPU,降低显存峰值")
    parser.add_argument("--weights-dtype", type=str, default="fp16", choices=["fp32", "fp16", "bf16"], help="模型迁移到GPU时的权重精度")
    parser.add_argument("--phase-breakdown", action="store_true", help="打印并保存 SmolVLM2 与 SmolVLA 近似分阶段耗时")
    parser.add_argument("--verify-vlm", action="store_true", help="验证SmolVLM2是否参与:参数/调用/峰值显存")
    parser.add_argument("--first-breakdown", action="store_true", help="拆分首次推理耗时(select_action/postprocess/flatten)")
    parser.add_argument("--reset-policy-each-frame", action="store_true", default=True, help="每帧推理前执行policy.reset(),避免缓存/动作队列复用")
    parser.add_argument("--no-reset-policy-each-frame", dest="reset_policy_each_frame", action="store_false", help="关闭每帧reset,允许复用策略内部状态")
    return parser.parse_args()


def dtype_from_arg(name: str) -> torch.dtype:
    if name == "fp16":
        return torch.float16
    if name == "bf16":
        return torch.bfloat16
    return torch.float32


def print_cuda_memory(tag: str, device: torch.device) -> None:
    if device.type != "cuda":
        return
    allocated = torch.cuda.memory_allocated(device) / (1024 ** 3)
    reserved = torch.cuda.memory_reserved(device) / (1024 ** 3)
    total = torch.cuda.get_device_properties(device).total_memory / (1024 ** 3)
    print(f"[显存] {tag}: allocated={allocated:.2f}GB reserved={reserved:.2f}GB total={total:.2f}GB")


def collect_vlm_param_stats(policy: Any) -> Dict[str, Any]:
    total_params = 0
    total_cuda_params = 0
    vlm_params = 0
    vlm_cuda_params = 0
    vlm_modules: List[str] = []

    for name, _module in policy.named_modules():
        if not name:
            continue
        lower = name.lower()
        if "vlm" in lower or "vision_language" in lower or "smolvlm" in lower:
            vlm_modules.append(name)

    for name, p in policy.named_parameters():
        n = int(p.numel())
        total_params += n
        if p.is_cuda:
            total_cuda_params += n

        lower = name.lower()
        if any(k in lower for k in ("vlm", "vision", "language", "text_model", "vision_model")):
            vlm_params += n
            if p.is_cuda:
                vlm_cuda_params += n

    return {
        "total_params": total_params,
        "total_cuda_params": total_cuda_params,
        "vlm_params": vlm_params,
        "vlm_cuda_params": vlm_cuda_params,
        "vlm_modules_sample": vlm_modules[:12],
    }


def verify_vlm_runtime(policy: Any, batch: Dict[str, Any], device: torch.device, use_autocast: bool, autocast_dtype: torch.dtype) -> Dict[str, Any]:
    if device.type != "cuda":
        return {"enabled": False, "reason": "requires CUDA"}

    primary = _find_primary_vlm_module(policy)
    if primary is None:
        return {"enabled": False, "reason": "no primary vlm module found"}

    module_name, module = primary
    calls = {"n": 0}
    sub_calls = {"n": 0}
    top_sub_calls: Dict[str, int] = {}

    # Best-effort reset to reduce queue/cache fast-path impact.
    if hasattr(policy, "reset") and callable(getattr(policy, "reset")):
        try:
            policy.reset()
        except Exception:
            pass

    def _pre(_m, _in):
        calls["n"] += 1

    def _sub_pre(name: str):
        def _hook(_m, _in):
            sub_calls["n"] += 1
            top_sub_calls[name] = top_sub_calls.get(name, 0) + 1
        return _hook

    handle = module.register_forward_pre_hook(_pre)
    sub_handles = []
    for child_name, child_mod in module.named_modules():
        if not child_name:
            continue
        sub_handles.append(child_mod.register_forward_pre_hook(_sub_pre(child_name)))

    torch.cuda.reset_peak_memory_stats(device)
    before_alloc = torch.cuda.memory_allocated(device)
    before_resv = torch.cuda.memory_reserved(device)

    t0 = time.perf_counter()
    with torch.inference_mode():
        if use_autocast:
            with torch.autocast(device_type="cuda", dtype=autocast_dtype):
                _ = policy.select_action(batch)
        else:
            _ = policy.select_action(batch)
    torch.cuda.synchronize(device)
    t1 = time.perf_counter()

    after_alloc = torch.cuda.memory_allocated(device)
    after_resv = torch.cuda.memory_reserved(device)
    peak_alloc = torch.cuda.max_memory_allocated(device)
    peak_resv = torch.cuda.max_memory_reserved(device)

    handle.remove()
    for h in sub_handles:
        h.remove()

    top_called = sorted(top_sub_calls.items(), key=lambda x: x[1], reverse=True)[:10]

    return {
        "enabled": True,
        "primary_vlm_module": module_name,
        "vlm_forward_calls": int(calls["n"]),
        "vlm_submodule_forward_calls": int(sub_calls["n"]),
        "vlm_top_called_submodules": top_called,
        "select_total_ms": (t1 - t0) * 1000.0,
        "memory_gb": {
            "before_allocated": before_alloc / (1024 ** 3),
            "before_reserved": before_resv / (1024 ** 3),
            "after_allocated": after_alloc / (1024 ** 3),
            "after_reserved": after_resv / (1024 ** 3),
            "peak_allocated": peak_alloc / (1024 ** 3),
            "peak_reserved": peak_resv / (1024 ** 3),
        },
    }


def count_vlm_submodule_calls_once(policy: Any, batch: Dict[str, Any], use_autocast: bool, autocast_dtype: torch.dtype) -> int:
    primary = _find_primary_vlm_module(policy)
    if primary is None:
        return -1

    _, module = primary
    sub_calls = {"n": 0}

    def _sub_pre(_m, _in):
        sub_calls["n"] += 1

    sub_handles = []
    for child_name, child_mod in module.named_modules():
        if not child_name:
            continue
        sub_handles.append(child_mod.register_forward_pre_hook(_sub_pre))

    try:
        with torch.inference_mode():
            if use_autocast:
                with torch.autocast(device_type="cuda", dtype=autocast_dtype):
                    _ = policy.select_action(batch)
            else:
                _ = policy.select_action(batch)
    finally:
        for h in sub_handles:
            h.remove()

    return int(sub_calls["n"])


def main() -> None:
    args = parse_args()

    model_path = args.model_path.resolve()
    vlm_path = args.vlm_path.resolve()
    image_path = args.image_path.resolve()
    extra_image_paths = [p.resolve() for p in args.extra_image_paths]

    if not model_path.exists():
        raise FileNotFoundError(f"Model path not found: {model_path}")
    if not vlm_path.exists():
        raise FileNotFoundError(f"VLM path not found: {vlm_path}")
    if not image_path.exists():
        raise FileNotFoundError(f"Image path not found: {image_path}")
    for p in extra_image_paths:
        if not p.exists():
            raise FileNotFoundError(f"Extra image path not found: {p}")

    print("[信息] ===== SmolVLA RTX3090 推理测试开始 =====")

    machine = platform.machine().lower()
    is_amd64 = machine in ("x86_64", "amd64")
    if args.strict_arch and not is_amd64:
        raise RuntimeError(f"Expected amd64/x86_64, got: {machine}")

    if not torch.cuda.is_available():
        raise RuntimeError("CUDA is not available. This benchmark requires a CUDA GPU.")

    gpu_count = torch.cuda.device_count()
    if args.gpu_id < 0 or args.gpu_id >= gpu_count:
        raise RuntimeError(f"无效的 --gpu-id={args.gpu_id},当前可用GPU数量为 {gpu_count}")

    device = torch.device(f"cuda:{args.gpu_id}")
    gpu_name = torch.cuda.get_device_name(device)

    torch.backends.cudnn.benchmark = True
    torch.backends.cuda.matmul.allow_tf32 = True
    torch.backends.cudnn.allow_tf32 = True

    print(f"[信息] 机器架构: {machine} (amd64={is_amd64})")
    print(f"[信息] 使用GPU编号: {args.gpu_id} / 总GPU数: {gpu_count}")
    print(f"[信息] GPU类型: {gpu_name}")
    print(f"[信息] 模型路径: {model_path}")
    print(f"[信息] VLM路径: {vlm_path}")
    print(f"[信息] 测试图像: {image_path}")
    if extra_image_paths:
        print(f"[信息] 额外图像数量: {len(extra_image_paths)}")
        for idx, p in enumerate(extra_image_paths, start=1):
            print(f"[信息]   额外图像{idx}: {p}")
    print(f"[信息] 预热次数: {args.warmup}")
    print(f"[信息] 统计次数: {args.iters}")
    print(f"[信息] 混精模式: {args.autocast}")
    print(f"[信息] 加载策略: load_on_cpu={args.load_on_cpu} weights_dtype={args.weights_dtype}")
    print(f"[信息] 每帧强制reset: {args.reset_policy_each_frame}")
    print("[信息] 输入模态: 视觉(3路相机) + 语言(任务) + 状态(6维姿态)")

    # Force local-only loading.
    os.environ["HF_HUB_OFFLINE"] = "1"
    os.environ["TRANSFORMERS_OFFLINE"] = "1"

    config = load_json(model_path / "config.json")

    localized_model_path = create_localized_model_view(model_path, vlm_path)
    print(f"[信息] 本地化模型视图: {localized_model_path}")

    factory_mod = importlib.import_module("lerobot.policies.factory")
    smolvla_mod = importlib.import_module("lerobot.policies.smolvla.modeling_smolvla")
    make_pre_post_processors = factory_mod.make_pre_post_processors
    SmolVLAPolicy = smolvla_mod.SmolVLAPolicy

    t0_load = time.perf_counter()
    print("[信息] 正在加载 SmolVLA 模型...")
    if device.type == "cuda":
        torch.cuda.empty_cache()
        print_cuda_memory("加载前", device)

    target_dtype = dtype_from_arg(args.weights_dtype)
    try:
        with warnings.catch_warnings():
            warnings.filterwarnings("ignore", message=r".*torch_dtype.*deprecated.*")
            # 静默第三方库在加载阶段的 stdout/stderr,避免干扰测试日志。
            with contextlib.redirect_stdout(io.StringIO()), contextlib.redirect_stderr(io.StringIO()):
                if args.load_on_cpu:
                    # 先改为CPU加载,避免在GPU上直接反序列化导致峰值OOM。
                    cfg_path = localized_model_path / "config.json"
                    cfg = load_json(cfg_path)
                    cfg["device"] = "cpu"
                    with cfg_path.open("w", encoding="utf-8") as f:
                        json.dump(cfg, f, ensure_ascii=False, indent=2)

                    policy = SmolVLAPolicy.from_pretrained(str(localized_model_path)).eval()
                    policy = policy.to(device=device, dtype=target_dtype).eval()
                else:
                    policy = SmolVLAPolicy.from_pretrained(str(localized_model_path)).to(device=device, dtype=target_dtype).eval()
    except RuntimeError as exc:
        msg = str(exc).lower()
        if "out of memory" in msg or "cuda error" in msg:
            print("[错误] 模型加载发生CUDA显存不足(OOM)")
            print("[建议] 可尝试以下参数重跑:")
            print("  --load-on-cpu --weights-dtype fp16")
            print("  --autocast fp16 --warmup 5 --iters 1")
            print("[建议] 关闭占用GPU显存的其他进程后再试")
            if device.type == "cuda":
                print_cuda_memory("OOM时", device)
        raise

    t1_load = time.perf_counter()
    print(f"[时间] 模型加载耗时: {(t1_load - t0_load):.6f} 秒")
    if device.type == "cuda":
        print_cuda_memory("加载后", device)

    if args.verify_vlm:
        pstats = collect_vlm_param_stats(policy)
        print(f"[验证] 总参数量: {pstats['total_params']}")
        print(f"[验证] CUDA参数量: {pstats['total_cuda_params']}")
        print(f"[验证] VLM参数量(按名称匹配): {pstats['vlm_params']}")
        print(f"[验证] VLM CUDA参数量: {pstats['vlm_cuda_params']}")

    t0_pp = time.perf_counter()
    preprocess, postprocess = make_pre_post_processors(
        policy.config,
        str(localized_model_path),
        preprocessor_overrides={"device_processor": {"device": str(device)}},
    )
    t1_pp = time.perf_counter()
    print(f"[时间] 预处理/后处理器构建耗时: {(t1_pp - t0_pp):.6f} 秒")

    image_paths: List[Path] = []
    seen = set()
    for p in [image_path] + extra_image_paths:
        sp = str(p)
        if sp in seen:
            continue
        seen.add(sp)
        image_paths.append(p)

    # 先基于首图构建一个基准batch,用于后续可选验证功能。
    t0_pre = time.perf_counter()
    first_observation = build_observation(config, image_paths[0])
    first_batch = preprocess(dict(first_observation))
    first_batch = sanitize_batch_for_select_action(first_batch, device)
    t1_pre = time.perf_counter()
    print(f"[信息] 观测构造完成,共{len(image_paths)}张图像")
    print(f"[信息] 任务文本: Pick the red block and place it in the blue region")
    print(f"[信息] 状态维度: {tuple(torch.zeros(parse_state_dim(config), dtype=torch.float32).shape)}")
    print(f"[时间] 首图预处理耗时: {(t1_pre - t0_pre):.6f} 秒")

    use_autocast = args.autocast != "none"
    autocast_dtype = torch.float16 if args.autocast == "fp16" else torch.bfloat16

    def infer_once(batch_input: Dict[str, Any]) -> List[float]:
        if args.reset_policy_each_frame and hasattr(policy, "reset") and callable(getattr(policy, "reset")):
            policy.reset()
        with torch.inference_mode():
            if use_autocast:
                with torch.autocast(device_type="cuda", dtype=autocast_dtype):
                    out = policy.select_action(batch_input)
            else:
                out = policy.select_action(batch_input)
            try:
                out = postprocess(out)
            except Exception:
                pass
            return flatten_action_to_list(out)

    def infer_once_with_breakdown(batch_input: Dict[str, Any]) -> Tuple[List[float], Dict[str, float]]:
        breakdown: Dict[str, float] = {}

        if args.reset_policy_each_frame and hasattr(policy, "reset") and callable(getattr(policy, "reset")):
            policy.reset()

        with torch.inference_mode():
            if device.type == "cuda":
                sel_start = torch.cuda.Event(enable_timing=True)
                sel_end = torch.cuda.Event(enable_timing=True)
                if use_autocast:
                    with torch.autocast(device_type="cuda", dtype=autocast_dtype):
                        sel_start.record()
                        out = policy.select_action(batch_input)
                        sel_end.record()
                else:
                    sel_start.record()
                    out = policy.select_action(batch_input)
                    sel_end.record()
                torch.cuda.synchronize(device)
                breakdown["select_action_ms"] = float(sel_start.elapsed_time(sel_end))
            else:
                t_sel0 = time.perf_counter()
                if use_autocast:
                    with torch.autocast(device_type="cpu", dtype=autocast_dtype):
                        out = policy.select_action(batch_input)
                else:
                    out = policy.select_action(batch_input)
                t_sel1 = time.perf_counter()
                breakdown["select_action_ms"] = (t_sel1 - t_sel0) * 1000.0

            t_post0 = time.perf_counter()
            try:
                out = postprocess(out)
            except Exception:
                pass
            if device.type == "cuda":
                torch.cuda.synchronize(device)
            t_post1 = time.perf_counter()
            breakdown["postprocess_ms"] = (t_post1 - t_post0) * 1000.0

            t_flat0 = time.perf_counter()
            action_vals = flatten_action_to_list(out)
            t_flat1 = time.perf_counter()
            breakdown["flatten_ms"] = (t_flat1 - t_flat0) * 1000.0

        breakdown["inside_infer_ms"] = (
            breakdown["select_action_ms"]
            + breakdown["postprocess_ms"]
            + breakdown["flatten_ms"]
        )
        return action_vals, breakdown

    def prepare_batch_for_image(img_path: Path) -> Tuple[Dict[str, Any], float]:
        t_pre0 = time.perf_counter()
        observation = build_observation(config, img_path)
        batch_item = preprocess(dict(observation))
        batch_item = sanitize_batch_for_select_action(batch_item, device)
        t_pre1 = time.perf_counter()
        return batch_item, (t_pre1 - t_pre0) * 1000.0

    # Warmup
    if args.warmup > 0:
        print("[信息] 开始预热推理...")
    for _ in range(max(0, args.warmup)):
        _ = infer_once(first_batch)
    torch.cuda.synchronize(device)
    if args.warmup > 0:
        print("[信息] 预热完成")

    # Benchmark
    print("[信息] 开始正式计时...")
    latencies_ms: List[float] = []
    last_action: List[float] = []
    first_call_breakdown: Optional[Dict[str, float]] = None
    per_image_results: List[Dict[str, Any]] = []
    if len(image_paths) > 1:
        print("[信息] 多图模式已启用:每张图像测试一次")
        for i, img_p in enumerate(image_paths):
            batch_item, pre_ms = prepare_batch_for_image(img_p)
            t0 = time.perf_counter()
            if args.verify_vlm:
                if args.reset_policy_each_frame and hasattr(policy, "reset") and callable(getattr(policy, "reset")):
                    policy.reset()
                sub_calls = count_vlm_submodule_calls_once(
                    policy=policy,
                    batch=batch_item,
                    use_autocast=use_autocast,
                    autocast_dtype=autocast_dtype,
                )
                if device.type == "cuda":
                    torch.cuda.synchronize(device)
                action_vals = infer_once(batch_item)
                if i == 0 and args.first_breakdown:
                    # 多图+verify路径下也给首图分段,单独再测一次。
                    _a, first_call_breakdown = infer_once_with_breakdown(batch_item)
            else:
                if i == 0 and args.first_breakdown:
                    action_vals, first_call_breakdown = infer_once_with_breakdown(batch_item)
                else:
                    action_vals = infer_once(batch_item)
            torch.cuda.synchronize(device)
            t1 = time.perf_counter()
            ms = (t1 - t0) * 1000.0
            latencies_ms.append(ms)
            last_action = action_vals
            row = {
                "image_path": str(img_p),
                "preprocess_ms": pre_ms,
                "latency_ms": ms,
                "action": action_vals,
            }
            if args.verify_vlm:
                row["vlm_submodule_forward_calls"] = sub_calls
            per_image_results.append(row)
    else:
        for i in range(max(1, args.iters)):
            if i == 0:
                batch_item = first_batch
            else:
                batch_item, _ = prepare_batch_for_image(image_paths[0])
            t0 = time.perf_counter()
            if i == 0 and args.first_breakdown:
                action_vals, first_call_breakdown = infer_once_with_breakdown(batch_item)
            else:
                action_vals = infer_once(batch_item)
            torch.cuda.synchronize(device)
            t1 = time.perf_counter()
            latencies_ms.append((t1 - t0) * 1000.0)
            last_action = action_vals

    lat_sorted = sorted(latencies_ms)
    total_ms = sum(latencies_ms)
    mean_ms = statistics.fmean(latencies_ms)
    fps = (1000.0 / mean_ms) if mean_ms > 0 else 0.0
    first_call_ms = latencies_ms[0] if latencies_ms else 0.0
    steady_latencies = latencies_ms[1:] if len(latencies_ms) > 1 else []
    steady_mean_ms = statistics.fmean(steady_latencies) if steady_latencies else first_call_ms

    report = {
        "machine": machine,
        "is_amd64": is_amd64,
        "gpu_name": gpu_name,
        "model_path": str(model_path),
        "vlm_path": str(vlm_path),
        "image_path": str(image_path),
        "tested_images": [str(p) for p in image_paths],
        "load_seconds": t1_load - t0_load,
        "warmup": int(args.warmup),
        "iterations": int(len(image_paths) if len(image_paths) > 1 else args.iters),
        "autocast": args.autocast,
        "latency_ms": {
            "mean": mean_ms,
            "first_call": first_call_ms,
            "steady_mean": steady_mean_ms,
            "min": lat_sorted[0] if lat_sorted else 0.0,
            "max": lat_sorted[-1] if lat_sorted else 0.0,
            "p50": percentile(lat_sorted, 50),
            "p90": percentile(lat_sorted, 90),
            "p95": percentile(lat_sorted, 95),
            "p99": percentile(lat_sorted, 99),
            "sum": total_ms,
        },
        "throughput_fps": fps,
        "action_output": {
            "last_action": last_action,
            "action_dim": len(last_action),
        },
    }

    if args.phase_breakdown or int(args.iters) == 1:
        phase = measure_phase_breakdown(
            policy=policy,
            batch=first_batch,
            device=device,
            use_autocast=use_autocast,
            autocast_dtype=autocast_dtype,
        )
        report["phase_breakdown"] = phase

    if args.verify_vlm:
        vlm_verify = verify_vlm_runtime(
            policy=policy,
            batch=first_batch,
            device=device,
            use_autocast=use_autocast,
            autocast_dtype=autocast_dtype,
        )
        report["vlm_verification"] = vlm_verify

    if per_image_results:
        report["per_image_results"] = per_image_results

    if args.first_breakdown and first_call_breakdown is not None and latencies_ms:
        first_call_breakdown["iteration_total_ms"] = latencies_ms[0]
        report["first_call_breakdown"] = first_call_breakdown

    args.report_json.resolve().parent.mkdir(parents=True, exist_ok=True)
    with args.report_json.resolve().open("w", encoding="utf-8") as f:
        json.dump(report, f, ensure_ascii=False, indent=2)

    print("[结果] 推理测试完成")
    if latencies_ms:
        print(f"[结果] 单次推理耗时(最新一次): {latencies_ms[-1]:.3f} ms")
    if int(args.iters) == 1 and latencies_ms:
        print(f"[结果] 单次推理耗时(唯一样本): {latencies_ms[0]:.3f} ms")
    print(f"[结果] 平均推理耗时: {mean_ms:.3f} ms")
    print(f"[结果] 首次推理耗时: {first_call_ms:.3f} ms")
    print(f"[结果] 稳态推理均值(去掉首帧): {steady_mean_ms:.3f} ms")
    print(f"[结果] P95 推理耗时: {percentile(lat_sorted, 95):.3f} ms")
    print(f"[结果] 推理吞吐: {fps:.3f} FPS")
    print(f"[结果] 动作输出(Action): {last_action}")
    if per_image_results:
        print("[结果] 每张图像单次耗时:")
        for item in per_image_results:
            msg = (
                f"[结果]   {item['image_path']}: preprocess={float(item['preprocess_ms']):.3f} ms, "
                f"inference={float(item['latency_ms']):.3f} ms"
            )
            if "vlm_submodule_forward_calls" in item:
                msg += f", vlm_sub_calls={int(item['vlm_submodule_forward_calls'])}"
            print(msg)
    first_breakdown_data = report.get("first_call_breakdown")
    if isinstance(first_breakdown_data, dict):
        print("[结果] 首次推理分段耗时:")
        print(f"[结果]   模型前向(select_action): {float(first_breakdown_data.get('select_action_ms', 0.0)):.3f} ms")
        print(f"[结果]   后处理(postprocess): {float(first_breakdown_data.get('postprocess_ms', 0.0)):.3f} ms")
        print(f"[结果]   动作展平(flatten): {float(first_breakdown_data.get('flatten_ms', 0.0)):.3f} ms")
        print(f"[结果]   推理内合计(inside_infer): {float(first_breakdown_data.get('inside_infer_ms', 0.0)):.3f} ms")
        print(f"[结果]   迭代端到端总耗时: {float(first_breakdown_data.get('iteration_total_ms', 0.0)):.3f} ms")
    phase_data = report.get("phase_breakdown")
    if isinstance(phase_data, dict) and phase_data.get("enabled"):
        print(f"[结果] 分阶段耗时(近似) SmolVLM2: {phase_data.get('smolvlm2_ms', 0.0):.3f} ms")
        print(f"[结果] 分阶段耗时(近似) SmolVLA头: {phase_data.get('smolvla_head_ms', 0.0):.3f} ms")
        print(f"[结果] select_action总耗时(近似): {phase_data.get('select_total_ms', 0.0):.3f} ms")
        if phase_data.get("method"):
            print(f"[结果] 分阶段计时方法: {phase_data.get('method')}")
    elif isinstance(phase_data, dict):
        print(f"[结果] 分阶段耗时不可用: {phase_data.get('reason', 'unknown')} ")

    vlm_verify_data = report.get("vlm_verification")
    if isinstance(vlm_verify_data, dict):
        if vlm_verify_data.get("enabled"):
            print(f"[验证] VLM前向调用次数: {vlm_verify_data.get('vlm_forward_calls')} (模块: {vlm_verify_data.get('primary_vlm_module')})")
            print(f"[验证] VLM子模块前向调用总数: {vlm_verify_data.get('vlm_submodule_forward_calls', 0)}")
            mem = vlm_verify_data.get("memory_gb", {})
            print(f"[验证] 单次推理峰值显存(allocated): {float(mem.get('peak_allocated', 0.0)):.3f} GB")
            print(f"[验证] 单次推理峰值显存(reserved): {float(mem.get('peak_reserved', 0.0)):.3f} GB")
        else:
            print(f"[验证] VLM运行验证不可用: {vlm_verify_data.get('reason', 'unknown')} ")
    print(json.dumps(report, ensure_ascii=False, indent=2))
    print(f"[结果] 报告文件: {args.report_json.resolve()}")
    print("[信息] ===== SmolVLA RTX3090 推理测试结束 =====")


if __name__ == "__main__":
    main()

SmolVLA 和 Tesla自动驾驶方案的对比

  SmolVLA Tesla 自动驾驶
本质 小模型 + VLM + 控制头 超大规模端到端系统
目标 机械臂任务(抓取/操作) 自动驾驶(感知+决策+控制)
实时性 低(~600ms) 极高(<50ms)

模型结构对比

AI范式 Tesla 自动驾驶 SmolVLA
计算机视觉 ✅(多摄像头+高精度感知) ✅(作为VLM输入的一部分)
Transformer ✅(核心:时序+空间) ✅(核心:VLM)
时序模型 ✅(非常核心,视频级) ❌(基本没有,单帧)
3D建模 ✅(BEV / Occupancy) ❌(没有显式3D世界)
概率预测 ✅(轨迹预测、行为预测) ⚠️(隐式,不显式建模)
强化学习 ⚠️(训练中辅助) ❌(基本不用)
控制理论 ✅(MPC / 控制器) ❌(纯神经网络输出)
多模态(语言) ✅(核心能力)

SmolVLA


图像 + 文本任务 + 状态
        ↓
   SmolVLM2(视觉语言模型)
        ↓
   融合特征(embedding)
        ↓
   Action Head(小网络)
        ↓
   机械臂动作(6维)

👉 特点:

  • 强依赖 VLM(占99%计算)
  • “看懂 + 执行”一体
  • 类似:GPT控制机器人

Tesla


多摄像头(8路视频)
        ↓
时序融合(Video Transformer / Occupancy Network)
        ↓
3D世界建模(BEV空间)
        ↓
轨迹预测(Planning)
        ↓
控制输出(转向/油门/刹车)

👉 特点:

  • 不用语言
  • 完全视觉端到端
  • 强时序(连续帧)

输入数据对比

  SmolVLA Tesla
摄像头 1~3张图 8路视频流
频率 单帧 36 FPS
额外输入 文本任务
状态输入 6维姿态 车速/IMU等

👉 关键差异:

👉 Tesla 是“连续世界” 👉 SmolVLA 是“静态任务”

推理机制(重点)

🟢 SmolVLA


每帧:
    跑一次 VLM(600ms)
    输出动作

👉 问题:

太慢
没有时间连续性
每帧“重新理解世界”

🔴 Tesla


每帧:
    利用历史帧(时序模型)
    增量更新世界模型

👉 核心能力:

🔥 1. 时序记忆(非常关键)
不重新理解
用历史上下文
🔥 2. 多帧融合
不是 image → action
是 video → world model