fhdl.ast, back.pysim: implement shifts.
authorwhitequark <cz@m-labs.hk>
Sat, 15 Dec 2018 09:58:30 +0000 (09:58 +0000)
committerwhitequark <cz@m-labs.hk>
Sat, 15 Dec 2018 09:58:30 +0000 (09:58 +0000)
nmigen/back/pysim.py
nmigen/back/rtlil.py
nmigen/fhdl/ast.py
nmigen/test/test_fhdl_ast.py
nmigen/test/test_sim.py

index 6a15b7de821fc722875a6397f572f2c56219a52e..69ddf12228da6b205f6edc9ad84df4ce9a80727d 100644 (file)
@@ -75,15 +75,23 @@ class _RHSValueCompiler(ValueTransformer):
         elif len(value.operands) == 2:
             lhs, rhs = map(self, value.operands)
             if value.op == "+":
-                return lambda state: normalize(lhs(state) + rhs(state), shape)
+                return lambda state: normalize(lhs(state) +  rhs(state), shape)
             if value.op == "-":
-                return lambda state: normalize(lhs(state) - rhs(state), shape)
+                return lambda state: normalize(lhs(state) -  rhs(state), shape)
             if value.op == "&":
-                return lambda state: normalize(lhs(state) & rhs(state), shape)
+                return lambda state: normalize(lhs(state) &  rhs(state), shape)
             if value.op == "|":
-                return lambda state: normalize(lhs(state) | rhs(state), shape)
+                return lambda state: normalize(lhs(state) |  rhs(state), shape)
             if value.op == "^":
-                return lambda state: normalize(lhs(state) ^ rhs(state), shape)
+                return lambda state: normalize(lhs(state) ^  rhs(state), shape)
+            if value.op == "<<":
+                def sshl(lhs, rhs):
+                    return lhs << rhs if rhs >= 0 else lhs >> -rhs
+                return lambda state: normalize(sshl(lhs(state), rhs(state)), shape)
+            if value.op == ">>":
+                def sshr(lhs, rhs):
+                    return lhs >> rhs if rhs >= 0 else lhs << -rhs
+                return lambda state: normalize(sshr(lhs(state), rhs(state)), shape)
             if value.op == "==":
                 return lambda state: normalize(lhs(state) == rhs(state), shape)
             if value.op == "!=":
index a8dd7d39d7b4b2bb8c9f92a7af31e00cf5524e72..fc7420fcf24a65ccf1cf22954d862327c6e4ed2f 100644 (file)
@@ -206,8 +206,8 @@ class _ValueTransformer(xfrm.ValueTransformer):
         (2, "/"):    "$div",
         (2, "%"):    "$mod",
         (2, "**"):   "$pow",
-        (2, "<<<"):  "$sshl",
-        (2, ">>>"):  "$sshr",
+        (2, "<<"):   "$sshl",
+        (2, ">>"):   "$sshr",
         (2, "&"):    "$and",
         (2, "^"):    "$xor",
         (2, "|"):    "$or",
index d8789687761e29f7c01f6286cff6e090e2581849..c8c13ae4af4a73602405186acb8cb0084061fd76 100644 (file)
@@ -75,13 +75,13 @@ class Value(metaclass=ABCMeta):
     def __rdiv__(self, other):
         return Operator("/", [other, self])
     def __lshift__(self, other):
-        return Operator("<<<", [self, other])
+        return Operator("<<", [self, other])
     def __rlshift__(self, other):
-        return Operator("<<<", [other, self])
+        return Operator("<<", [other, self])
     def __rshift__(self, other):
-        return Operator(">>>", [self, other])
+        return Operator(">>", [self, other])
     def __rrshift__(self, other):
-        return Operator(">>>", [other, self])
+        return Operator(">>", [other, self])
     def __and__(self, other):
         return Operator("&", [self, other])
     def __rand__(self, other):
@@ -306,15 +306,15 @@ class Operator(Value):
                 return 1, False
             if self.op in ("&", "^", "|"):
                 return self._bitwise_binary_shape(*op_shapes)
-            if self.op == "<<<":
+            if self.op == "<<":
                 if b_sign:
-                    extra = 2**(b_bits - 1) - 1
+                    extra = 2 ** (b_bits - 1) - 1
                 else:
-                    extra = 2**b_bits - 1
+                    extra = 2 ** (b_bits)     - 1
                 return a_bits + extra, a_sign
-            if self.op == ">>>":
+            if self.op == ">>":
                 if b_sign:
-                    extra = 2**(b_bits - 1)
+                    extra = 2 ** (b_bits - 1)
                 else:
                     extra = 0
                 return a_bits + extra, a_sign
index 433b609181a14f404160be30ffdddd6535c6dded..fd1d58f25b2839109c0a0141f9aa35bfd2133ec7 100644 (file)
@@ -182,6 +182,20 @@ class OperatorTestCase(FHDLTestCase):
         v5 = 10 ^ Const(0, 4)
         self.assertEqual(v5.shape(), (4, False))
 
+    def test_shl(self):
+        v1 = Const(1, 4) << Const(4)
+        self.assertEqual(repr(v1), "(<< (const 4'd1) (const 3'd4))")
+        self.assertEqual(v1.shape(), (11, False))
+        v2 = Const(1, 4) << Const(-3)
+        self.assertEqual(v2.shape(), (7, False))
+
+    def test_shr(self):
+        v1 = Const(1, 4) >> Const(4)
+        self.assertEqual(repr(v1), "(>> (const 4'd1) (const 3'd4))")
+        self.assertEqual(v1.shape(), (4, False))
+        v2 = Const(1, 4) >> Const(-3)
+        self.assertEqual(v2.shape(), (8, False))
+
     def test_lt(self):
         v = Const(0, 4) < Const(0, 6)
         self.assertEqual(repr(v), "(< (const 4'd0) (const 6'd0))")
index ec4198612aec6220505e8dd52cf728d5a1821e3b..ef7b18fba90797411d72381411bfc3d74dcbdeb7 100644 (file)
@@ -71,6 +71,18 @@ class SimulatorUnitTestCase(FHDLTestCase):
         stmt = lambda a, b: a ^ b
         self.assertOperator(stmt, [C(0b1100, 4), C(0b1010, 4)], C(0b0110, 4))
 
+    def test_shl(self):
+        stmt = lambda a, b: a << b
+        self.assertOperator(stmt, [C(0b1001, 4), C(0)],  C(0b1001,    5))
+        self.assertOperator(stmt, [C(0b1001, 4), C(3)],  C(0b1001000, 7))
+        self.assertOperator(stmt, [C(0b1001, 4), C(-2)], C(0b10,      7))
+
+    def test_shr(self):
+        stmt = lambda a, b: a >> b
+        self.assertOperator(stmt, [C(0b1001, 4), C(0)],  C(0b1001,    4))
+        self.assertOperator(stmt, [C(0b1001, 4), C(2)],  C(0b10,      4))
+        self.assertOperator(stmt, [C(0b1001, 4), C(-2)], C(0b100100,  5))
+
     def test_eq(self):
         stmt = lambda a, b: a == b
         self.assertOperator(stmt, [C(0, 4), C(0, 4)], C(1))