style: move style verifiers into classes
authorNathan Binkert <nate@binkert.org>
Fri, 15 Apr 2011 17:43:47 +0000 (10:43 -0700)
committerNathan Binkert <nate@binkert.org>
Fri, 15 Apr 2011 17:43:47 +0000 (10:43 -0700)
util/style.py

index 4283ac4a33d2c6f92a9f23e64efd7234bcf15ab4..68f1cdffe52c5dfc9727d024eb20a7af8a37262a 100644 (file)
@@ -1,6 +1,6 @@
 #! /usr/bin/env python
 # Copyright (c) 2006 The Regents of The University of Michigan
-# Copyright (c) 2007 The Hewlett-Packard Development Company
+# Copyright (c) 2007,2011 The Hewlett-Packard Development Company
 # All rights reserved.
 #
 # Redistribution and use in source and binary forms, with or without
 #
 # Authors: Nathan Binkert
 
-import re
+import heapq
 import os
+import re
 import sys
 
-sys.path.insert(0, os.path.dirname(__file__))
+from os.path import dirname, join as joinpath
+from itertools import count
+from mercurial import bdiff, mdiff
+
+current_dir = dirname(__file__)
+sys.path.insert(0, current_dir)
+sys.path.insert(1, joinpath(dirname(current_dir), 'src', 'python'))
 
+from m5.util import neg_inf, pos_inf, Region, Regions
 from file_types import lang_type
 
+all_regions = Region(neg_inf, pos_inf)
+
 tabsize = 8
 lead = re.compile(r'^([ \t]+)')
 trail = re.compile(r'([ \t]+)$')
 any_control = re.compile(r'\b(if|while|for)[ \t]*[(]')
 good_control = re.compile(r'\b(if|while|for) [(]')
 
-whitespace_types = set(('C', 'C++', 'swig', 'python', 'asm', 'isa', 'scons'))
 format_types = set(('C', 'C++'))
 
+def modified_regions(old_data, new_data):
+    regions = Regions()
+    beg = None
+    for pbeg, pend, fbeg, fend in bdiff.blocks(old_data, new_data):
+        if beg is not None and beg != fbeg:
+            regions.append(beg, fbeg)
+        beg = fend
+    return regions
+
+def modregions(wctx, fname):
+    fctx = wctx.filectx(fname)
+    pctx = fctx.parents()
+
+    file_data = fctx.data()
+    lines = mdiff.splitnewlines(file_data)
+    if len(pctx) in (1, 2):
+        mod_regions = modified_regions(pctx[0].data(), file_data)
+        if len(pctx) == 2:
+            m2 = modified_regions(pctx[1].data(), file_data)
+            # only the lines that are new in both
+            mod_regions &= m2
+    else:
+        mod_regions = Regions()
+        mod_regions.add(0, len(lines))
+
+    return mod_regions
+
 class UserInterface(object):
     def __init__(self, verbose=False, auto=False):
         self.auto = auto
@@ -77,67 +113,106 @@ class StdioUI(UserInterface):
     def write(self, string):
         sys.stdout.write(string)
 
-def checkwhite_line(line):
-    match = lead.search(line)
-    if match and match.group(1).find('\t') != -1:
-        return False
+class Region(object):
+    def __init__(self, asdf):
+        self.regions = Foo
 
-    match = trail.search(line)
-    if match:
-        return False
+class Verifier(object):
+    def __init__(self, ui, repo=None):
+        self.ui = ui
+        self.repo = repo
+        if repo is None:
+            self.wctx = None
+
+    def __getattr__(self, attr):
+        if attr in ('prompt', 'write'):
+            return getattr(self.ui, attr)
+
+        if attr == 'wctx':
+            try:
+                wctx = repo.workingctx()
+            except:
+                from mercurial import context
+                wctx = context.workingctx(repo)
+            self.wctx = wctx
+            return wctx
+
+        raise AttributeError
+
+    def open(self, filename, mode):
+        if self.repo:
+            filename = self.repo.wjoin(filename)
+
+        try:
+            f = file(filename, mode)
+        except OSError, msg:
+            print 'could not open file %s: %s' % (filename, msg)
+            return None
+
+        return f
+
+    def skip(self, filename):
+        return lang_type(filename) not in self.languages
+
+    def check(self, filename, regions=all_regions):
+        f = self.open(filename, 'r')
+
+        errors = 0
+        for num,line in enumerate(f):
+            if num not in regions:
+                continue
+            if not self.check_line(line):
+                self.write("invalid %s in %s:%d\n" % \
+                               (self.test_name, filename, num + 1))
+                if self.ui.verbose:
+                    self.write(">>%s<<\n" % line[-1])
+                errors += 1
+        return errors
 
-    return True
+    def fix(self, filename, regions=all_regions):
+        f = self.open(filename, 'r+')
 
-def checkwhite(filename):
-    if lang_type(filename) not in whitespace_types:
-        return
+        lines = list(f)
 
-    try:
-        f = file(filename, 'r+')
-    except OSError, msg:
-        print 'could not open file %s: %s' % (filename, msg)
-        return
+        f.seek(0)
+        f.truncate()
 
-    for num,line in enumerate(f):
-        if not checkwhite_line(line):
-            yield line,num + 1
-
-def fixwhite_line(line):
-    if lead.search(line):
-        newline = ''
-        for i,c in enumerate(line):
-            if c == ' ':
-                newline += ' '
-            elif c == '\t':
-                newline += ' ' * (tabsize - len(newline) % tabsize)
-            else:
-                newline += line[i:]
-                break
-
-        line = newline
+        for i,line in enumerate(lines):
+            if i in regions:
+                line = self.fix_line(line)
 
-    return line.rstrip() + '\n'
+            f.write(line)
+        f.close()
 
-def fixwhite(filename, fixonly=None):
-    if lang_type(filename) not in whitespace_types:
-        return
+class Whitespace(Verifier):
+    languages = set(('C', 'C++', 'swig', 'python', 'asm', 'isa', 'scons'))
+    test_name = 'whitespace'
+    def check_line(self, line):
+        match = lead.search(line)
+        if match and match.group(1).find('\t') != -1:
+            return False
 
-    try:
-        f = file(filename, 'r+')
-    except OSError, msg:
-        print 'could not open file %s: %s' % (filename, msg)
-        return
+        match = trail.search(line)
+        if match:
+            return False
 
-    lines = list(f)
+        return True
 
-    f.seek(0)
-    f.truncate()
+    def fix_line(self, line):
+        if lead.search(line):
+            newline = ''
+            for i,c in enumerate(line):
+                if c == ' ':
+                    newline += ' '
+                elif c == '\t':
+                    newline += ' ' * (tabsize - len(newline) % tabsize)
+                else:
+                    newline += line[i:]
+                    break
 
-    for i,line in enumerate(lines):
-        if fixonly is None or i in fixonly:
-            line = fixwhite_line(line)
+            line = newline
 
-        print >>f, line,
+        return line.rstrip() + '\n'
 
 def linelen(line):
     tabs = line.count('\t')
@@ -241,22 +316,6 @@ def validate(filename, stats, verbose, exit_code):
                     msg(i, line, 'improper spacing after %s' % match.group(1))
                 bad()
 
-def modified_lines(old_data, new_data, max_lines):
-    from itertools import count
-    from mercurial import bdiff, mdiff
-
-    modified = set()
-    counter = count()
-    for pbeg, pend, fbeg, fend in bdiff.blocks(old_data, new_data):
-        for i in counter:
-            if i < fbeg:
-                modified.add(i)
-            elif i + 1 >= fend:
-                break
-            elif i > max_lines:
-                break
-    return modified
-
 def do_check_style(hgui, repo, *files, **args):
     """check files for proper m5 style guidelines"""
     from mercurial import mdiff, util
@@ -272,30 +331,26 @@ def do_check_style(hgui, repo, *files, **args):
     def skip(name):
         return files and name in files
 
-    def prompt(name, func, fixonly=None):
+    def prompt(name, func, regions=all_regions):
         result = ui.prompt("(a)bort, (i)gnore, or (f)ix?", 'aif', 'a')
         if result == 'a':
             return True
         elif result == 'f':
-            func(repo.wjoin(name), fixonly)
+            func(repo.wjoin(name), regions)
 
         return False
 
     modified, added, removed, deleted, unknown, ignore, clean = repo.status()
 
+    whitespace = Whitespace(ui)
     for fname in added:
-        if skip(fname):
+        if skip(fname) or whitespace.skip(fname):
             continue
 
-        ok = True
-        for line,num in checkwhite(repo.wjoin(fname)):
-            ui.write("invalid whitespace in %s:%d\n" % (fname, num))
-            if ui.verbose:
-                ui.write(">>%s<<\n" % line[-1])
-            ok = False
-
-        if not ok:
-            if prompt(fname, fixwhite):
+        errors = whitespace.check(fname)
+        if errors:
+            print errors
+            if prompt(fname, whitespace.fix):
                 return True
 
     try:
@@ -305,41 +360,14 @@ def do_check_style(hgui, repo, *files, **args):
         wctx = context.workingctx(repo)
 
     for fname in modified:
-        if skip(fname):
+        if skip(fname) or whitespace.skip(fname):
             continue
 
-        if lang_type(fname) not in whitespace_types:
-            continue
-
-        fctx = wctx.filectx(fname)
-        pctx = fctx.parents()
-
-        file_data = fctx.data()
-        lines = mdiff.splitnewlines(file_data)
-        if len(pctx) in (1, 2):
-            mod_lines = modified_lines(pctx[0].data(), file_data, len(lines))
-            if len(pctx) == 2:
-                m2 = modified_lines(pctx[1].data(), file_data, len(lines))
-                # only the lines that are new in both
-                mod_lines = mod_lines & m2
-        else:
-            mod_lines = xrange(0, len(lines))
-
-        fixonly = set()
-        for i,line in enumerate(lines):
-            if i not in mod_lines:
-                continue
-
-            if checkwhite_line(line):
-                continue
-
-            ui.write("invalid whitespace: %s:%d\n" % (fname, i+1))
-            if ui.verbose:
-                ui.write(">>%s<<\n" % line[:-1])
-            fixonly.add(i)
+        regions = modregions(wctx, fname)
 
-        if fixonly:
-            if prompt(fname, fixwhite, fixonly):
+        errors = whitespace.check(fname, regions)
+        if errors:
+            if prompt(fname, whitespace.fix, regions):
                 return True
 
 def do_check_format(hgui, repo, **args):