got fed up of hard-coded names, allow pywriter.py to take arguments
[soc.git] / src / soc / decoder / pseudo / parser.py
index 6abe1913a204d3bf7b6683e28fda08ae94fc4840..6ccbe96c94b9fcfbe323391bd28547d54966aaf0 100644 (file)
@@ -23,7 +23,7 @@ import ast
 # Helper function
 
 
-def Assign(left, right):
+def Assign(left, right, iea_mode):
     names = []
     print("Assign", left, right)
     if isinstance(left, ast.Name):
@@ -106,7 +106,13 @@ def make_eq_compare(arg):
     return ast.Compare(left, [ast.Eq()], [right])
 
 
+def make_ne_compare(arg):
+    (left, right) = arg
+    return ast.Compare(left, [ast.NotEq()], [right])
+
+
 binary_ops = {
+    "^": ast.BitXor(),
     "&": ast.BitAnd(),
     "|": ast.BitOr(),
     "+": ast.Add(),
@@ -119,6 +125,7 @@ binary_ops = {
     "<": make_lt_compare,
     ">": make_gt_compare,
     "=": make_eq_compare,
+    "!=": make_ne_compare,
 }
 unary_ops = {
     "+": ast.UAdd(),
@@ -139,17 +146,20 @@ def check_concat(node):  # checks if the comparison is already a concat
     return node.args
 
 
-# identify SelectableInt pattern
+# identify SelectableInt pattern [something] * N
+# must return concat(something, repeat=N)
 def identify_sint_mul_pattern(p):
-    if not isinstance(p[3], ast.Constant):
+    if p[2] != '*': # multiply
+        return False
+    if not isinstance(p[3], ast.Constant): # rhs = Num
         return False
-    if not isinstance(p[1], ast.List):
+    if not isinstance(p[1], ast.List): # lhs is a list
         return False
     l = p[1].elts
-    if len(l) != 1:
+    if len(l) != 1: # lhs is a list of length 1
         return False
-    elt = l[0]
-    return isinstance(elt, ast.Constant)
+    return True # yippee!
+
 
 def apply_trailer(atom, trailer):
     if trailer[0] == "TLIST":
@@ -207,18 +217,29 @@ def apply_trailer(atom, trailer):
 class PowerParser:
 
     precedence = (
-        ("left", "EQ", "GT", "LT", "LE", "GE", "LTU", "GTU"),
+        ("left", "EQ", "NE", "GT", "LT", "LE", "GE", "LTU", "GTU"),
         ("left", "BITOR"),
+        ("left", "BITXOR"),
         ("left", "BITAND"),
         ("left", "PLUS", "MINUS"),
         ("left", "MULT", "DIV", "MOD"),
         ("left", "INVERT"),
     )
 
-    def __init__(self):
+    def __init__(self, form):
         self.gprs = {}
+        form = self.sd.sigforms[form]
+        print (form)
+        formkeys = form._asdict().keys()
         for rname in ['RA', 'RB', 'RC', 'RT', 'RS']:
             self.gprs[rname] = None
+        self.available_op_fields = set()
+        for k in formkeys:
+            if k not in self.gprs:
+                if k == 'SPR': # sigh, lower-case to not conflict
+                    k = k.lower()
+                self.available_op_fields.add(k)
+        self.op_fields = OrderedSet()
         self.read_regs = OrderedSet()
         self.uninit_regs = OrderedSet()
         self.write_regs = OrderedSet()
@@ -318,7 +339,8 @@ class PowerParser:
     # augassign: ('+=' | '-=' | '*=' | '/=' | '%=' | '&=' | '|=' | '^=' |
     #             '<<=' | '>>=' | '**=' | '//=')
     def p_expr_stmt(self, p):
-        """expr_stmt : testlist ASSIGN testlist
+        """expr_stmt : testlist ASSIGNEA testlist
+                     | testlist ASSIGN testlist
                      | testlist """
         print("expr_stmt", p)
         if len(p) == 2:
@@ -326,15 +348,17 @@ class PowerParser:
             #p[0] = ast.Discard(p[1])
             p[0] = p[1]
         else:
+            iea_mode = p[2] == '<-iea'
             name = None
             if isinstance(p[1], ast.Name):
                 name = p[1].id
             elif isinstance(p[1], ast.Subscript):
-                name = p[1].value.id
-                if name in self.gprs:
-                    # add to list of uninitialised
-                    self.uninit_regs.add(name)
-            elif isinstance(p[1], ast.Call) and p[1].func.id == 'GPR':
+                if isinstance(p[1].value, ast.Name):
+                    name = p[1].value.id
+                    if name in self.gprs:
+                        # add to list of uninitialised
+                        self.uninit_regs.add(name)
+            elif isinstance(p[1], ast.Call) and p[1].func.id in ['GPR', 'SPR']:
                 print(astor.dump_tree(p[1]))
                 # replace GPR(x) with GPR[x]
                 idx = p[1].args[0]
@@ -354,7 +378,7 @@ class PowerParser:
             print("expr assign", name, p[1])
             if name and name in self.gprs:
                 self.write_regs.add(name)  # add to list of regs to write
-            p[0] = Assign(p[1], p[3])
+            p[0] = Assign(p[1], p[3], iea_mode)
 
     def p_flow_stmt(self, p):
         "flow_stmt : return_stmt"
@@ -368,6 +392,7 @@ class PowerParser:
     def p_compound_stmt(self, p):
         """compound_stmt : if_stmt
                          | while_stmt
+                         | switch_stmt
                          | for_stmt
                          | funcdef
         """
@@ -379,9 +404,9 @@ class PowerParser:
         p[0] = ast.Break()
 
     def p_for_stmt(self, p):
-        """for_stmt : FOR test EQ test TO test COLON suite
+        """for_stmt : FOR atom EQ test TO test COLON suite
+                    | DO atom EQ test TO test COLON suite
         """
-        p[0] = ast.While(p[2], p[4], [])
         # auto-add-one (sigh) due to python range
         start = p[4]
         end = ast.BinOp(p[6], ast.Add(), ast.Constant(1))
@@ -397,6 +422,93 @@ class PowerParser:
         else:
             p[0] = ast.While(p[3], p[5], p[8])
 
+    def p_switch_smt(self, p):
+        """switch_stmt : SWITCH LPAR atom RPAR COLON NEWLINE INDENT switches DEDENT
+        """
+        switchon = p[3]
+        print("switch stmt")
+        print(astor.dump_tree(p[1]))
+
+        cases = []
+        current_cases = [] # for deferral
+        for (case, suite) in p[8]:
+            print ("for", case, suite)
+            if suite is None:
+                for c in case:
+                    current_cases.append(ast.Num(c))
+                continue
+            if case == 'default': # last
+                break
+            for c in case:
+                current_cases.append(ast.Num(c))
+            print ("cases", current_cases)
+            compare = ast.Compare(switchon, [ast.In()],
+                                  [ast.List(current_cases)])
+            current_cases = []
+            cases.append((compare, suite))
+
+        print ("ended", case, current_cases)
+        if case == 'default':
+            if current_cases:
+                compare = ast.Compare(switchon, [ast.In()],
+                                      [ast.List(current_cases)])
+                cases.append((compare, suite))
+            cases.append((None, suite))
+
+        cases.reverse()
+        res = []
+        for compare, suite in cases:
+            print ("after rev", compare, suite)
+            if compare is None:
+                assert len(res) == 0, "last case should be default"
+                res = suite
+            else:
+                if not isinstance(res, list):
+                    res = [res]
+                res = ast.If(compare, suite, res)
+        p[0] = res
+
+    def p_switches(self, p):
+        """switches : switch_list switch_default
+                    | switch_default
+        """
+        if len(p) == 3:
+            p[0] = p[1] + [p[2]]
+        else:
+            p[0] = [p[1]]
+
+    def p_switch_list(self, p):
+        """switch_list : switch_case switch_list
+                       | switch_case
+        """
+        if len(p) == 3:
+            p[0] = [p[1]] + p[2]
+        else:
+            p[0] = [p[1]]
+
+    def p_switch_case(self, p):
+        """switch_case : CASE LPAR atomlist RPAR COLON suite
+        """
+        # XXX bad hack
+        if isinstance(p[6][0], ast.Name) and p[6][0].id == 'fallthrough':
+            p[6] = None
+        p[0] = (p[3], p[6])
+
+    def p_switch_default(self, p):
+        """switch_default : DEFAULT COLON suite
+        """
+        p[0] = ('default', p[3])
+
+    def p_atomlist(self, p):
+        """atomlist : atom COMMA atomlist
+                    | atom
+        """
+        assert isinstance(p[1], ast.Constant), "case must be numbers"
+        if len(p) == 4:
+            p[0] = [p[1].value] + p[3]
+        else:
+            p[0] = [p[1].value]
+
     def p_if_stmt(self, p):
         """if_stmt : IF test COLON suite ELSE COLON if_stmt
                    | IF test COLON suite ELSE COLON suite
@@ -432,6 +544,7 @@ class PowerParser:
                       | comparison DIV comparison
                       | comparison MOD comparison
                       | comparison EQ comparison
+                      | comparison NE comparison
                       | comparison LE comparison
                       | comparison GE comparison
                       | comparison LTU comparison
@@ -439,6 +552,7 @@ class PowerParser:
                       | comparison LT comparison
                       | comparison GT comparison
                       | comparison BITOR comparison
+                      | comparison BITXOR comparison
                       | comparison BITAND comparison
                       | PLUS comparison
                       | comparison MINUS
@@ -454,7 +568,7 @@ class PowerParser:
             elif p[2] == '||':
                 l = check_concat(p[1]) + check_concat(p[3])
                 p[0] = ast.Call(ast.Name("concat"), l, [])
-            elif p[2] in ['<', '>', '=', '<=', '>=']:
+            elif p[2] in ['<', '>', '=', '<=', '>=', '!=']:
                 p[0] = binary_ops[p[2]]((p[1], p[3]))
             elif identify_sint_mul_pattern(p):
                 keywords=[ast.keyword(arg='repeat', value=p[3])]
@@ -487,11 +601,15 @@ class PowerParser:
 
     def p_atom_name(self, p):
         """atom : NAME"""
-        p[0] = ast.Name(id=p[1], ctx=ast.Load())
+        name = p[1]
+        if name in self.available_op_fields:
+            self.op_fields.add(name)
+        p[0] = ast.Name(id=name, ctx=ast.Load())
 
     def p_atom_number(self, p):
         """atom : BINARY
                 | NUMBER
+                | HEX
                 | STRING"""
         p[0] = ast.Constant(p[1])
 
@@ -517,9 +635,10 @@ class PowerParser:
         print(astor.dump_tree(p[2]))
 
         if isinstance(p[2], ast.Name):
-            print("tuple name", p[2].id)
-            if p[2].id in self.gprs:
-                self.read_regs.add(p[2].id)  # add to list of regs to read
+            name = p[2].id
+            print("tuple name", name)
+            if name in self.gprs:
+                self.read_regs.add(name)  # add to list of regs to read
                 #p[0] = ast.Subscript(ast.Name("GPR"), ast.Str(p[2].id))
                 # return
             p[0] = p[2]
@@ -643,20 +762,20 @@ class PowerParser:
 
 
 class GardenSnakeParser(PowerParser):
-    def __init__(self, lexer=None):
-        PowerParser.__init__(self)
+    def __init__(self, lexer=None, debug=False, form=None):
+        self.sd = create_pdecode()
+        PowerParser.__init__(self, form)
+        self.debug = debug
         if lexer is None:
             lexer = IndentLexer(debug=0)
         self.lexer = lexer
         self.tokens = lexer.tokens
         self.parser = yacc.yacc(module=self, start="file_input_end",
-                                debug=False, write_tables=False)
-
-        self.sd = create_pdecode()
+                                debug=debug, write_tables=False)
 
     def parse(self, code):
         # self.lexer.input(code)
-        result = self.parser.parse(code, lexer=self.lexer, debug=False)
+        result = self.parser.parse(code, lexer=self.lexer, debug=self.debug)
         return ast.Module(result)
 
 
@@ -665,8 +784,8 @@ class GardenSnakeParser(PowerParser):
 #from compiler import misc, syntax, pycodegen
 
 class GardenSnakeCompiler(object):
-    def __init__(self):
-        self.parser = GardenSnakeParser()
+    def __init__(self, debug=False, form=None):
+        self.parser = GardenSnakeParser(debug=debug, form=form)
 
     def compile(self, code, mode="exec", filename="<string>"):
         tree = self.parser.parse(code)