add reduce_only option to prefix_sum_ops
[nmutil.git] / src / nmutil / prefix_sum.py
1 # SPDX-License-Identifier: LGPL-3-or-later
2 # Copyright 2022 Jacob Lifshay programmerjake@gmail.com
3
4 # Funded by NLnet Assure Programme 2021-02-052, https://nlnet.nl/assure part
5 # of Horizon 2020 EU Programme 957073.
6
7 from collections import defaultdict
8 from dataclasses import dataclass
9 import operator
10
11
12 @dataclass(order=True, unsafe_hash=True, frozen=True)
13 class Op:
14 """An associative operation in a prefix-sum.
15 The operation is `items[self.out] = fn(items[self.lhs], items[self.rhs])`.
16 The operation is not assumed to be commutative.
17 """
18 out: int
19 """index of the item to output to"""
20 lhs: int
21 """index of the item the left-hand-side input comes from"""
22 rhs: int
23 """index of the item the right-hand-side input comes from"""
24 row: int
25 """row in the prefix-sum diagram"""
26
27
28 def prefix_sum_ops(item_count, *, work_efficient=False, reduce_only=False):
29 """ Get the associative operations needed to compute a parallel prefix-sum
30 of `item_count` items.
31
32 The operations aren't assumed to be commutative.
33
34 This has a depth of `O(log(N))` and an operation count of `O(N)` if
35 `work_efficient` is true, otherwise `O(N*log(N))`.
36
37 The algorithms used are derived from:
38 https://en.wikipedia.org/wiki/Prefix_sum#Algorithm_1:_Shorter_span,_more_parallel
39 https://en.wikipedia.org/wiki/Prefix_sum#Algorithm_2:_Work-efficient
40
41 Parameters:
42 item_count: int
43 number of input items.
44 work_efficient: bool
45 True if the algorithm used should be work-efficient -- has a larger
46 depth (about twice as large) but does only `O(N)` operations total
47 instead of `O(N*log(N))`.
48 reduce_only: bool
49 True if the work-efficient algorithm should stop after the initial
50 tree-reduction step.
51 Returns: Iterable[Op]
52 output associative operations.
53 """
54 assert isinstance(item_count, int)
55 # compute the partial sums using a set of binary trees
56 # this is the first half of the work-efficient algorithm and the whole of
57 # the non-work-efficient algorithm.
58 dist = 1
59 row = 0
60 while dist < item_count:
61 start = dist * 2 - 1 if work_efficient else dist
62 step = dist * 2 if work_efficient else 1
63 for i in reversed(range(start, item_count, step)):
64 yield Op(out=i, lhs=i - dist, rhs=i, row=row)
65 dist <<= 1
66 row += 1
67 if work_efficient and not reduce_only:
68 # express all output items in terms of the computed partial sums.
69 dist >>= 1
70 while dist >= 1:
71 for i in reversed(range(dist * 3 - 1, item_count, dist * 2)):
72 yield Op(out=i, lhs=i - dist, rhs=i, row=row)
73 row += 1
74 dist >>= 1
75
76
77 def prefix_sum(items, fn=operator.add, *, work_efficient=False):
78 """ Compute the parallel prefix-sum of `items`, using associative operator
79 `fn` instead of addition.
80
81 This has a depth of `O(log(N))` and an operation count of `O(N)` if
82 `work_efficient` is true, otherwise `O(N*log(N))`.
83
84 The algorithms used are derived from:
85 https://en.wikipedia.org/wiki/Prefix_sum#Algorithm_1:_Shorter_span,_more_parallel
86 https://en.wikipedia.org/wiki/Prefix_sum#Algorithm_2:_Work-efficient
87
88 Parameters:
89 items: Iterable[_T]
90 input items.
91 fn: Callable[[_T, _T], _T]
92 Operation to use for the prefix-sum algorithm instead of addition.
93 Assumed to be associative not necessarily commutative.
94 work_efficient: bool
95 True if the algorithm used should be work-efficient -- has a larger
96 depth (about twice as large) but does only `O(N)` operations total
97 instead of `O(N*log(N))`.
98 Returns: list[_T]
99 output items.
100 """
101 items = list(items)
102 for op in prefix_sum_ops(len(items), work_efficient=work_efficient):
103 items[op.out] = fn(items[op.lhs], items[op.rhs])
104 return items
105
106
107 @dataclass
108 class _Cell:
109 slant: bool
110 plus: bool
111 tee: bool
112
113
114 def render_prefix_sum_diagram(item_count, *, work_efficient=False,
115 sp=" ", vbar="|", plus="⊕",
116 slant="\\", connect="●", no_connect="X",
117 padding=1,
118 ):
119 """renders a prefix-sum diagram, matches `prefix_sum_ops`.
120
121 Parameters:
122 item_count: int
123 number of input items.
124 work_efficient: bool
125 True if the algorithm used should be work-efficient -- has a larger
126 depth (about twice as large) but does only `O(N)` operations total
127 instead of `O(N*log(N))`.
128 sp: str
129 character used for blank space
130 vbar: str
131 character used for a vertical bar
132 plus: str
133 character used for the addition operation
134 slant: str
135 character used to draw a line from the top left to the bottom right
136 connect: str
137 character used to draw a connection between a vertical line and a line
138 going from the center of this character to the bottom right
139 no_connect: str
140 character used to draw two lines crossing but not connecting, the lines
141 are vertical and diagonal from top left to the bottom right
142 padding: int
143 amount of padding characters in the output cells.
144 Returns: str
145 rendered diagram
146 """
147 assert isinstance(item_count, int)
148 assert isinstance(padding, int)
149 ops_by_row = defaultdict(set)
150 for op in prefix_sum_ops(item_count, work_efficient=work_efficient):
151 assert op.out == op.rhs, f"can't draw op: {op}"
152 assert op not in ops_by_row[op.row], f"duplicate op: {op}"
153 ops_by_row[op.row].add(op)
154
155 def blank_row():
156 return [_Cell(slant=False, plus=False, tee=False)
157 for _ in range(item_count)]
158
159 cells = [blank_row()]
160
161 for row in sorted(ops_by_row.keys()):
162 ops = ops_by_row[row]
163 max_distance = max(op.rhs - op.lhs for op in ops)
164 cells.extend(blank_row() for _ in range(max_distance))
165 for op in ops:
166 assert op.lhs < op.rhs and op.out == op.rhs, f"can't draw op: {op}"
167 y = len(cells) - 1
168 x = op.out
169 cells[y][x].plus = True
170 x -= 1
171 y -= 1
172 while op.lhs < x:
173 cells[y][x].slant = True
174 x -= 1
175 y -= 1
176 cells[y][x].tee = True
177
178 lines = []
179 for cells_row in cells:
180 row_text = [[] for y in range(2 * padding + 1)]
181 for cell in cells_row:
182 # top padding
183 for y in range(padding):
184 # top left padding
185 for x in range(padding):
186 is_slant = x == y and (cell.plus or cell.slant)
187 row_text[y].append(slant if is_slant else sp)
188 # top vertical bar
189 row_text[y].append(vbar)
190 # top right padding
191 for x in range(padding):
192 row_text[y].append(sp)
193 # center left padding
194 for x in range(padding):
195 row_text[padding].append(sp)
196 # center
197 center = vbar
198 if cell.plus:
199 center = plus
200 elif cell.tee:
201 center = connect
202 elif cell.slant:
203 center = no_connect
204 row_text[padding].append(center)
205 # center right padding
206 for x in range(padding):
207 row_text[padding].append(sp)
208 # bottom padding
209 for y in range(padding + 1, 2 * padding + 1):
210 # bottom left padding
211 for x in range(padding):
212 row_text[y].append(sp)
213 # bottom vertical bar
214 row_text[y].append(vbar)
215 # bottom right padding
216 for x in range(padding + 1, 2 * padding + 1):
217 is_slant = x == y and (cell.tee or cell.slant)
218 row_text[y].append(slant if is_slant else sp)
219 for line in row_text:
220 lines.append("".join(line))
221
222 return "\n".join(map(str.rstrip, lines))
223
224
225 if __name__ == "__main__":
226 print("the non-work-efficient algorithm, matches the diagram in wikipedia:"
227 "\n"
228 "https://commons.wikimedia.org/wiki/File:Hillis-Steele_Prefix_Sum.svg"
229 "\n\n")
230 print(render_prefix_sum_diagram(16, work_efficient=False))
231 print()
232 print()
233 print("the work-efficient algorithm, matches the diagram in wikipedia:")
234 print("https://en.wikipedia.org/wiki/File:Prefix_sum_16.svg")
235 print()
236 print(render_prefix_sum_diagram(16, work_efficient=True))
237 print()
238 print()