Check write register number too
[soc.git] / src / soc / alu / test / test_pipe_caller.py
index 5fe4cbc02c2761899c6c3d484b4b45bcea1d490e..d3ec23abe7336f1cde9934077624a1fa02f5e4bd 100644 (file)
@@ -103,8 +103,12 @@ class ALUTestCase(FHDLTestCase):
                     vld = yield alu.n.valid_o
                 yield
                 alu_out = yield alu.n.data_o.o
-                print(f"expected {simulator.gpr(3).value:x}, actual: {alu_out:x}")
-                self.assertEqual(simulator.gpr(3).value, alu_out)
+                out_reg_valid = yield pdecode2.e.write_reg.ok
+                if out_reg_valid:
+                    write_reg_idx = yield pdecode2.e.write_reg.data
+                    expected = simulator.gpr(write_reg_idx).value
+                    print(f"expected {expected:x}, actual: {alu_out:x}")
+                    self.assertEqual(expected, alu_out)
 
         sim.add_sync_process(process)
         with sim.write_vcd("simulator.vcd", "simulator.gtkw",
@@ -152,7 +156,6 @@ class ALUTestCase(FHDLTestCase):
             with Program(lst) as program:
                 sim = self.run_tst_program(program, initial_regs)
 
-    unittest.skip("broken")
     def test_shift(self):
         insns = ["slw", "sld", "srw", "srd", "sraw", "srad"]
         for i in range(20):
@@ -165,15 +168,7 @@ class ALUTestCase(FHDLTestCase):
             with Program(lst) as program:
                 sim = self.run_tst_program(program, initial_regs)
 
-    unittest.skip("broken")
-    def test_shift_imm(self):
-        lst = ["sradi 3, 1, 5"]
-        initial_regs = [0] * 32
-        initial_regs[1] = random.randint(0, (1<<64)-1)
-        with Program(lst) as program:
-            sim = self.run_tst_program(program, initial_regs)
 
-    @unittest.skip("broken")
     def test_shift_arith(self):
         lst = ["sraw 3, 1, 2"]
         initial_regs = [0] * 32