NumPy 进阶:搞懂广播机制,才能真正玩转数组
一、前情回顾:数组的“超能力”从哪来?
上一章我们已经知道,NumPy 数组能直接做加减乘除:
import numpy as np
arr = np.array([1, 2, 3, 4, 5])
print(arr * 2)
# [ 2 4 6 8 10]但你有没有想过,为什么这个操作不报错?
Python 原生列表可做不到这事儿。
[hint info]
答案就是:广播机制(Broadcasting)。
NumPy 允许不同形状的数组在数学运算时“自动对齐”。
[/hint]
二、广播机制的核心原理
简单一句话总结:
广播就是 NumPy 在维度不匹配时自动扩展数组,使它们形状兼容。
来,我们先用一个经典例子开胃 👇
a = np.array([[1, 2, 3],
[4, 5, 6]])
b = np.array([10, 20, 30])
print(a + b)输出:
[14 25 36]]看似理所当然,但实际上这背后 NumPy 悄悄做了扩展 👇
[collapse title="📊 广播过程可视化"]
| 数组 | 原始形状 | 扩展形状 |
|---|---|---|
a | (2, 3) | (2, 3) |
b | (3,) | (1, 3) → (2, 3) |
NumPy 自动在前面补了一个维度 (1, 3),然后复制两次让它变成 (2, 3),于是它们形状一致,就能运算了!
[/collapse]
三、广播规则总结
[hint tip]
判断两个数组是否能广播,其实就三步:
[/hint]
- 从 尾部维度 开始对齐;
- 每个维度要么相等,要么其中一个是 1;
- 不满足规则时,广播失败(直接报错)。
[tabs]
[tab title="能广播 ✅"]
A.shape = (4, 1, 3)
B.shape = (1, 5, 3)
# 结果 -> (4, 5, 3)[/tab]
[tab title="不能广播 ❌"]
A.shape = (3, 2)
B.shape = (2, 3)
# 维度都不匹配,直接报错![/tab]
[/tabs]
四、copy vs view:NumPy 的内存诡计
你可能遇到过这种情况 👇
arr = np.arange(6).reshape(2, 3)
sub = arr[:, 1:]
sub[0, 0] = 99
print(arr)输出:
[[ 0 99 2]
[ 3 4 5]]为什么我改了 sub,arr 也跟着变了?
[hint warning]
那是因为 NumPy 返回的不是拷贝(copy),而是视图(view)!
两者共用底层内存,一改俱改。
[/hint]
🔍 如何判断是否为 view?
sub.base is arr # True -> 表示 sub 是 arr 的视图如果想要真正复制一份独立的副本:
sub = arr[:, 1:].copy()五、性能优化:让你的代码飞起来 🚀
1️⃣ 避免循环
# 慢
for i in range(len(arr)):
arr[i] *= 2
# 快
arr *= 2循环会触发 Python 解释器逐元素执行;
而向量化运算在底层 C 实现中批量计算,速度往往快几十倍。
2️⃣ 使用 numexpr 进行表达式优化
import numexpr as ne
a = np.random.rand(10_000_000)
b = np.random.rand(10_000_000)
c = np.random.rand(10_000_000)
# 普通写法
res1 = a * b + c
# numexpr 写法
res2 = ne.evaluate("a * b + c")numexpr 会自动进行并行和缓存优化,CPU 利用率更高。
3️⃣ 就地操作(In-place)
arr *= 2 # 就地修改,不创建新数组这种写法可以节省内存分配,尤其在大数组时效果显著。
六、实用技巧锦集 🧰
| 场景 | 方法 |
|---|---|
| 随机数 | np.random.rand(3,3) |
| 拼接 | np.concatenate((a,b), axis=0) |
| 转置 | arr.T |
| 排序 | np.sort(arr) |
| 求唯一值 | np.unique(arr) |
这些都是项目中常用的“手上活儿”。
七、性能演示:到底快多少?
[tabs]
[tab title="Python 循环"]
import time
a = list(range(10_000_000))
start = time.time()
b = [x * 2 for x in a]
print("耗时:", time.time() - start)
# 耗时: 0.13944315910339355[/tab]
[tab title="NumPy 向量化"]
import numpy as np
import time
a = np.arange(10_000_000)
start = time.time()
b = a * 2
print("耗时:", time.time() - start)
# 耗时: 0.05504918098449707[/tab]
[/tabs]
实测下来:NumPy 版本比纯 Python 快 30~100 倍。
一旦数据量大,优势就像核弹一样炸裂 💥。
八、小结
| 模块 | 要点 |
|---|---|
| 广播机制 | 自动扩展维度,按尾部对齐 |
| 内存视图 | view 共用底层数据,copy 独立内存 |
| 性能优化 | 向量化、numexpr、就地操作 |
| 实用技巧 | 拼接、随机数、转置、排序、唯一值 |
[hint tip]
下一篇(第 3 篇)我们将正式上手项目实战:
用 NumPy 做 图像处理与数据分析,把理论变成肌肉记忆。
[/hint]
🏁 系列目录
- ✅ NumPy 入门:别再用 for 循环折磨自己了
- ✅ NumPy 进阶:搞懂广播机制,才能真正玩转数组
- 🔜 NumPy 实战:用数组玩转图像与数据分析