Fluent Python 笔记 —— 可迭代对象、迭代器和生成器

迭代是数据处理的基石。扫描内存中放不下的数据集时,通常需要一种惰性获取数据项的方式,即按需一次获取一个数据项。这就是迭代器模式。

在 Python 中,所有序列类型都支持迭代。在语言内部,迭代器用于支持以下操作:

  • for 循环
  • 构建和扩展序列类型
  • 逐行遍历文本文件
  • 列表推导、字典推导和集合推导
  • 元组拆包
  • 调用函数时,使用 * 拆包实参

可迭代对象

以下代码实现了一个 Sentence 类,通过索引从文本中提取单词:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
import re
import reprlib

RE_WORD = re.compile('\w+')

class Sentence:
def __init__(self, text):
self.text = text
self.words = RE_WORD.findall(text)

def __getitem__(self, index):
return self.words[index]

def __len__(self):
return len(self.words)

def __repr__(self):
return f'Sentence({reprlib.repr(self.text)})'

效果如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
>>> from sentence import Sentence
>>> s = Sentence('"The time has come," the Walrus said,')
>>> s
Sentence('"The time ha... Walrus said,')
>>> for word in s:
... print(word)
...
The
time
has
come
the
Walrus
said
>>> list(s)
['The', 'time', 'has', 'come', 'the', 'Walrus', 'said']

上面创建的 Sentence 实例是可迭代的。因此该实例对象可被 for 循环调用、可以用于构建列表等。

迭代的机制

Python 解释器需要迭代对象 x 时,会自动执行 iter(x)。其作用如下:

  • 检查对象是否实现了 __iter__ 方法,如已实现则调用 __iter__,返回一个迭代器对象
  • 若对象没有实现 __iter__ 方法,但实现了 __getitem__ 方法,Python 会创建一个迭代器,尝试按顺序(从索引 0 开始)获取元素
  • 若上述尝试失败,抛出 TypeError 异常(X object is not iterable

所有 Python 序列都实现了 __iter__ 方法,因此都支持迭代操作。

可迭代对象与迭代器的对比

可迭代对象指通过 iter 函数调用可以获取迭代器的对象。即对象实现了能够返回迭代器的 __iter__ 方法,该对象就是可迭代的;或者实现了 __getitem__ 方法,且其参数是从 0 开始的索引,则对象也可以迭代。

一个简单的 for 循环背后也是有迭代器的作用的:

1
2
3
4
5
6
7
>>> s = 'ABC'
>>> for char in s:
... print(char)
...
A
B
C

使用 while 循环模拟效果如下:

1
2
3
4
5
6
7
8
9
10
11
12
>>> s = 'ABC'
>>> it = iter(s)
>>> while True:
... try:
... print(next(it))
... except StopIteration:
... del it
... break
...
A
B
C

  • 使用可迭代的对象(字符串 s)创建迭代器 it
  • 不断在迭代器 it 上调用 next 函数,获取下一个字符
  • 若已获取到最后一个字符,迭代器抛出 StopIteration 异常
  • 捕获 StopIteration 异常,释放 it 对象,退出循环

Python 语言内部会自动处理 for 循环和其他迭代上下文(如列表推导等)中的 StopIteration 异常。

迭代器(如前面的 it)实现了无参数的 __next__ 方法,返回序列中的下一个元素;若没有元素了,则抛出 StopIteration 异常。Python 中的迭代器还实现了 __iter__ 方法,返回该迭代器本身(即确保迭代器本身也是可迭代对象)

典型的迭代器

关于可迭代对象与迭代器之间的区别,可以参考如下代码:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
import re
import reprlib

RE_WORD = re.compile('\w+')

class Sentence:
def __init__(self, text):
self.text = text
self.words = RE_WORD.findall(text)

def __repr__(self):
return f'Sentence({reprlib.repr(self.text)})'

def __iter__(self):
return SentenceIterator(self.words)


class SentenceIterator:
def __init__(self, words):
self.words = words
self.index = 0

def __next__(self):
try:
word = self.words[self.index]
except IndexError:
raise StopIteration()
self.index += 1
return word

def __iter__(self):
return self

根据迭代器协议,可迭代对象 Sentence 中的 __iter__ 方法会实例化并返回一个迭代器(SentenceIterator),而 SentenceIterator 作为迭代器实现了 __next____iter__ 方法。

构建可迭代对象时出现错误的原因经常是混淆了可迭代对象与迭代器。可迭代对象通过内部的 __iter__ 方法返回一个实例化的迭代器对象;而迭代器要实现 __next__ 方法返回单个元素,此外还需要实现 __iter__ 方法返回迭代器本身

生成器函数

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
import re
import reprlib

RE_WORD = re.compile('\w+')

class Sentence:
def __init__(self, text):
self.text = text
self.words = RE_WORD.findall(text)

def __repr__(self):
return f'Sentence({reprlib.repr(self.text)})'

def __iter__(self):
return SentenceIterator(self.words)


class SentenceIterator:
def __init__(self, words):
self.words = words
self.index = 0

def __next__(self):
try:
word = self.words[self.index]
except IndexError:
raise StopIteration()
self.index += 1
return word

def __iter__(self):
return self

实现可迭代对象,相较于之前的代码,符合 Python 习惯的方式是用生成器函数替换手动实现的迭代器 SentenceIterator 类。

只要 Python 函数的定义体中有 yield 关键字,则该函数就是生成器函数。调用生成器函数会返回一个生成器对象。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
>>> def gen_123():
... yield 1
... yield 2
... yield 3
...
>>> gen_123
<function gen_123 at 0x7f63e57f0f80>
>>> gen_123()
<generator object gen_123 at 0x7f63e57dc950>
>>> for i in gen_123():
... print(i)
...
1
2
3
>>> g = gen_123()
>>> next(g)
1
>>> next(g)
2
>>> next(g)
3
>>> next(g)
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
StopIteration

把生成器对象传递给 next() 函数时,其行为与迭代器一致。

惰性求值

re.finditerre.findall 函数的惰性版本,返回的不是结果列表而是一个生成器,按需生成 re.MatchObject 实例。即只在需要时才生成下一个单词。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
import re
import reprlib

RE_WORD = re.compile('\w+')

class Sentence:
def __init__(self, text):
self.text = text

def __repr__(self):
return f'Sentence({reprlib.repr(self.text)})'

def __iter__(self):
for match in RE_WORD.finditer(self.text):
yield match.group()

finditer 函数返回一个迭代器,包含 self.text 中匹配 RE_WORD 的单词,产出 MatchObject 实例。match.group() 方法从 MatchObject 实例中提取匹配正则表达式的具体文本。

生成器函数已极大地简化了代码,但使用生成器表达式能够把代码变得更为简短。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
>>> def gen_AB():
... print('start')
... yield 'A'
... print('continue')
... yield 'B'
... print('end.')
...
>>> res = (x * 3 for x in gen_AB())
>>> res
<generator object <genexpr> at 0x7f4619324ad0>
>>> for i in res:
... print('-->', i)
...
start
--> AAA
continue
--> BBB
end.

可以看出,生成器表达式会产出生成器

1
2
3
4
5
6
7
8
9
10
11
12
13
14
import re
import reprlib

RE_WORD = re.compile('\w+')

class Sentence:
def __init__(self, text):
self.text = text

def __repr__(self):
return f'Sentence({reprlib.repr(self.text)})'

def __iter__(self):
return (match.group() for match in RE_WORD.finditer(self.text))

标准库中的生成器函数

用于过滤的生成器函数

模块 函数 说明
itertools compress(it, selector_it) 并行处理两个可迭代对象。若 selector_it 中的元素是真值,产出 it 中对应的元素
itertools dropwhile(predicate, it) 把可迭代对象 it 中的元素传给 predicate,跳过 predicate(item) 为真值的元素,在 predicate(item) 为假时停止,产出剩余(未跳过)的所有元素(不再继续检查)
内置 filter(predicate, it) it 中的各个元素传给 predicate,若 predicate(item) 返回真值,产出对应元素
itertools filterfalse(predicate, it) filter 函数类似,不过 predicate(item) 返回假值时产出对应元素
itertools takewhile(predicate, it) predicate(item) 返回真值时产出对应元素,然后立即停止不再继续检查
itertools islice(it, stop)islice(it, start, stop, step=1) 产出 it 的切片,作用类似于 s[:stop]s[start:stop:step,不过 it 可以是任何可迭代对象,且实现的是惰性操作
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
>>> def vowel(c):
... return c.lower() in 'aeiou'
...
>>> list(filter(vowel, 'Aardvark'))
['A', 'a', 'a']
>>> import itertools
>>> list(itertools.filterfalse(vowel, 'Aardvark'))
['r', 'd', 'v', 'r', 'k']
>>> list(itertools.dropwhile(vowel, 'Aardvark'))
['r', 'd', 'v', 'a', 'r', 'k']
>>> list(itertools.compress('Aardvark', (1,0,1,1,0,1)))
['A', 'r', 'd', 'a']
>>> list(itertools.islice('Aardvark', 4))
['A', 'a', 'r', 'd']
>>> list(itertools.islice('Aardvark', 1, 7, 2))
['a', 'd', 'a']

用于映射的生成器函数

模块 函数 说明
itertools accumulate(it, [func]) 产出累积的总和。若提供了 func,则把 it 中的前两个元素传给 func,再把计算结果连同下一个元素传给 func,以此类推,产出结果
内置 enumerate(it, start=0) 产出由两个元素构成的元组,结构是 (index, item)。其中 indexstart 开始计数,item 则从 it 中获取
内置 map(func, it1, [it2, ..., itN]) it 中的各个元素传给 func,产出结果;若传入 N 个可迭代对象,则 func 必须能接受 N 个参数,且并行处理各个可迭代对象
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
>>> list(enumerate('albatroz', 1))
[(1, 'a'), (2, 'l'), (3, 'b'), (4, 'a'), (5, 't'), (6, 'r'), (7, 'o'), (8, 'z')]
>>> import operator
>>> list(map(operator.mul, range(11), range(11)))
[0, 1, 4, 9, 16, 25, 36, 49, 64, 81, 100]
>>> list(map(operator.mul, range(11), [2, 4, 8]))
[0, 4, 16]
>>> list(map(lambda a, b: (a, b), range(11), [2, 4, 8]))
[(0, 2), (1, 4), (2, 8)]
>>> import itertools
>>> sample = [5, 4, 2, 8, 7, 6, 3, 0, 9, 1]
>>> list(itertools.accumulate(sample))
[5, 9, 11, 19, 26, 32, 35, 35, 44, 45]
>>> list(itertools.accumulate(sample, max))
[5, 5, 5, 8, 8, 8, 8, 8, 9, 9]

合并多个可迭代对象的生成器函数

模块 函数 说明
itertools chain(it1, ..., itN) 先产出 it1 中的所有元素,然后产出 it2 中的所有元素,以此类推,无缝连接
itertools chain.from_iterable(it) 产出 it 生成的各个可迭代对象中的元素,一个接一个无缝连接;it 中的元素应该为可迭代对象(即 it 是嵌套了可迭代对象的可迭代对象)
itertools product(it1, ..., itN, repeat=1) 计算笛卡尔积。从输入的各个可迭代对象中获取元素,合并成 N 个元素组成的元组,与嵌套的 for 循环效果一样。repeat 指明重复处理多少次输入的可迭代对象
内置 zip(it1, ..., itN) 并行从输入的各个可迭代对象中获取元素,产出由 N 个元素组成的元组。只要其中任何一个可迭代对象到头了,就直接停止
itertools zip_longest(it1, ..., itN, fillvalue=None) 并行从输入的各个可迭代对象中获取元素,产出由 N 个元素组成的元组,等到最长的可迭代对象到头后才停止。空缺的值用 fillvalue 填充
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
>>> import itertools
>>> list(itertools.chain('ABC', range(2)))
['A', 'B', 'C', 0, 1]
>>> list(itertools.chain(enumerate('ABC')))
[(0, 'A'), (1, 'B'), (2, 'C')]
>>> list(itertools.chain.from_iterable(enumerate('ABC')))
[0, 'A', 1, 'B', 2, 'C']
>>> list(zip('ABC', range(5)))
[('A', 0), ('B', 1), ('C', 2)]
>>> list(zip('ABC', range(5), [10, 20, 30, 40]))
[('A', 0, 10), ('B', 1, 20), ('C', 2, 30)]
>>> list(itertools.zip_longest('ABC', range(5)))
[('A', 0), ('B', 1), ('C', 2), (None, 3), (None, 4)]
>>> list(itertools.zip_longest('ABC', range(5), fillvalue='?'))
[('A', 0), ('B', 1), ('C', 2), ('?', 3), ('?', 4)]
1
2
3
4
5
6
7
8
9
>>> list(itertools.product('ABC', range(2)))
[('A', 0), ('A', 1), ('B', 0), ('B', 1), ('C', 0), ('C', 1)]
>>> suits = 'spades hearts diamonds clubs'.split()
>>> list(itertools.product('AK', suits))
[('A', 'spades'), ('A', 'hearts'), ('A', 'diamonds'), ('A', 'clubs'), ('K', 'spades'), ('K', 'hearts'), ('K', 'diamonds'), ('K', 'clubs')]
>>> list(itertools.product('ABC'))
[('A',), ('B',), ('C',)]
>>> list(itertools.product('ABC', repeat=2))
[('A', 'A'), ('A', 'B'), ('A', 'C'), ('B', 'A'), ('B', 'B'), ('B', 'C'), ('C', 'A'), ('C', 'B'), ('C', 'C')]

把输入的各个元素扩展成多个输出元素的生成器函数
|模块|函数|说明|
|-|-|-|
|itertools|combinations(it, out_len)|把可迭代对象 it 产出的 out_len 个元素组合在一起产出|
|itertools|combinations_with_replacement(it, out_len)|把 it 产出的 out_len 个元素组合在一起产出,包含相同元素的组合|
|itertools|count(start=0, step=1)|从 start 开始不断产出数字,按 step 指定的步幅增加|
|itertools|cycle(it)|从 it 中产出各个元素,存储各个元素的副本,然后按顺序重复不断地产出各个元素|
|itertools|permutations(it, out_len=None)|把 out_lenit 产出的元素排列在一起,然后产出这些排列;out_len 的默认值等于 len(list(it))|
|itertools|repeat(item, [times])|重复不断地产出指定的元素,除非提供 times 指定次数|

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
>>> import itertools
>>> ct = itertools.count()
>>> next(ct)
0
>>> next(ct), next(ct), next(ct)
(1, 2, 3)
>>> list(itertools.islice(itertools.count(1, .3), 3))
[1, 1.3, 1.6]
>>> cy = itertools.cycle('ABC')
>>> next(cy)
'A'
>>> list(itertools.islice(cy, 7))
['B', 'C', 'A', 'B', 'C', 'A', 'B']
>>> rp = itertools.repeat(7)
>>> next(rp), next(rp)
(7, 7)
>>> list(itertools.repeat(8, 4))
[8, 8, 8, 8]
1
2
3
4
5
6
7
8
9
>>> import itertools
>>> list(itertools.combinations('ABC', 2))
[('A', 'B'), ('A', 'C'), ('B', 'C')]
>>> list(itertools.combinations_with_replacement('ABC', 2))
[('A', 'A'), ('A', 'B'), ('A', 'C'), ('B', 'B'), ('B', 'C'), ('C', 'C')]
>>> list(itertools.permutations('ABC', 2))
[('A', 'B'), ('A', 'C'), ('B', 'A'), ('B', 'C'), ('C', 'A'), ('C', 'B')]
>>> list(itertools.product('ABC', repeat=2))
[('A', 'A'), ('A', 'B'), ('A', 'C'), ('B', 'A'), ('B', 'B'), ('B', 'C'), ('C', 'A'), ('C', 'B'), ('C', 'C')]

用于重新排列元素的生成器函数
|模块|函数|说明|
|-|-|-|
|itertools|groupby(it, key=None)|产出由两个元素组成的元素,形式为 (key, group),其中 key 是分组标准,group 是生成器,用于产出分组里的元素|
|内置|reversed(seq)|从后向前,倒序产出 seq 中的元素;seq 必须是序列,或者实现了 __reversed__ 特殊方法的对象|
|itertools|tee(it, n=2)|产出一个有 n 个生成器组成的元组,每个生成器都可以独立地产出输入的可迭代对象中的元素|

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
>>> import itertools
>>> animals = ['duck', 'eagle', 'rat', 'giraffe', 'bear', 'bat', 'dolphin', 'shark', 'lion']
>>> animals.sort(key=len)
>>> animals
['rat', 'bat', 'duck', 'bear', 'lion', 'eagle', 'shark', 'giraffe', 'dolphin']
>>> for length, group in itertools.groupby(animals, len):
... print(length, '->', list(group))
...
3 -> ['rat', 'bat']
4 -> ['duck', 'bear', 'lion']
5 -> ['eagle', 'shark']
7 -> ['giraffe', 'dolphin']
>>>
>>> g1, g2 = itertools.tee('ABC')
>>> next(g1)
'A'
>>> next(g2)
'A'
>>> next(g2)
'B'
>>> list(g1)
['B', 'C']
>>> list(g2)
['C']
>>> list(zip(*itertools.tee('ABC')))
[('A', 'A'), ('B', 'B'), ('C', 'C')]

PSitertools.groupby 假定输入的可迭代对象已按照分组标准完成排序

读取迭代器,返回单个值的函数

模块 函数 说明
内置 all(it) it 中的所有元素都为真值时返回 True,否则返回 False;all([]) 返回 True
内置 any(it) 只要 it 中有元素为真值就返回 True,否则返回 False;any([]) 返回 False
内置 max(it, [key=], [default=]) 返回 it 中值最大的元素;key 是排序函数,与 sorted 中的一样;若可迭代对象为空,返回 default
内置 min(it, [key=], [default=]) 返回 it 中值最小的元素;key 是排序函数;若可迭代对象为空,返回 default
functools reduce(func, it, [initial]) 把前两个元素传给 func,然后把计算结果和第三个元素传给 func,以此类推,返回最后的结果。若提供了 initial,则将其作为第一个元素传入
内置 sum(it, start=0) it 中所有元素的总和,若提供可选的 start,会把它也加上
1
2
3
4
5
6
7
8
9
10
11
12
13
14
>>> all([1, 2, 3])
True
>>> all([1, 0, 3])
False
>>> all([])
True
>>> any([1, 2, 3])
True
>>> any([1, 0, 3])
True
>>> any([0, 0.0])
False
>>> any([])
False
1
2
3
4
5
6
>>> import functools
>>> functools.reduce(lambda a, b: a * b, range(1, 6))
120
>>> import operator
>>> functools.reduce(operator.mul, range(1, 6))
120

参考资料

Fluent Python