ESMM论文总结
通过全空间多任务学习,解决CVR的样本偏差等问题
摘要
CVR任务,以点击数据为训练数据,但应用在曝光场景。因此存在样本选择偏差的问题。同时还存在数据稀疏的问题,导致模型拟合困难。论文利用用户的行为顺序进行建模。
ESMM模型通过以下两点,可以消除上面的两个问题:
- 直接在全空间对CVR建模。
- 对特征表示采用迁移学习策略(employing a feature representation transfer learning strategy)。
1 介绍
背景介绍:
- 曝光:将产品推荐给用户
- 点击:用户对产品进行点击
- 转化:用户点击后进行注册/激活等操作
- CTR:Click Through Rate,点击率
- CVR:Click Value Rate,转化率
从曝光到点击的过程,计算的就是点击率。比如有1000个人看到这个按钮,其中100个点了,那么点击率就是0.1,点击率就是click-through rate,ctr。
从点击到转化(下单,消费)的过程,计算的是转化率。比如100个点击了产品详情页,最后完成下单的是5个人,那么转化率是0.05,转化率就是conversion rate,cvr。
传统CVR存在两个关键问题:
1.样本选择偏差问题(sample selection bias,SSB)
应用场景:曝光
建模数据:点击之后的数据集(点击后才有是否转化的标签)
2.数据稀疏问题(data sparsity,DS)
CVR的训练数据量远低于曝光数据(4%),模型拟合困难
为了解决上述问题,提出了全空间多任务模型(ESMM).具体细节在后面介绍。
2 模型介绍
2.1 点击与转化关系
CTR:点击
CVR:转化
CTCVR: 点击然后转化
2.2 模型结构
模型分成2个网络:CVR网络和CTR网络
CTCVR把CVR输出和CTR输出的乘积作为最后的输出。
模型有以下几个重点:
Modeling over entire space
基于公式1,可得
从公式中可以看到,不论是计算CTR(),CVR(),还是CTCVR(),都用到了整个曝光数据x。
损失函数为:
整个损失函数包含两部分:CTR的loss和CTCVR的loss。注意:loss项中没有直接评估CVR任务的loss。
两项loss中的损失函数为交叉熵。
因为整个训练过程中使用的是曝光阶段的数据,因此避免了前面提到的样本偏差的问题。
Feature representation transfer
模型中的embedding层的作用是对高纬输入进行降维。这一层占据了网络中的大部分参数,需要较多的数据进行训练。
CVR网络CTR网络的Embedding层共享,因此解决了前面提到的数据稀疏的问题。
2.3 代码示例
class FeatureExtractor(nn.Module):
def __init__(self, embedding_sizes):
super(FeatureExtractor, self).__init__()
self.embedding_layers = nn.ModuleList([nn.Embedding(categories, size) for categories, size in embedding_sizes])
def forward(self, category_inputs):
h = [embedding_layer(category_inputs[:, i]) for i, embedding_layer in enumerate(self.embedding_layers)]
h = torch.cat(h, dim=1)
return h
class CtrNetwork(nn.Module):
def __init__(self, input_dim):
super(CtrNetwork, self).__init__()
self.mlp = nn.Sequential(
nn.Linear(in_features=input_dim, out_features=128),
nn.ReLU(),
nn.Linear(in_features=128, out_features=1),
)
self.sigmoid = nn.Sigmoid()
def forward(self, inputs):
p = self.mlp(inputs)
return self.sigmoid(p)
class CvrNetwork(nn.Module):
def __init__(self, input_dim):
super(CvrNetwork, self).__init__()
self.mlp = nn.Sequential(
nn.Linear(in_features=input_dim, out_features=128),
nn.ReLU(),
nn.Linear(in_features=128, out_features=1),
)
self.sigmoid = nn.Sigmoid()
def forward(self, inputs):
p = self.mlp(inputs)
return self.sigmoid(p)
class ESMM(nn.Module):
"""ESMM"""
def __init__(self, feature_extractor: FeatureExtractor, ctr_network: CtrNetwork, cvr_network: CvrNetwork):
super(ESMM, self).__init__()
self.feature_extractor = feature_extractor
self.ctr_network = ctr_network
self.cvr_network = cvr_network
def forward(self, inputs):
h = self.feature_extractor(inputs) # encode
# Predict pCTR
p_ctr = self.ctr_network(h)
# Predict pCVR
p_cvr = self.cvr_network(h)
# Predict pCTCVR
p_ctcvr = torch.mul(p_ctr, p_cvr)
return p_ctr, p_ctcvr
3 实验
论文中对比了多种其他方法,具体对比结果可以查看论文对应章节。
标签分布:
标签值域 | 是否合法 |
---|---|
y=0 & z=0 | 合法 |
y=0 & z=1 | 非法 |
y=1 & z=0 | 合法 |
y=1 & z=1 | 合法 |
注意:y=0的用户,虽然没有了后续操作,z也被标为了0
参考:https://tianchi.aliyun.com/dataset/dataDetail?dataId=408&userId=1
评估指标:
-
pCVR
传统CVR模型的评估
-
pCTCVR
在整个样本空间上的评估
参考文献
Entire Space Multi-Task Model: An Effective Approach for Estimating Post-Click Conversion Rate