基本算法问题的 Python 解法(递归与搜索)

一、斐波那契数列

递归函数
1
2
3
4
5
6
7
8
9
def fib(n: int) -> int:
if n < 2:
return n
return fib(n - 2) + fib(n - 1)

if __name__ == '__main__':
for i in range(10):
print(fib(i), end=' ')
# => 0 1 1 2 3 5 8 13 21 34

不要尝试调用 fib(50)(这个我试过了,因为等得不耐烦 Ctrl+C 掉了)。在上面版本的程序中,每次对 fib() 函数的调用都会导致额外的两次对 fib() 自身的调用(即 fib(n -2)fib(n - 1))。这个行为模式会一直传递下去,直到每一个新的调用分支都满足 n < 2
比如 fib(4) 最终会执行以下函数:

1
2
3
4
5
6
7
8
9
fib(4) -> fib(3), fib(2)
fib(3) -> fib(2), fib(1)
fib(2) -> fib(1), fib(0)
fib(2) -> fib(1), fib(0)
fib(1) -> 1
fib(1) -> 1
fib(1) -> 1
fib(0) -> 0
fib(0) -> 0

计算 fib(4) 最终会调用 9 次 fib() 函数,fib(5) 调用 15 次,fib(10) 调用 177 次,计算 fib(20) 则需要执行整整 21891 次。
换句话说,函数的调用树会以指数级的速度扩展。

Memoization

Memoization 是指将计算任务执行后得到的结果保存在某个地方,后面需要用到时就可以直接取出使用,而不必再次计算。

1
2
3
4
5
6
7
8
9
10
11
from typing import Dict
memo: Dict[int, int] = {0: 0, 1: 1} # base cases

def fib2(n: int) -> int:
if n not in memo:
memo[n] = fib2(n -1) + fib2(n - 2) # memoization
return memo[n]

if __name__ == '__main__':
print(fib2(50))
# => 12586269025

lru_cache

Python 提供了一个内置的装饰器 ``@functools.lru_cache()用来自动记录(缓存)某个函数的执行结果。因此上面的fib2() 可以改为如下形式:

1
2
3
4
5
6
7
8
9
10
11
from functools import lru_cache

@lru_cache(maxsize=None)
def fib3(n: int) -> int:
if n < 2:
return n
return fib3(n - 2) + fib3(n - 1)

if __name__ == '__main__':
print(fib3(50))
# => 12586269025

迭代式方案
1
2
3
4
5
6
7
8
9
10
11
def fib4(n: int) -> int:
if n == 0: return n
last: int = 0
next: int = 1
for _ in range(1, n):
last, next = next, last + next
return next

if __name__== '__main__':
print(fib4(50))
# => 12586269025

上面的程序算是这几个方案中性能最好的一个,for 循环最多只执行 n-1 次。递归式方案借助逆向思维,而迭代式方案则是正向的逻辑。
在某些情况下,原始的递归计算方式会带来过多的性能消耗。但是任何可以通过递归计算解决的问题,同样可以用迭代的方式解决。

生成器
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
from typing import Generator

def fib5(n: int) -> Generator[int, None, None]:
yield 0
if n > 0: yield 1
last: int = 0
next: int = 1
for _ in range(1, n):
last, next = next, last + next
yield next

if __name__ == '__main__':
for i in fib5(50):
print(i, end=' ')
# => 0 1 1 2 3 5 8 13 21 34 55 89 144 233 377 610 987 1597 2584 4181 6765 10946 17711 28657 46368 75025 121393 196418 317811 514229 832040 1346269 2178309 3524578 5702887 9227465 14930352 24157817 39088169 63245986 102334155 165580141 267914296 433494437 701408733 1134903170 1836311903 2971215073 4807526976 7778742049 12586269025

二、汉诺塔问题

关于汉诺塔问题的简单描述:

  • 有三根柱子,柱子 A 上摞有 n 个不同大小的圆盘,需要将所有圆盘借助柱子 B 转移到柱子 C
  • 每次只能转移一个圆盘
  • 大的圆盘必须放置在小的圆盘下面

Hanoi Tower

递归解法的逻辑如下:

  • 利用 C 作中转,移动 A 顶部的 n-1 个圆盘到柱子 B
  • 移动 A 底部剩下的最大号圆盘到柱子 C
  • 利用 A 作中转,移动 B 上的 n-1 个圆盘到柱子 C(此时问题由最初的将 n 个圆盘从 A 转移到 C,变成将 n-1 个圆盘从 B 转移到 C)
  • 重复以上步骤

实现代码如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
from typing import TypeVar, Generic, List
T = TypeVar('T')

class Stack(Generic[T]):
def __init__(self) -> None:
self._container: List[T] = []

def push(self, item: T) -> None:
self._container.append(item)

def pop(self) -> T:
return self._container.pop()

def __repr__(self) -> str:
return repr(self._container)


def hanobi(begin: Stack[int], end: Stack[int], temp: Stack[int], n: int) -> None:
if n == 1:
end.push(begin.pop())
else:
hanobi(begin, temp, end, n - 1)
hanobi(begin, end, temp, 1)
hanobi(temp, end, begin, n - 1)

if __name__ == '__main__':
num_discs: int = 10

tower_a: Stack[int] = Stack()
tower_b: Stack[int] = Stack()
tower_c: Stack[int] = Stack()
for i in range(1, num_discs + 1):
tower_a.push(i)

hanobi(tower_a, tower_c, tower_b, num_discs)
print(tower_a) # []
print(tower_b) # []
print(tower_c) # [1, 2, 3]

这只是一个简单的演示程序,Stack 类抽象出单根柱子与圆盘的模型,hanobi 函数则以递归的方式定义了圆盘移动的过程。
如果是真人来操作的话,每次移动一个圆盘花费 1 秒钟时间。解决由 50 个圆盘组成的汉诺塔问题,共需要耗费多长时间?
约等于 3570.2 万年

三、DNA 检索

在计算机软件中,基因(Gene)一般通过由 A、C、G、T 组成的字符序列表示。其中每个字符代表一个核苷酸(nucleotide),每相邻的三个核苷酸(碱基)组成一个密码子(codon)。

基因的数据模型如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
from enum import IntEnum
from typing import Tuple, List

Nucleotide: IntEnum = IntEnum('Nucleotide', ('A', 'C', 'G', 'T'))
Codon = Tuple[Nucleotide, Nucleotide, Nucleotide]
Gene = List[Codon]

def string_to_gene(s: str) -> Gene:
gene: Gene = []
for i in range(0, len(s), 3):
if (i + 2) >= len(s):
return gene
codon: Codon = (Nucleotide[s[i]], Nucleotide[s[i + 1]],
Nucleotide[s[i + 2]])
gene.append(codon)
return gene

gene_str: str = "ACGTGGCTCTCTAACGTACGTACGTACGGGGTTTATATATACCCTAGGACTCCCTTT"
my_gene: Gene = string_to_gene(gene_str)
for conda in my_gene:
print(conda)
# => (<Nucleotide.A: 1>, <Nucleotide.C: 2>, <Nucleotide.G: 3>)
# => (<Nucleotide.T: 4>, <Nucleotide.G: 3>, <Nucleotide.G: 3>)
# => (<Nucleotide.C: 2>, <Nucleotide.T: 4>, <Nucleotide.C: 2>)
# => (<Nucleotide.T: 4>, <Nucleotide.C: 2>, <Nucleotide.T: 4>)
# => (<Nucleotide.A: 1>, <Nucleotide.A: 1>, <Nucleotide.C: 2>)
# => (<Nucleotide.G: 3>, <Nucleotide.T: 4>, <Nucleotide.A: 1>)
# => (<Nucleotide.C: 2>, <Nucleotide.G: 3>, <Nucleotide.T: 4>)
# => (<Nucleotide.A: 1>, <Nucleotide.C: 2>, <Nucleotide.G: 3>)
# => (<Nucleotide.T: 4>, <Nucleotide.A: 1>, <Nucleotide.C: 2>)
# => (<Nucleotide.G: 3>, <Nucleotide.G: 3>, <Nucleotide.G: 3>)
# => (<Nucleotide.G: 3>, <Nucleotide.T: 4>, <Nucleotide.T: 4>)
# => (<Nucleotide.T: 4>, <Nucleotide.A: 1>, <Nucleotide.T: 4>)
# => (<Nucleotide.A: 1>, <Nucleotide.T: 4>, <Nucleotide.A: 1>)
# => (<Nucleotide.T: 4>, <Nucleotide.A: 1>, <Nucleotide.C: 2>)
# => (<Nucleotide.C: 2>, <Nucleotide.C: 2>, <Nucleotide.T: 4>)
# => (<Nucleotide.A: 1>, <Nucleotide.G: 3>, <Nucleotide.G: 3>)
# => (<Nucleotide.A: 1>, <Nucleotide.C: 2>, <Nucleotide.T: 4>)
# => (<Nucleotide.C: 2>, <Nucleotide.C: 2>, <Nucleotide.C: 2>)
# => (<Nucleotide.T: 4>, <Nucleotide.T: 4>, <Nucleotide.T: 4>)

线性搜索会按照原始数据结构中的元素排列顺序,逐个检查搜索空间中的每一个元素,是最简单直观的搜索方式,复杂度为 O(n)。

1
2
3
4
5
6
7
8
9
10
def linear_contains(gene: Gene, key_codon: Codon) -> bool:
for codon in gene:
if codon == key_codon:
return True
return False

acg: Codon = (Nucleotide.A, Nucleotide.C, Nucleotide.G)
gat: Codon = (Nucleotide.G, Nucleotide.A, Nucleotide.T)
print(linear_contains(my_gene, acg)) # True
print(linear_contains(my_gene, gat)) # False

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
def binary_contains(gene: Gene, key_codon: Codon) -> bool:
low: int = 0
high: int = len(gene) - 1
while low <= high:
mid: int = (low + high)
if gene[mid] < key_codon:
low = mid + 1
elif gene[mid] > key_codon:
high = mid - 1
else:
return True
return False

my_sorted_gene: Gene = sorted(my_gene)
print(binary_contains(my_sorted_gene, acg))
print(binary_contains(my_sorted_gene, gat))

Binary Search 将搜索对象与有序列表中间位置的值进行比对,根据大小关系确定下一次搜索在序列的前半或者后半部分继续进行。依照此规则持续进行检索,每一次比较都会将下一次的搜索空间减小一半(类似猜数字游戏)。

Binary Search 在最坏情况下的复杂度为 O(lg n),前提是序列中的所有元素已经根据大小排序。最好的排序算法复杂度为 O(n lg n)
因此在需要对无序列表执行多次搜索的场景下,可以先对列表元素排序再执行 Binary Search。

通用示例
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
# generic_search.py
from typing import TypeVar, Iterable, Sequence, Generic, List, Callable, Set, Deque, Dict, Any, Optional
from typing import Protocol
from heapq import heappush, heappop

T = TypeVar('T')

def linear_contains(iterable: Iterable[T], key: T) -> bool:
for item in iterable:
if item == key:
return True
return False

C = TypeVar("C", bound="Comparable")

class Comparable(Protocol):
def __eq__(self, other: Any) -> bool:
return self == other

def __lt__(self, other: C) -> bool:
return self < other

def __gt__(self: C, other: C) -> bool:
return (not self < other) and self != other

def __le__(self: C, other: C) -> bool:
return self < other or self == other

def __ge__(self: C, other: C) -> bool:
return not self < other

def binary_contains(sequence: Sequence[C], key: C) -> bool:
low: int = 0
high: int = len(sequence) - 1
while low <= high:
mid: int = (low + high) // 2
if sequence[mid] < key:
low = mid + 1
elif sequence[mid] > key:
high = mid - 1
else:
return True
return False

if __name__ == '__main__':
print(linear_contains([1, 5, 15, 15, 15, 15, 20], 5))
print(binary_contains(["a", "d", "e", "f", "z"], "f"))
print(binary_contains(["join", "mark", "ronald", "sarah"], "sheila"))

四、迷宫问题

迷宫建模:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
# maze.py
from enum import Enum
from typing import List, NamedTuple, Callable, Optional
import random
from math import sqrt

class Cell(str, Enum):
EMPTY = " "
BLOCKED = "X"
START = "S"
GOAL = "G"
PATH = "*"


class MazeLocation(NamedTuple):
row: int
column: int


class Maze:
def __init__(self, rows: int = 10, columns: int = 10, sparseness: float = 0.2,
start: MazeLocation = MazeLocation(0, 0),
goal: MazeLocation = MazeLocation(9, 9)) -> None:
self._rows: int = rows
self._columns: int = columns
self.start: MazeLocation = start
self.goal: MazeLocation = goal
self._grid: List[List[Cell]] = [[Cell.EMPTY for c in range(columns)]
for r in range(rows)]
self._randomly_fill(rows, columns, sparseness)
self._grid[start.row][start.column] = Cell.START
self._grid[goal.row][goal.column] = Cell.GOAL

def _randomly_fill(self, rows: int, columns: int, sparseness: float):
for row in range(rows):
for column in range(columns):
if random.uniform(0, 1.0) < sparseness:
self._grid[row][column] = Cell.BLOCKED

def __str__(self) -> str:
output: str = ""
for row in self._grid:
output += "".join([c.value for c in row]) + "\n"
return output


maze: Maze = Maze()
print(maze)

1
2
3
4
5
6
7
8
9
10
SX X    X
XX

X X X
X
X
XX X
X X X
X
XX X X G

Maze 类创建如下两个方法,goal_test 用于确定当前 Cell 是否为目标地点;successors 用于判断与当前 Cell 相邻的哪些 Cell 可以作为下一步的落脚点,返回其列表:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
# maze.py continued
def goal_test(self, ml: MazeLocation) -> bool:
return ml == self.goal

def successors(self, ml: MazeLocation) -> List[MazeLocation]:
locations: List[MazeLocation] = []
if ml.row + 1 < self._rows and self._grid[ml.row + 1][ml.column] != Cell.BLOCKED:
locations.append(MazeLocation(ml.row + 1, ml.column))
if ml.row - 1 >= 0 and self._grid[ml.row - 1][ml.column] != Cell.BLOCKED:
locations.append(MazeLocation(ml.row - 1, ml.column))
if ml.column + 1 < self._columns and self._grid[ml.row][ml.column + 1] != Cell.BLOCKED:
locations.append(MazeLocation(ml.row, ml.column + 1))
if ml.column - 1 >= 0 and self._grid[ml.row][ml.column - 1] != Cell.BLOCKED:
locations.append(MazeLocation(ml.row, ml.column - 1))
return locations

DFS(depth-first search)算法

我理解的 DFS 算法,就是“一条道儿走到黑”。只要当前的路径能走通就一直走下去,不去考虑途中其他可能的方案。一旦遇到障碍走入“死胡同”,则回退到上一个节点尝试之前未选择的路径。
最终到达目标地点后停止,或者一路回退到起点,则证明不存在可行的方案。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
# generic_search.py continued
class Stack(Generic[T]):
def __init__(self) -> None:
self._container: List[T] = []

@property
def empty(self) -> bool:
return not self._container # not is true for empty container

def push(self, item: T) -> None:
self._container.append(item)

def pop(self) -> T:
return self._container.pop() # LIFO

def __repr__(self) -> str:
return repr(self._container)


class Node(Generic[T]):
def __init__(self, state: T, parent: Optional[Node], cost: float = 0.0, heuristic: float = 0.0) -> None:
self.state: T = state
self.parent: Optional[Node] = parent
self.cost: float = cost
self.heuristic: float = heuristic

def __lt__(self, other: Node) -> bool:
return (self.cost + self.heuristic) < (other.cost + other.heuristic)


def dfs(initial: T, goal_test: Callable[[T], bool], successors: Callable[[T], List[T]]) -> Optional[Node[T]]:
# frontier is where we've yet to go
frontier: Stack[Node[T]] = Stack()
frontier.push(Node(initial, None))
# explored is where we've been
explored: Set[T] = {initial}

# keep going while there is more to explore
while not frontier.empty:
current_node: Node[T] = frontier.pop()
current_state: T = current_node.state
# if we found the goal, we're done
if goal_test(current_state):
return current_node
# check where we can go next and haven't explored
for child in successors(current_state):
if child in explored: # skip children we already explored
continue
explored.add(child)
frontier.push(Node(child, current_node))
return None # went through everything and never found goal


def node_to_path(node: Node[T]) -> List[T]:
path: List[T] = [node.state]
# work backwards from end to front
while node.parent is not None:
node = node.parent
path.append(node.state)
path.reverse()
return path

其中 Node 类可以算作 MazeLocation 上的又一层封装,为其添加 parent-child 关系方便后续连接为路径。其 costheuristic 属性会在后面的 A* 算法中用到。

dfs 函数中的 frontier 使用 Stack 数据结构记录每一步选择中出现的可行路径节点,其 pop 方法用于回溯到上一个节点。若 frontier 为空则说明已回溯到起点,未找到合适路径。
explored 使用 Set 数据结构记录所有已经尝试过的路径选择,避免回溯时重复之前的路径。

node_to_path 函数则可以从终点开始反向查找,将 Node 节点(即带有 parent-child 关系的 MazeLocation)逐步扩展连接成完整路径存放在列表中。

补充 maze.py 代码:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
# maze.py continued
from generic_search import dfs, node_to_path, Node
def mark(self, path: List[MazeLocation]):
for maze_location in path:
self._grid[maze_location.row][maze_location.column] = Cell.PATH
self._grid[self.start.row][self.start.column] = Cell.START
self._grid[self.goal.row][self.goal.column] = Cell.GOAL

def clear(self, path: List[MazeLocation]):
for maze_location in path:
self._grid[maze_location.row][maze_location.column] = Cell.EMPTY
self._grid[self.start.row][self.start.column] = Cell.START
self._grid[self.goal.row][self.goal.column] = Cell.GOAL


if __name__ == '__main__':
m: Maze = Maze()
print(m)

# test DFS
solution1: Optional[Node[MazeLocation]] = dfs(m.start, m.goal_test,
m.successors)
if solution1 is None:
print("No solution found using depth-first search")
else:
path1: List[MazeLocation] = node_to_path(solution1)
m.mark(path1)
print(m)
m.clear(path1)

结果如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
SX   XX X
X
X
X X
X X
X
XX X

X X
XX G

SX XX X
**X
*X
X*X
** X X
* *****X
****XX *X
***
X X*
XX****G

BFS(breadth-first search)算法寻找最短路径

BFS 算法更倾向于一步一个脚印,每一步都要尝试过所有可能的方案,层层递进直到其中任何一种方案走通。
假如共有 100 种路径都可以到达迷宫终点,各条路径或长或短。BFS 每次都会在所有 100 种路径上各走一步,直到其中某一条路径刚好到达终点后停止。
由于对所有可行路径的尝试是同步进行的,因此该算法找到的路径总是最短的。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
# generic_search.py continued
class Queue(Generic[T]):
def __init__(self) -> None:
self._container: Deque[T] = Deque()

@property
def empty(self) -> bool:
return not self._container # not is true for empty container

def push(self, item: T) -> None:
self._container.append(item)

def pop(self) -> T:
return self._container.popleft() # FIFO

def __repr__(self) -> str:
return repr(self._container)


def bfs(initial: T, goal_test: Callable[[T], bool], successors: Callable[[T], List[T]]) -> Optional[Node[T]]:
# frontier is where we've yet to go
frontier: Queue[Node[T]] = Queue()
frontier.push(Node(initial, None))
# explored is where we've been
explored: Set[T] = {initial}

# keep going while there is more to explore
while not frontier.empty:
current_node: Node[T] = frontier.pop()
current_state: T = current_node.state
# if we found the goal, we're done
if goal_test(current_state):
return current_node
# check where we can go next and haven't explored
for child in successors(current_state):
if child in explored: # skip children we already explored
continue
explored.add(child)
frontier.push(Node(child, current_node))
return None # went through everything and never found goal

其中 Queue 数据结构底层使用 Deque 而不是 List,目的是提升 popleft 方法的效率。
相对于 DFS 中的 Stack 数据结构,Queue 的 pop 方法是从序列的最左侧(Stack 是从最右侧)弹出项目。
由此使得 BFS 可以从起点开始最大程度地尝试所有可行方案,即广度优先原则。

1
2
3
4
5
6
7
8
9
10
# maze.py continued
# test BFS
solution2: Optional[Node[MazeLocation]] = bfs(m.start, m.goal_test, m.successors)
if solution2 is None:
print("No solution found using breadth-first search!")
else:
path2: List[MazeLocation] = node_to_path(solution2)
m.mark(path2)
print(m)
m.clear(path2)

执行结果:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
S*******
X X X *XX
********
* X X
**********
X *
XX *******
* X X
XXX*X ****
X****XXG

S
*X X X XX
*
* X X
*
*** X
XX*****
X* X
XXX X ****
X XXG

A* 算法寻找最优解

粗略地说,DFS 可以较快地找到合适路径,BFS 则可以寻求最短路径(但时间花费较高)。

A* 算法则引入了 heuristic 概念,大概就是在做出决策前先根据一定的指导原则确定路径选择的优先级。即优先选择可能最短(离终点最近)的路径作为下一步的节点,而不是随机尝试所有可行的节点,以此来降低 BFS 导致的时间损耗。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
# generic_search.py continued
class PriorityQueue(Generic[T]):
def __init__(self) -> None:
self._container: List[T] = []

@property
def empty(self) -> bool:
return not self._container # not is true for empty container

def push(self, item: T) -> None:
heappush(self._container, item) # in by priority

def pop(self) -> T:
return heappop(self._container) # out by priority

def __repr__(self) -> str:
return repr(self._container)


def astar(initial: T, goal_test: Callable[[T], bool], successors: Callable[[T], List[T]], heuristic: Callable[[T], float]) -> Optional[Node[T]]:
# frontier is where we've yet to go
frontier: PriorityQueue[Node[T]] = PriorityQueue()
frontier.push(Node(initial, None, 0.0, heuristic(initial)))
# explored is where we've been
explored: Dict[T, float] = {initial: 0.0}

# keep going while there is more to explore
while not frontier.empty:
current_node: Node[T] = frontier.pop()
current_state: T = current_node.state
# if we found the goal, we're done
if goal_test(current_state):
return current_node
# check where we can go next and haven't explored
for child in successors(current_state):
new_cost: float = current_node.cost + 1 # 1 assumes a grid, need a cost function for more sophisticated apps

if child not in explored or explored[child] > new_cost:
explored[child] = new_cost
frontier.push(Node(child, current_node, new_cost, heuristic(child)))
return None # went through everything and never found goal

其中 PriorityQueue 数据结构使用 heappop 方法确保每次从队列中取出的数据项都是(对算法而言)优先级最高的。

下面 maze.py 中添加的 manhattan_distance 函数则以意向节点与目标节点的路线距离作为下一步路径选择的依据。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
# maze.py continued
def manhattan_distance(goal: MazeLocation) -> Callable[[MazeLocation], float]:
def distance(ml: MazeLocation) -> float:
xdist: int = abs(ml.column - goal.column)
ydist: int = abs(ml.row - goal.row)
return (xdist + ydist)
return distance
# ...

# Test A*
distance: Callable[[MazeLocation], float] = manhattan_distance(m.goal)
solution3: Optional[Node[MazeLocation]] = astar(m.start, m.goal_test, m.successors, distance)
if solution3 is None:
print("No solution found using A*!")
else:
path3: List[MazeLocation] = node_to_path(solution3)
m.mark(path3)
print(m)

执行结果:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
S******X
X ****
***** *
*X X******
**** X
XX*X
X** XX***
X * ****X*
XX***XX *
X G

S X
*X
*
*X X
**** X
XX*X
X * XX
X * X
XX * XX
X******G

S X
*X
*
*X X
*******X
XX X **
X XX***
X X*
XX XX *
X G

参考资料

Classic Computer Science Problems in Python
davecom/ClassicComputerScienceProblemsInPython