← 返回 VLA 技术站

Octo

Open-Source Generalist Robot Policy · 27M~93M 参数 · UC Berkeley / Stanford

核心思路

设计哲学:让通用策略像 BERT 之于 NLP

Octo 的目标做第一个真正开源、轻量级、资源友好的通用机器人策略。让通用机器人策略像 BERT 之于 NLP 一样——成为可被广泛使用的、低门槛的策略初始化(policy initialization)。开发者拿到预训练权重,用少量领域数据微调,就能获得强泛化能力。

要解决的问题

已有的通用策略(如 RT-X)通常闭源、参数量大、对硬件要求高,难以被普通研究者复现和改进。Octo 把「通用策略」拉到了单卡可玩的级别。

设计目标Octo 的做法
开源可复现全部代码、权重、训练配置、800K 数据混合配方开放(MIT 协议)
跨平台泛化在 Open X-Embodiment 800K 轨迹上训练,覆盖 9+ 种机器人平台
轻量/资源友好27M / 93M 参数,推理在单张 4090 上 13-17 it/s
可灵活微调模块化 Transformer 架构,支持新增传感器输入、新动作空间

技术方案特点

模型架构:纯 Transformer + 模块化注意力

Octo 是一个基于 Transformer 的扩散策略,核心是 OctoTransformer。它把整个输入序列组织成三个区域:

[ <task 语言/图像 tokens>, ← prefix(前缀,全局可见) <t=0 image_primary tokens>, <t=0 image_wrist tokens>, <t=0 readout_action tokens>, <t=1 image_primary tokens>, <t=1 image_wrist tokens>, <t=1 readout_action tokens>, ... ]

关键设计 — Blockwise Causal Attention(分块因果注意力)

Readout 机制是 Octo 的灵魂:readout 是若干个「空 token」(仅含位置编码),通过 attention 从序列中「抽取」信息,再交给下游 head。换 head 时不影响 backbone,因此微调新动作空间代价极低

Transformer 规模

名称layersmlp_dimheads参数量
Octo-Small (vit_s)121536627M
Octo-Base (vit_b)1230721293M

输入输出格式

输入(Observations) — 字典形式,历史窗口 history_window=2(当前帧 + 上一帧):

{
    "image_primary": (batch, 2, 256, 256, 3),   # 第三人称/工作区相机
    "image_wrist":   (batch, 2, 128, 128, 3),   # 腕部相机(可选)
    "timestep_pad_mask": (batch, 2),            # True=该帧有效
}

Task(指令) — 支持三种条件方式:

输出(Actions)(batch, action_horizon=4, action_dim=7) — 一次预测未来 4 步动作(action chunking)。

动作表达方式:Diffusion Action Head

Octo 默认使用 DiffusionActionHead

为什么用 diffusion head?多模态动作分布(同一观测下多种合理动作)下,回归 head 会取平均导致动作模糊;diffusion 能建模多峰分布,更适合复杂操作。

仓库还提供其他 head 供微调时替换:

训练数据:Open X-Embodiment 800K

指标数值
轨迹总数800K 条
数据集数量20 个 OXE 子数据集混合
预处理大小~1.2 TB
主要数据集Fractal (RT-1) 17%、Kuka 17%、Bridge V2 17%、BC-Z 9.1% 等
语言增强用 GPT-3.5 对语言指令做改写(paraphrasing)

代码框架

仓库结构

octo/
├── octo/                          # 核心库
│   ├── model/
│   │   ├── octo_model.py          # ★ OctoModel:加载/保存/推理统一入口
│   │   ├── octo_module.py         # ★ OctoTransformer(架构核心)
│   │   └── components/
│   │       ├── action_heads.py    # Diffusion/Continuous/Discrete Head
│   │       ├── tokenizers.py      # Image/Language/Lowdim Tokenizer
│   │       ├── transformer.py     # BlockTransformer + 注意力规则
│   │       └── vit_encoders.py    # SmallStem16 等 ViT 编码器
│   ├── data/
│   │   ├── dataset.py             # ★ 单/多数据集构建 + 交错采样
│   │   ├── traj_transforms.py     # 轨迹级变换(窗口、action chunk)
│   │   └── oxe/                   # Open X-Embodiment 数据配置
│   └── utils/
│       └── gym_wrappers.py        # Gym 环境包装(HistoryWrapper)
├── scripts/
│   ├── train.py                   # 预训练主脚本
│   ├── finetune.py                # ★ 微调主脚本
│   └── configs/
│       ├── config.py              # 基础超参
│       └── finetune_config.py     # 微调配置(冻结策略)
├── examples/                      # ★ 6 个示例
│   ├── 01_inference_pretrained.ipynb
│   ├── 02_finetune_new_observation_action.py
│   ├── 03_eval_finetuned.py
│   └── 05_dataloading.ipynb
└── tests/                         # 调试用数据集 + 单测

核心模块

模块关键方法
OctoModelload_pretrainedsample_actionscreate_taskssave_pretrained
OctoTransformer__call__(observations, tasks, timestep_pad_mask)
DiffusionActionHeadlosspredict_action
TrainState封装 params + optimizer,freeze_weights 按 key 前缀冻结

微调流程

三种冻结模式(命令行 --config=finetune_config.py:<mode>,<task>):

模式冻结范围适用场景
full不冻结,全模型训练数据较多
head_only冻结 backbone,只训 head数据少、动作空间相近
head_mlp_only冻结 backbone + head attention数据极少、最快

命令行微调

python scripts/finetune.py \
    --config.pretrained_path=hf://rail-berkeley/octo-small-1.5 \
    --config=finetune_config.py:full,image_conditioned

Python API 微调(改观测/动作空间)

核心步骤(详见 examples/02_finetune_new_observation_action.py):

  1. OctoModel.load_pretrained() 加载预训练
  2. 修改 config["model"]["observation_tokenizers"](如删 wrist、加 proprio)
  3. 修改 config["model"]["heads"]["action"](如换成 L1ActionHead)
  4. OctoModel.from_config() 重建模型 → merge_params() 合入预训练权重
  5. freeze_weights(tx, ...) 冻结指定 key
  6. 自定义 train_step(jax.jit + value_and_grad)跑训练
模块化换头:这是 Octo 最大的工程优势。你可以保留预训练的 backbone,只换 head 适配新的动作空间(如从 7-DoF 换成 14-DoF 双臂),backbone 权重不用重训。

推理部署

from octo.model.octo_model import OctoModel

model = OctoModel.load_pretrained("hf://rail-berkeley/octo-small-1.5")

# 语言指令 → 动作
task = model.create_tasks(texts=["pick up the spoon"])
action = model.sample_actions(observation, task, rng=jax.random.PRNGKey(0),
    unnormalization_statistics=model.dataset_statistics["bridge_dataset"]["action"])

sample_actions 已用 @jax.jit 装饰,推理图编译后缓存(首次有编译开销)。部署到真机可用 HistoryWrapper 包装环境,自动管理历史窗口。

快速上手

安装

conda create -n octo python=3.10
conda activate octo
git clone https://github.com/octo-models/octo.git && cd octo
pip install -e .
pip install -r requirements.txt
# GPU 版(CUDA 11)
pip install --upgrade "jax[cuda11_pip]==0.4.20" \
    -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
# 验证安装
python scripts/finetune.py \
  --config.pretrained_path=hf://rail-berkeley/octo-small-1.5 --debug
⚠ 版本依赖:JAX 0.4.20 + Flax 0.7.5 + Optax 0.1.5 + TensorFlow 2.15(仅用于数据加载)。numpy==1.24.3ml_dtypes==0.2.0 需严格对齐。

加载模型 + 推理(最短路径)

import os
os.environ["TOKENIZERS_PARALLELISM"] = "false"
import jax, numpy as np
from octo.model.octo_model import OctoModel

# 1) 加载(自动从 HuggingFace 下载)
model = OctoModel.load_pretrained("hf://rail-berkeley/octo-small-1.5")
print(model.get_pretty_spec())

# 2) 准备观测(batch=1, history=1, 256×256 RGB)
img = np.random.randint(0, 255, (1, 1, 256, 256, 3), dtype=np.uint8)
observation = {"image_primary": img, "timestep_pad_mask": np.array([[True]])}

# 3) 构造任务(语言 或 目标图像二选一)
task = model.create_tasks(texts=["pick up the fork"])
# task = model.create_tasks(goals={"image_primary": goal_img[None]})

# 4) 采样动作(需传反归一化统计量以得到真实物理动作)
action = model.sample_actions(
    observation, task,
    unnormalization_statistics=model.dataset_statistics["bridge_dataset"]["action"],
    rng=jax.random.PRNGKey(0),
)
print(action.shape)  # (1, 4, 7) → [batch, action_horizon, action_dim]

硬件需求(资源友好是最大卖点)

场景硬件说明
推理单张 RTX 4090(消费级)Octo-Base: 13 it/s;Octo-Small: 17 it/s
微调单张消费级 GPU论文明确:「几小时内在标准消费级 GPU 上」
预训练(复现)TPUv4-128 podOcto-Small: 8 小时;Octo-Base: 14 小时
对比 RT-2(55B / 需多 TPU / 几乎不可复现),Octo 把通用策略拉到了单卡可玩级别——这是它能火的核心原因。

关键数字

指标数值
Octo-Small 参数量27M(≈ ViT-S)
Octo-Base 参数量93M(≈ ViT-B)
训练数据Open X-Embodiment 800K 轨迹,20 个数据集
支持机器人平台实验验证 9+ 种(WidowX, RT-1, Kuka, Franka, UR5 等)
预训练步数300,000 steps,batch 512
观测历史窗口2(当前 + 上一帧)
动作输出action chunking = 4,默认 action_dim = 7
扩散步数20(cosine schedule)
图像分辨率primary 256×256 / wrist 128×128,patch 16×16
推理速度 (4090)13-17 it/s
微调耗时单消费级 GPU / 数小时
许可MIT
💡 v1.5 关键改进:① 语言 token 在每个时间步重复(repeat_task_tokens=True),增强视觉-语言交叉注意力;② GPT-3.5 语言改写增强;③ Bug 修复(diffusion dropout、attention mask off-by-one)。