Fix gearing
[gram.git] / gram / compat.py
index a76227edd79867f98ec9ec99a6a2e12ac4a4eddb..49ae71013312da80091197c800e51b072b916aa4 100644 (file)
@@ -1,11 +1,15 @@
 # This file is Copyright (c) 2020 LambdaConcept <contact@lambdaconcept.com>
 
+import unittest
+
 from nmigen import *
 from nmigen import tracer
 from nmigen.compat import Case
+from nmigen.back.pysim import *
 
 __ALL__ = ["delayed_enter", "RoundRobin", "Timeline", "CSRPrefixProxy"]
 
+
 def delayed_enter(m, src, dst, delay):
     assert delay > 0
 
@@ -23,9 +27,8 @@ def delayed_enter(m, src, dst, delay):
         with m.State(statename):
             m.next = deststate
 
-# Original nMigen implementation by HarryHo90sHK
 class RoundRobin(Elaboratable):
-    """A round-robin scheduler.
+    """A round-robin scheduler. (HarryHo90sHK)
     Parameters
     ----------
     n : int
@@ -39,6 +42,7 @@ class RoundRobin(Elaboratable):
     stb : Signal()
         Strobe signal to enable granting access to the next device requesting. Externally driven.
     """
+
     def __init__(self, n):
         self.n = n
         self.request = Signal(n)
@@ -48,19 +52,23 @@ class RoundRobin(Elaboratable):
     def elaborate(self, platform):
         m = Module()
 
-        with m.If(self.stb):
-            with m.Switch(self.grant):
-                for i in range(self.n):
-                    with m.Case(i):
-                        for j in reversed(range(i+1, i+self.n)):
-                            # If i+1 <= j < n, then t == j;     (after i)
-                            # If n <= j < i+n, then t == j - n  (before i)
-                            t = j % self.n
-                            with m.If(self.request[t]):
-                                m.d.sync += self.grant.eq(t)
+        if self.n == 1:
+            m.d.comb += self.grant.eq(0)
+        else:
+            with m.If(self.stb):
+                with m.Switch(self.grant):
+                    for i in range(self.n):
+                        with m.Case(i):
+                            for j in reversed(range(i+1, i+self.n)):
+                                # If i+1 <= j < n, then t == j;     (after i)
+                                # If n <= j < i+n, then t == j - n  (before i)
+                                t = j % self.n
+                                with m.If(self.request[t]):
+                                    m.d.sync += self.grant.eq(t)
 
         return m
 
+
 class Timeline(Elaboratable):
     def __init__(self, events):
         self.trigger = Signal()
@@ -98,6 +106,7 @@ class Timeline(Elaboratable):
 
         return m
 
+
 class CSRPrefixProxy:
     def __init__(self, bank, prefix):
         self._bank = bank
@@ -111,4 +120,4 @@ class CSRPrefixProxy:
 
         prefixed_name = "{}_{}".format(self._prefix, name)
         return self._bank.csr(width=width, access=access, addr=addr,
-            alignment=alignment, name=prefixed_name)
+                              alignment=alignment, name=prefixed_name)