add beginnings of shift unit test for partsig
authorLuke Kenneth Casson Leighton <lkcl@lkcl.net>
Fri, 14 Feb 2020 10:48:37 +0000 (10:48 +0000)
committerLuke Kenneth Casson Leighton <lkcl@lkcl.net>
Fri, 14 Feb 2020 10:48:37 +0000 (10:48 +0000)
src/ieee754/part/partsig.py
src/ieee754/part/test/test_partsig.py

index 5977bca64a04875930725220cba1cae046aca80b..0ab603def3c8103062c8d33109013eadf2443ff4 100644 (file)
@@ -18,6 +18,7 @@ nmigen.Case, or other constructs: only Mux and other logic.
 
 from ieee754.part_mul_add.adder import PartitionedAdder
 from ieee754.part_cmp.eq_gt_ge import PartitionedEqGtGe
+from ieee754.part_shift.part_shift_dynamic import PartitionedDynamicShift
 from ieee754.part_mul_add.partpoints import make_partition
 from operator import or_, xor, and_, not_
 
@@ -41,7 +42,7 @@ class PartitionedSignal:
         # create partition points
         self.partpoints = make_partition(mask, width)
         self.modnames = {}
-        for name in ['add', 'eq', 'gt', 'ge']:
+        for name in ['add', 'eq', 'gt', 'ge', 'ls']:
             self.modnames[name] = 0
 
     def set_module(self, m):
@@ -90,9 +91,22 @@ class PartitionedSignal:
     # TODO: detect if the 2nd operand is a Const, a Signal or a
     # PartitionedSignal.  if it's a Const or a Signal, a global shift
     # can occur.  if it's a PartitionedSignal, that's much more interesting.
+    def ls_op(self, op1, op2, carry):
+        op1 = getsig(op1)
+        op2 = getsig(op2)
+        shape = op1.shape()
+        pa = PartitionedDynamicShift(shape[0], self.partpoints)
+        setattr(self.m.submodules, self.get_modname('ls'), pa)
+        comb = self.m.d.comb
+        comb += pa.a.eq(op1)
+        comb += pa.b.eq(op2)
+        # XXX TODO: carry-in, carry-out
+        #comb += pa.carry_in.eq(carry)
+        return (pa.output, 0)
+
     def __lshift__(self, other):
-        raise NotImplementedError
-        return Operator("<<", [self, other])
+        result, _ = self.ls_op(self, other, carry=0)
+        return result
 
     def __rlshift__(self, other):
         raise NotImplementedError
index 699a34f9caaf001a14c73c60bced4a8d1916c3b4..0375f47175a98f13bd70f4683b33b49179ef0448 100644 (file)
@@ -39,6 +39,7 @@ class TestAddMod(Elaboratable):
         self.a = PartitionedSignal(partpoints, width)
         self.b = PartitionedSignal(partpoints, width)
         self.add_output = Signal(width)
+        self.ls_output = Signal(width) # left shift
         self.sub_output = Signal(width)
         self.eq_output = Signal(len(partpoints)+1)
         self.gt_output = Signal(len(partpoints)+1)
@@ -58,6 +59,7 @@ class TestAddMod(Elaboratable):
         comb = m.d.comb
         self.a.set_module(m)
         self.b.set_module(m)
+        # compares
         comb += self.lt_output.eq(self.a < self.b)
         comb += self.ne_output.eq(self.a != self.b)
         comb += self.le_output.eq(self.a <= self.b)
@@ -69,11 +71,15 @@ class TestAddMod(Elaboratable):
                                            self.carry_in)
         comb += self.add_output.eq(add_out)
         comb += self.add_carry_out.eq(add_carry)
+        # sub
         sub_out, sub_carry = self.a.sub_op(self.a, self.b,
                                            self.carry_in)
         comb += self.sub_output.eq(sub_out)
         comb += self.sub_carry_out.eq(sub_carry)
+        # neg
         comb += self.neg_output.eq(-self.a)
+        # left shift
+        comb += self.ls_output.eq(self.a << self.b)
         ppts = self.partpoints
         comb += self.mux_out.eq(PMux(m, ppts, self.mux_sel, self.a, self.b))
 
@@ -96,11 +102,22 @@ class TestPartitionPoints(unittest.TestCase):
 
         def async_process():
 
+            def test_ls_fn(carry_in, a, b, mask):
+                # TODO: carry
+                carry_in = 0
+                lsb = mask & ~(mask-1) if carry_in else 0
+                sum = (a & mask) << (b & mask) + lsb
+                result = mask & sum
+                carry = (sum & mask) != sum
+                print(a, b, sum, mask)
+                return result, carry
+
             def test_add_fn(carry_in, a, b, mask):
                 lsb = mask & ~(mask-1) if carry_in else 0
                 sum = (a & mask) + (b & mask) + lsb
                 result = mask & sum
                 carry = (sum & mask) != sum
+                carry = 0
                 print(a, b, sum, mask)
                 return result, carry
 
@@ -156,6 +173,7 @@ class TestPartitionPoints(unittest.TestCase):
             for (test_fn, mod_attr) in ((test_add_fn, "add"),
                                         (test_sub_fn, "sub"),
                                         (test_neg_fn, "neg"),
+                                        (test_ls_fn, "ls"),
                                         ):
                 yield part_mask.eq(0)
                 yield from test_op("16-bit", 1, test_fn, mod_attr, 0xFFFF)