Octo 的目标做第一个真正开源、轻量级、资源友好的通用机器人策略。让通用机器人策略像 BERT 之于 NLP 一样——成为可被广泛使用的、低门槛的策略初始化(policy initialization)。开发者拿到预训练权重,用少量领域数据微调,就能获得强泛化能力。
已有的通用策略(如 RT-X)通常闭源、参数量大、对硬件要求高,难以被普通研究者复现和改进。Octo 把「通用策略」拉到了单卡可玩的级别。
| 设计目标 | Octo 的做法 |
|---|---|
| 开源可复现 | 全部代码、权重、训练配置、800K 数据混合配方开放(MIT 协议) |
| 跨平台泛化 | 在 Open X-Embodiment 800K 轨迹上训练,覆盖 9+ 种机器人平台 |
| 轻量/资源友好 | 仅 27M / 93M 参数,推理在单张 4090 上 13-17 it/s |
| 可灵活微调 | 模块化 Transformer 架构,支持新增传感器输入、新动作空间 |
Octo 是一个基于 Transformer 的扩散策略,核心是 OctoTransformer。它把整个输入序列组织成三个区域:
关键设计 — Blockwise Causal Attention(分块因果注意力):
| 名称 | layers | mlp_dim | heads | 参数量 |
|---|---|---|---|---|
| Octo-Small (vit_s) | 12 | 1536 | 6 | 27M |
| Octo-Base (vit_b) | 12 | 3072 | 12 | 93M |
输入(Observations) — 字典形式,历史窗口 history_window=2(当前帧 + 上一帧):
{
"image_primary": (batch, 2, 256, 256, 3), # 第三人称/工作区相机
"image_wrist": (batch, 2, 128, 128, 3), # 腕部相机(可选)
"timestep_pad_mask": (batch, 2), # True=该帧有效
}
Task(指令) — 支持三种条件方式:
输出(Actions):(batch, action_horizon=4, action_dim=7) — 一次预测未来 4 步动作(action chunking)。
Octo 默认使用 DiffusionActionHead:
仓库还提供其他 head 供微调时替换:
ContinuousActionHead / L1ActionHead:tanh 压缩 + 回归(常用于 ALOHA 双臂 14 维)DiscreteActionHead:动作离散化为 256 bin(类 RT-1 风格)| 指标 | 数值 |
|---|---|
| 轨迹总数 | 800K 条 |
| 数据集数量 | 20 个 OXE 子数据集混合 |
| 预处理大小 | ~1.2 TB |
| 主要数据集 | Fractal (RT-1) 17%、Kuka 17%、Bridge V2 17%、BC-Z 9.1% 等 |
| 语言增强 | 用 GPT-3.5 对语言指令做改写(paraphrasing) |
octo/
├── octo/ # 核心库
│ ├── model/
│ │ ├── octo_model.py # ★ OctoModel:加载/保存/推理统一入口
│ │ ├── octo_module.py # ★ OctoTransformer(架构核心)
│ │ └── components/
│ │ ├── action_heads.py # Diffusion/Continuous/Discrete Head
│ │ ├── tokenizers.py # Image/Language/Lowdim Tokenizer
│ │ ├── transformer.py # BlockTransformer + 注意力规则
│ │ └── vit_encoders.py # SmallStem16 等 ViT 编码器
│ ├── data/
│ │ ├── dataset.py # ★ 单/多数据集构建 + 交错采样
│ │ ├── traj_transforms.py # 轨迹级变换(窗口、action chunk)
│ │ └── oxe/ # Open X-Embodiment 数据配置
│ └── utils/
│ └── gym_wrappers.py # Gym 环境包装(HistoryWrapper)
├── scripts/
│ ├── train.py # 预训练主脚本
│ ├── finetune.py # ★ 微调主脚本
│ └── configs/
│ ├── config.py # 基础超参
│ └── finetune_config.py # 微调配置(冻结策略)
├── examples/ # ★ 6 个示例
│ ├── 01_inference_pretrained.ipynb
│ ├── 02_finetune_new_observation_action.py
│ ├── 03_eval_finetuned.py
│ └── 05_dataloading.ipynb
└── tests/ # 调试用数据集 + 单测
| 模块 | 关键方法 |
|---|---|
OctoModel | load_pretrained、sample_actions、create_tasks、save_pretrained |
OctoTransformer | __call__(observations, tasks, timestep_pad_mask) |
DiffusionActionHead | loss、predict_action |
TrainState | 封装 params + optimizer,freeze_weights 按 key 前缀冻结 |
三种冻结模式(命令行 --config=finetune_config.py:<mode>,<task>):
| 模式 | 冻结范围 | 适用场景 |
|---|---|---|
full | 不冻结,全模型训练 | 数据较多 |
head_only | 冻结 backbone,只训 head | 数据少、动作空间相近 |
head_mlp_only | 冻结 backbone + head attention | 数据极少、最快 |
python scripts/finetune.py \
--config.pretrained_path=hf://rail-berkeley/octo-small-1.5 \
--config=finetune_config.py:full,image_conditioned
核心步骤(详见 examples/02_finetune_new_observation_action.py):
OctoModel.load_pretrained() 加载预训练config["model"]["observation_tokenizers"](如删 wrist、加 proprio)config["model"]["heads"]["action"](如换成 L1ActionHead)OctoModel.from_config() 重建模型 → merge_params() 合入预训练权重freeze_weights(tx, ...) 冻结指定 keytrain_step(jax.jit + value_and_grad)跑训练from octo.model.octo_model import OctoModel
model = OctoModel.load_pretrained("hf://rail-berkeley/octo-small-1.5")
# 语言指令 → 动作
task = model.create_tasks(texts=["pick up the spoon"])
action = model.sample_actions(observation, task, rng=jax.random.PRNGKey(0),
unnormalization_statistics=model.dataset_statistics["bridge_dataset"]["action"])
sample_actions 已用 @jax.jit 装饰,推理图编译后缓存(首次有编译开销)。部署到真机可用 HistoryWrapper 包装环境,自动管理历史窗口。
conda create -n octo python=3.10
conda activate octo
git clone https://github.com/octo-models/octo.git && cd octo
pip install -e .
pip install -r requirements.txt
# GPU 版(CUDA 11)
pip install --upgrade "jax[cuda11_pip]==0.4.20" \
-f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
# 验证安装
python scripts/finetune.py \
--config.pretrained_path=hf://rail-berkeley/octo-small-1.5 --debug
numpy==1.24.3、ml_dtypes==0.2.0 需严格对齐。
import os
os.environ["TOKENIZERS_PARALLELISM"] = "false"
import jax, numpy as np
from octo.model.octo_model import OctoModel
# 1) 加载(自动从 HuggingFace 下载)
model = OctoModel.load_pretrained("hf://rail-berkeley/octo-small-1.5")
print(model.get_pretty_spec())
# 2) 准备观测(batch=1, history=1, 256×256 RGB)
img = np.random.randint(0, 255, (1, 1, 256, 256, 3), dtype=np.uint8)
observation = {"image_primary": img, "timestep_pad_mask": np.array([[True]])}
# 3) 构造任务(语言 或 目标图像二选一)
task = model.create_tasks(texts=["pick up the fork"])
# task = model.create_tasks(goals={"image_primary": goal_img[None]})
# 4) 采样动作(需传反归一化统计量以得到真实物理动作)
action = model.sample_actions(
observation, task,
unnormalization_statistics=model.dataset_statistics["bridge_dataset"]["action"],
rng=jax.random.PRNGKey(0),
)
print(action.shape) # (1, 4, 7) → [batch, action_horizon, action_dim]
| 场景 | 硬件 | 说明 |
|---|---|---|
| 推理 | 单张 RTX 4090(消费级) | Octo-Base: 13 it/s;Octo-Small: 17 it/s |
| 微调 | 单张消费级 GPU | 论文明确:「几小时内在标准消费级 GPU 上」 |
| 预训练(复现) | TPUv4-128 pod | Octo-Small: 8 小时;Octo-Base: 14 小时 |
| 指标 | 数值 |
|---|---|
| Octo-Small 参数量 | 27M(≈ ViT-S) |
| Octo-Base 参数量 | 93M(≈ ViT-B) |
| 训练数据 | Open X-Embodiment 800K 轨迹,20 个数据集 |
| 支持机器人平台 | 实验验证 9+ 种(WidowX, RT-1, Kuka, Franka, UR5 等) |
| 预训练步数 | 300,000 steps,batch 512 |
| 观测历史窗口 | 2(当前 + 上一帧) |
| 动作输出 | action chunking = 4,默认 action_dim = 7 |
| 扩散步数 | 20(cosine schedule) |
| 图像分辨率 | primary 256×256 / wrist 128×128,patch 16×16 |
| 推理速度 (4090) | 13-17 it/s |
| 微调耗时 | 单消费级 GPU / 数小时 |
| 许可 | MIT |
repeat_task_tokens=True),增强视觉-语言交叉注意力;② GPT-3.5 语言改写增强;③ Bug 修复(diffusion dropout、attention mask off-by-one)。