LayoutLMv2 中的 Spatial-Aware Self-Attention Mechanism

date
Jan 27, 2022
slug
LayoutLMv2
status
Published
tags
Multi-Modal
Deep Learning
summary
好像作用不大
type
Post
论文中关于引入 Spatial-Aware Self-Attention Mechanism(SASAM) 的出发点:
However, the original self-attention mechanism can only implicitly capture the relationship between the input tokens with the absolute position hints. In order to efficiently model local invariance in the document layout, it is necessary to insert relative position information explicitly.
原始 self-attention 的权重系数(softmax 前)计算公式如下:
SASAM 会在其基础上,加上三项相对位置编码作为 bias
  • semantic relative position( ):语义信息的相对位置编码,由绝对位置编码经过 relative_position_bucket 生成
  • spatial relative position( ):x 和 y 方向的空间信息的相对位置编码,由归一化之后的 x/y 坐标(0-1000) 经过 relative_position_bucket 生成
为例,以下为加了部分 tensor 维度注释的代码,帮助理解:
def _cal_1d_pos_emb(self, hidden_states, position_ids): """ position_ids: [batch_size, 512+49=561] 512 -> [0,1,2,3....,511] 49 -> [0,1,2,...48] position_ids.unsqueeze(-2): [batch_size, 1, 561] position_ids.unsqueeze(-1): [batch_size, 561, 1] rel_pos_mat: [batch_size, 561, 561] [ [0,1,2,3,...512,0,1,2,...,48], [-1,0,1,2,...511,0,1,2,...,47], .... [-48,-47,-46,...463,-48,-47,,...,0], ] """ rel_pos_mat = position_ids.unsqueeze(-2) - position_ids.unsqueeze(-1) """ rel_pos: [batch_size, 561, 561] # 前 512 [ [0,17,18,...,30,31,31,...,31], [1,0,17,18,...,30,31,31,...,31], .... [13,12,11,10...0,17,18,..,30,31,31] ] """ rel_pos = relative_position_bucket( rel_pos_mat, num_buckets=self.rel_pos_bins, max_distance=self.max_rel_pos, ) rel_pos = F.one_hot(rel_pos, num_classes=self.rel_pos_onehot_size).type_as(hidden_states) # [batch_size, num_attention_heads, 561, 561] rel_pos = self.rel_pos_bias(rel_pos).permute(0, 3, 1, 2) rel_pos = rel_pos.contiguous() return rel_pos
def relative_position_bucket(relative_position, bidirectional=True, num_buckets=32, max_distance=128): """ Args: relative_position: [batch_size, 512+49, 512+49] bidirectional: num_buckets: max_distance: Returns: """ ret = 0 if bidirectional: num_buckets //= 2 # 16 # ret: 上三角矩阵,值为 16 ret += (relative_position > 0).long() * num_buckets # n: 对称矩阵 n = torch.abs(relative_position) else: n = torch.max(-relative_position, torch.zeros_like(relative_position)) # now n is in the range [0, inf) # half of the buckets are for exact increments in positions max_exact = num_buckets // 2 # 找到相对位置关系的注意力窗口,从当前 token 开始左右各 max_exact - 1 is_small = n < max_exact # The other half of the buckets are for logarithmically bigger bins in positions up to max_distance val_if_large = max_exact + ( torch.log(n.float() / max_exact) / math.log(max_distance / max_exact) * (num_buckets - max_exact) ).to(torch.long) val_if_large = torch.min(val_if_large, torch.full_like(val_if_large, num_buckets - 1)) ret += torch.where(is_small, n, val_if_large) return ret
在 LayoutXLM 中,虽然作者在 Figure 1 中画了 SASAM,但实际上官方放出来的 LayouXLM 的预训练模型配置中未开启 Spatial-Aware Self-Attention Mechanism 相关的参数,官方项目上也有人提了 issue:
Mismatches between paper descriptions and codes
Updated Sep 28, 2021
。在业务数据 fine tuning 测试中(NER 任务),开启 has_spatial_attention_biashas_relative_attention_bias 参数后,f1 还下降了 1%,不知道是否和预训练模型没有开启这两个参数有关。

© PanicByte 2021 - 2022