1 # SPDX-License-Identifier: LGPL-3-or-later
2 # Copyright 2022 Jacob Lifshay programmerjake@gmail.com
4 # Funded by NLnet Assure Programme 2021-02-052, https://nlnet.nl/assure part
5 # of Horizon 2020 EU Programme 957073.
7 from collections
import defaultdict
8 from dataclasses
import dataclass
10 from nmigen
.hdl
.ast
import Value
, Const
13 @dataclass(order
=True, unsafe_hash
=True, frozen
=True)
15 """An associative operation in a prefix-sum.
16 The operation is `items[self.out] = fn(items[self.lhs], items[self.rhs])`.
17 The operation is not assumed to be commutative.
20 """index of the item to output to"""
22 """index of the item the left-hand-side input comes from"""
24 """index of the item the right-hand-side input comes from"""
26 """row in the prefix-sum diagram"""
29 def prefix_sum_ops(item_count
, *, work_efficient
=False):
30 """ Get the associative operations needed to compute a parallel prefix-sum
31 of `item_count` items.
33 The operations aren't assumed to be commutative.
35 This has a depth of `O(log(N))` and an operation count of `O(N)` if
36 `work_efficient` is true, otherwise `O(N*log(N))`.
38 The algorithms used are derived from:
39 https://en.wikipedia.org/wiki/Prefix_sum#Algorithm_1:_Shorter_span,_more_parallel
40 https://en.wikipedia.org/wiki/Prefix_sum#Algorithm_2:_Work-efficient
44 number of input items.
46 True if the algorithm used should be work-efficient -- has a larger
47 depth (about twice as large) but does only `O(N)` operations total
48 instead of `O(N*log(N))`.
50 output associative operations.
52 assert isinstance(item_count
, int)
53 # compute the partial sums using a set of binary trees
54 # this is the first half of the work-efficient algorithm and the whole of
55 # the non-work-efficient algorithm.
58 while dist
< item_count
:
59 start
= dist
* 2 - 1 if work_efficient
else dist
60 step
= dist
* 2 if work_efficient
else 1
61 for i
in reversed(range(start
, item_count
, step
)):
62 yield Op(out
=i
, lhs
=i
- dist
, rhs
=i
, row
=row
)
66 # express all output items in terms of the computed partial sums.
69 for i
in reversed(range(dist
* 3 - 1, item_count
, dist
* 2)):
70 yield Op(out
=i
, lhs
=i
- dist
, rhs
=i
, row
=row
)
75 def prefix_sum(items
, fn
=operator
.add
, *, work_efficient
=False):
76 """ Compute the parallel prefix-sum of `items`, using associative operator
77 `fn` instead of addition.
79 This has a depth of `O(log(N))` and an operation count of `O(N)` if
80 `work_efficient` is true, otherwise `O(N*log(N))`.
82 The algorithms used are derived from:
83 https://en.wikipedia.org/wiki/Prefix_sum#Algorithm_1:_Shorter_span,_more_parallel
84 https://en.wikipedia.org/wiki/Prefix_sum#Algorithm_2:_Work-efficient
89 fn: Callable[[_T, _T], _T]
90 Operation to use for the prefix-sum algorithm instead of addition.
91 Assumed to be associative not necessarily commutative.
93 True if the algorithm used should be work-efficient -- has a larger
94 depth (about twice as large) but does only `O(N)` operations total
95 instead of `O(N*log(N))`.
100 for op
in prefix_sum_ops(len(items
), work_efficient
=work_efficient
):
101 items
[op
.out
] = fn(items
[op
.lhs
], items
[op
.rhs
])
112 def render_prefix_sum_diagram(item_count
, *, work_efficient
=False,
113 sp
=" ", vbar
="|", plus
="⊕",
114 slant
="\\", connect
="●", no_connect
="X",
117 """renders a prefix-sum diagram, matches `prefix_sum_ops`.
121 number of input items.
123 True if the algorithm used should be work-efficient -- has a larger
124 depth (about twice as large) but does only `O(N)` operations total
125 instead of `O(N*log(N))`.
127 character used for blank space
129 character used for a vertical bar
131 character used for the addition operation
133 character used to draw a line from the top left to the bottom right
135 character used to draw a connection between a vertical line and a line
136 going from the center of this character to the bottom right
138 character used to draw two lines crossing but not connecting, the lines
139 are vertical and diagonal from top left to the bottom right
141 amount of padding characters in the output cells.
145 assert isinstance(item_count
, int)
146 assert isinstance(padding
, int)
147 ops_by_row
= defaultdict(set)
148 for op
in prefix_sum_ops(item_count
, work_efficient
=work_efficient
):
149 assert op
.out
== op
.rhs
, f
"can't draw op: {op}"
150 assert op
not in ops_by_row
[op
.row
], f
"duplicate op: {op}"
151 ops_by_row
[op
.row
].add(op
)
154 return [_Cell(slant
=False, plus
=False, tee
=False)
155 for _
in range(item_count
)]
157 cells
= [blank_row()]
159 for row
in sorted(ops_by_row
.keys()):
160 ops
= ops_by_row
[row
]
161 max_distance
= max(op
.rhs
- op
.lhs
for op
in ops
)
162 cells
.extend(blank_row() for _
in range(max_distance
))
164 assert op
.lhs
< op
.rhs
and op
.out
== op
.rhs
, f
"can't draw op: {op}"
167 cells
[y
][x
].plus
= True
171 cells
[y
][x
].slant
= True
174 cells
[y
][x
].tee
= True
177 for cells_row
in cells
:
178 row_text
= [[] for y
in range(2 * padding
+ 1)]
179 for cell
in cells_row
:
181 for y
in range(padding
):
183 for x
in range(padding
):
184 is_slant
= x
== y
and (cell
.plus
or cell
.slant
)
185 row_text
[y
].append(slant
if is_slant
else sp
)
187 row_text
[y
].append(vbar
)
189 for x
in range(padding
):
190 row_text
[y
].append(sp
)
191 # center left padding
192 for x
in range(padding
):
193 row_text
[padding
].append(sp
)
202 row_text
[padding
].append(center
)
203 # center right padding
204 for x
in range(padding
):
205 row_text
[padding
].append(sp
)
207 for y
in range(padding
+ 1, 2 * padding
+ 1):
208 # bottom left padding
209 for x
in range(padding
):
210 row_text
[y
].append(sp
)
211 # bottom vertical bar
212 row_text
[y
].append(vbar
)
213 # bottom right padding
214 for x
in range(padding
+ 1, 2 * padding
+ 1):
215 is_slant
= x
== y
and (cell
.tee
or cell
.slant
)
216 row_text
[y
].append(slant
if is_slant
else sp
)
217 for line
in row_text
:
218 lines
.append("".join(line
))
220 return "\n".join(map(str.rstrip
, lines
))
223 def partial_prefix_sum_ops(needed_outputs
, *, work_efficient
=False):
224 """ Get the associative operations needed to compute a parallel prefix-sum
225 of `len(needed_outputs)` items.
227 The operations aren't assumed to be commutative.
229 This has a depth of `O(log(N))` and an operation count of `O(N)` if
230 `work_efficient` is true, otherwise `O(N*log(N))`.
232 The algorithms used are derived from:
233 https://en.wikipedia.org/wiki/Prefix_sum#Algorithm_1:_Shorter_span,_more_parallel
234 https://en.wikipedia.org/wiki/Prefix_sum#Algorithm_2:_Work-efficient
237 needed_outputs: Iterable[bool]
238 The length is the number of input/output items.
239 Each item is True if that corresponding output is needed.
240 Unneeded outputs have unspecified value.
242 True if the algorithm used should be work-efficient -- has a larger
243 depth (about twice as large) but does only `O(N)` operations total
244 instead of `O(N*log(N))`.
245 Returns: Iterable[Op]
246 output associative operations.
249 assert isinstance(v
, bool)
251 items_live_flags
= [assert_bool(i
) for i
in needed_outputs
]
252 ops
= list(prefix_sum_ops(item_count
=len(items_live_flags
),
253 work_efficient
=work_efficient
))
254 ops_live_flags
= [False] * len(ops
)
255 for i
in reversed(range(len(ops
))):
257 out_live
= items_live_flags
[op
.out
]
258 items_live_flags
[op
.out
] = False
259 items_live_flags
[op
.lhs
] |
= out_live
260 items_live_flags
[op
.rhs
] |
= out_live
261 ops_live_flags
[i
] = out_live
262 for op
, live_flag
in zip(ops
, ops_live_flags
):
267 def tree_reduction_ops(item_count
):
268 assert isinstance(item_count
, int) and item_count
>= 1
269 needed_outputs
= (i
== item_count
- 1 for i
in range(item_count
))
270 return partial_prefix_sum_ops(needed_outputs
)
273 def tree_reduction(items
, fn
=operator
.add
):
275 for op
in tree_reduction_ops(len(items
)):
276 items
[op
.out
] = fn(items
[op
.lhs
], items
[op
.rhs
])
280 def pop_count(v
, *, width
=None, process_temporary
=lambda v
: v
):
281 if isinstance(v
, Value
):
284 assert width
== len(v
)
285 bits
= [v
[i
] for i
in range(width
)]
289 assert isinstance(width
, int) and width
>= 0
290 assert isinstance(v
, int)
291 bits
= [(v
& (1 << i
)) != 0 for i
in range(width
)]
294 return tree_reduction(bits
, fn
=lambda a
, b
: process_temporary(a
+ b
))
297 if __name__
== "__main__":
298 print("the non-work-efficient algorithm, matches the diagram in wikipedia:"
300 "https://commons.wikimedia.org/wiki/File:Hillis-Steele_Prefix_Sum.svg"
302 print(render_prefix_sum_diagram(16, work_efficient
=False))
305 print("the work-efficient algorithm, matches the diagram in wikipedia:")
306 print("https://en.wikipedia.org/wiki/File:Prefix_sum_16.svg")
308 print(render_prefix_sum_diagram(16, work_efficient
=True))