style: add sort_includes to the style hook
authorNathan Binkert <nate@binkert.org>
Fri, 15 Apr 2011 17:43:51 +0000 (10:43 -0700)
committerNathan Binkert <nate@binkert.org>
Fri, 15 Apr 2011 17:43:51 +0000 (10:43 -0700)
util/style.py

index 68f1cdffe52c5dfc9727d024eb20a7af8a37262a..cd9cd965260fbbc0ffcad381d42ed88d9e69f4b1 100644 (file)
@@ -42,6 +42,7 @@ sys.path.insert(0, current_dir)
 sys.path.insert(1, joinpath(dirname(current_dir), 'src', 'python'))
 
 from m5.util import neg_inf, pos_inf, Region, Regions
+import sort_includes
 from file_types import lang_type
 
 all_regions = Region(neg_inf, pos_inf)
@@ -184,6 +185,15 @@ class Verifier(object):
             f.write(line)
         f.close()
 
+    def apply(self, filename, prompt, regions=all_regions):
+        if not self.skip(filename):
+            errors = self.check(filename, regions)
+            if errors:
+                if prompt(filename, self.fix, regions):
+                    return True
+        return False
+
+
 class Whitespace(Verifier):
     languages = set(('C', 'C++', 'swig', 'python', 'asm', 'isa', 'scons'))
     test_name = 'whitespace'
@@ -214,6 +224,53 @@ class Whitespace(Verifier):
 
         return line.rstrip() + '\n'
 
+class SortedIncludes(Verifier):
+    languages = sort_includes.default_languages
+    def __init__(self, *args, **kwargs):
+        super(SortedIncludes, self).__init__(*args, **kwargs)
+        self.sort_includes = sort_includes.SortIncludes()
+
+    def check(self, filename, regions=all_regions):
+        f = self.open(filename, 'r')
+
+        lines = [ l.rstrip('\n') for l in f.xreadlines() ]
+        old = ''.join(line + '\n' for line in lines)
+        f.close()
+
+        language = lang_type(filename, lines[0])
+        sort_lines = list(self.sort_includes(lines, filename, language))
+        new = ''.join(line + '\n' for line in sort_lines)
+
+        mod = modified_regions(old, new)
+        modified = mod & regions
+        print mod, regions, modified
+
+        if modified:
+            self.write("invalid sorting of includes\n")
+            if self.ui.verbose:
+                for start, end in modified.regions:
+                    self.write("bad region [%d, %d)\n" % (start, end))
+            return 1
+
+        return 0
+
+    def fix(self, filename, regions=all_regions):
+        f = self.open(filename, 'r+')
+
+        old = f.readlines()
+        lines = [ l.rstrip('\n') for l in old ]
+        language = lang_type(filename, lines[0])
+        sort_lines = list(self.sort_includes(lines, filename, language))
+        new = ''.join(line + '\n' for line in sort_lines)
+
+        f.seek(0)
+        f.truncate()
+
+        for i,line in enumerate(sort_lines):
+            f.write(line)
+            f.write('\n')
+        f.close()
+
 def linelen(line):
     tabs = line.count('\t')
     if not tabs:
@@ -343,15 +400,16 @@ def do_check_style(hgui, repo, *files, **args):
     modified, added, removed, deleted, unknown, ignore, clean = repo.status()
 
     whitespace = Whitespace(ui)
+    sorted_includes = SortedIncludes(ui)
     for fname in added:
-        if skip(fname) or whitespace.skip(fname):
+        if skip(fname):
             continue
 
-        errors = whitespace.check(fname)
-        if errors:
-            print errors
-            if prompt(fname, whitespace.fix):
-                return True
+        if whitespace.apply(fname, prompt):
+            return True
+
+        if sorted_includes.apply(fname, prompt):
+            return True
 
     try:
         wctx = repo.workingctx()
@@ -360,15 +418,18 @@ def do_check_style(hgui, repo, *files, **args):
         wctx = context.workingctx(repo)
 
     for fname in modified:
-        if skip(fname) or whitespace.skip(fname):
+        if skip(fname):
             continue
 
         regions = modregions(wctx, fname)
 
-        errors = whitespace.check(fname, regions)
-        if errors:
-            if prompt(fname, whitespace.fix, regions):
-                return True
+        if whitespace.apply(fname, prompt, regions):
+            return True
+
+        if sorted_includes.apply(fname, prompt, regions):
+            return True
+
+    return False
 
 def do_check_format(hgui, repo, **args):
     ui = MercurialUI(hgui, hgui.verbose, auto)