ESMM论文总结

通过全空间多任务学习,解决CVR的样本偏差等问题

摘要

CVR任务,以点击数据为训练数据,但应用在曝光场景。因此存在样本选择偏差的问题。同时还存在数据稀疏的问题,导致模型拟合困难。论文利用用户的行为顺序进行建模。

ESMM模型通过以下两点,可以消除上面的两个问题:

  1. 直接在全空间对CVR建模。
  2. 对特征表示采用迁移学习策略(employing a feature representation transfer learning strategy)。

1 介绍

背景介绍

  • 曝光:将产品推荐给用户
  • 点击:用户对产品进行点击
  • 转化:用户点击后进行注册/激活等操作
  • CTR:Click Through Rate,点击率
  • CVR:Click Value Rate,转化率

从曝光到点击的过程,计算的就是点击率。比如有1000个人看到这个按钮,其中100个点了,那么点击率就是0.1,点击率就是click-through rate,ctr。

从点击到转化(下单,消费)的过程,计算的是转化率。比如100个点击了产品详情页,最后完成下单的是5个人,那么转化率是0.05,转化率就是conversion rate,cvr。

传统CVR存在两个关键问题:

1.样本选择偏差问题(sample selection bias,SSB)

image-20211118164012057

应用场景:曝光

建模数据:点击之后的数据集(点击后才有是否转化的标签)

2.数据稀疏问题(data sparsity,DS)

CVR的训练数据量远低于曝光数据(4%),模型拟合困难

为了解决上述问题,提出了全空间多任务模型(ESMM).具体细节在后面介绍。

2 模型介绍

2.1 点击与转化关系

p(y=1,z=1x)pCTCVR=p(y=1x)pCTR×p(z=1y=1,x)pCVR\underbrace{p(y=1, z=1 \mid x)}_{p C T C V R}=\underbrace{p(y=1 \mid x)}_{p C T R} \times \underbrace{p(z=1 \mid y=1, x)}_{p C V R}

CTR:点击

CVR:转化

CTCVR: 点击然后转化

2.2 模型结构

image-20211118175236345

模型分成2个网络:CVR网络CTR网络

CTCVR把CVR输出和CTR输出的乘积作为最后的输出。

模型有以下几个重点:

Modeling over entire space

基于公式1,可得

p(z=1y=1,x)=p(y=1,z=1x)p(y=1x)p(z=1 \mid y=1, \boldsymbol{x})=\frac{p(y=1, z=1 \mid \boldsymbol{x})}{p(y=1 \mid \boldsymbol{x})}

从公式中可以看到,不论是计算CTR(p(y=1x)p(y=1|x)),CVR(p(z=1y=1,x)p(z=1|y=1,x)),还是CTCVR(p(y=1,z=1x)p(y=1,z=1|x)),都用到了整个曝光数据x。

损失函数为:

L(θcvr,θctr)=i=1Nl(yi,f(xi;θctr))+i=1Nl(yi&zi,f(xi;θctr)×f(xi;θcvr))\begin{aligned} L\left(\theta_{c v r}, \theta_{c t r}\right) &=\sum_{i=1}^{N} l\left(y_{i}, f\left(x_{i} ; \theta_{c t r}\right)\right) \\ &+\sum_{i=1}^{N} l\left(y_{i} \& z_{i}, f\left(x_{i} ; \theta_{c t r}\right) \times f\left(x_{i} ; \theta_{c v r}\right)\right) \end{aligned}

整个损失函数包含两部分:CTR的loss和CTCVR的loss。注意:loss项中没有直接评估CVR任务的loss。

两项loss中的损失函数为交叉熵。

因为整个训练过程中使用的是曝光阶段的数据,因此避免了前面提到的样本偏差的问题。

Feature representation transfer

模型中的embedding层的作用是对高纬输入进行降维。这一层占据了网络中的大部分参数,需要较多的数据进行训练。

CVR网络CTR网络的Embedding层共享,因此解决了前面提到的数据稀疏的问题。

2.3 代码示例

class FeatureExtractor(nn.Module):
    def __init__(self, embedding_sizes):
        super(FeatureExtractor, self).__init__()
        self.embedding_layers = nn.ModuleList([nn.Embedding(categories, size) for categories, size in embedding_sizes])

    def forward(self, category_inputs):
        h = [embedding_layer(category_inputs[:, i]) for i, embedding_layer in enumerate(self.embedding_layers)]
        h = torch.cat(h, dim=1)  
        return h


class CtrNetwork(nn.Module):
    def __init__(self, input_dim):
        super(CtrNetwork, self).__init__()
        self.mlp = nn.Sequential(
            nn.Linear(in_features=input_dim, out_features=128),
            nn.ReLU(),
            nn.Linear(in_features=128, out_features=1),
        )
        self.sigmoid = nn.Sigmoid()
    
    def forward(self, inputs):
        p = self.mlp(inputs)
        return self.sigmoid(p)


class CvrNetwork(nn.Module):

    def __init__(self, input_dim):
        super(CvrNetwork, self).__init__()
        self.mlp = nn.Sequential(
            nn.Linear(in_features=input_dim, out_features=128),
            nn.ReLU(),
            nn.Linear(in_features=128, out_features=1),
        )
        self.sigmoid = nn.Sigmoid()
    
    def forward(self, inputs):
        p = self.mlp(inputs)
        return self.sigmoid(p)


class ESMM(nn.Module):
    """ESMM"""
    def __init__(self, feature_extractor: FeatureExtractor, ctr_network: CtrNetwork, cvr_network: CvrNetwork):
        super(ESMM, self).__init__()
        self.feature_extractor = feature_extractor
        self.ctr_network = ctr_network
        self.cvr_network = cvr_network
    
    def forward(self, inputs):
        h = self.feature_extractor(inputs)  # encode
        # Predict pCTR
        p_ctr = self.ctr_network(h)
        # Predict pCVR
        p_cvr = self.cvr_network(h)
        # Predict pCTCVR
        p_ctcvr = torch.mul(p_ctr, p_cvr)
        return p_ctr, p_ctcvr

3 实验

论文中对比了多种其他方法,具体对比结果可以查看论文对应章节。

标签分布:

标签值域 是否合法
y=0 & z=0 合法
y=0 & z=1 非法
y=1 & z=0 合法
y=1 & z=1 合法

注意:y=0的用户,虽然没有了后续操作,z也被标为了0

参考:https://tianchi.aliyun.com/dataset/dataDetail?dataId=408&userId=1

评估指标:

  • pCVR

    传统CVR模型的评估

  • pCTCVR

    在整个样本空间上的评估

参考文献

Entire Space Multi-Task Model: An Effective Approach for Estimating Post-Click Conversion Rate