理解广播(Broadcasting)规则
5 min
广播(Broadcasting) 是一种强大的机制,使得 NumPy 在进行算术运算时能够处理不同形状的数组。广播规则的妙处在于,它完全没有像传统矩阵运算那样严格要求矩阵形状,而且可加可减可乘可除,但正因如此就给广播机制的理解增加了难度,如果不熟悉的话,可能会误判断支持广播的运算后的结果。
这篇文章会使用 PyTorch 进行演示,通过对张量(tensor)的操作来辅助理解广播规则。
运算步骤
判断
首先,PyTorch 会判断两个张量是否可被广播,要求为:
- 每个数组至少有一个维度。
- 从尾部维数开始迭代维数大小时,以下三者任意满足其一:(a)维数大小必须相等(b)其中一个为 1 (c)其中一个不存在。
在以下示例中,两个张量x和y是可以被广播的:
# 完全相同的张量可以被广播
x = torch.empty(5,7,3)
y = torch.empty(5,7,3)
# 为数不同的张量也可以被广播,但有相应的条件。
x = torch.empty(5,3,4,1)
y = torch.empty( 3,1,1)
# x y 倒数第一个维度都为 1,满足 2(a)
# x y 倒数第二个维度有一个为 1,满足 2(b)
# x y 倒数第三个维度都为 3,满足 2(a)
# y 倒数第四个维度不存在,满足 2(c)在以下示例中,两个张量x和y是不可被广播的:
x=torch.empty(5,2,4,1)
y=torch.empty( 3,1,1)
# 问题出在倒数第三个维度,二者维度均存在且均不为1,维数大小不相等广播运算后张量的 Size
总体计算规则如下:
- 如果 x 和 y 的维数不相等,则在维数较少的张量的维数前加 1,使其长度相等。
- 然后,对于每个维度的大小,得到的维度大小是 x 和 y 在该维度上的最大值。
通过这一规则,可以预期到目标张量的 size:
x=torch.empty(5,1,4,1)
y=torch.empty( 3,1,1)
(x+y).size # 结果为 torch.Size([5, 3, 4, 1])还是从最后一个维度开始:
- 倒数第一个维度相等
- 倒数第二个维度取二者较大值 4
- 倒数第三个维度取二者较大值 3
- 倒数第四个维度,先给 y 填充上1,使得秩相等,此后取二者较大值 5
具体运算
简单的广播
当然,困难的是如何想象出广播变换。我们先从一个二维张量开始:
x = torch.tensor([
[0., 30., 600.],
[1., 10., 200.],
[-1., 20., 400.]
])
y = torch.tensor([0,20,400])
z = x + y这两个张量形状为(3,3) (1,3),显然是可广播的。可以预期到,y 在广播后的形状为(3,3),这意味着 y 的行数将由 1 变为 3,那么实际就是将当前的一行进行复制:
y_broadcasting = torch.tensor([
[0,20,400],
[0,20,400],
[0,20,400]
])此后再进行运算就很简单了。
更高维数的广播
接下来是一个更为复杂的:
x = torch.tensor([[[[1, 2],
[3, 4]]],
[[[5, 6],
[7, 8]]]])
y = torch.tensor([[[1],
[2]],
[[3],
[4]],
[[5],
[6]]])y 的形状为 (2,1,2,2),y 的形状为 (3,2,1),可以与 x 进行广播,可以预期到,y 在广播后的形状为(2,3,2,2)。
按照最末端开始进行变化,y 的最末尾 1 维变 2 维,现在是torch.Size([3,2,2])
y_broadcast_step1 = torch.tensor([[[1., 1.],
[2., 2.]],
[[3., 3.],
[4., 4.]],
[[5., 5.],
[6., 6.]]])为 y 的倒数第四个维度添加 2
y_broadcast = torch.tensor([[[[1., 1.],
[2., 2.]],
[[3., 3.],
[4., 4.]],
[[5., 5.],
[6., 6.]]],
[[[1., 1.],
[2., 2.]],
[[3., 3.],
[4., 4.]],
[[5., 5.],
[6., 6.]]]])至此,y 变化完毕。x 的倒数第三个维度 1 需要变成3,要想构造(2,3,2,2),需要 [1, 2], [3, 4]和[5, 6], [7, 8] 在各自维度复制两次,得到:
x_broadcast = torch.tensor([[[[1, 2],
[3, 4]],
[[1, 2],
[3, 4]],
[[1, 2],
[3, 4]]],
[[[5, 6],
[7, 8]],
[[5, 6],
[7, 8]],
[[5, 6],
[7, 8]]]])现在两者形状完全相同,就可以逐元素进行任意运算了。
reference: Broadcasting semantics