23eca36e2bb748c296c5a7ca88b9fa578258c653
1 # SPDX-License-Identifier: LGPL-3-or-later
2 # Copyright 2022 Jacob Lifshay programmerjake@gmail.com
4 # Funded by NLnet Assure Programme 2021-02-052, https://nlnet.nl/assure part
5 # of Horizon 2020 EU Programme 957073.
7 from collections
import defaultdict
9 from nmigen
.hdl
.ast
import Value
, Const
10 from nmutil
.plain_data
import plain_data
13 @plain_data(order
=True, unsafe_hash
=True, frozen
=True)
15 """An associative operation in a prefix-sum.
16 The operation is `items[self.out] = fn(items[self.lhs], items[self.rhs])`.
17 The operation is not assumed to be commutative.
19 __slots__
= "out", "lhs", "rhs", "row"
21 def __init__(self
, out
, lhs
, rhs
, row
):
23 "index of the item to output to"
26 "index of the item the left-hand-side input comes from"
29 "index of the item the right-hand-side input comes from"
32 "row in the prefix-sum diagram"
35 def prefix_sum_ops(item_count
, *, work_efficient
=False):
36 """Get the associative operations needed to compute a parallel prefix-sum
37 of `item_count` items.
39 The operations aren't assumed to be commutative.
41 This has a depth of `O(log(N))` and an operation count of `O(N)` if
42 `work_efficient` is true, otherwise `O(N*log(N))`.
44 The algorithms used are derived from:
45 https://en.wikipedia.org/wiki/Prefix_sum#Algorithm_1:_Shorter_span,_more_parallel
46 https://en.wikipedia.org/wiki/Prefix_sum#Algorithm_2:_Work-efficient
50 number of input items.
52 True if the algorithm used should be work-efficient -- has a larger
53 depth (about twice as large) but does only `O(N)` operations total
54 instead of `O(N*log(N))`.
56 output associative operations.
58 assert isinstance(item_count
, int)
59 # compute the partial sums using a set of binary trees
60 # this is the first half of the work-efficient algorithm and the whole of
61 # the non-work-efficient algorithm.
64 while dist
< item_count
:
65 start
= dist
* 2 - 1 if work_efficient
else dist
66 step
= dist
* 2 if work_efficient
else 1
67 for i
in reversed(range(start
, item_count
, step
)):
68 yield Op(out
=i
, lhs
=i
- dist
, rhs
=i
, row
=row
)
72 # express all output items in terms of the computed partial sums.
75 for i
in reversed(range(dist
* 3 - 1, item_count
, dist
* 2)):
76 yield Op(out
=i
, lhs
=i
- dist
, rhs
=i
, row
=row
)
81 def prefix_sum(items
, fn
=operator
.add
, *, work_efficient
=False):
82 """Compute the parallel prefix-sum of `items`, using associative operator
83 `fn` instead of addition.
85 This has a depth of `O(log(N))` and an operation count of `O(N)` if
86 `work_efficient` is true, otherwise `O(N*log(N))`.
88 The algorithms used are derived from:
89 https://en.wikipedia.org/wiki/Prefix_sum#Algorithm_1:_Shorter_span,_more_parallel
90 https://en.wikipedia.org/wiki/Prefix_sum#Algorithm_2:_Work-efficient
95 fn: Callable[[_T, _T], _T]
96 Operation to use for the prefix-sum algorithm instead of addition.
97 Assumed to be associative not necessarily commutative.
99 True if the algorithm used should be work-efficient -- has a larger
100 depth (about twice as large) but does only `O(N)` operations total
101 instead of `O(N*log(N))`.
106 for op
in prefix_sum_ops(len(items
), work_efficient
=work_efficient
):
107 items
[op
.out
] = fn(items
[op
.lhs
], items
[op
.rhs
])
113 __slots__
= "slant", "plus", "tee"
115 def __init__(self
, slant
, plus
, tee
):
121 def render_prefix_sum_diagram(item_count
, *, work_efficient
=False,
122 sp
=" ", vbar
="|", plus
="⊕",
123 slant
="\\", connect
="●", no_connect
="X",
126 """renders a prefix-sum diagram, matches `prefix_sum_ops`.
130 number of input items.
132 True if the algorithm used should be work-efficient -- has a larger
133 depth (about twice as large) but does only `O(N)` operations total
134 instead of `O(N*log(N))`.
136 character used for blank space
138 character used for a vertical bar
140 character used for the addition operation
142 character used to draw a line from the top left to the bottom right
144 character used to draw a connection between a vertical line and a line
145 going from the center of this character to the bottom right
147 character used to draw two lines crossing but not connecting, the lines
148 are vertical and diagonal from top left to the bottom right
150 amount of padding characters in the output cells.
154 ops_by_row
= defaultdict(set)
155 for op
in prefix_sum_ops(item_count
, work_efficient
=work_efficient
):
156 assert op
.out
== op
.rhs
, f
"can't draw op: {op}"
157 assert op
not in ops_by_row
[op
.row
], f
"duplicate op: {op}"
158 ops_by_row
[op
.row
].add(op
)
161 return [_Cell(slant
=False, plus
=False, tee
=False)
162 for _
in range(item_count
)]
164 cells
= [blank_row()]
166 for row
in sorted(ops_by_row
.keys()):
167 ops
= ops_by_row
[row
]
168 max_distance
= max(op
.rhs
- op
.lhs
for op
in ops
)
169 cells
.extend(blank_row() for _
in range(max_distance
))
171 assert op
.lhs
< op
.rhs
and op
.out
== op
.rhs
, f
"can't draw op: {op}"
174 cells
[y
][x
].plus
= True
178 cells
[y
][x
].slant
= True
181 cells
[y
][x
].tee
= True
184 for cells_row
in cells
:
185 row_text
= [[] for y
in range(2 * padding
+ 1)]
186 for cell
in cells_row
:
188 for y
in range(padding
):
190 for x
in range(padding
):
191 is_slant
= x
== y
and (cell
.plus
or cell
.slant
)
192 row_text
[y
].append(slant
if is_slant
else sp
)
194 row_text
[y
].append(vbar
)
196 for x
in range(padding
):
197 row_text
[y
].append(sp
)
198 # center left padding
199 for x
in range(padding
):
200 row_text
[padding
].append(sp
)
209 row_text
[padding
].append(center
)
210 # center right padding
211 for x
in range(padding
):
212 row_text
[padding
].append(sp
)
214 for y
in range(padding
+ 1, 2 * padding
+ 1):
215 # bottom left padding
216 for x
in range(padding
):
217 row_text
[y
].append(sp
)
218 # bottom vertical bar
219 row_text
[y
].append(vbar
)
220 # bottom right padding
221 for x
in range(padding
+ 1, 2 * padding
+ 1):
222 is_slant
= x
== y
and (cell
.tee
or cell
.slant
)
223 row_text
[y
].append(slant
if is_slant
else sp
)
224 for line
in row_text
:
225 lines
.append("".join(line
))
227 return "\n".join(map(str.rstrip
, lines
))
230 def partial_prefix_sum_ops(needed_outputs
, *, work_efficient
=False):
231 """ Get the associative operations needed to compute a parallel prefix-sum
232 of `len(needed_outputs)` items.
234 The operations aren't assumed to be commutative.
236 This has a depth of `O(log(N))` and an operation count of `O(N)` if
237 `work_efficient` is true, otherwise `O(N*log(N))`.
239 The algorithms used are derived from:
240 https://en.wikipedia.org/wiki/Prefix_sum#Algorithm_1:_Shorter_span,_more_parallel
241 https://en.wikipedia.org/wiki/Prefix_sum#Algorithm_2:_Work-efficient
244 needed_outputs: Iterable[bool]
245 The length is the number of input/output items.
246 Each item is True if that corresponding output is needed.
247 Unneeded outputs have unspecified value.
249 True if the algorithm used should be work-efficient -- has a larger
250 depth (about twice as large) but does only `O(N)` operations total
251 instead of `O(N*log(N))`.
252 Returns: Iterable[Op]
253 output associative operations.
256 # needed_outputs is an iterable, we need to construct a new list so we
257 # don't modify the passed-in value
258 items_live_flags
= [bool(i
) for i
in needed_outputs
]
259 ops
= list(prefix_sum_ops(item_count
=len(items_live_flags
),
260 work_efficient
=work_efficient
))
261 ops_live_flags
= [False] * len(ops
)
262 for i
in reversed(range(len(ops
))):
264 out_live
= items_live_flags
[op
.out
]
265 items_live_flags
[op
.out
] = False
266 items_live_flags
[op
.lhs
] |
= out_live
267 items_live_flags
[op
.rhs
] |
= out_live
268 ops_live_flags
[i
] = out_live
269 for op
, live_flag
in zip(ops
, ops_live_flags
):
274 def tree_reduction_ops(item_count
):
275 assert item_count
>= 1
276 needed_outputs
= (i
== item_count
- 1 for i
in range(item_count
))
277 return partial_prefix_sum_ops(needed_outputs
)
280 def tree_reduction(items
, fn
=operator
.add
):
282 for op
in tree_reduction_ops(len(items
)):
283 items
[op
.out
] = fn(items
[op
.lhs
], items
[op
.rhs
])
287 if __name__
== "__main__":
288 print("the non-work-efficient algorithm, matches the diagram in wikipedia:"
290 "https://commons.wikimedia.org/wiki/File:Hillis-Steele_Prefix_Sum.svg"
292 print(render_prefix_sum_diagram(16, work_efficient
=False))
295 print("the work-efficient algorithm, matches the diagram in wikipedia:")
296 print("https://en.wikipedia.org/wiki/File:Prefix_sum_16.svg")
298 print(render_prefix_sum_diagram(16, work_efficient
=True))