from nmigen import (Signal, Cat, Const, Mux, Module, Elaboratable, Array,
                     Value, Shape)
-from math import log
+from nmigen.utils import bits_for
 from operator import or_
 from functools import reduce
 
         self.assertEqual(i, True)
 
 
-class MultiShiftR:
+class MultiShiftR(Elaboratable):
 
     def __init__(self, width):
         self.width = width
-        self.smax = int(log(width) / log(2))
+        self.smax = bits_for(width - 1)
         self.i = Signal(width, reset_less=True)
         self.s = Signal(self.smax, reset_less=True)
         self.o = Signal(width, reset_less=True)
 
     def __init__(self, width):
         self.width = width
-        self.smax = int(log(width) / log(2))
+        self.smax = bits_for(width - 1)
 
     def lshift(self, op, s):
         res = op << s
 
     def __init__(self, width, s_max=None):
         if s_max is None:
-            s_max = int(log(width) / log(2))
+            s_max = bits_for(width - 1)
         self.smax = Shape.cast(s_max)
         self.m = Signal(width, reset_less=True)
         self.inp = Signal(width, reset_less=True)