Allow the formal engine to perform a same-cycle result in the ALU
[soc.git] / src / soc / fu / logical / popcount.py
1 """Popcount: a successive (cascading) sum-reduction algorithm for counting bits
2
3 starting from single-bit adds and reducing down to one final answer:
4 the total number of bits set to "1" in the input.
5
6 unfortunately there is a bit of a "trick" going on which you have to
7 watch out for: whilst the first list added to pc is a single entry (the
8 input, a), subsequent rows in the cascade are arrays of partial results,
9 yet it turns out that referring to them using the exact same start/end
10 slicing is perfect. this comes down to nmigen's transparent use of
11 python features to make Signals iterable.
12 """
13
14 from nmigen import (Elaboratable, Module, Signal, Cat, Const)
15
16
17 def array_of(count, bitwidth):
18 res = []
19 for i in range(count):
20 res.append(Signal(bitwidth, reset_less=True,
21 name=f"pop_{bitwidth}_{i}"))
22 return res
23
24
25 class Popcount(Elaboratable):
26 def __init__(self):
27 self.a = Signal(64, reset_less=True)
28 self.b = Signal(64, reset_less=True)
29 self.data_len = Signal(4, reset_less=True) # data len up to... err.. 8?
30 self.o = Signal(64, reset_less=True)
31
32 def elaborate(self, platform):
33 m = Module()
34 comb = m.d.comb
35 a, b, data_len, o = self.a, self.b, self.data_len, self.o
36
37 # starting from a, perform successive addition-reductions
38 # creating arrays big enough to store the sum, each time
39 pc = [a]
40 # QTY32 2-bit (to take 2x 1-bit sums) etc.
41 work = [(32, 2), (16, 3), (8, 4), (4, 5), (2, 6), (1, 7)]
42 for l, bw in work: # l=number of add-reductions, bw=bitwidth
43 pc.append(array_of(l, bw))
44 pc8 = pc[3] # array of 8 8-bit counts (popcntb)
45 pc32 = pc[5] # array of 2 32-bit counts (popcntw)
46 popcnt = pc[-1] # array of 1 64-bit count (popcntd)
47 # cascade-tree of adds
48 for idx, (l, bw) in enumerate(work):
49 for i in range(l):
50 stt, end = i*2, i*2+1
51 src, dst = pc[idx], pc[idx+1]
52 comb += dst[i].eq(Cat(src[stt], Const(0, 1)) +
53 Cat(src[end], Const(0, 1)))
54 # decode operation length (1-hot)
55 with m.If(data_len == 1):
56 # popcntb - pack 8x 4-bit answers into 8x 8-bit output fields
57 for i in range(8):
58 comb += o[i*8:(i+1)*8].eq(pc8[i])
59 with m.Elif(data_len == 4):
60 # popcntw - pack 2x 5-bit answers into 2x 32-bit output fields
61 for i in range(2):
62 comb += o[i*32:(i+1)*32].eq(pc32[i])
63 with m.Else():
64 # popcntd - put 1x 6-bit answer into 64-bit output
65 comb += o.eq(popcnt[0])
66
67 return m