use straight << and >> operator instead of multi-level Mux
authorLuke Kenneth Casson Leighton <lkcl@lkcl.net>
Mon, 18 Feb 2019 07:00:56 +0000 (07:00 +0000)
committerLuke Kenneth Casson Leighton <lkcl@lkcl.net>
Mon, 18 Feb 2019 07:00:56 +0000 (07:00 +0000)
src/add/fpbase.py
src/add/test_multishift.py

index 48e93a01fa1decedbc3b24e0be7f700171cfc59a..8c10d7821c96dd0f1592d05df7b5e415e6f796a8 100644 (file)
@@ -2,11 +2,26 @@
 # Copyright (C) Jonathan P Dawson 2013
 # 2013-12-12
 
-from nmigen import Signal, Cat, Const, Mux
+from nmigen import Signal, Cat, Const, Mux, Module
 from math import log
 from operator import or_
 from functools import reduce
 
+class MultiShiftR:
+
+    def __init__(self, width):
+        self.width = width
+        self.smax = int(log(width) / log(2))
+        self.i = Signal(width)
+        self.s = Signal(self.smax)
+        self.o = Signal(width)
+
+    def elaborate(self, platform):
+        m = Module()
+        m.d.comb += self.o.eq(self.i >> self.s)
+        return m
+
+
 class MultiShift:
     """ Generates variable-length single-cycle shifter from a series
         of conditional tests on each bit of the left/right shift operand.
@@ -23,6 +38,8 @@ class MultiShift:
         self.smax = int(log(width) / log(2))
 
     def lshift(self, op, s):
+        res = op << s
+        return res[:len(op)]
         res = op
         for i in range(self.smax):
             zeros = [0] * (1<<i)
@@ -30,6 +47,8 @@ class MultiShift:
         return res
 
     def rshift(self, op, s):
+        res = op >> s
+        return res[:len(op)]
         res = op
         for i in range(self.smax):
             zeros = [0] * (1<<i)
index 0486d33f2c3fd2e2ea4ba5918a306d527018f414..5fa649ef83f79dcbce7d3338da418d02dd9b9b2d 100644 (file)
@@ -2,7 +2,7 @@ from random import randint
 from nmigen import Module, Signal
 from nmigen.compat.sim import run_simulation
 
-from fpbase import MultiShift
+from fpbase import MultiShift, MultiShiftR
 
 class MultiShiftModL:
     def __init__(self, width):
@@ -32,6 +32,23 @@ class MultiShiftModR:
 
         return m
 
+class MultiShiftModRMod:
+    def __init__(self, width):
+        self.ms = MultiShiftR(width)
+        self.a = Signal(width)
+        self.b = Signal(self.ms.smax)
+        self.x = Signal(width)
+
+    def get_fragment(self, platform=None):
+
+        m = Module()
+        m.submodules += self.ms
+        m.d.comb += self.ms.i.eq(self.a)
+        m.d.comb += self.ms.s.eq(self.b)
+        m.d.comb += self.x.eq(self.ms.o)
+
+        return m
+
 def check_case(dut, width, a, b):
     yield dut.a.eq(a)
     yield dut.b.eq(b)
@@ -65,6 +82,9 @@ def testbenchr(dut):
             yield from check_caser(dut, 32, a, i)
 
 if __name__ == '__main__':
+    dut = MultiShiftModRMod(width=32)
+    run_simulation(dut, testbenchr(dut), vcd_name="test_multishift.vcd")
+
     dut = MultiShiftModR(width=32)
     run_simulation(dut, testbenchr(dut), vcd_name="test_multishift.vcd")