Add Count Leading Zeros module to nmutil
authorMichael Nolan <mtnolan2640@gmail.com>
Tue, 5 May 2020 15:15:43 +0000 (11:15 -0400)
committerMichael Nolan <mtnolan2640@gmail.com>
Tue, 5 May 2020 15:15:43 +0000 (11:15 -0400)
src/nmutil/clz.py [new file with mode: 0644]
src/nmutil/formal/proof_clz.py [new file with mode: 0644]
src/nmutil/test/test_clz.py [new file with mode: 0644]

diff --git a/src/nmutil/clz.py b/src/nmutil/clz.py
new file mode 100644 (file)
index 0000000..fed98cf
--- /dev/null
@@ -0,0 +1,83 @@
+from nmigen import Module, Signal, Elaboratable, Cat, Repl
+import math
+
+class CLZ(Elaboratable):
+    def __init__(self, width):
+        self.width = width
+        self.sig_in = Signal(width, reset_less=True)
+        out_width = math.ceil(math.log2(width+1))
+        self.lz = Signal(out_width)
+
+    def generate_pairs(self, m):
+        comb = m.d.comb
+        pairs = []
+        for i in range(0, self.width, 2):
+            if i+1 >= self.width:
+                pair = Signal(1, name="cnt_1_%d" % (i/2))
+                comb += pair.eq(~self.sig_in[i])
+                pairs.append((pair, 1))
+            else:
+                pair = Signal(2, name="pair%d" % i)
+                comb += pair.eq(self.sig_in[i:i+2])
+
+                pair_cnt = Signal(2, name="cnt_1_%d" % (i/2))
+                with m.Switch(pair):
+                    with m.Case(0):
+                        comb += pair_cnt.eq(2)
+                    with m.Case(1):
+                        comb += pair_cnt.eq(1)
+                    with m.Default():
+                        comb += pair_cnt.eq(0)
+                pairs.append((pair_cnt, 2))  # append pair, max_value
+        return pairs
+
+    def combine_pairs(self, m, iteration, pairs):
+        comb = m.d.comb
+        length = len(pairs)
+        ret = []
+        for i in range(0, length, 2):
+            if i+1 >= length:
+                right, mv = pairs[i]
+                width = right.width
+                new_pair = Signal(width, name="cnt_%d_%d" % (iteration, i))
+                comb += new_pair.eq(Cat(right, 0))
+                ret.append((new_pair, mv))
+            else:
+                left, lv = pairs[i+1]
+                right, rv = pairs[i]
+                width = right.width + 1
+                new_pair = Signal(width, name="cnt_%d_%d" %
+                                  (iteration, i))
+                if rv == lv:
+                    with m.If(left[-1] == 1):
+                        with m.If(right[-1] == 1):
+                            comb += new_pair.eq(Cat(Repl(0, width-1), 1))
+                        with m.Else():
+                            comb += new_pair.eq(Cat(right[0:-1], 0b01))
+                    with m.Else():
+                        comb += new_pair.eq(Cat(left, 0))
+                else:
+                    with m.If(left == lv):
+                        comb += new_pair.eq(right + left)
+                    with m.Else():
+                        comb += new_pair.eq(left)
+
+
+                ret.append((new_pair, lv+rv))
+        return ret
+
+    def elaborate(self, platform):
+        m = Module()
+        comb = m.d.comb
+
+        pairs = self.generate_pairs(m)
+        i = 2
+        while len(pairs) > 1:
+            pairs = self.combine_pairs(m, i, pairs)
+            i += 1
+
+        comb += self.lz.eq(pairs[0][0])
+
+        return m
+
+
diff --git a/src/nmutil/formal/proof_clz.py b/src/nmutil/formal/proof_clz.py
new file mode 100644 (file)
index 0000000..209658e
--- /dev/null
@@ -0,0 +1,62 @@
+from nmigen import Module, Signal, Elaboratable, Mux, Const
+from nmigen.asserts import Assert, AnyConst, Assume
+from nmigen.test.utils import FHDLTestCase
+from nmigen.cli import rtlil
+
+from nmutil.clz import CLZ
+import unittest
+
+
+# This defines a module to drive the device under test and assert
+# properties about its outputs
+class Driver(Elaboratable):
+    def __init__(self):
+        # inputs and outputs
+        pass
+
+    def elaborate(self, platform):
+        m = Module()
+        comb = m.d.comb
+        width = 10
+
+        m.submodules.dut = dut = CLZ(width)
+        sig_in = Signal.like(dut.sig_in)
+        count = Signal.like(dut.lz)
+
+
+        m.d.comb += [
+            sig_in.eq(AnyConst(width)),
+            dut.sig_in.eq(sig_in),
+            count.eq(dut.lz)]
+
+        result = Const(width)
+        for i in range(width):
+            print(result)
+            result_next = Signal.like(count, name="count_%d" % i)
+            with m.If(sig_in[i] == 1):
+                comb += result_next.eq(width-i-1)
+            with m.Else():
+                comb += result_next.eq(result)
+            result = result_next
+
+        result_sig = Signal.like(count)
+        comb += result_sig.eq(result)
+
+        comb += Assert(result_sig == count)
+        
+        # setup the inputs and outputs of the DUT as anyconst
+
+        return m
+
+class CLZTestCase(FHDLTestCase):
+    def test_proof(self):
+        module = Driver()
+        self.assertFormal(module, mode="bmc", depth=4)
+    def test_ilang(self):
+        dut = Driver()
+        vl = rtlil.convert(dut, ports=[])
+        with open("clz.il", "w") as f:
+            f.write(vl)
+
+if __name__ == '__main__':
+    unittest.main()
diff --git a/src/nmutil/test/test_clz.py b/src/nmutil/test/test_clz.py
new file mode 100644 (file)
index 0000000..1d066b4
--- /dev/null
@@ -0,0 +1,46 @@
+from nmigen import Module, Signal
+from nmigen.back.pysim import Simulator, Delay
+from nmigen.test.utils import FHDLTestCase
+
+from nmutil.clz import CLZ
+import unittest
+import math
+import random
+
+
+class CLZTestCase(FHDLTestCase):
+    def run_test(self, inputs, width=8):
+
+        m = Module()
+
+        m.submodules.dut = dut = CLZ(width)
+        sig_in = Signal.like(dut.sig_in)
+        count = Signal.like(dut.lz)
+
+
+        m.d.comb += [
+            dut.sig_in.eq(sig_in),
+            count.eq(dut.lz)]
+
+        sim = Simulator(m)
+
+        def process():
+            for i in inputs:
+                yield sig_in.eq(i)
+                yield Delay(1e-6)
+        sim.add_process(process)
+        with sim.write_vcd("clz.vcd", "clz.gtkw", traces=[
+                sig_in, count]):
+            sim.run()
+
+    def test_selected(self):
+        inputs = [0, 15, 10, 127]
+        self.run_test(iter(inputs), width=8)
+
+    def test_non_power_2(self):
+        inputs = [0, 128, 512]
+        self.run_test(iter(inputs), width=11)
+
+
+if __name__ == "__main__":
+    unittest.main()