TABM: 通过参数高效集成推进表格深度学习
用于表格数据监督学习的深度学习架构从简单的多层感知器(MLP)到复杂的变换器和检索增强方法不等。本研究强调了一个重要但迄今为止被忽视的机会,以显著改善表格MLP:即参数高效集成——一种将多个模型作为一个模型生成多个预测的集成实现范式。我们首先开发了TabM——一个基于MLP的简单模型以及我们对BatchEnsemble(现有技术)的变体。然后,我们在公共基准上对表格深度学习架构进行大规模评估,考虑
TABM: 通过参数高效集成推进表格深度学习
Github:https://github.com/yandex-research/tabm
论文:https://arxiv.org/abs/2410.24210
文章目录
摘要
用于表格数据监督学习的深度学习架构从简单的多层感知器(MLP)到复杂的变换器和检索增强方法不等。本研究强调了一个重要但迄今为止被忽视的机会,以显著改善表格MLP:即参数高效集成——一种将多个模型作为一个模型生成多个预测的集成实现范式。我们首先开发了TabM——一个基于MLP的简单模型以及我们对BatchEnsemble(现有技术)的变体。然后,我们在公共基准上对表格深度学习架构进行大规模评估,考虑任务性能和效率,从而以新的视角呈现表格深度学习的全貌。一般而言,我们展示了包括TabM在内的MLP形成了一系列比基于注意力和检索的架构更强大且更实用的模型。特别是,我们发现TabM在表格深度学习模型中表现最佳。最后,我们对TabM的集成特性进行了实证分析。例如,我们观察到TabM的多个预测单独来看较弱,但整体上却非常强大。总体而言,我们的工作为表格深度学习带来了一个有影响力的技术,分析了其行为,并通过TabM推进了性能与效率的权衡——为研究人员和从业者提供了一个简单而强大的基线。代码可在以下位置获取:https://github.com/yandex-research/tabm.
1 引言
对于表格数据的监督学习是广泛工业应用中的一种普遍机器学习(ML)场景。在经典的非深度学习方法中,针对此类任务的最先进解决方案是梯度提升决策树(GBDT)(Prokhorenkova et al., 2018; Chen & Guestrin, 2016; Ke et al., 2017)。而针对表格数据的深度学习(DL)模型据报道正在改善,最近的研究声称在学术基准上与GBDT表现相当,甚至超越GBDT(Hollmann et al., 2023; Chen et al., 2023b;a; Gorishniy et al., 2024)。
然而,从实际的角度来看,目前尚不清楚表格深度学习是否提供了任何明显的基准,超越简单的多层感知器(MLP)架构。首先,文献中对新方法相对于简单的类似MLP基准的性能提升的规模和一致性并不总是明确分析。因此,人们必须从众多每个数据集的性能得分中推断这些统计数据,这使得推理进展变得困难。同时,由于表格数据集的极端多样性,一致性对于假设的基准来说是一个特别有价值且难以实现的特性。其次,与效率相关的属性,例如训练时间,尤其是推理吞吐量,有时受到的关注较少。虽然这些方法在小到中等规模的数据集上通常同样可负担(例如 < 100 K < {100}\mathrm{\;K} <100K 对象),但它们在更大数据集上的适用性仍然不确定。第三,一些近期的研究普遍表明,学术基准上的进展可能并不太好地转移到现实世界的任务中(Rubachev 等,2024)。考虑到以上所有因素,在本研究中,我们彻底评估了现有的表格深度学习方法,并发现非MLP模型尚未提供对MLP的令人信服的替代方案。
与此同时,我们识别出一条先前被忽视的路径,以实现更强大、可靠且合理高效的表格深度学习模型。简而言之,我们发现深度集成的参数高效方法,其中大多数权重在集成成员之间共享,能够将简单而强大的表格模型构建于普通的MLP之上。例如,结合BatchEnsemble(Wen 等,2020)的MLP——一种长期存在的方法——立即超越了流行的基于注意力的模型,如FT-Transformer(Gorishniy 等,2021),同时更简单且更高效。仅这一结果就表明,参数高效的集成是表格深度学习中的一个低垂的果实。
我们的工作基于上述观察,并提供了 TabM——一个为研究人员和从业者设计的新型强大且实用的模型。与 GBDT(决策树的集成)进行非正式类比,TabM 也可以被视为一个简单的基础模型(MLP)与类似集成的技术相结合,同时提供高性能和简单实现。
主要贡献。我们将主要贡献总结如下:
-
我们提出了 TabM——一个用于表格数据监督学习的简单深度学习架构。TabM 基于 MLP 和与 BatchEnsemble(Wen et al., 2020)密切相关的参数高效集成技术。特别地,TabM 为每个对象生成多个预测。TabM 容易与 GBDT 竞争,并超越先前的表格深度学习模型,同时比基于注意力和检索的深度学习架构更高效。
-
我们在大规模评估中提供了对表格深度学习模型的新视角,涵盖四个维度:性能排名、性能得分分布、训练时间和推理吞吐量。我们的发现之一是,MLP,包括 TabM,达到了一个令人满意的性能-效率权衡,而基于注意力和检索的模型则没有。
-
从经验上看,我们展示了 TabM 的多个预测在个体上是弱且过拟合的,而它们的平均值则是强且具有可推广性的。TabM 的训练梯度可以被视为来自多个预测的多样梯度的“集成”。
2 相关工作
基于决策树的模型。梯度提升决策树(GBDT)(Chen & Guestrin, 2016; Ke et al., 2017; Prokhorenkova et al., 2018)是表格任务的一个强大且高效的基准。GBDT 是一种经典的机器学习模型,具体来说,是决策树的集成。我们的模型 TabM 是一种深度学习模型,具体来说,是一种参数高效的多层感知器(MLP)集成。
表格深度学习架构。近年来,已经提出了大量用于表格数据的深度学习架构。这包括基于注意力的架构(Song et al., 2019; Gorishniy et al., 2021; Somepalli et al., 2021; Kossen et al., 2021; Yan et al., 2023)、增强检索的架构(Somepalli et al., 2021; Kossen et al., 2021; Gorishniy et al., 2024; Ye et al., 2024)、类似于多层感知器(MLP)的模型(Gorishniy et al., 2021; Klambauer et al., 2017; Wang et al., 2020)及其他(Arik & Pfister, 2020; Popov et al., 2020; Chen et al., 2023b; Marton et al., 2024; Hollmann et al., 2023)。与之前的工作相比,我们的模型TabM的关键区别在于其计算流程,其中一个TabM通过产生多个独立训练的预测来模仿多个MLP的集成。之前尝试将集成元素引入表格深度学习(Badirli et al., 2020; Popov et al., 2020)并未取得令人满意的结果(Gorishniy et al., 2021)。此外,作为一个简单的前馈基于MLP的模型,TabM在效率上显著优于一些先前的工作。与基于注意力的模型相比,TabM在数据集维度方面不受二次计算复杂度的影响。与基于检索的模型相比,TabM易于应用于大型数据集。
改进表格 MLP 类模型。最近的多项研究通过应用架构修改(Gorishniy et al., 2022)、正则化(Kadra et al., 2021; Jeffares et al., 2023a; Holzmüller et al., 2024)和自定义训练技术(Bahri et al., 2021; Rubachev et al., 2022),在表格任务上实现了与 MLP 类架构的竞争性表现。因此,表格 MLP 似乎具有良好的潜力,但必须处理过拟合和优化问题以揭示这一潜力。我们的模型 TabM 以不同的方式实现了与 MLP 的高性能,即在 BatchEsnsemble(Wen et al., 2020)精神下,将其作为参数高效集成的基础骨干。我们的方法与上述训练技术和架构进展是正交的。
深度集成。在本文中,深度集成指的是多个相同架构的深度学习模型在不同随机种子下独立训练(Jeffares et al., 2023b)以完成相同任务(即具有不同初始化、训练批次序列等)。深度集成的预测是其成员预测的平均值。深度集成通常显著优于相同架构的单一深度学习模型(Fort et al., 2020),并且在不确定性估计或分布外检测等其他任务中表现出色(Lakshminarayanan et al., 2017)。观察到深度集成的个体成员能够学习从输入中提取多样化信息,而深度集成的能力依赖于这种多样性(Allen-Zhu & Li, 2023)。深度集成的主要缺点是训练和使用多个模型的成本和不便。
参数高效的深度“集成”。为了以更低的成本实现深度集成的性能,多项研究提出了通过一个模型生成多个预测来模仿集成的架构(Lee et al., 2015; Zhang et al., 2020; Wen et al., 2020; Havasi et al., 2021; Antorán et al., 2020; Turkoglu et al., 2022)。这样的模型可以被视为“集成”,其中隐式集成成员共享大量权重。还有一些非架构方法用于高效集成,例如 FGE(Garipov et al., 2018),但我们不对此进行探讨,因为我们特别关注架构技术。在本文中,我们强调参数高效的集成作为表格深度学习的一个重要范式。特别地,我们描述了两种对表格 MLP 非常有效的 BatchEnsemble(Wen et al., 2020)的简单变体。一种变体使用了更高效的参数化,另一种则使用了改进的初始化。
3 TABM
在本节中,我们介绍 TabM - 一种进行多重预测的表格模型。
3.1 预备知识
记号。我们考虑表格数据上的分类和回归任务。 x x x 和 y y y 分别表示给定数据集中一个对象的特征和标签。机器学习模型将 x x x 作为输入,并生成 y ^ \widehat{y} y 作为 y . N ∈ N y.N \in \mathbb{N} y.N∈N 的预测, d ∈ N d \in \mathbb{N} d∈N 分别表示给定神经网络的“深度”(例如,块的数量)和“宽度”(例如,潜在表示的大小)。 d y ∈ N {d}_{y} \in \mathbb{N} dy∈N 是输出表示的大小(例如,回归任务的 d y = 1 {d}_{y} = 1 dy=1 ,而分类任务的 d y {d}_{y} dy 等于类别的数量)。
数据集。我们的基准包含46个在先前工作中使用的公开可用数据集,包括 Grinsztajn 等(2022);Gorishniy 等(2024);Rubachev 等(2024)。我们的基准的主要属性总结在表1中,更多细节见附录C。
表1:我们基准的概述。“拆分类型”属性在文本中进行了说明。

领域感知拆分。我们特别关注我们称之为“领域感知”拆分的数据集,包括来自 Rubachev 等(2024)的八个数据集和微软数据集(Qin & Liu, 2013)。对于这些数据集,它们的原始现实世界拆分是可用的,例如,Rubachev 等(2024)中的时间感知拆分。这些数据集被证明对某些方法具有挑战性,因为它们在训练和测试部分之间自然表现出一定程度的分布变化(Rubachev 等,2024)。其余37个数据集的随机拆分继承自先前的工作。
实验设置。我们使用 Gorishniy 等人(2024)提出的设置,并在 D.2 小节中详细描述。最重要的是,在每个数据集上,给定模型在验证集上进行超参数调优,然后在多个随机种子下从头开始训练调优后的模型,随机种子上测试指标的平均值成为模型在数据集上的最终得分。
指标。我们对回归任务使用 RMSE(均方根误差),对于分类任务则根据数据集来源使用准确率或 ROC-AUC。有关详细信息,请参见 D.3 小节。
此外,在整篇论文中,我们经常使用模型相对于 MLP 的相对性能作为关键指标。该指标为所有任务提供了统一的视角,并允许推理相对于简单基线(MLP)的改进规模。形式上,在给定数据集上,该指标定义为 ( score baseline − 1 ) ⋅ 100 % \left( {\frac{\text{score}}{\text{baseline}} - 1}\right) \cdot {100}\% (baselinescore−1)⋅100% ,其中“得分”是给定模型的指标,“基线”是 MLP 的指标。在此计算中,对于回归任务,我们将原始指标从 RMSE 转换为 R 2 {R}^{2} R2 以更好地对齐分类和回归指标的尺度。
3.2 批集成的快速介绍。
对于给定的架构,我们考虑其中的任何线性层 l l l : l ( x ) = W x + b l\left( x\right) = {Wx} + b l(x)=Wx+b ,其中 x ∈ R d 1 x \in {\mathbb{R}}^{{d}_{1}} x∈Rd1 , W ∈ R d 2 × d 1 , b ∈ R d 2 W \in {\mathbb{R}}^{{d}_{2} \times {d}_{1}}, b \in {\mathbb{R}}^{{d}_{2}} W∈Rd2×d1,b∈Rd2 。为了简化符号,设 d 1 = d 2 = d {d}_{1} = {d}_{2} = d d1=d2=d 。在传统的深度集成中,第 i i i 个成员有其自己的权重集 W i , b i {W}_{i},{b}_{i} Wi,bi 用于这个线性层: l i ( x i ) = W i x i + b i {l}_{i}\left( {x}_{i}\right) = {W}_{i}{x}_{i} + {b}_{i} li(xi)=Wixi+bi ,其中 x i {x}_{i} xi 是第 i i i 个成员中的对象表示。相比之下,在 BatchEnsemble 中,这个线性层要么 (1) 在所有成员之间完全共享,要么 (2) 大部分共享: l i ( x i ) = s i ⊙ ( W ( r i ⊙ x i ) ) + b i {l}_{i}\left( {x}_{i}\right) = {s}_{i} \odot \left( {W\left( {{r}_{i} \odot {x}_{i}}\right) }\right) + {b}_{i} li(xi)=si⊙(W(ri⊙xi))+bi ,其中 ⊙ \odot ⊙ 是逐元素相乘, W ∈ R d × d W \in {\mathbb{R}}^{d \times d} W∈Rd×d 在所有成员之间共享,而 r i , s i , b i ∈ R d {r}_{i},{s}_{i},{b}_{i} \in {\mathbb{R}}^{d} ri,si,bi∈Rd 在成员之间不共享。这相当于将第 i i i 个权重矩阵定义为 W i = W ⊙ ( r i s i T ) {W}_{i} = W \odot \left( {{r}_{i}{s}_{i}^{T}}\right) Wi=W⊙(risiT) 。为了确保集成成员的多样性,所有成员的 r i {r}_{i} ri 和 s i {s}_{i} si 都是随机初始化的,使用 ± 1 \pm 1 ±1 。BatchEnsemble 的所有其他层在成员之间完全共享。
所描述的参数化允许将所有集成成员打包在一个模型中,该模型同时接收 k k k 个对象的副本作为输入,并并行应用所有 k k k 个隐式成员,而无需显式地实现每个成员。这是通过用其 BatchEnsemble 版本替换原始神经网络的一个或多个线性层来实现的: l B E ( X ) = ( ( X ⊙ R ) W ) ⊙ S + B {l}_{\mathrm{{BE}}}\left( X\right) = \left( {\left( {X \odot R}\right) W}\right) \odot S + B lBE(X)=((X⊙R)W)⊙S+B ,其中 X ∈ R k × d X \in {\mathbb{R}}^{k \times d} X∈Rk×d 存储相同输入对象的 k k k 个表示(每个成员一个),而 R , S , B ∈ R d R, S, B \in {\mathbb{R}}^{d} R,S,B∈Rd 存储子模型的非共享权重 ( r i , s i , b i ) \left( {{r}_{i},{s}_{i},{b}_{i}}\right) (ri,si,bi) ,如图 1 左下部分所示。
模型大小的开销。使用 BatchEnsemble,添加一个新的集成成员意味着仅需在每个矩阵 R , S R, S R,S 和 B B B 中添加一行,这导致每层增加 3 d {3d} 3d 个新参数。对于典型的 d d d 值,这对原始层大小 d 2 + d {d}^{2} + d d2+d 来说是一个微不足道的开销。
运行时的开销。得益于现代硬件,大量共享权重和 k k k 前向传播的并行执行,BatchEnsemble 的运行时开销可以显著低于 × k \times k ×k (Wen et al., 2020)。直观上,如果原始工作负载未充分利用硬件,那么支付低于 × k \times k ×k 的开销的机会就会增多。
术语。在本文中,我们将 r i , s i , b i , R , S {r}_{i},{s}_{i},{b}_{i}, R, S ri,si,bi,R,S 和 B B B 称为适配器,而将参数高效集成(例如 BatchEnsemble)的隐式成员称为隐式子模型或简单地称为子模型。
3.3 TABM MINI & {}_{\text{MINI }}\& MINI & TABM
我们的模型 TabMmini 和 TabM 基于多层感知器(MLP)和参数高效集成方法,与在 3.2 小节中介绍的 BatchEnsemble (Wen et al., 2020) 有着紧密的联系。在 A.1 小节中,我们解释了选择 BatchEnsemble 作为基线高效集成方法的原因,因为它在性能和易用性之间具有良好的平衡,而使用 MLP 作为基础模型至关重要,因为它具有出色的效率。我们通过几个步骤从基本基线获得我们的模型。我们始终使用集成大小 k = 32 k = {32} k=32 并在 5.3 小节中分析这个超参数。
MLP。我们将 MLP 定义为一系列 N N N 简单块,后跟线性预测头: MLP ( x ) = Linear ( Block N ( … ( Block 1 ( x ) ) ) , \operatorname{MLP}\left( x\right) = \operatorname{Linear}\left( {{\operatorname{Block}}_{N}\left( {\ldots \left( {{\operatorname{Block}}_{1}\left( x\right) }\right) }\right) ,}\right. MLP(x)=Linear(BlockN(…(Block1(x))), 其中 Block i ( x ) = Dropout ( ReLU ( Linear ( ( x ) ) ) ) {\operatorname{Block}}_{i}\left( x\right) = \operatorname{Dropout}\left( {\operatorname{ReLU}\left( {\operatorname{Linear}\left( \left( x\right) \right) }\right) }\right) Blocki(x)=Dropout(ReLU(Linear((x)))) 。 M L P × k = {\mathbf{{MLP}}}^{\times k} = MLP×k= MLP + 深度集成。我们将传统的深度集成的 k k k 独立训练的 MLP 表示为 M L P × k {\mathrm{{MLP}}}^{\times k} MLP×k 。该方法在图 1 中进行了说明,其性能在图 2 中报告(超参数调优是在一个 MLP 上进行的,之后将调优后的 MLP 进行集成)。有趣的是,结果已经比 FT-Transformer(Gorishniy 等,2021)——这一流行的基于注意力的基线——更好且更稳定。此外,考虑到 MLP 的显著更高效率(后面将展示), M L P × k {\mathrm{{MLP}}}^{\times k} MLP×k 实际上可能并不比 FT-Transformer 更不实用,特别是结合像 Packed-Ensembles(Laurent 等,2023)这样的附加技术。也就是说,我们将继续探索更高效的方法。 T a b M naive = M L P + {\mathbf{{TabM}}}_{\text{naive }} = \mathbf{{MLP}} + TabMnaive =MLP+ BatchEnsemble。现在,我们不再使用深度集成,而是天真地将 BatchEnsemble 应用于 MLP 的主干,同时保持预测头的独立。这给我们带来了 T a b M naive {\mathrm{{TabM}}}_{\text{naive }} TabMnaive - TabM 的初步次优版本。实际上, T a b M naive {\mathrm{{TabM}}}_{\text{naive }} TabMnaive 的架构(但不是初始化)已经等同于 TabM 的架构,因此图 1 是适用的。图 2 中显示的 T a b M naive {\mathrm{{TabM}}}_{\text{naive }} TabMnaive 的性能有两个重要原因。首先,Tarst,TabMaive - M L P × k {\mathrm{{MLP}}}^{\times k} MLP×k 的高效版本 - 显著优于 M L P × k {\mathrm{{MLP}}}^{\times k} MLP×k 本身,这一点令人感兴趣。我们不知道 BatchEnsemble 的类似结果,并在 A.2 小节中分享了对此现象的一些思考。其次, T a b M naive {\mathrm{{TabM}}}_{\text{naive }} TabMnaive 立即超越了 FT-Transformer,这证明了参数高效集成对 MLP 的巨大潜力。这激励了进一步的探索。

图1: (左上角) 实现一个 k k k MLP 集合的模板。图的其余部分是三种不同的 k k k MLP 主干的参数化,均在3.3节中描述。在所有情况下,每个 k k k MLP 主干独立处理其输入对象的副本。(右上角) M L P × k {\mathrm{{MLP}}}^{\times k} MLP×k 是一个传统的深度集成,由 k k k 完全独立的 MLP 组成。(左下角) TabM 是通过在一个 MLP 的每个 N N N 线性层中注入三个非共享适配器 R , S , B R, S, B R,S,B 获得的 (* 初始化与 Wen et al. (2020) 不同)。 (右下角) TabM 仅保留 TabM 的第一个适配器 R R R ,并移除其余的 3 N − 1 {3N} - 1 3N−1 适配器。因此,Tab M mini {\mathrm{M}}_{\text{mini }} Mmini 将相同的共享 MLP 应用于 k k k 对象表示,仅有两个非共享元素确保预测的多样性:随机初始化的乘法适配器 R R R 和 k k k 预测头。(细节) 输入转换,例如独热编码、特征嵌入 (Gorishniy et al., 2022) 等,为了简化省略。在实践中,它们在 Clone 模块之前应用(并且结果被展平)。Drop 表示 dropout (Srivastava et al., 2014)。

图 2:在表 1 中的 46 个数据集上,子节 3.3 中描述的模型的性能;左侧还有几个基线。对于给定模型,抖动图上的一个点描述了在 46 个数据集中的一个性能得分。箱形图描述了抖动图的百分位数:箱体描述了第 25、第 50 和第 75 百分位数,须描述了第 10 和第 90 百分位数。离群值被裁剪。底部的数字是抖动图的均值和标准差。对于每个模型,超参数都经过调优。 T a b M mini = M L P + {\mathbf{{TabM}}}_{\text{mini }} = \mathbf{{MLP}} + TabMmini =MLP+ 最小集成。根据构造,刚刚讨论的 TabM naive {\operatorname{TabM}}_{\text{naive }} TabMnaive (在图 1 中表示为 “TabM”)具有 3 N {3N} 3N 个适配器:在每个 N N N 块中有 R , S R, S R,S 和 B B B 。在 3 N {3N} 3N 个适配器中,第一个适配器 R R R 在第一个线性层中负责将输入的 k k k 个相等副本(在图 1 中打包为 X X X )转换为 k k k 个不同的表示,然后表格特征首次与 @ W @W @W 混合。一个简单的实验表明,这个适配器是至关重要的。首先,我们将其从 T a b M naive {\mathrm{{TabM}}}_{\text{naive }} TabMnaive 中移除,并保持其余的 3 N − 1 {3N} - 1 3N−1 个适配器不变,这使得我们得到的 T a b M bad {\mathrm{{TabM}}}_{\text{bad }} TabMbad 性能更差,如图 2 所示。然后,我们做相反的事情:我们仅保留 T a b M naive {\mathrm{{TabM}}}_{\text{naive }} TabMnaive 的第一个适配器,移除其余的 3 N − 1 {3N} - 1 3N−1 个适配器,这使我们得到 T a b M mini {\mathrm{{TabM}}}_{\text{mini }} TabMmini - TabM 的最小版本。图 1 中展示了 TabMmi,我们非正式地称所描述的方法为 “最小集成”。也许令人惊讶的是,图 2 显示 T a b M mini {\mathrm{{TabM}}}_{\text{mini }} TabMmini 的表现优于 T a b M naive {\mathrm{{TabM}}}_{\text{naive }} TabMnaive ,尽管它只有一个适配器而不是 3 N {3N} 3N 个适配器。
TabM = MLP + BatchEnsemble + 更好的初始化。刚获得的结果激励着下一步。我们回到 T a b M naive {\mathrm{{TabM}}}_{\text{naive }} TabMnaive 的架构,使用所有 3 N {3N} 3N 适配器,但将所有乘法适配器 R R R 和 S S S (除了第一个)以确定性方式初始化为 1。因此,在初始化时,确定性初始化的适配器没有影响,模型表现得像 TabM mini {\operatorname{TabM}}_{\text{mini }} TabMmini ,但这些适配器在训练期间可以自由增加更多的表现力。这给我们带来了 TabM,如图 1 所示。图 2 显示 TabM 是迄今为止表现最好的变体。 TabM mini † & TabM † {\operatorname{TabM}}_{\text{mini }}^{ \dagger }\& {\operatorname{TabM}}^{ \dagger } TabMmini †&TabM† 。非线性特征嵌入(Gorishniy 等,2022)已知能提升许多表格模型的性能,尤其是 MLP。我们将 T a b M mini {\mathrm{{TabM}}}_{\text{mini }} TabMmini 和带有非线性特征嵌入的 TabM 分别表示为 TabM mini † {\operatorname{TabM}}_{\text{mini }}^{ \dagger } TabMmini † 和 TabM † {\operatorname{TabM}}^{ \dagger } TabM† 。默认情况下,我们建议使用分段线性嵌入(Gorishniy 等,2022)。在小节 A.3 中,我们提供了额外的实现细节,例如稍微不同的初始化。图 2 显示 TabM mini † {\operatorname{TabM}}_{\text{mini }}^{ \dagger } TabMmini † 与 T a b M † {\mathrm{{TabM}}}^{ \dagger } TabM† 具有竞争力,因此我们将使用 T a b M mini † {\mathrm{{TabM}}}_{\text{mini }}^{ \dagger } TabMmini † 以简化处理。
直觉。为了给 TabM 提供额外的直觉,我们做出以下观察:
-
设置 k = 1 k = 1 k=1 使 TabM 与一个普通的 MLP 相同。
-
将 k k k 增加一个会向 TabM 添加微不足道的新参数数量。
-
将 TabM 视为一个单一模型,可以受益于深度集成,见图 2 中的 TabM mini † × 5 {\operatorname{TabM}}_{\text{mini }}^{\dagger \times 5} TabMmini †×5 。
-
在类似于 Transformer (Vaswani et al., 2017) 和类似于 Mixer (Tolstikhin et al., 2021) 的模型中:(a) 潜在表示的形状为 m × d m \times d m×d ,其中 m m m 是表格特征的数量, d d d 是嵌入大小;(b) m m m 嵌入在注意力或线性层中相互混合,© 每个嵌入的变换(FFN 层)对所有嵌入都是相同的。相比之下,在 TabM 中:(a) 形状仅为 k × d , ( b ) k \times d,\left( b\right) k×d,(b) , k k k 嵌入之间从不相互作用,© 每个嵌入的变换包含嵌入特定的权重(适配器)。
超参数。与 MLP 相比,TabM 唯一的新超参数是 k k k - 隐式子模型的数量。我们启发式地设置 k = 32 k = {32} k=32 ,并不调整该值。我们在 5.3 小节中分析 k k k 的影响。我们还注意到,TabM 的平均最佳学习率高于 MLP,这在 A.4 小节中进行了说明。
限制和实际考虑在 A.5 小节中进行了评论。
下一步。图 2 中 TabM 的表现使其成为一个极具前景的模型。这激励我们对之前的表格模型进行全面的实证比较(第 4 节)以及对 TabM 行为的详细分析(第 5 节)。
4 评估表格深度学习架构
现在,我们对许多表格模型进行实证比较,包括在第 3 节中介绍的 TabM。模型的实现细节在附录 D 中提供。
4.1 基线
在主要文本中,我们使用以下基线:MLP(经典的多层感知器)、FT-Transformer(简称“FT-T”,来自Gorishniy等人(2021)的基于注意力的模型)、SAINT(来自Somepalli等人(2021)的基于注意力和检索的模型)、T2G-Former(简称“T2G”,来自Yan等人(2023)的基于注意力的模型)、ExcelFormer(简称“Excel”,来自Chen等人(2023a)的基于注意力的模型)、TabR(来自Gorishniy等人(2024)的基于检索的模型)、ModernNCA(简称“MNCA”,来自Ye等人(2024)的基于检索的模型)以及三种GBDT实现:XGBoost(Chen & Guestrin, 2016)、LightGBM(Ke等人,2017)和CatBoost(Prokhorenkova等人,2018)。MLP † {}^{ \dagger } † 、TabR † {}^{ \dagger } † 和 MNCA † {}^{ \dagger } † 表示具有非线性特征嵌入的相应模型(Gorishniy等人,2022)。实际上,一些其他基线,例如Excel(Chen等人,2023a),已经使用了自定义的非线性特征嵌入。
我们在附录B中提供更多基线的结果。

图 3:表格模型在表 1 中的 46 个数据集上的任务表现。(左)所有数据集上的性能排名的均值和标准差总结了模型在所有数据集上的逐对比较。(中间和右侧)相对于普通多层感知器(MLP)的相对表现允许推理相对于这一简单基线的改进规模和一致性。抖动图中的一个点对应于模型在 46 个数据集中的一个数据集上的表现。箱线图可视化了抖动图的第 10、第 25、第 50、第 75 和第 90 百分位数。异常值被裁剪。随机和领域感知数据集划分的分离在 3.1 小节中进行了说明。

图 4:图 3 中模型的训练时间(左)和推理吞吐量(右)。一个点代表一个数据集上的测量。
表 2:两个大型数据集上的 RMSE(上行)和训练时间(下行)。最佳值用粗体表示。模型颜色的含义遵循图 3。

4.2 任务表现
我们按照 3.1 小节中公布的协议评估所有模型,并在图 3 中报告结果(另见图 9 中的关键差异图)。我们做出以下观察:
-
性能排名使 TabM 成为顶级深度学习模型。
-
图 3 的中间和右侧部分为每个数据集的指标提供了新的视角。TabM 在深度学习模型中保持其领导地位。同时,许多深度学习方法在相当数量的数据集上表现并不比 MLP 更好,甚至更差,这使它们成为不太可靠的解决方案,并改变了排名,特别是在领域感知划分(右侧)。
-
模型的一个重要特征是其性能轮廓中最弱的部分(例如中间图中的第10或第25百分位数),因为它显示了模型在“困难”数据集上的可靠性。从这个角度来看, M L P † {\mathrm{{MLP}}}^{ \dagger } MLP† 似乎是一个在简单的 MLP 和 TabM 之间的不错实用选择,尤其是考虑到它相对于基于检索的替代方案(如 TabR 和 ModernNCA)的简单性和效率。
摘要。TabM 自信地展示了在表格深度学习模型中最佳的性能,并可以作为一个可靠的深度学习基准。这对于基于注意力和检索的模型并不适用。总体而言,类似 MLP 的模型,包括 TabM,形成了一组具有代表性的表格深度学习基准。
4.3 效率
现在,我们从训练和推理效率的角度评估表格模型,这对某些方法来说成为了一个严峻的现实检验。我们基准测试正是图3中所展示的模型的超参数配置(见子节 B.3 以获取动机)。 T a b M mini † ∗ {\mathbf{{TabM}}}_{\text{mini }}^{\dagger * } TabMmini †∗ 。此外,我们还包括 T a b M mini † ∗ {\mathrm{{TabM}}}_{\text{mini }}^{\dagger * } TabMmini †∗ ,它是 T a b M mini † {\mathrm{{TabM}}}_{\text{mini }}^{ \dagger } TabMmini † ,并增强了两个在 PyTorch 中开箱即用的与效率相关的插件(Paszke 等,2019):自动混合精度(AMP)和 torch.compile(Ansel 等,2024)。 T a b M mini † ∗ {\mathrm{{TabM}}}_{\text{mini }}^{\dagger * } TabMmini †∗ 的目的是展示现代硬件和软件对强大表格深度学习模型的潜力,并且不应直接与其他深度学习模型进行比较。然而,TabM 的实现简单性起着重要作用,因为它促进了上述 PyTorch 插件的无缝集成。
训练时间。我们关注较大数据集上的训练时间,因为在小数据集上,所有方法几乎都变得同样可负担,无论正式的相对差异如何。然而,在图10中,我们也提供了小数据集上的测量结果。图4的左侧显示,TabM提供了实用的训练时间。相比之下,基于注意力和检索的模型的长训练时间成为这些方法的另一个限制。
推理吞吐量。图4的右侧基本上讲述了与左侧相同的故事。在小节B.3中,我们还报告了在大批量情况下GPU上的推理吞吐量。
对大数据集的适用性。在表2中,我们报告了两个大数据集上的指标。如预期的那样,基于注意力和检索的模型面临困难,导致极长的训练时间,或者在没有额外努力的情况下根本无法应用。有关实现细节,请参见小节D.4。
参数数量。大多数表格网络总体上是紧凑的。这一点尤其适用于TabM,因为其大小在设计上与MLP相当。我们在小节B.3中报告模型大小。
摘要。简单的MLP是最快的深度学习模型,而TabM是亚军。基于注意力和检索的模型显著较慢。总体而言,包括TabM在内的类似MLP的模型形成了一组具有代表性的实用和可访问的表格深度学习基准。
5 分析
5.1 各个子模型的性能和训练动态
回想一下,TabM的预测被定义为其 k k k 隐式子模型的平均预测。这些子模型几乎共享所有权重,并且是同时训练的。在本节中,我们将更仔细地观察这些子模型的个体性能和训练动态。
在下一个实验中,我们故意简化了设置,如 D.5 小节中详细描述的那样。最重要的是,所有模型的深度为 3,宽度为 512,并且在训练过程中没有早停,即训练超出了最佳的纪元。我们使用图 1 中的 T a b M mini {\mathrm{{TabM}}}_{\text{mini }} TabMmini ,其中 k = 32 k = {32} k=32 表示为 TabM mini k = 32 {\operatorname{TabM}}_{\operatorname{mini}}^{k = {32}} TabMminik=32 。我们使用 TabM mini k = 1 {\operatorname{TabM}}_{\operatorname{mini}}^{k = 1} TabMminik=1 (即基本上是一个普通的 MLP)作为 T a b M mini k = 32 {\mathrm{{TabM}}}_{\text{mini }}^{k = {32}} TabMmini k=32 子模型的自然基线,因为每个 32 个子模型的架构为 T a b M mini k = 1 {\mathrm{{TabM}}}_{\text{mini }}^{k = 1} TabMmini k=1 。
我们在图 5 中可视化了四个不同数据集(两个分类问题和两个不同规模的回归问题)的训练曲线。提醒一下, k k k 个体损失的平均值是在 T a b M mini {\mathrm{{TabM}}}_{\text{mini }} TabMmini 的训练过程中明确优化的,而集体均值预测的损失对应于 T a b M mini {\mathrm{{TabM}}}_{\text{mini }} TabMmini 在推断时使用的内容(见图 1)。
在图 5 的上排中,子模型的集体均值预测在训练和测试损失方面都优于它们的个体预测。在初始纪元之后,基线 MLP 的训练损失低于集体和个体预测的损失。
在图 5 的中间行中,我们看到子模型的个体表现与集体表现之间的明显对比。与基线 MLP 相比,子模型在个体上看起来过拟合,而它们的集体预测表现出显著更好的泛化能力。这个结果严格证明了子模型的非平凡多样性:如果没有这一点,它们的集体测试性能将与其个体测试性能相似。此外,我们在图 6 中报告了 TabM 的最佳子模型在多个数据集上的表现,称为 TabM[B]。因此,单独来看,即使是 TabM 的最佳子模型也不比一个简单的 MLP 更好。
图5的下方行分析了 T a b M mini k = 32 {\mathrm{{TabM}}}_{\text{mini }}^{k = {32}} TabMmini k=32 的梯度结构。提醒一下,由于 k = 32 k = {32} k=32 子模型之间的同时训练和权重共享,大多数 T a b M mini {\mathrm{{TabM}}}_{\text{mini }} TabMmini 的权重在每个训练步骤中接收每个对象的 k k k 梯度的均值。图5下方行中的绿色线条显示这些 k k k 梯度之间的余弦相似度接近于零。这可能解释了图5第一行中 T a b M min k = 32 {\mathrm{{TabM}}}_{\min }^{k = {32}} TabMmink=32 的训练损失高于 T a b M min k = 1 {\mathrm{{TabM}}}_{\min }^{k = 1} TabMmink=1 的原因:也许,权重共享结合多样化的梯度阻止了 T a b M mini {\mathrm{{TabM}}}_{\text{mini }} TabMmini 对训练任务的(过度)优化。在图5的同一行下方,我们对 TabM mini {\operatorname{TabM}}_{\operatorname{mini}} TabMmini 中子模型多样性的两个来源进行了消融研究:适配器 R R R 中的随机初始化和 k k k 预测头中的随机初始化。当所有 R R R 的行(即图1中的 r i {r}_{i} ri )都接收相同的初始化,而 k k k 预测头则完全随机初始化(橙色线条)时,子模型梯度是相关的,任务性能较差。相比之下,当 k k k 预测头接收相同的初始化,而 R R R 的初始化完全随机(紫色线条)时,这一问题不那么明显,尽管这也可能会影响性能。因此,第一个适配器似乎是梯度多样性更具影响力的来源。总体而言,我们将梯度多样性视为一个需要更多探索的实验指标。

图5:如5.1小节所述的 T a b M mini k = 32 {\mathrm{{TabM}}}_{\text{mini }}^{k = {32}} TabMmini k=32 和 T a b M mini k = 1 {\mathrm{{TabM}}}_{\text{mini }}^{k = 1} TabMmini k=1 的训练轮廓。(上)训练曲线。 k = 32 [ i ] k = {32}\left\lbrack i\right\rbrack k=32[i] 表示32个子模型的平均个体损失。(中)与第一行相同,但在训练-测试坐标中:每个点代表第一行中的某个时期,训练通常是从左到右进行。这允许通过比较给定训练损失值的测试损失值来推理过拟合情况。(下) k k k 的个体梯度与默认初始化(绿色)和5.1小节中描述的两个次优初始化之间的平均成对余弦相似度。形式上: 2 n ⋅ k ( k − 1 ) ∑ l , i , j ( i < j ) cos ( g i l , g j l ) \frac{2}{n \cdot k\left( {k - 1}\right) }\mathop{\sum }\limits_{{l, i, j\left( {i < j}\right) }}\cos \left( {{g}_{i}^{l},{g}_{j}^{l}}\right) n⋅k(k−1)2l,i,j(i<j)∑cos(gil,gjl) ,其中 g i l {g}_{i}^{l} gil 是由第 l l l 个 n = 1000 n = {1000} n=1000 训练对象引起的第 i i i 个子模型的梯度。有关详细信息,请参见D.5小节。如果使用了提前停止,图例中包含测试分数。
摘要。TabM的力量来源于弱但多样的子模型的集体预测。
5.2 训练后选择子模型
TabM的设计允许在训练后根据任何标准选择仅一部分子模型,只需修剪额外的预测头和相应的适配器矩阵行。为了展示这一机制,在训练后,我们贪婪地构建了在验证集上表现最佳的TabM子模型的子集,并将这个“修剪”的TabM称为TabM[G]。图6中报告的性能显示,TabM[G]略微落后于原始TabM。在46个数据集上的平均结果中,贪婪的子模型选择导致从最初的 k = 32 k = {32} k=32 中选择了 8.8 ± 6.6 {8.8} \pm {6.6} 8.8±6.6 个子模型,这可以导致更快的推理。有关实现细节,请参见D.6小节。
5.3 TABM 的性能如何依赖于 k k k ?

为了回答标题中的问题,我们选择了具有 3 层和宽度为 512 的 TabM,分别调整每个 k k k 的学习率,并在图 7 中报告性能。根据该图以及第 4 节的结果,我们建议在整篇论文中使用的 k = 32 k = {32} k=32 是一个合理的默认值,具有良好的性能与效率平衡。此外,从图 7 中可以看出,TabM 比 T a b M mini {\mathrm{{TabM}}}_{\text{mini }} TabMmini 更容易容纳大量子模型。也许,TabM 中更多的子模型适配器提供了重要的额外权重容量,以便在给定大小的模型中适配更多的子模型。实施细节见 D.7 小节。
6 结论与未来工作
在本研究中,我们展示了表格多层感知器(MLP)在参数高效集成方面的巨大收益。基于这一见解,我们开发了 TabM——一个基于 MLP 的简单模型,具有最先进的性能。在与许多表格深度学习模型的大规模比较中,我们证明了 TabM 准备好作为一个新的强大且高效的表格深度学习基线。最后,我们分析了 TabM 背后的隐式子模型的特性。
未来工作的一个想法是将(参数)高效集成的优势引入其他具有优化相关挑战的非表格领域,并理想地使用轻量级基础模型。另一个想法是评估 TabM 在表格数据上的不确定性估计和分布外(OOD)检测,这受到 Lakshminarayanan 等人(2017)工作的启发。
可重复性声明。代码已提供在以下仓库:链接。它包含了 TabM 的实现、超参数调优脚本、评估脚本、包含超参数的配置文件(位于 exp/ 目录中的 TOML 文件)以及包含主要指标的报告文件(位于 exp/ 目录中的 JSON 文件)。在论文中,模型在第 3 节中描述,实施细节在附录 D 中提供。
更多推荐
所有评论(0)