AlphaFold3 Technical Report
date
Mar 11, 2025
slug
AlphaFold3_Reading_Guide
status
Published
tags
Deep Learning
summary
type
Post
AlphaFold3架构的视觉导览,包含比你可能需要的更多细节和图表。
Part 0: 引言
0-1 谁应该阅读本文?
你想了解AlphaFold 3是如何工作的吗?它的架构相当复杂,论文中的描述可能会让人感到不知所措,因此我们制作了一个更友好的(但同样详细的!)视觉导览。
本文主要面向机器学习(Machine Learning,ML)受众,我们假定读者熟悉注意力机制(Attention)的步骤。如果你对这些知识有些陌生,可以参考 Jay Alammar 的图解Transformer,那份博客有详细的插图解释。那篇博客是对模型架构在单个矩阵操作层面的最佳解释之一,也是本文图表和命名的灵感来源。
目前已经有许多关于蛋白质结构预测的动机、CASP 竞赛、模型失败模式、评估争论以及生物技术影响的优秀解释,在这里我们并不关注这些。相反,我们探索如何实现AlphaFold 3这种模型。
分子在模型中是如何表示的?将它们转换为预测结构的所有操作是什么?
这可能比大多数人所需要的了解的更详尽,但如果你想了解所有细节并且喜欢通过图表学习,这应该会对你有所帮助 :)
0-2 架构概述
首先指出,该模型的目标与之前的 AlphaFold 2 模型略有不同:仅从序列出发,它不仅可以预测单个蛋白质序列(AlphaFold 2)或蛋白质复合物(AF-multimer)的结构,还可以预测蛋白质(或与其它蛋白质、核酸或小分子复合)的结构。虽也正是如此,之前的 AlphaFold 2 模型只需要表示标准氨基酸的序列,而 AlphaFold 3 需要表示更加复杂的输入类型,也因此其有更复杂的特征化/分词方案。分词将在单独的部分描述,但现在只需要知道,当我们说「token」时,它要么代表单个氨基酸(对于蛋白质),要么代表核苷酸(对于 DNA/RNA),要么代表单个原子(如果该原子不属于标准氨基酸/核苷酸)。

该模型可以分解为 3 个主要部分:
- 输入准备。用户提供分子的序列以预测其结构,这些序列会被转换为数值张量。此外,模型会检索一组新的分子,这些分子具有与用户提供分子相似的结构。在输入准备阶段,模型也会将这些检索得到的新的分子转换为输入特征的一部份。
- 表征学习。给定在输入准备阶段获得的 single 和 pair 张量,使用多种注意力变体更新这些表征。
- 结构预测。使用第 2 部分表征学习阶段更新后得到的表征以及在第 1 部分输入准备阶段中计算得到的原始输入特征,通过条件扩散来预测结构。
此外额外章节也描述 4. 损失函数、置信头和其他相关训练细节 和 5. 从机器学习视角对AlphaFold 3模型的一些思考。
0-3 关于变量和图表的说明
在整个模型中,蛋白质复合物主要以两种形式表示:「single 表征」:表示蛋白质复合物中的所有 token;「pair 表征」:代表复合物中所有氨基酸/原子对之间的关系(例如距离、潜在相互作用)。这些都可以以 atom 级别或 token 级别表示,并将始终以这些名称(根据 AlphaFold 3 论文中的规定)和颜色显示:

- 图表抽象了模型权重,仅可视化激活张量形状的变化。
- 激活张量始终标有论文中使用的维度名称,图表旨在跟随这些维度的增长/缩小。隐藏维度的名称通常以「c」开头,表示「channel」。参考主要使用的维度是 =128, =64, =128, =16, =768, =384。
- 只要可能,在此(和每个)图表中张量上方的名称与 AlphaFold 3 补充材料中使用的张量名称相匹配。通常,一个张量在模型中经过时保持其名称。然而,在某些情况下,会使用不同的名称来区分张量在处理不同阶段的版本。例如,在 atom 级别的 single 表征中, c代表初始的atom 级别 single 表征,而q代表在 Atom Transformer 中进行时此表征的更新版本。
- 为了简单起见,在介绍中我们忽略了大多数
LayerNorm
模块,尽管该模块在AlphaFold 3中随处可见。
Part 1: 输入准备

用户提供给 AlphaFold 3 的实际输入是一个蛋白质的序列和可选的额外分子/配体(ligand)。本节的目的是将这些序列转换为一系列 6 个张量,这些张量将作为模型主干(truck)的输入,如本图所示。这些张量是:
- token 级别:token 级别 single 表征 ;token 级别 pair 表征 ;
- atom 级别: atom 级别 single 表征 ;atom 级别 pair 表征 ;
- 额外信息:MSA 表征 ;模板 (template) 表征 。
本节包含的主要内容概述如下:
- 分词: 描述了分子如何被分词,并阐明了 atom 级别和 token 级别之间的区别。
- 检索 (创建 MSA 和模板): 解释了为什么以及如何将 MSA () 和结构模板 () 的特征输入到模型中。
- 创建 atom 级别表征: 创建了初始 atom 级别表征 (single) 和 (pair),并包含关于分子生成构象的信息。
- 更新 atom 级别表征(使用
Atom Transformer
): 是主要的"Input Embedder"块,它被重复 3 次并更新 atom 级别 single 表征 ()。这里介绍的构建块 (Adaptive LayerNorm
,Attention with Pair Bias
,Conditioned Gating
, 和Conditioned Transition
) 在模型的后面部分也很重要。
- 聚合 atom 级别表征汇总至 token 级别表征: 取 atom 级别表征 (, ) 并聚合所有属于多原子 token 的原子,以创建 token 级别表征 (single) 和 (pair),并包含来自 MSA () 和用户提供的关于涉及配体的已知键的信息。
1-1 分词(Tokenization)

在 AlphaFold 2 中,由于模型只表示具有固定氨基酸类型(20种天然氨基酸+1种表示未知氨基酸)的蛋白质,每个氨基酸都用自己的 token 表示。这种表示方式在 AlphaFold 3 中得以保留,但 AlphaFold 3 为了可以处理其他分子类型,引入了额外的 token 表示:
- 标准氨基酸:1 个 token (与 AlphaFold 2 相同)。
- 标准核苷酸:1 个 token。
- 非标准氨基酸或核苷酸(甲基化核苷酸、具有翻译后修饰的氨基酸等):每个原子 1 个 token。
- 其他分子:每个原子 1 个 token。
因此,我们可以认为一些 token (如氨基酸的 token) 与多个原子相关联,而其他 token (如配体中的原子的 token) 仅与单个原子相关联。因此,一个具有 35 个标准氨基酸 (可能 >600 个原子) 的蛋白质将由 35 个 token 表示,而一个具有 35 个原子的配体也将由 35 个 token 表示。
1-2 检索(创建 MSA 和模板)

AlphaFold 3 中的一个关键早期步骤类似于语言模型中的检索增强生成 RAG。找到与我们感兴趣的蛋白质/RNA 序列相似的序列 (收集到多序列比对"MSA"中),以及与这些序列相关的任何结构 (称为"模板"),然后将它们作为额外的输入 和 包含到模型中。

为什么我们要包括 MSA 和模板?
不同物种中发现的同一蛋白质的版本在结构和序列上可能非常相似。通过将这些序列对齐到一个多序列比对(MSA)中,我们可以观察蛋白质序列中单个位置在进化过程中的变化。你可以把给定蛋白质的 MSA 想象成一个矩阵,其中每一行是来自不同物种的类似蛋白质的序列。研究表明,蛋白质中特定位置的列的保守模式可以反映该位置存在某些氨基酸的重要性,不同列之间的关系反映了氨基酸之间的关系(即,如果两个氨基酸在物理上相互作用,它们在进化过程中的氨基酸变化很可能会相关)。因此,MSA 经常被用来丰富单个蛋白质的表征。
类似地,如果这些蛋白质中的任何一个有已知结构,这些结构也可能为该蛋白质的结构提供信息。不是搜索完整的结构,而是使用蛋白质的单个链。这类似于同源建模的做法,其中查询蛋白质的结构是基于假定相似的已知蛋白质结构的模板建模的。
那么这些序列和结构是如何检索的?
首先,进行基因搜索,寻找与任何输入蛋白质或 RNA 链相似的任何蛋白质或 RNA 链。这不涉及任何训练,依赖于现有的基于隐马尔可夫模型(HMM)的方法。具体来说,他们使用 jackhmmer、HHBlits 和 nhmmer 扫描多个蛋白质数据库和 RNA 数据库以获取相关命中。然后将这些序列相互对齐,以构建一个具有 N个MSA 序列的 MSA。由于模型的计算复杂性随着 NMSA 的增加而增加,他们将 NMSA 限制在 <214。通常,MSA 是从单个蛋白质链构建的,但如 AF-multimer 中所述,不是简单地将单独的 MSA 连接成块对角矩阵,而是可以将来自同一物种的某些链「配对」,如此处所述。这样,MSA 不必那么大和稀疏,并且可以学习关于链之间关系的进化信息。
然后,对于每个蛋白质链,他们使用另一种基于 HMM 的方法(hmmsearch)在蛋白质数据库(PDB)中查找与构建的 MSA 相似的序列。选择最高质量的结构,并从中采样多达 4 个作为「模板」包含。
与 AF-multimer 相比,这些检索步骤中唯一的更新部分是,我们现在除了对蛋白质序列进行检索外,还对 RNA 序列进行检索。注意,这传统上不称为「检索」,因为在 RAG 这个术语出现之前,使用结构模板来指导蛋白质结构建模的做法在同源建模领域中就已经很常见了。然而,尽管 AlphaFold 没有明确地将此过程称为检索,但它确实与现在流行的 RAG 非常相似。
1-2-1 我们如何表示这些模板?
在模板搜索中,每个模板都有一个 3D 结构,以及可以从当前模版中获得关于哪些 token 在哪些链中的信息。首先,计算给定模板中所有 token 对之间的欧几里得距离。对于与多个原子相关的 token,使用一个代表性的「中心原子」来计算距离。对于氨基酸,这是原子,对于标准核苷酸,这是原子。

每个模板会生成一个 的矩阵。然而,不是将每个距离表示为数值,而是将距离离散化为「Distogram 距离直方图」。具体来说,值被分到 到 之间的 38 个 bin 中,并且有一个额外的 bin 用于任何大于此的距离。
对于每个 distogram,随后附加关于每个 token 所属的链的元数据。在分子复合物中,链指一个独立的分子或分子的部分。这可以是蛋白质链(氨基酸序列)、DNA 或 RNA 链(核苷酸序列)或其他生物分子。AlphaFold 使用链信息来区分复合物的不同部分,帮助预测这些部分如何相互作用形成整体结构。我们还附加关于该 token 是否在晶体结构中被解析的信息,以及关于每个氨基酸内部局部距离的信息。然后,我们掩码该矩阵,使得我们只能看到每条链内部的距离(例如,我们忽略链 A 和链 B 之间的距离),因为他们「没有尝试选择模板…以获取关于链间相互作用的信息」。
1-3 创建 atom 级别表征

为了创建 atom 级别的 single 表征 ,需要提取所有 atom 级别的特征。第一步是为每个氨基酸、核苷酸和配体计算一个「参考构象」 (Reference Conformer)。虽然暂时还不知道整个复合物的结构,但我们对每个单独组分的局部结构有很强的先验知识。构象 (构象异构体 的简称) 是分子中原子的一种 3D 排列,通过对单键的旋转采样生成。每个氨基酸都有一个「标准构象」,这只是该氨基酸可以存在的低能量构象之一,可以使用 RDKit 的 ETKDGv3 生成,这是一种结合实验数据和扭转角偏好来产生 3D 构象的算法。每个小分子都需要生成对应的构象。然后,我们将来自该构象的信息 (相对位置
ref_pos
)与每个原子的电荷(ref_charge
)、原子序数(ref_element
)和其他标识符(ref_mask
, ref_atom_name_chars
)等信息拼接起来,并用矩阵存储序列中所有原子的这类信息。
在 AlphaFold 3 的补充材料中,atom 级别的矩阵 (和) 通常以其向量形式 (例如或) 提及,其中和用于索引原子。AlphaFold 3 使用初始化 atom 级别 pair 表征,以存储原子之间的相对距离。因为暂时只知道每个 token 内部的参考距离,我们使用一个掩码 () 来确保这个初始距离矩阵只计算生成构象中的相对距离。此外还包括距离的平方倒数的线性嵌入,将其与和的投影相加,并通过带有残差连接的线性层更新这个表征。
AlphaFold 3 原论文中并没有真正澄清为什么执行这个额外的倒数距离步骤,也没有包含对其影响的消融实验;因此我们只能假设它们在经验上被证明是有用的。
在 AlphaFold 3 补充材料中,张量通常以其向量形式和(表示原子和原子之间的关系) 提及。最后,同时得到 atom 级别 single 表征。矩阵是接下来将要更新的表征,但被保存并在后面继续使用。
1-4 更新 atom 级别表征(Atom Transformer)

在生成(single原子的表征信息) 和(pair原子的表征信息) 之后,我们想根据每个原子附近的其它原子的表征更新该原子的表征。每当 AlphaFold 3 在 atom 级别应用注意力时,都会使用一个称为
Atom Transformer
的模块。Atom Transformer
是一系列块的组合,通过使用注意力机制来更新,同时也会使用到和的原始表征。由于不会被 Attention Transformer
更新,它可以被视为起始表征的残差连接。Atom Transformer
主要遵循标准的 Transformer 结构,使用层归一化、注意力机制,然后是 MLP 层。然而每个步骤都经过调整,从而将来自和的额外输入包含在更新的步骤中 (有时被称为"conditioning")。在注意力和 MLP 块之间还有一个"gating"步骤。我们在此详细地介绍这 4 个步骤:1-4-1 Adaptive LayerNorm

Adaptive LayerNorm (
AdaNorm
) 是 LayerNorm 的一种变体,有一个简单的扩展。对于给定的输入矩阵,传统的 LayerNorm 学习两个参数 (缩放因子 和偏置因子 ),用来调整矩阵中每个通道的均值(mean)和标准差(std error)。而 AdaNorm
不是学习固定的 和 参数,而是学习一个函数,根据输入矩阵自适应地生成 和 。然而,不是基于正在被重新缩放的输入 (在 Atom Transformer
中这是) 生成参数,而是使用次要输入,即conditioning (在 Atom Transformer
中是) 来预测重新缩放的均值和标准差的 和 。
1-4-2 Attention with Pair Bias

Atom-Level Attention with pair-Bias 可以被视为自注意力的扩展。与正常的自注意力机制一样,查询(query)、键(key)和值(value)都来自同一个 1D 序列 (atom 级别 single 表征)。然而,有 3 个不同之处:
- pair-biasing:在计算 query 和 key 的点积后,pair 表征的线性投影被添加为偏置(bias),以缩放注意力权重。注意,该操作不涉及任何来自的信息被用来更新,只是从 pair 表征到的单向流动。
这样做的原因是由于具有更强成对关系的原子应该更强烈地相互关注,而实际上已经编码了一个注意力图。
- Gating:除了query、key和value外,我们还额外创建了的另一个投影(),通过
Sigmoid
函数,将值压缩到 0 和 1 之间。所有输出头在被重新组合之前都与这个"gate"相乘。
这种 gating 在 AlphaFold 3 中频繁出现。简要说明一下,这实际上迫使模型忽略在注意力过程中学到的一些信息。
由于模型不断将每个部分的输出添加到残差流中,gating 机制可以被视为模型指定哪些信息被保存或不被保存在这个残差流中的方式。它可能以 LSTM 中的类似"gates"命名,LSTM 使用
Sigmoid
来学习一个过滤器,决定哪些输入被添加到运行的 cell state 中。
- Sparse attention:由于原子的数量可能远大于 token 的数量,在这一步不会运行完全版本的注意力运算,而是使用一种稀疏注意力 (称为 Sequence-local atom attention),其中注意力实际上在局部组中运行,每次 32 个原子的组可以全部关注 128 个其他原子。稀疏注意力模式在互联网上的其他地方有更详细的描述。
1-4-3 Conditioned Gating

我们对数据应用另一个 gate 运算,但这次 gate 是从我们的原始 atom 级别 single 矩阵生成的。与许多步骤一样,不清楚为什么这样做,以及基于原始表征的 conditioning 与从主要的 single 表征学习 gate 相比有什么好处。
1-4-4 Conditioned Transition
这一步等同于 Transformer 中的 MLP 层,之所以称为"conditioned",是因为 MLP 被夹在 Adaptive LayerNorm (Atom Transformer 的第 1 步) 和 Conditional Gating (Atom Transformer 的第 3 步) 之间,这两者都依赖于。
本节中唯一值得注意的其他部分是,AlphaFold 3 在 transition 块中使用
SwiGLU
而不是 ReLU
。从 ReLU
到 SwiGLU
的转换发生在 AlphaFold 2 到 AlphaFold 3 之间,并且在许多最近的架构中是一个常见的变化,所以我们在这里可视化它。使用基于
ReLU
的 transition 层 (如 AlphaFold 2 中),我们取激活,将它们投影到 4 倍大小,应用 ReLU
,然后将它们投影回原始大小。使用 SwiGLU
(在 AlphaFold 3 中),输入激活创建两个中间的 up-projection,其中一个通过 swish 非线性 (ReLU
的改进变体),然后在 down-projecting 之前将它们相乘。下面的图表显示了差异:
1-5 聚合Atom级别表示到Token级别表示

虽然迄今为止的数据都存储在atom 级别,但从该模块开始,AlphaFold 3 的表征学习部分会在 token 级别上运行。为创建这些 token 级别表征,首先会将 atom 级别表征投影到一个更大的维度 (=128, =384)。然后对分配给同一个 token 的所有原子取平均值。注意,这只适用于与标准氨基酸和核苷酸相关的原子 (通过对附着在同一个 token 上的所有原子取平均值),而其余的保持不变。AlphaFold 3 将这些分子类型描述为每个 token 有一个代表性原子 (中心原子)。回想一下,对于氨基酸,这是 原子,对于标准核苷酸,这是 原子。因此,虽然我们主要将这个减少的表征视为"token 空间",我们也可以认为每个 token 代表一个单一原子 (要么是代表性的 /原子,要么是单个原子)。

对于 token 空间,首先将 token 级别特征和来自 MSA 的统计信息 (如果可用) 连接起来。例如,氨基酸类型 (dim=32,20种天然氨基酸+1种未知氨基酸+4种天然核糖核苷酸+1种未知核糖核苷酸+4种天然脱氧核糖核苷酸+1种未知脱氧核糖核苷酸),以及来自 MSA 中表示该 token 的删除均值 (dim=1)。注意,对于不与 MSA 相关的配体原子,这些值将为零。矩阵 由于拼接后有所增长,会被投影回 维度,并称为 ,该表征将在表征学习部分中更新。注意,在表征学习部分中更新,但 被保存以在结构预测部分中使用。

现在我们已经创建了初始 single 表征 ,下一步是初始化 pair 表征 。pair 表征是一个三维张量,但最容易将其视为一个类似热图的 2D 矩阵,具有隐含的深度维度 =384 个通道。因此,pair 表征的条目 是一个 维向量,旨在存储关于 token 序列中第 和第 个token之间关系的信息。由于已经创建了一个类似的 atom 级别pair表征 ,在这里的 token 级别上遵循类似的计算流程。
为了初始化 ,我们使用线性投影使序列表征的通道维度与 pair 表征的通道维度相匹配 (384→128),并将得到的 和 相加。然后,添加一个相对位置编码 。相对位置编码由 ,,这三部分组成。其中:
- 定义了两个token在同一链(chain)内的相对位置偏移量的热编码(one-hot encoding),若两个token属于同一链(如蛋白质的一条多肽链或DNA的一条链),则计算它们在序列中的位置差(offset);如果两个token 不属于同一链,或它们的偏移量超过阈值(例如65),则将编码值设为65(或对应的one-hot位置)。该特征主要是为了捕捉同一链内两个token的局部空间关系,例如相邻或间隔较近的氨基酸。
- 定义了两个token在同一残基(amino acid/nucleotide)内的相对位置偏移量的一位热编码。若两个token属于同一残基(如同一氨基酸或核苷酸的组成部分),则计算它们的偏移量;如果两个token属于不同残基,则将编码值设为65。该特征可以捕捉同一残基内不同位置(如蛋白质的侧链原子)的精细结构关系。
- 定义了两个token所在链(chain)之间的相对偏移量编码。若两个token属于不同链(如蛋白质的两条链或DNA的双链),则编码它们所属链的索引差(例如链A和链B的索引差为1)。该特征可以表示不同链之间的全局拓扑关系,例如链间的距离或排列顺序。
我们相对位置编码的信息同样投影到 的维度。如果用户还指定了 token 之间的特定键,这些键在这里被线性嵌入并添加到 pair 表征中的该条目。现在我们已经成功创建并嵌入将在模型的其余部分中使用的所有输入:

对于第 2 步,我们将把 atom 级别表征 (, , ) 放在一边,专注于在下一节中更新我们的 token 级别表征 和 (在 和 的帮助下)。
Part 2: 表征学习

本节是模型的主干部分,通常被称为「trunk」,这也是模型大部分计算完成的模块。称其为表征学习部分,其目标是为了学习 token 级别 single() 和 pair() 张量的更新表征,这些张量在第一部分已完成初始化。本节具体包含:
- 模板模块: 使用结构模板 更新 。
- MSA 模块: 首先更新 MSA 的表征,随后将其添加到 token 级别的 pair 表征 中。主要关注两个操作:
- 外积均值:使 能够影响 。
- 仅使用 pair bias 的 MSA 行向 gated 自注意力: 基于 更新 ,是 attention with pair-bias 的简化版本 (专为 MSA 设计)。
- Pairformer: 采用三角注意力更新 和 。主要讨论了:
- 为什么使用三角操作?解释了三角操作的一些直觉。
- 三角更新和三角注意力: 受三角不等式的启发的自注意力的方法更新 。
- Single Attention With Pair Bias: 基于 更新 ,是 attention with pair-bias 的 token 级别等价算法 (专为单个序列设计)。
每个单独的块被重复多次,然后整个部分的输出再次作为输入反馈给自己,重复这个过程 (这称为 recycling)。
2-1 模板模块


每个模板 (图中 =2) 通过一个线性投影,与 pair 表征 () 的线性投影相加。这个新组合的矩阵通过一系列称为
Pairformer Stack
的操作 (稍后详细描述)。最后,所有模板的特征通过另一个线性层被平均在一起,。有趣的是,最后的线性层使用 ReLU
作为非线性激活函数,这本来不值得注意,但它是 AlphaFold 3 中仅有的两个使用 ReLU
作为非线性的地方之一。2-2 MSA 模块


该模块类似于 AlphaFold 2 中的「Evoformer」,其目标是同时改进 MSA 和 pair 的表征。它对这两个表征独立地进行一系列操作,使得它们之间能够进行交叉交流。在 AlphaFold 3 中该模块进行了如下更新:
- 对 MSA 行进行子采样,并非使用之前生成的所有 MSA 行 (可能多达 16 k),将 single 表征的投影张量添加到这个子采样的 MSA 中。
2-2-1 外积均值

MSA 包含多个同源序列的对齐结果,用于捕捉序列间的共进化信息(如两个位置在进化过程中是否倾向于共同突变)。Pair 表征表示序列中任意两个位置 () 之间的关系(如空间距离、共进化强度等)。取 MSA 表征并通过「外积均值」将其纳入 pair 表征。MSA 中的两个列揭示了序列中两个位置之间的关系信息 (例如,这两个位置在进化过程中序列的相关性如何)。对于每一对 token 索引 和 ,遍历所有 MSA 序列,计算 和 的外积,然后对所有 MSA 序列取平均。然后,将这个外积展平,投影回较低维度,并将其添加到 pair 表征 (完整细节见图表)。虽然每个外积只比较给定序列 的值,但取这些的均值时混合了跨序列的信息。这是模型中唯一一个在进化序列之间共享信息的地方。也是为了减少 AlphaFold 2 中 Evoformer 计算复杂性而做的一个重大改变。
为了将 MSA 中的信息(单序列的表征)转化为成对标表征(pair representation),并跨进化序列共享信息。对 MSA 中的第 个序列 计算其第 和第 位的表征向量的外积(outer product)。: 这里 表示外积(张量积),用于将两个向量扩展为矩阵。若要对所有序列求平均,则需要将所有序列的外积结果在序列维度上求平均值,得到最终的成对标表征:
2-2-2 仅使用 pair bias 的 MSA 行式门控自注意力

在基于 MSA 更新 pair 表征后,模型接下来基于 pair 表征更新 MSA。这种特定的更新模式称为仅使用 pair bias 的 MSA 行式门控自注意力,是 self attention with pair bias 的简化版本,在
Atom Transformer
部分讨论过,独立地应用于 MSA 中的每个序列 (行)。它受注意力的启发,但不是使用查询和键来确定每个 token 应该关注哪些其他位置,而是直接使用 pair 表征 中存储的 token 之间的现有关系。
在 pair 表征中,每个 是一个包含 token 和 之间关系信息的向量。当张量 被投影到一个矩阵时,每个 向量变成一个标量,可以用来确定 token 应该关注 token 的程度。在应用行向
softmax
后,这些现在等同于注意力分数,用于创建值的加权平均,就像典型的注意力图那样。请注意,由于它独立地对每行运行,MSA 中的进化序列之间没有共享信息。2-2-3 Pair 表征的更新
MSA 模块的最后一步是通过一系列称为
triangle updates
和 attention 的步骤更新 pair 表征。还有一些 transition 块使用 SwiGLU
来 up/down project 矩阵,就像在 Atom Transformer
中做的那样。2-3 Pairformer 模块


在使用模板和 MSA 更新 pair 表征后,更新后的 pair 表征 () 和 single 表征 () 进入 Pairformer 并相互更新。由于 transition 块已经描述过,本节重点关注
triangle updates
和 triangle attention
,然后简要解释 Single Attention with Pair Bias 与之前描述的变体有何不同。这些基于 triangle 的层是在 AlphaFold 2 中首次引入的,不仅在 AlphaFold 3 中保留,而且现在在架构中更加重要。2-3-1 为什么使用三角形?
这里的指导原则是三角不等式的思想:「三角形的任意两边之和大于或等于第三边」。回想一下,pair 张量中的每个 编码序列中位置 和 之间的关系。虽然它不是字面上的 token 对之间的物理距离,但如果想象每个 是两个氨基酸之间的距离,同时知道 和 ,根据三角不等式, 不能大于 2。知道其中两个距离给我们一个强烈的信念,关于第三个距离必须是什么。
triangle updates
和 triangle attention
的目标是尝试将这些几何约束编码到模型中。模型中没有强制执行三角不等式,而是通过确保每个位置 在更新时一次查看所有可能的位置三元组 () 来鼓励它。因此, 基于 和 进行更新。因为 代表这些 token 之间的复杂物理关系,而不仅仅是它们的距离,这些关系可以是有方向的。因此,对于 ,还想鼓励与 和 (对于所有原子 ) 的一致性。如果将原子视为一个图, 作为有向邻接矩阵,那么 AlphaFold 称这些为「outgoing edges」和「incoming edges」是有道理的。
考虑这个邻接矩阵的行 ,假设想更新 ,它已被紫色突出显示。更新背后的想法是,如果知道 0→1 和 2→1 之间的距离,这给一些关于 0→2 可以是什么的约束。类似地,如果知道 0→3 和 2→3 之间的距离,这也给一个关于 0→2 的约束。这适用于所有原子。

因此,在
triangle updates
和 triangle attention
中,有效地查看这个图中 3 个节点的所有有向路径 (即三角形,因此得名!)。
2-3-2 三角形更新
从图论的角度仔细看了 triangle 操作后,可以看到这是如何通过张量操作实现的。在 outgoing update 中,pair 表征中的每个位置 独立地基于同一行中其他元素的加权组合进行更新,其中每个 的权重基于其 outgoing edge triangle 中的第三个元素 。

实际上,取 的三个线性投影 (称为)。为了更新 ,取 的第 行和 的第 行的元素 wise multiplication。然后,对所有这些行 ( 的不同值) 求和,并用的 投影进行 gating。
此时你可能注意到,gating 在整个架构中都被使用!

对于 incoming update,有效地做同样的事情,但将行和列翻转,因此为了更新 ,取同一列中其他元素的加权和 ,其中每个 的权重基于其 outgoing edge triangle 中的第三个元素 ()。在创建相同的线性投影后,取 a 的第 列和 b 的 列的元素 wise multiplication,并对这个矩阵的所有行求和。你会发现这些操作完全镜像了上面描述的图论邻接视图。
三角形注意力
在的两个 triangle update 步骤之后,还使用 triangle attention 更新每个 ,用于 outgoing edges 和 incoming edges。AlphaFold 3 论文将「outgoing edges」称为「around starting node」的 attention,将「incoming edges」称为「around ending node」的 attention。

为了构建 triangle attention,从 1D 序列上的典型自注意力开始可能会有所帮助。回想一下,查询、键和值都是原始 1D 序列的转换。一种称为 axial attention 的注意力变体通过对 2D 矩阵的不同轴 (行,然后列) 应用独立的 1D 自注意力来扩展到矩阵。Triangle attention 将之前讨论的 triangle 原则添加到这个中,通过合并 和 (对于所有原子 ) 来更新 。具体来说,在「starting node」情况下,为了计算沿行 i 的注意力分数 (以确定 应该受 影响多少),像往常一样对 和 进行 query-key 比较,然后基于 偏置注意力,如上所示。

对于「ending node」情况,再次将行和列交换。对于 ,键和值都来自 的 column i,而偏置来自 column j。因此,在比较 query 和 key 时,基于 偏置那个注意力分数。然后,一旦有关于所有 的注意力分数,使用来自 column i 的值向量。
Single Attention with Pair Bias

现在已经用这四个 triangle 步骤更新了的 pair 表征,将 pair 表征通过如上所述的 Transition 块。最后,想使用这个新的更新 pair 表征 () 来更新的 single 表征 (),因此将使用 single attention with pair bias,如下图所示。这与在
Atom Transformer
部分中描述的 Single Attention with Pair Bias 相同。在 AlphaFold 3 补充材料中,Single Attention with Pair Bias 也被称为 「Attention Pair Bias」,但在 token 级别上。由于它在 token 级别上操作,它使用完全注意力,而不是在 atom 级别上操作时使用的块状稀疏模式。重复 Pairformer 48 个块,最终创建 和 。
3. 结构预测
Diffusion 的基础
现在,有了这些精炼的表征,准备使用 和 来预测复合物的结构。AlphaFold 3 引入的变化之一是整个结构预测基于 atom 级别的 diffusion。现有帖子更彻底地解释了 diffusion 的直觉和数学,但 Diffusion Model 的基本思想是从真实数据开始,向数据中添加随机噪声,然后训练一个模型来预测添加了什么噪声。噪声在 个时间步内迭代地添加到数据中,为每个数据点创建 个变体的序列。称原始数据点为 ,完全噪声版本为 。在训练期间,在时间步 ,模型被给予 ,并预测在 和 之间添加了什么噪声。对预测的噪声与实际添加的噪声进行梯度步。
然后,在推理时,简单地从随机噪声开始,这等同于 。对于每个时间步,预测模型认为已经添加的噪声,并去除那个预测的噪声。经过预先指定数量的时间步后,最终得到一个完全「去噪」的数据点,它应该类似于数据集中的原始数据。
Conditional Diffusion 让模型在某些输入上 condition 这些去噪预测。实际上,这意味着模型的每一步都接受三个输入:
- 当前生成的迭代噪声
- 当前所处的时间步表征
- 想要 condition 的信息 (这可以是生成图像的标题,或蛋白质的属性)。
因此,最终生成的结果不仅仅是一个类似于训练数据分布的随机示例,而是应该具体匹配这个 conditioning 向量所代表的信息。
在 AlphaFold 3 中,学习去噪的数据是一个矩阵 ,包含序列中所有原子的 坐标。在训练期间,向这些坐标添加高斯噪声,直到它们完全随机。然后在推理时,从随机坐标开始。在每个时间步,首先随机旋转和平移整个预测的复合物。这种数据增强教模型的复合物的任何旋转和平移都是同样有效的,并取代了 AlphaFold 2 中使用的更复杂的
Invariant Point Attention
。AlphaFold 2 开发了一个复杂的架构,称为 Invariant Point Attention
,旨在强制对平移和旋转的 equivariance。这引发了关于 IPA 在 AlphaFold 2 成功中的重要性的激烈辩论。在 AlphaFold 3 中,IPA 模块被放弃,转而采用更简单的方法:应用随机旋转和平移作为数据增强,以帮助模型自然地学习这种 equivariance。所以在这里,简单地围绕当前生成的中心 (所有原子坐标的均值) 随机旋转所有原子的坐标,并从 高斯分布中随机采样每个维度 的平移。从算法中看,平移是普遍的,即对当前生成中的每个原子应用相同的平移。这种数据增强在 CNN 中很受欢迎,但在过去的几年里,像 IPA 这样的 equivariant 架构被认为是一种更有效和优雅的方法来解决相同的问题。因此,当 AlphaFold 3 用数据增强替换 equivariant attention 时,引发了很多讨论。然后,向坐标添加少量噪声,以鼓励更多异质的生成,模型生成几个略有不同的变体是有益的。在推理时,可以使用的 confidence head 对每个变体评分,并只返回得分最高的生成。最后,使用 Diffusion Module 预测一个去噪步骤。在下面更详细地介绍这个模块:
Diffusion 模块


在每个去噪 diffusion 步骤中,根据输入序列的多个表征 condition 的预测:
- Trunk 的输出 (经过 Pairformer 更新后的 和 ,现在称为 和 )
- 在 input embedder 中创建的序列的初始原子和 token 级别表征,这些表征没有经过 trunk ( , )
AlphaFold 3 论文将其 diffusion 过程分解为 4 个步骤,涉及从 token 到原子,再回到 token,再回到原子:
- 准备 token 级别的 conditioning 张量
- 准备 atom 级别的 conditioning 张量,使用 Atom Transformer 更新它们,并将它们聚合回 token 级别
- 在 token 级别应用注意力,并投影回原子
- 在 atom 级别应用注意力以预测 atom 级别的噪声更新
1. 准备 token 级别的 conditioning 张量


为了初始化的 token 级别 conditioning 表征,将 连接到相对位置编码,然后将这个更大的表征投影回较低维度,并通过几个残差连接 transition 块。
类似地,对于的 token 级别 single 表征,连接模型开始时创建的输入的第一个表征 (sinputs) 和当前的表征 (),然后将其投影回原始大小。然后,基于当前的 diffusion 时间步创建傅里叶嵌入。更具体地说,与这个时间步在 Noise Schedule 中关联的噪声量,将其添加到的 single 表征中,并将这个组合通过几个 Transition 块。通过在这里的 conditioning 输入中包括 diffusion 时间步,它确保模型在进行去噪预测时知道 diffusion 过程中的时间步,从而为这个时间步预测正确规模的噪声去除。
2. 准备 atom 级别张量,应用 atom 级别注意力,并聚合回 token 级别
此时,得到 conditioning 向量在每个 token 级别存储信息,但也想在 atom 级别运行注意力。为了解决这个问题,取在 Embedding 部分中创建的输入的初始 atom 级别表征 ( 和),并基于当前的 token 级别表征更新它们,以创建 atom 级别的 conditioning 张量。


接下来,将原子的当前坐标 () 缩放为数据的方差,有效地创建具有单位方差的「无量纲」坐标 (称为 )。然后,基于 更新 ,使得 现在知道原子的当前位置。最后,用
Atom Transformer
更新 (它也接受 pair 表征作为输入),并将原子聚合回 token,如之前看到的。回想一下,在输入准备部分,Atom Transformer 在原子上运行稀疏注意力,所有步骤 (layer norm, attention, gating) 都 condition on conditioning tensor 。
在本步骤结束时,返回:
- :在纳入关于原子当前坐标的信息后更新的原子表征
- : 的 token 级别聚合形式,捕获坐标和序列信息
- :基于 trunk 的 conditioning 的原子表征
- :更新的 conditioning 的原子对表征
3. 在 token 级别应用注意力

本步骤的目标是应用注意力来更新对原子坐标和序列信息的 token 级别表征 。本步骤使用在输入准备期间可视化的 Diffusion Transformer,它与 Atom Transformer 类似,但针对 token。
4. 在 atom 级别应用注意力以预测 atom 级别的噪声更新
现在,回到原子空间。使用更新的 (基于当前「中心原子」位置的 token 级别表征) 来更新 (基于当前位置的所有原子的 atom 级别表征),使用 Atom Transformer。像在第 3 步中做的那样,广播的 token 表征以匹配开始时的原子数量 (有选择地复制代表多个原子的 token),并运行 Atom Transformer。最重要的是,最后一个线性层将这个 atom 级别表征 映射回 。这是关键步骤:已经使用所有这些 conditioning 表征为所有原子生成了坐标更新 。现在,因为在「无量纲」空间 中生成了这些更新,仔细地重新缩放。这种仔细的缩放涉及数据的方差和基于当前时间步的噪声计划,使得的更新在去噪过程中越来越小。将 的更新重新缩放到它们具有非单位方差的形式 ,并将更新应用到 。
至此,完成了对 AlphaFold 3 主要架构的导览!现在提供一些关于损失函数、辅助 confidence heads 和训练细节的额外信息。
4. 损失函数和其他训练细节
损失函数和 confidence heads
损失是 3 个项的加权和:
- : 评估 token 级别的预测 distogram 的准确性
- : 评估 atom 级别的预测 distogram 的准确性。它查看所有成对距离,然后包括额外的项以优先考虑附近原子和参与蛋白质-配体键的原子之间的距离。
- : 评估模型对自己关于哪些结构可能不准确的自我意识
模型的输出是 atom 级别的坐标,可以轻松地用于创建 atom 级别的 distogram。回想一下,distograms 最初是通过对原子之间的成对距离进行 binning 创建的。然而,这个损失评估的是 token 级别的 distogram。为了获得 token 的 坐标,只使用「中心原子」的坐标。由于这些 distogram 距离是分类的,预测的 distogram 通过交叉熵与真实的 distogram 进行比较。
Diffusion 损失本身是三个项的加权和,每个项都是在原子位置上计算的,此外还由 noiset 进行缩放:
- 是刚刚讨论的 distogram 损失的一个版本,但针对所有原子而不仅仅是「中心原子」(并且 DNA、RNA 和配体原子被上调权重)。此外,它查看位置之间的均方误差,而不是将它们 binning 到 distogram 中。
- 旨在通过对参与蛋白质-配体键的原子对的预测和 ground-truth distograms 的差异添加额外的 MSE 损失,确保蛋白质-配体键的键长的准确性。有各种训练阶段,在初始阶段 αbond 被设置为 0,因此这个项只在后面引入。
- (smoothed local distance difference test)
(平滑局部距离差异测试)是距离图损失的另一种变体,它试图捕捉局部距离的准确性。如果原子对的预测距离在原子对真实距离的给定阈值内,则该原子对“通过测试”。为了使该指标平滑且可微,我们将预测和真实距离图之间的差异通过一个以测试阈值为中心的 sigmoid 函数。我们可以将其视为生成该原子对通过测试的概率(介于 0 和 1 之间)。我们采用四个“测试”的平均值,这些测试具有越来越严格的阈值(4、2、1 和 0.5 Å)。使用这种损失鼓励模型降低每次测试失败的概率。最后,为了使测试“局部化”,如果原子对的真实距离很大,我们会忽略该原子对的损失,因为我们只希望模型专注于准确预测原子与其附近原子的距离。

该损失的目标不是提高结构的准确性,而是教会模型预测其自身的准确性。该损失是 4 个项的加权和,每个项对应于评估预测结构质量的一种方法:
- IDDT 原子级“局部距离差异测试”,捕获原子预测的到附近原子的距离的预期准确性。
- PAE 预测的标记 i 的预测位置和真实位置之间的对齐误差。我们首先将预测的标记 i 和真实标记 i 旋转和平移到标记 j 的坐标系中。也就是说,如果我们暂时假设标记 j 恰好在其真实位置,我们会根据标记 i 与标记 j 的关系来预测标记 i 与其应在位置的接近程度。
- PDE 标记之间的预测距离误差,捕获所有标记对之间预测差异的准确性。
- 实验解析预测 模型预测哪些原子是实验解析的(并非每个原子在每个晶体结构中都是实验解析的)。
为了获得每个指标的这些置信度损失,AF3 会预测这些误差指标的值,然后在预测的结构上计算这些误差指标,并且损失基于这两者之间的差异。因此,即使结构确实不正确并且 PAE 很高,如果预测的 PAE 也很高,则 将很低。

这些 confidence predictions 在 diffusion 过程的中间生成。在选定的 diffusion 步骤 t,预测的坐标 rt 用于更新在表征学习 trunk 中创建的 single 和 pair 表征。预测的误差然后从更新的 pair 表征 (对于 PAE 和 PDE) 或这个更新的 single 表征 (pLDDT 和 experimentally resolved) 的线性投影中计算。然后,基于相同的生成原子坐标计算实际的误差度量 (如果感兴趣,过程在下面描述) 以进行比较。
虽然这些项包含在 confidence head loss 中,但来自这些项的梯度只用于更新 confidence prediction heads,不影响模型的其余部分。
实际的误差度量是如何计算的?
pLDDT: 原子 l 的 LDDT 通过以下方式计算:在当前的预测结构中,计算原子 l 和一组由 m 索引的原子 R 之间的距离,并与 ground truth 等价物进行比较。要在这个集合中,原子 m 必须是聚合物链的一部分,距离 l 在 15 或 30 Å 内,具体取决于 m 所属的分子,并且是 token 的中心原子。然后,计算四个具有越来越严格阈值 (4, 2, 1 和 .5 Å) 的二元距离测试,并取平均通过率,并对 R 中的原子求和。将这个百分比 binning 到 0 和 1 之间的 50 个 bin 中。
在推理时,有一个 pLDDT head。这个 head 取给定 token 的 single 表征,将其重复到这个 token 「attached」 的所有原子 (技术上,是任何 token 附着的最大原子数,以便可以堆叠张量),并将所有这些 atom 级别表征投影到的 pLDDT_l 的 50 个 bin 中。将这些视为 50 个「类」的 logits,使用 softmax 转换为概率,并对 bin 取多类分类损失。
Predicted Alignment Error (PAE): 每个 token 被认为有一个 frame,即由涉及该 token 的三个原子 (称为 a, b, c) 创建的 3 D 坐标 frame。这三个原子中的原子 b 在这个 frame 中形成原点。在每个 token 有一个「attached」原子的情况下,frame 的中心原子是 token 的单个原子,相同实体 (例如,相同配体) 的另外两个最近的 token 形成 frame 的基础。对于每个 token 对 (i, j),使用 token_j 的 frame 重新表达 token_i 的中心原子的预测坐标。对 token_i 的中心原子的 ground-truth 坐标做同样的事情。这些转换后的 token_i 的中心原子的真实和预测坐标之间的欧几里得距离是的对齐误差,binning 到 64 个 bin 中。从 pair 表征 zi, j 预测这个对齐误差,将其投影到 64 个维度,将其视为 logits 并用 softmax 转换为概率。用分类损失训练这个 head,每个 bin 作为一个类。有关更多细节,请参见此处。
第三,AlphaFold 3 预测 token 之间的距离误差 (PDE)。真实的距离误差是通过取每个 token 对的中心原子之间的距离,并将这些距离 binning 到从 0 Å 到 32 Å 的 64 个均匀大小的 bin 中。预测的距离误差来自将 pair 表征 zi, j 加上 pair 表征 zj, i 投影到 64 个维度,再次将其视为 logits,并再次用 softmax 转换为概率。
最后,AlphaFold 3 预测每个原子是否在 ground-truth 结构中被实验解析。类似于 pLDDT head,将 si single 表征重复到这个 token 代表的原子数量,并投影到 2 个维度,使用二元分类损失。
其他训练细节
现在架构已经覆盖,最后的部分是一些额外的训练细节。
Recycling
如 AF2 中引入的,AF3 recycling 其权重;也就是说,不是使模型更深,而是重复使用模型权重并将输入多次运行通过模块,以持续改进表征。Diffusion 在推理时固有地使用 recycling,因为模型被训练为纳入时间步信息并对每个时间步使用相同的模型权重。
Cross-distillation
AF3 使用由自身(通过 self-distillation)和 AF2(通过 cross-distillation)生成的合成训练数据的混合。具体来说,作者指出,通过切换到基于 diffusion 的生成模块,模型停止了产生 AF2 用户视觉上识别低置信度和可能无序区域的特征「spaghetti」区域。只是视觉上查看基于 diffusion 的生成,所有区域都显得同样高置信度,使得更难以识别潜在的幻觉。
为了解决这个问题,他们在 AF3 的训练数据中纳入了 AF2 和 AF-Multimer 的生成,允许模型学习到,当 AF2 对其预测不自信时,它应该输出这些未折叠的区域,并「指导」AF3 做同样的事情。蒸馏数据集中的核酸和小分子必须被移除,因为 AF2 和 AF-multimer 无法处理它们。然而,一旦先前的模型生成了新的预测结构,并且这些结构与原始结构对齐,被移除的分子被添加回来。如果添加回来造成了新的原子冲突,整个结构被排除,以避免意外地教模型接受冲突。
(图表来自 AF3 论文)
Cropping 和训练阶段
虽然模型的任何部分都没有对输入序列长度的明确限制,但内存和计算需求随着序列长度的增加而显著增加(回想多个 O(Ntokens3 operations))。因此,为了效率,蛋白质被随机裁剪。如 AF-multimer 中引入的,因为我们想建模多个链之间的相互作用,随机裁剪需要包括所有这些。他们使用 3 种裁剪方法,这 3 种方法根据训练数据(例如:PDB 晶体结构 vs 无序 PDB 复合物 vs 蒸馏等)以不同比例使用:
- Contiguous cropping:为每条链选择连续的氨基酸序列
- Spatial cropping:基于到参考原子的距离选择氨基酸(通常这个原子是特定链或感兴趣的结合界面的一部分)
- Spatial interface cropping:类似于 spatial cropping,但基于到专门在结合界面上的原子的距离。
虽然在 384 的随机裁剪上训练的模型可以应用于更长的序列,但为了提高模型处理这些序列的能力,它在更大的序列长度上进行迭代微调。数据集的混合和其他训练细节也在每个训练阶段中变化,如下表所示。
(表来自 AF3 补充材料)
Clashing
作者指出,AF3 的损失不包括重叠原子的冲突惩罚。虽然切换到基于 diffusion 的结构模块意味着模型理论上可以预测两个原子在同一位置,但在训练后这种情况似乎很少见。也就是说,AF3 在对生成的结构进行排名时确实使用了冲突惩罚。
Batch sizes
虽然 diffusion 过程听起来相当复杂,但它仍然比模型的 trunk 计算成本低得多。因此,AF3 作者发现,从训练的角度来看,在 trunk 之后扩展模型的 batch size 更有效。所以对于每个输入结构,它通过 embedding 和 trunk 运行,然后应用 48 个独立的数据增强版本的结构,这 48 个结构都并行训练。
训练过程就到这里! 还有一些其他小细节,但这可能已经超出了你需要的范围,如果你已经读到这里,其余的部分应该很容易从阅读 AF3 补充材料中掌握。
ML Musings
如此彻底地走过 AF3 的架构及其与 AF2 的比较后,作者所做的选择如何融入更广泛的机器学习趋势是很有趣的。
AlphaFold 作为 Retrieval-Augmented Generation
在 AF2 发布时,在推理时从训练集中包含检索并不常见。在 AF 的情况下,利用 MSA 和模板搜索。基于 MSA 的方法被用于蛋白质建模,但这种检索在深度学习的其他领域中较少使用(例如,ResNets 在计算机视觉中分类新图像时不在推理时嵌入相关的训练图像)。虽然与 AF2 相比,AF3 减少了对 MSA 的强调(它不再在 Evoformer/Pairformer 的 48 个块中被操作和更新),但他们仍然纳入了 MSA 和模板,即使其他蛋白质预测模型如 ESMFold 已经放弃了检索,转而采用完全参数化的推理。
有趣的是,一些最大和最成功的深度学习模型现在经常在推理时包括类似的额外信息。虽然检索系统的细节并不总是公开的,大型语言模型通常在推理时使用 Retrieval Augmented Generation 系统,如传统的网络搜索,以将模型导向相关信息(即使该信息可能已经在其训练数据中),这应该指导推理。看到在推理时使用直接相关示例的做法在未来如何发展将很有趣。
Pair-Bias Attention
AF2 的一个主要组件,在 AF3 中更加突出的是 Pair-Bias Attention。即,注意力中查询、键和值都来自同一来源(像自注意力一样),但有一个来自另一个来源的偏置项被添加到注意力图中。这实际上作为信息共享的轻触版本,而没有完全的交叉注意力。Pair-Bias Attention 几乎出现在每个模块中。虽然这种注意力类型现在在其他蛋白质建模架构中使用,我们还没有在其他领域看到这种特定的交叉偏置的使用(尽管这并不意味着它没有被使用!)。也许它在这里效果很好,因为 pair-representation 自然地类似于自注意力图,但它是纯自注意力或纯交叉注意力的一个有趣的替代方案。
Self-supervised training
像 ESM 这样的 self-supervised 模型已经能够在预测蛋白质结构方面取得令人印象深刻的结果,通过使用 self-supervised 预训练将 MSA 嵌入替换为「probabilistic MSA」。在 AF2 中,模型有一个额外的任务,从 MSA 中预测 masked token,实现了类似的 self-supervision,但这在 AF3 中被移除了。我们还没有看到作者关于为什么他们没有在 MSA 上使用任何 self-supervised 语言建模预训练方法的评论,事实上,他们减少了用于处理 MSA 的计算。self-supervised learning 没有被用来初始化 MSA 嵌入的三个可能原因是 1) 他们认为大规模预训练阶段是对计算的次优使用 2) 他们尝试过,并发现包括一个小 MSA 模块的表现优于预训练嵌入,并且值得额外的推理时间成本 3) 利用氨基酸 token 的预训练嵌入和 DNA/RNA/配体的随机初始化嵌入的混合将不兼容或表现不如在他们的混合原子-token 结构上的完全监督训练。通过专注于 self-supervision 任务,ESM 家族中的模型也比 AF3 简单得多(尽管它们不处理 DNA/RNA/配体,并且有略微不同的目标。)有趣的是,观察到当一些模型旨在最大化架构的简单性时,AlphaFold 仍然如此复杂!
Classification vs. Regression
与 AF2 一样,AF3 继续使用 MSE 和 binning 分类损失的混合。分类组件很有趣,因为如果模型预测的 distogram bin 只差一,它不会因为接近而不是遥远而得到「credit」。不清楚是什么影响了这个设计决策,但也许作者发现梯度比处理几个不同的 MSE 损失更稳定,并且也许每个原子的损失看到了如此多的梯度步骤,以至于来自连续损失的额外信号不会证明是有益的。
与 Recurrent Architectures(例如 LSTMs)的相似性
AF3 的架构融入了几个让人联想到 recurrent neural networks 的设计元素,这些元素在传统的 transformers 中通常不出现:
- Extensive Gating:AF3 在其架构中广泛使用 gating 机制来控制残差流中的信息流动。这更类似于 LSTMs 或 GRUs 中的 gating,而不是正常 transformer 层的标准前馈性质。
- Iterative Processing with Weight Reuse:AF3 多次应用相同的权重来逐步完善其预测。这个过程,涉及 recycling 和 diffusion 模型,类似于 recurrent networks 如何使用共享的权重集在时间步上处理序列数据。这与标准 transformers 不同,标准 transformers 通常在单次前向传递中进行预测。这种方法允许 AF3 迭代地改进其蛋白质结构预测,而不增加参数数量。
- Adaptive Computation:recycling 也类似于 diffusion 中使用的迭代更新,并且与 adaptive compute time (ACT) 的思想相当相关,ACT 最初被引入以动态确定为 RNNs 使用多少计算,最近在 Mixture-of-Depths 中用于实现与 transformers 类似的目标。这与标准 transformers 的固定深度形成对比,理论上将允许模型对具有挑战性的输入应用更多处理。
在 AF2 的消融实验中,recycling 被证明是重要的,但关于 gating 的重要性几乎没有讨论。推测它像在 LSTMs 中一样有助于训练稳定性,但有趣的是,它在这里如此普遍,但在许多其他基于 transformer 的架构中却不是这样。
Cross-distillation
使用 AF2 的生成来重新引入其独特的风格,特别是对于低置信度区域,这非常有趣。如果这里有一个教训,它可能是最实际的:如果你的先前模型在某个特定方面比你的新模型做得更好,你可以尝试 cross-distillation 来获得两全其美的效果!