23eca36e2bb748c296c5a7ca88b9fa578258c653
2 # Copyright 2022 Jacob Lifshay programmerjake@gmail.com
4 # Funded by NLnet Assure Programme 2021-02-052, https://nlnet.nl/assure part
5 # of Horizon 2020 EU Programme 957073.
7 from collections import defaultdict
8 import operator
9 from nmigen.hdl.ast import Value, Const
10 from nmutil.plain_data import plain_data
13 @plain_data(order=True, unsafe_hash=True, frozen=True)
14 class Op:
15 """An associative operation in a prefix-sum.
16 The operation is `items[self.out] = fn(items[self.lhs], items[self.rhs])`.
17 The operation is not assumed to be commutative.
18 """
19 __slots__ = "out", "lhs", "rhs", "row"
21 def __init__(self, out, lhs, rhs, row):
22 self.out = out
23 "index of the item to output to"
25 self.lhs = lhs
26 "index of the item the left-hand-side input comes from"
28 self.rhs = rhs
29 "index of the item the right-hand-side input comes from"
31 self.row = row
32 "row in the prefix-sum diagram"
35 def prefix_sum_ops(item_count, *, work_efficient=False):
36 """Get the associative operations needed to compute a parallel prefix-sum
37 of `item_count` items.
39 The operations aren't assumed to be commutative.
41 This has a depth of `O(log(N))` and an operation count of `O(N)` if
42 `work_efficient` is true, otherwise `O(N*log(N))`.
44 The algorithms used are derived from:
45 https://en.wikipedia.org/wiki/Prefix_sum#Algorithm_1:_Shorter_span,_more_parallel
46 https://en.wikipedia.org/wiki/Prefix_sum#Algorithm_2:_Work-efficient
48 Parameters:
49 item_count: int
50 number of input items.
51 work_efficient: bool
52 True if the algorithm used should be work-efficient -- has a larger
53 depth (about twice as large) but does only `O(N)` operations total
55 Returns: Iterable[Op]
56 output associative operations.
57 """
58 assert isinstance(item_count, int)
59 # compute the partial sums using a set of binary trees
60 # this is the first half of the work-efficient algorithm and the whole of
61 # the non-work-efficient algorithm.
62 dist = 1
63 row = 0
64 while dist < item_count:
65 start = dist * 2 - 1 if work_efficient else dist
66 step = dist * 2 if work_efficient else 1
67 for i in reversed(range(start, item_count, step)):
68 yield Op(out=i, lhs=i - dist, rhs=i, row=row)
69 dist <<= 1
70 row += 1
71 if work_efficient:
72 # express all output items in terms of the computed partial sums.
73 dist >>= 1
74 while dist >= 1:
75 for i in reversed(range(dist * 3 - 1, item_count, dist * 2)):
76 yield Op(out=i, lhs=i - dist, rhs=i, row=row)
77 row += 1
78 dist >>= 1
81 def prefix_sum(items, fn=operator.add, *, work_efficient=False):
82 """Compute the parallel prefix-sum of `items`, using associative operator
85 This has a depth of `O(log(N))` and an operation count of `O(N)` if
86 `work_efficient` is true, otherwise `O(N*log(N))`.
88 The algorithms used are derived from:
89 https://en.wikipedia.org/wiki/Prefix_sum#Algorithm_1:_Shorter_span,_more_parallel
90 https://en.wikipedia.org/wiki/Prefix_sum#Algorithm_2:_Work-efficient
92 Parameters:
93 items: Iterable[_T]
94 input items.
95 fn: Callable[[_T, _T], _T]
97 Assumed to be associative not necessarily commutative.
98 work_efficient: bool
99 True if the algorithm used should be work-efficient -- has a larger
100 depth (about twice as large) but does only `O(N)` operations total
102 Returns: list[_T]
103 output items.
104 """
105 items = list(items)
106 for op in prefix_sum_ops(len(items), work_efficient=work_efficient):
107 items[op.out] = fn(items[op.lhs], items[op.rhs])
108 return items
111 @plain_data()
112 class _Cell:
113 __slots__ = "slant", "plus", "tee"
115 def __init__(self, slant, plus, tee):
116 self.slant = slant
117 self.plus = plus
118 self.tee = tee
121 def render_prefix_sum_diagram(item_count, *, work_efficient=False,
122 sp=" ", vbar="|", plus="⊕",
123 slant="\\", connect="●", no_connect="X",
125 ):
126 """renders a prefix-sum diagram, matches `prefix_sum_ops`.
128 Parameters:
129 item_count: int
130 number of input items.
131 work_efficient: bool
132 True if the algorithm used should be work-efficient -- has a larger
133 depth (about twice as large) but does only `O(N)` operations total
135 sp: str
136 character used for blank space
137 vbar: str
138 character used for a vertical bar
139 plus: str
140 character used for the addition operation
141 slant: str
142 character used to draw a line from the top left to the bottom right
143 connect: str
144 character used to draw a connection between a vertical line and a line
145 going from the center of this character to the bottom right
146 no_connect: str
147 character used to draw two lines crossing but not connecting, the lines
148 are vertical and diagonal from top left to the bottom right
150 amount of padding characters in the output cells.
151 Returns: str
152 rendered diagram
153 """
154 ops_by_row = defaultdict(set)
155 for op in prefix_sum_ops(item_count, work_efficient=work_efficient):
156 assert op.out == op.rhs, f"can't draw op: {op}"
157 assert op not in ops_by_row[op.row], f"duplicate op: {op}"
160 def blank_row():
161 return [_Cell(slant=False, plus=False, tee=False)
162 for _ in range(item_count)]
164 cells = [blank_row()]
166 for row in sorted(ops_by_row.keys()):
167 ops = ops_by_row[row]
168 max_distance = max(op.rhs - op.lhs for op in ops)
169 cells.extend(blank_row() for _ in range(max_distance))
170 for op in ops:
171 assert op.lhs < op.rhs and op.out == op.rhs, f"can't draw op: {op}"
172 y = len(cells) - 1
173 x = op.out
174 cells[y][x].plus = True
175 x -= 1
176 y -= 1
177 while op.lhs < x:
178 cells[y][x].slant = True
179 x -= 1
180 y -= 1
181 cells[y][x].tee = True
183 lines = []
184 for cells_row in cells:
185 row_text = [[] for y in range(2 * padding + 1)]
186 for cell in cells_row:
191 is_slant = x == y and (cell.plus or cell.slant)
192 row_text[y].append(slant if is_slant else sp)
193 # top vertical bar
194 row_text[y].append(vbar)
197 row_text[y].append(sp)
201 # center
202 center = vbar
203 if cell.plus:
204 center = plus
205 elif cell.tee:
206 center = connect
207 elif cell.slant:
208 center = no_connect
214 for y in range(padding + 1, 2 * padding + 1):
217 row_text[y].append(sp)
218 # bottom vertical bar
219 row_text[y].append(vbar)
221 for x in range(padding + 1, 2 * padding + 1):
222 is_slant = x == y and (cell.tee or cell.slant)
223 row_text[y].append(slant if is_slant else sp)
224 for line in row_text:
225 lines.append("".join(line))
227 return "\n".join(map(str.rstrip, lines))
230 def partial_prefix_sum_ops(needed_outputs, *, work_efficient=False):
231 """ Get the associative operations needed to compute a parallel prefix-sum
232 of `len(needed_outputs)` items.
234 The operations aren't assumed to be commutative.
236 This has a depth of `O(log(N))` and an operation count of `O(N)` if
237 `work_efficient` is true, otherwise `O(N*log(N))`.
239 The algorithms used are derived from:
240 https://en.wikipedia.org/wiki/Prefix_sum#Algorithm_1:_Shorter_span,_more_parallel
241 https://en.wikipedia.org/wiki/Prefix_sum#Algorithm_2:_Work-efficient
243 Parameters:
244 needed_outputs: Iterable[bool]
245 The length is the number of input/output items.
246 Each item is True if that corresponding output is needed.
247 Unneeded outputs have unspecified value.
248 work_efficient: bool
249 True if the algorithm used should be work-efficient -- has a larger
250 depth (about twice as large) but does only `O(N)` operations total
252 Returns: Iterable[Op]
253 output associative operations.
254 """
256 # needed_outputs is an iterable, we need to construct a new list so we
257 # don't modify the passed-in value
258 items_live_flags = [bool(i) for i in needed_outputs]
259 ops = list(prefix_sum_ops(item_count=len(items_live_flags),
260 work_efficient=work_efficient))
261 ops_live_flags = [False] * len(ops)
262 for i in reversed(range(len(ops))):
263 op = ops[i]
264 out_live = items_live_flags[op.out]
265 items_live_flags[op.out] = False
266 items_live_flags[op.lhs] |= out_live
267 items_live_flags[op.rhs] |= out_live
268 ops_live_flags[i] = out_live
269 for op, live_flag in zip(ops, ops_live_flags):
270 if live_flag:
271 yield op
274 def tree_reduction_ops(item_count):
275 assert item_count >= 1
276 needed_outputs = (i == item_count - 1 for i in range(item_count))
277 return partial_prefix_sum_ops(needed_outputs)
281 items = list(items)
282 for op in tree_reduction_ops(len(items)):
283 items[op.out] = fn(items[op.lhs], items[op.rhs])
284 return items[-1]
287 if __name__ == "__main__":
288 print("the non-work-efficient algorithm, matches the diagram in wikipedia:"
289 "\n"
290 "https://commons.wikimedia.org/wiki/File:Hillis-Steele_Prefix_Sum.svg"
291 "\n\n")
292 print(render_prefix_sum_diagram(16, work_efficient=False))
293 print()
294 print()
295 print("the work-efficient algorithm, matches the diagram in wikipedia:")
296 print("https://en.wikipedia.org/wiki/File:Prefix_sum_16.svg")
297 print()
298 print(render_prefix_sum_diagram(16, work_efficient=True))
299 print()
300 print()