小陈博客-个人分享

NumPy 进阶:搞懂广播机制,才能真正玩转数组

一、前情回顾:数组的“超能力”从哪来?

上一章我们已经知道,NumPy 数组能直接做加减乘除:

import numpy as np

arr = np.array([1, 2, 3, 4, 5])
print(arr * 2)
# [ 2  4  6  8 10]

但你有没有想过,为什么这个操作不报错?

Python 原生列表可做不到这事儿。

[hint info]
答案就是:广播机制(Broadcasting)
NumPy 允许不同形状的数组在数学运算时“自动对齐”。
[/hint]

二、广播机制的核心原理

简单一句话总结:

广播就是 NumPy 在维度不匹配时自动扩展数组,使它们形状兼容。

来,我们先用一个经典例子开胃 👇

a = np.array([[1, 2, 3],
              [4, 5, 6]])

b = np.array([10, 20, 30])

print(a + b)

输出:

[14 25 36]]

看似理所当然,但实际上这背后 NumPy 悄悄做了扩展 👇

[collapse title="📊 广播过程可视化"]

数组原始形状扩展形状
a(2, 3)(2, 3)
b(3,)(1, 3) → (2, 3)

NumPy 自动在前面补了一个维度 (1, 3),然后复制两次让它变成 (2, 3),于是它们形状一致,就能运算了!
[/collapse]

三、广播规则总结

[hint tip]
判断两个数组是否能广播,其实就三步:
[/hint]

  1. 尾部维度 开始对齐;
  2. 每个维度要么相等,要么其中一个是 1;
  3. 不满足规则时,广播失败(直接报错)。

[tabs]

[tab title="能广播 ✅"]

A.shape = (4, 1, 3)
B.shape = (1, 5, 3)
# 结果 -> (4, 5, 3)

[/tab]

[tab title="不能广播 ❌"]

A.shape = (3, 2)
B.shape = (2, 3)
# 维度都不匹配,直接报错!

[/tab]

[/tabs]

四、copy vs view:NumPy 的内存诡计

你可能遇到过这种情况 👇

arr = np.arange(6).reshape(2, 3)
sub = arr[:, 1:]
sub[0, 0] = 99

print(arr)

输出:

[[ 0 99  2]
 [ 3  4  5]]

为什么我改了 subarr 也跟着变了?

[hint warning]
那是因为 NumPy 返回的不是拷贝(copy),而是视图(view)
两者共用底层内存,一改俱改。
[/hint]

🔍 如何判断是否为 view?

sub.base is arr  # True -> 表示 sub 是 arr 的视图

如果想要真正复制一份独立的副本:

sub = arr[:, 1:].copy()

五、性能优化:让你的代码飞起来 🚀

1️⃣ 避免循环

# 慢
for i in range(len(arr)):
    arr[i] *= 2

# 快
arr *= 2

循环会触发 Python 解释器逐元素执行;
而向量化运算在底层 C 实现中批量计算,速度往往快几十倍。


2️⃣ 使用 numexpr 进行表达式优化

import numexpr as ne

a = np.random.rand(10_000_000)
b = np.random.rand(10_000_000)
c = np.random.rand(10_000_000)

# 普通写法
res1 = a * b + c

# numexpr 写法
res2 = ne.evaluate("a * b + c")

numexpr 会自动进行并行和缓存优化,CPU 利用率更高。


3️⃣ 就地操作(In-place)

arr *= 2  # 就地修改,不创建新数组

这种写法可以节省内存分配,尤其在大数组时效果显著。

六、实用技巧锦集 🧰

场景方法
随机数np.random.rand(3,3)
拼接np.concatenate((a,b), axis=0)
转置arr.T
排序np.sort(arr)
求唯一值np.unique(arr)

这些都是项目中常用的“手上活儿”。

七、性能演示:到底快多少?

[tabs]

[tab title="Python 循环"]

import time

a = list(range(10_000_000))
start = time.time()
b = [x * 2 for x in a]
print("耗时:", time.time() - start)

# 耗时: 0.13944315910339355

[/tab]

[tab title="NumPy 向量化"]

import numpy as np
import time

a = np.arange(10_000_000)
start = time.time()
b = a * 2
print("耗时:", time.time() - start)

# 耗时: 0.05504918098449707

[/tab]

[/tabs]

实测下来:NumPy 版本比纯 Python 快 30~100 倍
一旦数据量大,优势就像核弹一样炸裂 💥。

八、小结

模块要点
广播机制自动扩展维度,按尾部对齐
内存视图view 共用底层数据,copy 独立内存
性能优化向量化、numexpr、就地操作
实用技巧拼接、随机数、转置、排序、唯一值

[hint tip]
下一篇(第 3 篇)我们将正式上手项目实战:
用 NumPy 做 图像处理与数据分析,把理论变成肌肉记忆。
[/hint]

🏁 系列目录

  1. NumPy 入门:别再用 for 循环折磨自己了
  2. NumPy 进阶:搞懂广播机制,才能真正玩转数组
  3. 🔜 NumPy 实战:用数组玩转图像与数据分析

当前页面是本站的「Google AMP」版。查看和发表评论请点击:完整版 »