理解广播(Broadcasting)规则

5 min

广播(Broadcasting) 是一种强大的机制,使得 NumPy 在进行算术运算时能够处理不同形状的数组。广播规则的妙处在于,它完全没有像传统矩阵运算那样严格要求矩阵形状,而且可加可减可乘可除,但正因如此就给广播机制的理解增加了难度,如果不熟悉的话,可能会误判断支持广播的运算后的结果。

这篇文章会使用 PyTorch 进行演示,通过对张量(tensor)的操作来辅助理解广播规则。

运算步骤

判断

首先,PyTorch 会判断两个张量是否可被广播,要求为:

  1. 每个数组至少有一个维度。
  2. 从尾部维数开始迭代维数大小时,以下三者任意满足其一:(a)维数大小必须相等(b)其中一个为 1 (c)其中一个不存在。

在以下示例中,两个张量xy可以被广播的:

# 完全相同的张量可以被广播
x = torch.empty(5,7,3)
y = torch.empty(5,7,3)

# 为数不同的张量也可以被广播,但有相应的条件。
x = torch.empty(5,3,4,1)
y = torch.empty(  3,1,1)
# x y 倒数第一个维度都为 1,满足 2(a)
# x y 倒数第二个维度有一个为 1,满足 2(b)
# x y 倒数第三个维度都为 3,满足 2(a)
# y 倒数第四个维度不存在,满足 2(c)

在以下示例中,两个张量xy不可被广播的:

x=torch.empty(5,2,4,1)
y=torch.empty(  3,1,1)
# 问题出在倒数第三个维度,二者维度均存在且均不为1,维数大小不相等

广播运算后张量的 Size

总体计算规则如下:

  1. 如果 x 和 y 的维数不相等,则在维数较少的张量的维数前加 1,使其长度相等。
  2. 然后,对于每个维度的大小,得到的维度大小是 x 和 y 在该维度上的最大值。

通过这一规则,可以预期到目标张量的 size:

x=torch.empty(5,1,4,1)
y=torch.empty(  3,1,1)
(x+y).size # 结果为 torch.Size([5, 3, 4, 1])

还是从最后一个维度开始:

  1. 倒数第一个维度相等
  2. 倒数第二个维度取二者较大值 4
  3. 倒数第三个维度取二者较大值 3
  4. 倒数第四个维度,先给 y 填充上1,使得秩相等,此后取二者较大值 5

具体运算

简单的广播

当然,困难的是如何想象出广播变换。我们先从一个二维张量开始:

x = torch.tensor([
    [0., 30., 600.], 
    [1., 10., 200.], 
    [-1., 20., 400.]
])

y = torch.tensor([0,20,400])

z = x + y

这两个张量形状为(3,3) (1,3),显然是可广播的。可以预期到,y 在广播后的形状为(3,3),这意味着 y 的行数将由 1 变为 3,那么实际就是将当前的一行进行复制:

y_broadcasting = torch.tensor([
    [0,20,400],
    [0,20,400],
    [0,20,400]
    ])

此后再进行运算就很简单了。

更高维数的广播

接下来是一个更为复杂的:

x = torch.tensor([[[[1, 2],
                    [3, 4]]],
                  
                  [[[5, 6],
                    [7, 8]]]])

y = torch.tensor([[[1],
                   [2]],
                  
                  [[3],
                   [4]],
                  
                  [[5],
                   [6]]])

y 的形状为 (2,1,2,2),y 的形状为 (3,2,1),可以与 x 进行广播,可以预期到,y 在广播后的形状为(2,3,2,2)

按照最末端开始进行变化,y 的最末尾 1 维变 2 维,现在是torch.Size([3,2,2])

y_broadcast_step1 = torch.tensor([[[1., 1.],
                                   [2., 2.]],

                                    [[3., 3.],
                                     [4., 4.]],

                                    [[5., 5.],
                                     [6., 6.]]])

为 y 的倒数第四个维度添加 2

y_broadcast = torch.tensor([[[[1., 1.],
                              [2., 2.]],
                             
                             [[3., 3.],
                              [4., 4.]],
                             
                             [[5., 5.],
                              [6., 6.]]],
                            
                           [[[1., 1.],
                             [2., 2.]],
                            [[3., 3.],
                             [4., 4.]],
                            [[5., 5.],
                             [6., 6.]]]])

至此,y 变化完毕。x 的倒数第三个维度 1 需要变成3,要想构造(2,3,2,2),需要 [1, 2], [3, 4][5, 6], [7, 8] 在各自维度复制两次,得到:

x_broadcast = torch.tensor([[[[1, 2],
          [3, 4]],

         [[1, 2],
          [3, 4]],

         [[1, 2],
          [3, 4]]],


        [[[5, 6],
          [7, 8]],

         [[5, 6],
          [7, 8]],

         [[5, 6],
          [7, 8]]]])

现在两者形状完全相同,就可以逐元素进行任意运算了。

reference: Broadcasting semantics