Add module to handle partitioned eq, gt, and ge comparisons
authorMichael Nolan <mtnolan2640@gmail.com>
Wed, 5 Feb 2020 15:37:14 +0000 (10:37 -0500)
committerMichael Nolan <mtnolan2640@gmail.com>
Wed, 5 Feb 2020 15:49:05 +0000 (10:49 -0500)
src/ieee754/part_cmp/eq_gt_ge.py [new file with mode: 0644]
src/ieee754/part_cmp/experiments/gt_combiner.py
src/ieee754/part_cmp/formal/proof_eq_gt_ge.py [new file with mode: 0644]

diff --git a/src/ieee754/part_cmp/eq_gt_ge.py b/src/ieee754/part_cmp/eq_gt_ge.py
new file mode 100644 (file)
index 0000000..c1d223d
--- /dev/null
@@ -0,0 +1,89 @@
+# SPDX-License-Identifier: LGPL-2.1-or-later
+# See Notices.txt for copyright information
+
+"""
+Copyright (C) 2020 Luke Kenneth Casson Leighton <lkcl@lkcl.net>
+
+dynamically-partitionable "comparison" class, directly equivalent
+to Signal.__eq__ except SIMD-partitionable
+
+See:
+
+* http://libre-riscv.org/3d_gpu/architecture/dynamic_simd/eq
+* http://bugs.libre-riscv.org/show_bug.cgi?id=132
+"""
+
+from nmigen import Signal, Module, Elaboratable, Cat, C, Mux, Repl
+from nmigen.cli import main
+
+from ieee754.part_mul_add.partpoints import PartitionPoints
+from ieee754.part_cmp.experiments.gt_combiner import GTCombiner
+
+
+class PartitionedEqGtGe(Elaboratable):
+
+    # Expansion of the partitioned equals module to handle Greater
+    # Than and Greater than or Equal to. The function being evaluated
+    # is selected by the opcode signal, where:
+    # opcode 0x00 - EQ
+    # opcode 0x01 - GT
+    # opcode 0x02 - GE
+    def __init__(self, width, partition_points):
+        """Create a ``PartitionedEq`` operator
+        """
+        self.width = width
+        self.a = Signal(width, reset_less=True)
+        self.b = Signal(width, reset_less=True)
+        self.opcode = Signal(2)
+        self.partition_points = PartitionPoints(partition_points)
+        self.mwidth = len(self.partition_points)+1
+        self.output = Signal(self.mwidth, reset_less=True)
+        if not self.partition_points.fits_in_width(width):
+            raise ValueError("partition_points doesn't fit in width")
+
+    def elaborate(self, platform):
+        m = Module()
+        comb = m.d.comb
+        m.submodules.gtc = gtc = GTCombiner(self.mwidth)
+
+        # make a series of "eqs" and "gts", splitting a and b into partition chunks
+        eqs = Signal(self.mwidth, reset_less=True)
+        eql = []
+        gts = Signal(self.mwidth, reset_less=True)
+        gtl = []
+
+        keys = list(self.partition_points.keys()) + [self.width]
+        start = 0
+        for i in range(len(keys)):
+            end = keys[i]
+            eql.append(self.a[start:end] == self.b[start:end])
+            gtl.append(self.a[start:end] > self.b[start:end])
+            start = end # for next time round loop
+        comb += eqs.eq(Cat(*eql))
+        comb += gts.eq(Cat(*gtl))
+
+        # Signal to control the constant injected into the partition next to a closed gate
+        aux_input = Signal()
+        # Signal to enable or disable the gt input for the gt partition combiner
+        gt_en = Signal()
+
+        with m.Switch(self.opcode):
+            with m.Case(0b00):   # equals
+                comb += aux_input.eq(1)
+                comb += gt_en.eq(0)
+            with m.Case(0b01):   # greater than
+                comb += aux_input.eq(0)
+                comb += gt_en.eq(1)
+            with m.Case(0b10):   # greater than or equal to
+                comb += aux_input.eq(1)
+                comb += gt_en.eq(1)
+        
+        comb += gtc.gates.eq(self.partition_points.as_sig())
+        comb += gtc.eqs.eq(eqs)
+        comb += gtc.gts.eq(gts)
+        comb += gtc.aux_input.eq(aux_input)
+        comb += gtc.gt_en.eq(gt_en)
+        comb += self.output.eq(gtc.outputs)
+        
+
+        return m
index 421a39beab1fc32a62f61fb691b775877a511c3f..a8c9c85e2cf45cd2cdf9ab16235b9b665bf074e8 100644 (file)
@@ -32,8 +32,20 @@ class GTCombiner(Elaboratable):
 
     def __init__(self, width):
         self.width = width
+
+        # These two signals allow this module to do more than just a
+        # partitioned greater than comparison.
+        # - If aux_input is set to 0 and gt_en is set to 1, then this
+        #   module performs a partitioned greater than comparision
+        # - If aux_input is set to 1 and gt_en is set to 0, then this
+        #   module is functionally equivalent to the eq_combiner
+        #   module.
+        # - If aux_input is set to 1 and gt_en is set to 1, then this
+        #   module performs a partitioned greater than or equals
+        #   comparison
         self.aux_input = Signal(reset_less=True)  # right hand side mux input
         self.gt_en = Signal(reset_less=True)      # enable or disable the gt signal
+        
         self.eqs = Signal(width, reset_less=True) # the flags for EQ
         self.gts = Signal(width, reset_less=True) # the flags for GT
         self.gates = Signal(width-1, reset_less=True)
diff --git a/src/ieee754/part_cmp/formal/proof_eq_gt_ge.py b/src/ieee754/part_cmp/formal/proof_eq_gt_ge.py
new file mode 100644 (file)
index 0000000..497e234
--- /dev/null
@@ -0,0 +1,121 @@
+# Proof of correctness for partitioned equals module
+# Copyright (C) 2020 Michael Nolan <mtnolan2640@gmail.com>
+
+from nmigen import Module, Signal, Elaboratable, Mux, Cat
+from nmigen.asserts import Assert, AnyConst, Assume
+from nmigen.test.utils import FHDLTestCase
+from nmigen.cli import rtlil
+
+from ieee754.part_mul_add.partpoints import PartitionPoints
+from ieee754.part_cmp.eq_gt_ge import PartitionedEqGtGe
+import unittest
+
+
+# This defines a module to drive the device under test and assert
+# properties about its outputs
+class EqualsDriver(Elaboratable):
+    def __init__(self):
+        # inputs and outputs
+        pass
+
+    def get_intervals(self, signal, points):
+        start = 0
+        interval = []
+        keys = list(points.keys()) + [signal.width]
+        for key in keys:
+            end = key
+            interval.append(signal[start:end])
+            start = end
+        return interval
+
+    def elaborate(self, platform):
+        m = Module()
+        comb = m.d.comb
+        width = 24
+        mwidth = 3
+
+        # setup the inputs and outputs of the DUT as anyconst
+        a = Signal(width)
+        b = Signal(width)
+        points = PartitionPoints()
+        gates = Signal(mwidth-1)
+        opcode = Signal(2)
+        for i in range(mwidth-1):
+            points[i*8+8] = gates[i]
+        out = Signal(mwidth)
+
+        
+        comb += [a.eq(AnyConst(width)),
+                 b.eq(AnyConst(width)),
+                 opcode.eq(AnyConst(opcode.width)),
+                 gates.eq(AnyConst(mwidth-1))]
+
+        m.submodules.dut = dut = PartitionedEqGtGe(width, points)
+
+        a_intervals = self.get_intervals(a, points)
+        b_intervals = self.get_intervals(b, points)
+
+        with m.If(opcode == 0b00):
+            with m.Switch(gates):
+                with m.Case(0b00):
+                    comb += Assert(out[-1] == (a == b))
+                with m.Case(0b01):
+                    comb += Assert(out[2] == ((a_intervals[1] == b_intervals[1]) &
+                                              (a_intervals[2] == b_intervals[2])))
+                    comb += Assert(out[0] == (a_intervals[0] == b_intervals[0]))
+                with m.Case(0b10):
+                    comb += Assert(out[1] == ((a_intervals[0] == b_intervals[0]) &
+                                              (a_intervals[1] == b_intervals[1])))
+                    comb += Assert(out[2] == (a_intervals[2] == b_intervals[2]))
+                with m.Case(0b11):
+                    for i in range(mwidth-1):
+                        comb += Assert(out[i] == (a_intervals[i] == b_intervals[i]))
+        with m.If(opcode == 0b01):
+            with m.Switch(gates):
+                with m.Case(0b00):
+                    comb += Assert(out[-1] == (a > b))
+                with m.Case(0b01):
+                    comb += Assert(out[0] == (a_intervals[0] > b_intervals[0]))
+                                            
+                    comb += Assert(out[1] == 0)
+                    comb += Assert(out[2] == (Cat(*a_intervals[1:3]) > Cat(*b_intervals[1:3])))
+                with m.Case(0b10):
+                    comb += Assert(out[0] == 0)
+                    comb += Assert(out[1] == (Cat(*a_intervals[0:2]) > Cat(*b_intervals[0:2])))
+                    comb += Assert(out[2] == (a_intervals[2] > b_intervals[2]))
+                with m.Case(0b11):
+                    for i in range(mwidth-1):
+                        comb += Assert(out[i] == (a_intervals[i] > b_intervals[i]))
+        with m.If(opcode == 0b10):
+            with m.Switch(gates):
+                with m.Case(0b00):
+                    comb += Assert(out[-1] == (a >= b))
+                with m.Case(0b01):
+                    comb += Assert(out[0] == (a_intervals[0] >= b_intervals[0]))
+                                            
+                    comb += Assert(out[1] == 0)
+                    comb += Assert(out[2] == (Cat(*a_intervals[1:3]) >= Cat(*b_intervals[1:3])))
+                with m.Case(0b10):
+                    comb += Assert(out[0] == 0)
+                    comb += Assert(out[1] == (Cat(*a_intervals[0:2]) >= Cat(*b_intervals[0:2])))
+                    comb += Assert(out[2] == (a_intervals[2] >= b_intervals[2]))
+                with m.Case(0b11):
+                    for i in range(mwidth-1):
+                        comb += Assert(out[i] == (a_intervals[i] >= b_intervals[i]))
+                
+
+
+        comb += [dut.a.eq(a),
+                 dut.b.eq(b),
+                 dut.opcode.eq(opcode),
+                 out.eq(dut.output)]
+        return m
+
+class PartitionedEqTestCase(FHDLTestCase):
+    def test_eq(self):
+        module = EqualsDriver()
+        self.assertFormal(module, mode="bmc", depth=4)
+
+if __name__ == "__main__":
+    unittest.main()
+