From d978dd18bd710c9a3d21e595e590fe084480ecb8 Mon Sep 17 00:00:00 2001 From: whitequark Date: Wed, 26 Aug 2020 13:26:38 +0000 Subject: [PATCH] sim._pyrtl: optimize uses of reflexive operators. When a literal is used on the left-hand side of a numeric operator, Python is able to constant-fold some expressions: >>> dis.dis(lambda x: 0 + 0 + x) 1 0 LOAD_CONST 1 (0) 2 LOAD_FAST 0 (x) 4 BINARY_ADD 6 RETURN_VALUE If a literal is used on the right-hand side such that the left-hand side is variable, this doesn't happen: >>> dis.dis(lambda x: x + 0 + 0) 1 0 LOAD_FAST 0 (x) 2 LOAD_CONST 1 (0) 4 BINARY_ADD 6 LOAD_CONST 1 (0) 8 BINARY_ADD 10 RETURN_VALUE PyRTL generates fairly redundant code due to the pervasive masking, and because of that, transforming expressions into the former form, where possible, improves runtime by about 10% on Minerva SRAM SoC. --- nmigen/sim/_pyrtl.py | 46 ++++++++++++++++++++++---------------------- 1 file changed, 23 insertions(+), 23 deletions(-) diff --git a/nmigen/sim/_pyrtl.py b/nmigen/sim/_pyrtl.py index c2e9367..baf9a5b 100644 --- a/nmigen/sim/_pyrtl.py +++ b/nmigen/sim/_pyrtl.py @@ -103,7 +103,7 @@ class _RHSValueCompiler(_ValueCompiler): def on_Operator(self, value): def mask(value): value_mask = (1 << len(value)) - 1 - return f"({self(value)} & {value_mask})" + return f"({value_mask} & {self(value)})" def sign(value): if value.shape().signed: @@ -120,9 +120,9 @@ class _RHSValueCompiler(_ValueCompiler): if value.operator == "b": return f"bool({mask(arg)})" if value.operator == "r|": - return f"({mask(arg)} != 0)" + return f"(0 != {mask(arg)})" if value.operator == "r&": - return f"({mask(arg)} == {(1 << len(arg)) - 1})" + return f"({(1 << len(arg)) - 1} == {mask(arg)})" if value.operator == "r^": # Believe it or not, this is the fastest way to compute a sideways XOR in Python. return f"(format({mask(arg)}, 'b').count('1') % 2)" @@ -172,20 +172,20 @@ class _RHSValueCompiler(_ValueCompiler): raise NotImplementedError("Operator '{}' not implemented".format(value.operator)) # :nocov: def on_Slice(self, value): - return f"(({self(value.value)} >> {value.start}) & {(1 << len(value)) - 1})" + return f"({(1 << len(value)) - 1} & ({self(value.value)} >> {value.start}))" def on_Part(self, value): offset_mask = (1 << len(value.offset)) - 1 - offset = f"(({self(value.offset)} & {offset_mask}) * {value.stride})" - return f"({self(value.value)} >> {offset} & " \ - f"{(1 << value.width) - 1})" + offset = f"({value.stride} * ({offset_mask} & {self(value.offset)}))" + return f"({(1 << value.width) - 1} & " \ + f"{self(value.value)} >> {offset})" def on_Cat(self, value): gen_parts = [] offset = 0 for part in value.parts: part_mask = (1 << len(part)) - 1 - gen_parts.append(f"(({self(part)} & {part_mask}) << {offset})") + gen_parts.append(f"(({part_mask} & {self(part)}) << {offset})") offset += len(part) if gen_parts: return f"({' | '.join(gen_parts)})" @@ -193,7 +193,7 @@ class _RHSValueCompiler(_ValueCompiler): def on_Repl(self, value): part_mask = (1 << len(value.value)) - 1 - gen_part = self.emitter.def_var("repl", f"{self(value.value)} & {part_mask}") + gen_part = self.emitter.def_var("repl", f"{part_mask} & {self(value.value)}") gen_parts = [] offset = 0 for _ in range(value.count): @@ -205,15 +205,15 @@ class _RHSValueCompiler(_ValueCompiler): def on_ArrayProxy(self, value): index_mask = (1 << len(value.index)) - 1 - gen_index = self.emitter.def_var("rhs_index", f"{self(value.index)} & {index_mask}") + gen_index = self.emitter.def_var("rhs_index", f"{index_mask} & {self(value.index)}") gen_value = self.emitter.gen_var("rhs_proxy") if value.elems: gen_elems = [] for index, elem in enumerate(value.elems): if index == 0: - self.emitter.append(f"if {gen_index} == {index}:") + self.emitter.append(f"if {index} == {gen_index}:") else: - self.emitter.append(f"elif {gen_index} == {index}:") + self.emitter.append(f"elif {index} == {gen_index}:") with self.emitter.indent(): self.emitter.append(f"{gen_value} = {self(elem)}") self.emitter.append(f"else:") @@ -253,9 +253,9 @@ class _LHSValueCompiler(_ValueCompiler): def gen(arg): value_mask = (1 << len(value)) - 1 if value.shape().signed: - value_sign = f"sign({arg} & {value_mask}, {-1 << (len(value) - 1)})" + value_sign = f"sign({value_mask} & {arg}, {-1 << (len(value) - 1)})" else: # unsigned - value_sign = f"{arg} & {value_mask}" + value_sign = f"{value_mask} & {arg}" self.emitter.append(f"next_{self.state.get_signal(value)} = {value_sign}") return gen @@ -267,17 +267,17 @@ class _LHSValueCompiler(_ValueCompiler): width_mask = (1 << (value.stop - value.start)) - 1 self(value.value)(f"({self.lrhs(value.value)} & " \ f"{~(width_mask << value.start)} | " \ - f"(({arg} & {width_mask}) << {value.start}))") + f"(({width_mask} & {arg}) << {value.start}))") return gen def on_Part(self, value): def gen(arg): width_mask = (1 << value.width) - 1 offset_mask = (1 << len(value.offset)) - 1 - offset = f"(({self.rrhs(value.offset)} & {offset_mask}) * {value.stride})" + offset = f"({value.stride} * ({offset_mask} & {self.rrhs(value.offset)}))" self(value.value)(f"({self.lrhs(value.value)} & " \ f"~({width_mask} << {offset}) | " \ - f"(({arg} & {width_mask}) << {offset}))") + f"(({width_mask} & {arg}) << {offset}))") return gen def on_Cat(self, value): @@ -287,7 +287,7 @@ class _LHSValueCompiler(_ValueCompiler): offset = 0 for part in value.parts: part_mask = (1 << len(part)) - 1 - self(part)(f"(({gen_arg} >> {offset}) & {part_mask})") + self(part)(f"({part_mask} & ({gen_arg} >> {offset}))") offset += len(part) return gen @@ -302,9 +302,9 @@ class _LHSValueCompiler(_ValueCompiler): gen_elems = [] for index, elem in enumerate(value.elems): if index == 0: - self.emitter.append(f"if {gen_index} == {index}:") + self.emitter.append(f"if {index} == {gen_index}:") else: - self.emitter.append(f"elif {gen_index} == {index}:") + self.emitter.append(f"elif {index} == {gen_index}:") with self.emitter.indent(): self(elem)(arg) self.emitter.append(f"else:") @@ -332,7 +332,7 @@ class _StatementCompiler(StatementVisitor, _Compiler): def on_Switch(self, stmt): gen_test = self.emitter.def_var("test", - f"{self.rhs(stmt.test)} & {(1 << len(stmt.test)) - 1}") + f"{(1 << len(stmt.test)) - 1} & {self.rhs(stmt.test)}") for index, (patterns, stmts) in enumerate(stmt.cases.items()): gen_checks = [] if not patterns: @@ -342,10 +342,10 @@ class _StatementCompiler(StatementVisitor, _Compiler): if "-" in pattern: mask = int("".join("0" if b == "-" else "1" for b in pattern), 2) value = int("".join("0" if b == "-" else b for b in pattern), 2) - gen_checks.append(f"({gen_test} & {mask}) == {value}") + gen_checks.append(f"{value} == ({mask} & {gen_test})") else: value = int(pattern, 2) - gen_checks.append(f"{gen_test} == {value}") + gen_checks.append(f"{value} == {gen_test}") if index == 0: self.emitter.append(f"if {' or '.join(gen_checks)}:") else: -- 2.30.2