OpenVLA 的核心洞察是将连续的机器人动作离散化为 LLM 的 token,让 VLM 直接像「说下一个词」一样输出动作。这样就能完整复用标准 LLM 的 next-token prediction 训练范式——无需额外的扩散头或回归头,任何 LLM 训练/推理基础设施都能直接用。
不重新发明轮子,而是复用互联网规模预训练的 VLM(视觉-语言模型)作为「大脑」。VLM 已经从海量图文数据中学到了丰富的视觉理解和语言推理能力,这些知识可直接迁移到机器人操控任务。在 VLM 之上嫁接动作输出能力,是 OpenVLA 的基本策略。
此前 VLA 模型(如 RT-2/RT-2-X)由 Google 闭源,公众无法获取权重、无法微调、无法二次开发。OpenVLA 的目标是打造完全开源、可直接微调、在消费级 GPU 上可运行的通用机器人操控策略模型。
构建在 Prismatic VLM 框架之上,由三大组件融合而成:
| 组件 | 配置 | 说明 |
|---|---|---|
| 视觉编码器 | DINOv2 ViT-L + SigLIP-SO400M | 224×224px,patch=14,各出 256 token,序列拼接成 512 token |
| LLM 基座 | Llama-2-7B(base 版) | 非 chat 版,纯 base model |
| 投影层 | 3 层 MLP + GELU | Linear(1024→4096) → GELU → Linear(4096→4096) → GELU → Linear(4096→4096) |
输入:
# 语言指令 prompt 格式
prompt = "In: What action should the robot take to {INSTRUCTION}?\nOut:"
# 例: "In: What action should the robot take to pick up the red block?\nOut:"
输出:7 维连续动作向量(末端执行器增量)
action = [dx, dy, dz, droll, dpitch, dyaw, gripper] # 7-DoF
这是 OpenVLA 的核心设计——不输出连续值,而是将动作离散化为 token:
bins = np.linspace(-1, 1, 256))digitize 映射到 bin indexgenerate(max_new_tokens=7)# 核心逻辑(prismatic/vla/action_tokenizer.py)
self.action_token_begin_idx = tokenizer.vocab_size - (n_bins + 1)
# 编码:连续动作 → token ID
discretized_action = np.digitize(action, self.bins)
token_id = tokenizer.vocab_size - discretized_action
# 解码:token ID → 连续动作
discretized_actions = tokenizer.vocab_size - action_token_ids
continuous_action = bin_centers[discretized_actions - 1]
推理时反归一化:模型输出的归一化动作通过数据集的 q01/q99 分位数还原为真实物理量:
actions = 0.5 * (normalized_actions + 1) * (action_high - action_low) + action_low
# action_high = q99, action_low = q01
| 指标 | 数值 |
|---|---|
| 轨迹总数 | 970K 条真实机器人演示 |
| 数据集组件 | ~26 个 Open X-Embodiment 子数据集 |
| 主要数据集 | BridgeData V2、Google RT-1 (fractal)、DROID、KUKA、BC-Z 等 |
| 数据格式 | RLDS(Reinforcement Learning Datasets) |
| 混合策略 | 加权采样(前 70% 含 DROID 大权重,后 30% 不含 DROID) |
openvla/
├── prismatic/ # 核心包
│ ├── models/
│ │ ├── vlms/prismatic.py # PrismaticVLM 基类(VLM 主体)
│ │ ├── vlas/openvla.py # ★ OpenVLA 类(predict_action)
│ │ └── backbones/
│ │ ├── vision/ # 视觉编码器(DINOv2, SigLIP)
│ │ └── llm/ # LLM 后端(Llama-2)
│ ├── vla/
│ │ ├── action_tokenizer.py # ★ 动作离散化/反离散化核心
│ │ └── datasets/rlds/oxe/ # ★ OXE 数据加载管线
│ │ ├── configs.py # 各数据集观测/动作空间配置
│ │ └── mixtures.py # 数据混合权重定义
│ ├── extern/hf/ # ★ HuggingFace 兼容层
│ │ ├── configuration_prismatic.py
│ │ ├── modeling_prismatic.py # HF格式模型定义
│ │ └── processing_prismatic.py
│ └── conf/vla.py # VLA 训练超参配置
├── vla-scripts/ # ★ 训练/微调/部署脚本
│ ├── train.py # 全量训练(FSDP)
│ ├── finetune.py # LoRA 微调
│ └── deploy.py # REST API 部署
└── experiments/robot/ # 机器人环境评测代码
| 模块 | 作用 |
|---|---|
OpenVLA | PrismaticVLM 子类,封装 predict_action():构建 prompt → 预处理图像 → generate() → 解码 token → 反归一化 |
ActionTokenizer | 连续动作 ↔ 256 离散 token 双向转换,映射到词表末尾 |
OpenVLAForActionPrediction | HF 格式完整模型,支持 AutoModelForVision2Seq 自动加载 |
RLDSDataset | 高性能 RLDS 数据管道,支持 shuffle buffer、图像增强、轨迹变换 |
PrismaticProjector | MLP 投影层,将视觉特征映射到 LLM 隐藏维度 |
torchrun --standalone --nnodes 1 --nproc-per-node 1 vla-scripts/finetune.py \
--vla_path "openvla/openvla-7b" \
--data_root_dir <数据集根目录> \
--dataset_name bridge_orig \
--run_root_dir <日志/checkpoint目录> \
--adapter_tmp_dir <LoRA权重临时目录> \
--lora_rank 32 \
--batch_size 16 \
--grad_accumulation_steps 1 \
--learning_rate 5e-4 \
--image_aug True
关键参数:
target_modules="all-linear":LoRA 应用于所有线性层lora_alpha = min(lora_rank, 16)configs.py 和 transforms.py 中添加配置,数据转 RLDS 格式torchrun --standalone --nnodes 1 --nproc-per-node 8 vla-scripts/train.py \
--pretrained_checkpoint <checkpoint路径> \
--vla.type prism-dinosiglip-224px+mx-bridge \
--data_root_dir <数据集目录> \
--run_root_dir <日志目录> \
--is_resume False
全量微调使用 PyTorch FSDP(Fully Sharded Data Parallel)分片训练,完成后需用 convert_openvla_weights_to_hf.py 转换为 HF 格式。
方式一:直接加载推理(最简)
from transformers import AutoModelForVision2Seq, AutoProcessor
import torch
processor = AutoProcessor.from_pretrained("openvla/openvla-7b", trust_remote_code=True)
vla = AutoModelForVision2Seq.from_pretrained(
"openvla/openvla-7b",
torch_dtype=torch.bfloat16,
low_cpu_mem_usage=True,
trust_remote_code=True
).to("cuda:0")
action = vla.predict_action(**inputs, unnorm_key="bridge_orig", do_sample=False)
方式二:REST API 部署(适合机器人本体算力不足时远程推理)
# 服务端
python vla-scripts/deploy.py --openvla_path openvla/openvla-7b --port 8000
# 客户端(仅需 numpy + requests)
import requests, json_numpy, numpy as np
json_numpy.patch()
action = requests.post("http://SERVER:8000/act", json={
"image": np.zeros((256, 256, 3), dtype=np.uint8),
"instruction": "pick up the red block",
"unnorm_key": "bridge_orig"
}).json()
# 最小依赖(仅推理)
pip install timm==0.9.10 tokenizers==0.19.1 torch>=2.2.0 torchvision>=0.16.0 transformers==4.40.1
transformers==4.40.1、timm==0.9.10、tokenizers==0.19.1,更高版本有 breaking change。
from transformers import AutoModelForVision2Seq, AutoProcessor
from PIL import Image
import torch, numpy as np
# 1. 加载 Processor 和模型
processor = AutoProcessor.from_pretrained("openvla/openvla-7b", trust_remote_code=True)
vla = AutoModelForVision2Seq.from_pretrained(
"openvla/openvla-7b",
torch_dtype=torch.bfloat16,
low_cpu_mem_usage=True,
trust_remote_code=True
).to("cuda:0")
# 2. 准备输入
image = Image.fromarray(np.zeros((256, 256, 3), dtype=np.uint8)) # 替换为相机图像
prompt = "In: What action should the robot take to pick up the red block?\nOut:"
# 3. 推理 → 7维动作向量
inputs = processor(prompt, image).to("cuda:0", dtype=torch.bfloat16)
action = vla.predict_action(**inputs, unnorm_key="bridge_orig", do_sample=False)
print(f"预测动作 (7-DoF): {action}") # [dx,dy,dz,droll,dpitch,dyaw,gripper]
# 4. 发送给机器人执行
# robot.act(action)
| 场景 | GPU 需求 | 备注 |
|---|---|---|
| 推理(bf16) | ≥15GB 显存 | 单卡 RTX 3090/4090 (24GB) 或 A10 |
| 推理(4-bit 量化) | ~8GB 显存 | 论文证明量化推理不损失成功率 |
| LoRA 微调 | ≥27GB(batch=12 需 48GB) | A100 40/80GB 最佳 |
| 全量微调 | 8×A100 (80GB) | FSDP 分片 |
| 从头预训练 | 64×A100 (80GB) | global_batch_size=2048 |
| 指标 | 数值 |
|---|---|
| 模型参数量 | 7.5B |
| 训练数据 | 970K 条真机演示轨迹 |
| 基座 LLM | Llama-2-7B(base) |
| 视觉编码器 | DINOv2 ViT-L + SigLIP-SO400M,224×224px |
| 图像 token 数 | 512(双编码器各 256,序列拼接) |
| 动作表达 | 256-bin 离散化 token(每维度 1 个 token) |
| 动作维度 | 7-DoF(xyz + 旋转 + 夹爪) |
| 词表大小 | 32064(Llama-2 的 32000 + 64 padding) |
| vs RT-2-X (55B) | 绝对成功率高 16.5%,参数少 7× |
| 许可 | MIT(模型继承 Llama Community License) |