Keep the valid signal from the formal engine ALU stable, until read
[soc.git] / src / soc / logical / main_stage.py
1 # This stage is intended to do most of the work of executing Logical
2 # instructions. This is OR, AND, XOR, POPCNT, PRTY, CMPB, BPERMD, CNTLZ
3 # however input and output stages also perform bit-negation on input(s)
4 # and output, as well as carry and overflow generation.
5 # This module however should not gate the carry or overflow, that's up
6 # to the output stage
7
8 from nmigen import (Module, Signal, Cat, Repl, Mux, Const, Array)
9 from nmutil.pipemodbase import PipeModBase
10 from soc.logical.pipe_data import ALUInputData
11 from soc.alu.pipe_data import ALUOutputData
12 from ieee754.part.partsig import PartitionedSignal
13 from soc.decoder.power_enums import InternalOp
14 from soc.countzero.countzero import ZeroCounter
15
16 from soc.decoder.power_fields import DecodeFields
17 from soc.decoder.power_fieldsn import SignalBitRange
18
19
20 def array_of(count, bitwidth):
21 res = []
22 for i in range(count):
23 res.append(Signal(bitwidth, reset_less=True))
24 return res
25
26
27 class LogicalMainStage(PipeModBase):
28 def __init__(self, pspec):
29 super().__init__(pspec, "main")
30 self.fields = DecodeFields(SignalBitRange, [self.i.ctx.op.insn])
31 self.fields.create_specs()
32
33 def ispec(self):
34 return ALUInputData(self.pspec)
35
36 def ospec(self):
37 return ALUOutputData(self.pspec) # TODO: ALUIntermediateData
38
39 def elaborate(self, platform):
40 m = Module()
41 comb = m.d.comb
42 op, a, b, o = self.i.ctx.op, self.i.a, self.i.b, self.o.o
43
44 ##########################
45 # main switch for logic ops AND, OR and XOR, cmpb, parity, and popcount
46
47 with m.Switch(op.insn_type):
48
49 ###### AND, OR, XOR #######
50 with m.Case(InternalOp.OP_AND):
51 comb += o.eq(a & b)
52 with m.Case(InternalOp.OP_OR):
53 comb += o.eq(a | b)
54 with m.Case(InternalOp.OP_XOR):
55 comb += o.eq(a ^ b)
56
57 ###### cmpb #######
58 with m.Case(InternalOp.OP_CMPB):
59 l = []
60 for i in range(8):
61 slc = slice(i*8, (i+1)*8)
62 l.append(Repl(a[slc] == b[slc], 8))
63 comb += o.eq(Cat(*l))
64
65 ###### popcount #######
66 with m.Case(InternalOp.OP_POPCNT):
67 # starting from a, perform successive addition-reductions
68 # creating arrays big enough to store the sum, each time
69 pc = [a]
70 # QTY32 2-bit (to take 2x 1-bit sums) etc.
71 work = [(32, 2), (16, 3), (8, 4), (4, 5), (2, 6), (1, 6)]
72 for l, b in work:
73 pc.append(array_of(l, b))
74 pc8 = pc[3] # array of 8 8-bit counts (popcntb)
75 pc32 = pc[5] # array of 2 32-bit counts (popcntw)
76 popcnt = pc[-1] # array of 1 64-bit count (popcntd)
77 # cascade-tree of adds
78 for idx, (l, b) in enumerate(work):
79 for i in range(l):
80 stt, end = i*2, i*2+1
81 src, dst = pc[idx], pc[idx+1]
82 comb += dst[i].eq(Cat(src[stt], Const(0, 1)) +
83 Cat(src[end], Const(0, 1)))
84 # decode operation length
85 with m.If(op.data_len[2:4] == 0b00):
86 # popcntb - pack 8x 4-bit answers into output
87 for i in range(8):
88 comb += o[i*8:i*8+4].eq(pc8[i])
89 with m.Elif(op.data_len[3] == 0):
90 # popcntw - pack 2x 5-bit answers into output
91 for i in range(2):
92 comb += o[i*32:i*32+5].eq(pc32[i])
93 with m.Else():
94 # popcntd - put 1x 6-bit answer into output
95 comb += o.eq(popcnt[0])
96
97 ###### parity #######
98 with m.Case(InternalOp.OP_PRTY):
99 # strange instruction which XORs together the LSBs of each byte
100 par0 = Signal(reset_less=True)
101 par1 = Signal(reset_less=True)
102 comb += par0.eq(Cat(a[0] , a[8] , a[16], a[24]).xor())
103 comb += par1.eq(Cat(a[32], a[40], a[48], a[32]).xor())
104 with m.If(op.data_len[3] == 1):
105 comb += o.eq(par0 ^ par1)
106 with m.Else():
107 comb += o[0].eq(par0)
108 comb += o[32].eq(par1)
109
110 ###### cntlz #######
111 with m.Case(InternalOp.OP_CNTZ):
112 x_fields = self.fields.instrs['X']
113 XO = Signal(x_fields['XO'][0:-1].shape())
114 m.submodules.countz = countz = ZeroCounter()
115 comb += countz.rs_i.eq(a)
116 comb += countz.is_32bit_i.eq(op.is_32bit)
117 comb += countz.count_right_i.eq(XO[-1])
118 comb += o.eq(countz.result_o)
119
120 ###### bpermd #######
121 # TODO with m.Case(InternalOp.OP_BPERM): - not in microwatt
122
123 ###### sticky overflow and context, both pass-through #####
124
125 comb += self.o.so.eq(self.i.so)
126 comb += self.o.ctx.eq(self.i.ctx)
127
128 return m