d8cf84db363fd4b869b022145730d4e9bbb43a6e
[mesa.git] / src / gallium / tests / python / tests / base.py
1 #!/usr/bin/env python
2 ##########################################################################
3 #
4 # Copyright 2009 VMware, Inc.
5 # Copyright 2008 Tungsten Graphics, Inc., Cedar Park, Texas.
6 # All Rights Reserved.
7 #
8 # Permission is hereby granted, free of charge, to any person obtaining a
9 # copy of this software and associated documentation files (the
10 # "Software"), to deal in the Software without restriction, including
11 # without limitation the rights to use, copy, modify, merge, publish,
12 # distribute, sub license, and/or sell copies of the Software, and to
13 # permit persons to whom the Software is furnished to do so, subject to
14 # the following conditions:
15 #
16 # The above copyright notice and this permission notice (including the
17 # next paragraph) shall be included in all copies or substantial portions
18 # of the Software.
19 #
20 # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
21 # OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
22 # MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NON-INFRINGEMENT.
23 # IN NO EVENT SHALL VMWARE AND/OR ITS SUPPLIERS BE LIABLE FOR
24 # ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
25 # TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
26 # SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
27 #
28 ##########################################################################
29
30
31 """Base classes for tests.
32
33 Loosely inspired on Python's unittest module.
34 """
35
36
37 import os.path
38 import sys
39
40 from gallium import *
41
42
43 # Enumerate all pixel formats
44 formats = {}
45 for name, value in globals().items():
46 if name.startswith("PIPE_FORMAT_") and isinstance(value, int) and name not in ("PIPE_FORMAT_NONE", "PIPE_FORMAT_COUNT"):
47 formats[value] = name
48
49 def make_image(width, height, rgba):
50 import Image
51 outimage = Image.new(
52 mode='RGB',
53 size=(width, height),
54 color=(0,0,0))
55 outpixels = outimage.load()
56 for y in range(0, height):
57 for x in range(0, width):
58 offset = (y*width + x)*4
59 r, g, b, a = [int(min(max(rgba[offset + ch], 0.0), 1.0)*255) for ch in range(4)]
60 outpixels[x, y] = r, g, b
61 return outimage
62
63 def save_image(width, height, rgba, filename):
64 outimage = make_image(width, height, rgba)
65 outimage.save(filename, "PNG")
66
67 def show_image(width, height, **rgbas):
68 import Tkinter as tk
69 from PIL import Image, ImageTk
70
71 root = tk.Tk()
72
73 x = 64
74 y = 64
75
76 labels = rgbas.keys()
77 labels.sort()
78 for i in range(len(labels)):
79 label = labels[i]
80 outimage = make_image(width, height, rgbas[label])
81
82 if i:
83 window = tk.Toplevel(root)
84 else:
85 window = root
86 window.title(label)
87 image1 = ImageTk.PhotoImage(outimage)
88 w = image1.width()
89 h = image1.height()
90 window.geometry("%dx%d+%d+%d" % (w, h, x, y))
91 panel1 = tk.Label(window, image=image1)
92 panel1.pack(side='top', fill='both', expand='yes')
93 panel1.image = image1
94 x += w + 2
95
96 root.mainloop()
97
98
99 class TestFailure(Exception):
100
101 pass
102
103 class TestSkip(Exception):
104
105 pass
106
107
108 class Test:
109
110 def __init__(self):
111 pass
112
113 def _run(self, result):
114 raise NotImplementedError
115
116 def run(self):
117 result = TestResult()
118 self._run(result)
119 result.report()
120
121 def assert_rgba(self, ctx, surface, x, y, w, h, expected_rgba, pixel_tol=4.0/256, surface_tol=0.85):
122 total = h*w
123 different = ctx.surface_compare_rgba(surface, x, y, w, h, expected_rgba, tol=pixel_tol)
124 if different:
125 sys.stderr.write("%u out of %u pixels differ\n" % (different, total))
126
127 if float(total - different)/float(total) < surface_tol:
128 if 0:
129 rgba = FloatArray(h*w*4)
130 ctx.surface_read_rgba(surface, x, y, w, h, rgba)
131 show_image(w, h, Result=rgba, Expected=expected_rgba)
132 save_image(w, h, rgba, "result.png")
133 save_image(w, h, expected_rgba, "expected.png")
134 #sys.exit(0)
135
136 raise TestFailure
137
138
139 class TestCase(Test):
140
141 tags = ()
142
143 def __init__(self, dev, **kargs):
144 Test.__init__(self)
145 self.dev = dev
146 self.__dict__.update(kargs)
147
148 def description(self):
149 descriptions = []
150 for tag in self.tags:
151 value = self.get(tag)
152 if value is not None and value != '':
153 descriptions.append(tag + '=' + str(value))
154 return ' '.join(descriptions)
155
156 def get(self, tag):
157 try:
158 method = getattr(self, '_get_' + tag)
159 except AttributeError:
160 return getattr(self, tag, None)
161 else:
162 return method()
163
164 def _get_target(self):
165 return {
166 PIPE_TEXTURE_1D: "1d",
167 PIPE_TEXTURE_2D: "2d",
168 PIPE_TEXTURE_3D: "3d",
169 PIPE_TEXTURE_CUBE: "cube",
170 }[self.target]
171
172 def _get_format(self):
173 name = formats[self.format]
174 if name.startswith('PIPE_FORMAT_'):
175 name = name[12:]
176 name = name.lower()
177 return name
178
179 def _get_face(self):
180 if self.target == PIPE_TEXTURE_CUBE:
181 return {
182 PIPE_TEX_FACE_POS_X: "+x",
183 PIPE_TEX_FACE_NEG_X: "-x",
184 PIPE_TEX_FACE_POS_Y: "+y",
185 PIPE_TEX_FACE_NEG_Y: "-y",
186 PIPE_TEX_FACE_POS_Z: "+z",
187 PIPE_TEX_FACE_NEG_Z: "-z",
188 }[self.face]
189 else:
190 return ''
191
192 def test(self):
193 raise NotImplementedError
194
195 def _run(self, result):
196 result.test_start(self)
197 try:
198 self.test()
199 except KeyboardInterrupt:
200 raise
201 except TestSkip:
202 result.test_skipped(self)
203 except TestFailure:
204 result.test_failed(self)
205 else:
206 result.test_passed(self)
207
208
209 class TestSuite(Test):
210
211 def __init__(self, tests = None):
212 Test.__init__(self)
213 if tests is None:
214 self.tests = []
215 else:
216 self.tests = tests
217
218 def add_test(self, test):
219 self.tests.append(test)
220
221 def _run(self, result):
222 for test in self.tests:
223 test._run(result)
224
225
226 class TestResult:
227
228 def __init__(self):
229 self.tests = 0
230 self.passed = 0
231 self.skipped = 0
232 self.failed = 0
233
234 self.names = ['result']
235 self.types = ['pass skip fail']
236 self.rows = []
237
238 def test_start(self, test):
239 sys.stdout.write("Running %s...\n" % test.description())
240 sys.stdout.flush()
241 self.tests += 1
242
243 def test_passed(self, test):
244 sys.stdout.write("PASS\n")
245 sys.stdout.flush()
246 self.passed += 1
247 self.log_result(test, 'pass')
248
249 def test_skipped(self, test):
250 sys.stdout.write("SKIP\n")
251 sys.stdout.flush()
252 self.skipped += 1
253 self.log_result(test, 'skip')
254
255 def test_failed(self, test):
256 sys.stdout.write("FAIL\n")
257 sys.stdout.flush()
258 self.failed += 1
259 self.log_result(test, 'fail')
260
261 def log_result(self, test, result):
262 row = ['']*len(self.names)
263
264 # add result
265 assert self.names[0] == 'result'
266 assert result in ('pass', 'skip', 'fail')
267 row[0] = result
268
269 # add tags
270 for tag in test.tags:
271 value = test.get(tag)
272
273 # infer type
274 if value is None:
275 continue
276 elif isinstance(value, (int, float)):
277 value = str(value)
278 type = 'c' # continous
279 elif isinstance(value, basestring):
280 type = 'd' # discrete
281 else:
282 assert False
283 value = str(value)
284 type = 'd' # discrete
285
286 # insert value
287 try:
288 col = self.names.index(tag, 1)
289 except ValueError:
290 self.names.append(tag)
291 self.types.append(type)
292 row.append(value)
293 else:
294 row[col] = value
295 assert self.types[col] == type
296
297 self.rows.append(row)
298
299 def report(self):
300 sys.stdout.write("%u tests, %u passed, %u skipped, %u failed\n\n" % (self.tests, self.passed, self.skipped, self.failed))
301 sys.stdout.flush()
302
303 name, ext = os.path.splitext(os.path.basename(sys.argv[0]))
304
305 tree = self.report_tree(name)
306 self.report_junit(name, stdout=tree)
307
308 def report_tree(self, name):
309 filename = name + '.tsv'
310 stream = file(filename, 'wt')
311
312 # header
313 stream.write('\t'.join(self.names) + '\n')
314 stream.write('\t'.join(self.types) + '\n')
315 stream.write('class\n')
316
317 # rows
318 for row in self.rows:
319 if row[0] == 'skip':
320 continue
321 row += ['']*(len(self.names) - len(row))
322 stream.write('\t'.join(row) + '\n')
323
324 stream.close()
325
326 # See http://www.ailab.si/orange/doc/ofb/c_otherclass.htm
327 try:
328 import orange
329 import orngTree
330 except ImportError:
331 sys.stderr.write('Install Orange from http://www.ailab.si/orange/ for a classification tree.\n')
332 return None
333
334 data = orange.ExampleTable(filename)
335
336 tree = orngTree.TreeLearner(data, sameMajorityPruning=1, mForPruning=2)
337
338 orngTree.printTxt(tree, maxDepth=4)
339
340 text_tree = orngTree.dumpTree(tree)
341
342 file(name + '.txt', 'wt').write(text_tree)
343
344 orngTree.printDot(tree, fileName=name+'.dot', nodeShape='ellipse', leafShape='box')
345
346 return text_tree
347
348 def report_junit(self, name, stdout=None, stderr=None):
349 """Write test results in ANT's junit XML format, to use with Hudson CI.
350
351 See also:
352 - http://fisheye.hudson-ci.org/browse/Hudson/trunk/hudson/main/core/src/test/resources/hudson/tasks/junit
353 - http://www.junit.org/node/399
354 - http://wiki.apache.org/ant/Proposals/EnhancedTestReports
355 """
356
357 stream = file(name + '.xml', 'wt')
358
359 stream.write('<?xml version="1.0" encoding="UTF-8" ?>\n')
360 stream.write('<testsuite name="%s">\n' % self.escape_xml(name))
361 stream.write(' <properties>\n')
362 stream.write(' </properties>\n')
363
364 names = self.names[1:]
365
366 for row in self.rows:
367
368 test_name = ' '.join(['%s=%s' % pair for pair in zip(self.names[1:], row[1:])])
369
370 stream.write(' <testcase name="%s">\n' % (self.escape_xml(test_name)))
371
372 result = row[0]
373 if result == 'pass':
374 pass
375 elif result == 'skip':
376 stream.write(' <skipped/>\n')
377 else:
378 stream.write(' <failure/>\n')
379
380 stream.write(' </testcase>\n')
381
382 if stdout:
383 stream.write(' <system-out>%s</system-out>\n' % self.escape_xml(stdout))
384 if stderr:
385 stream.write(' <system-err>%s</system-err>\n' % self.escape_xml(stderr))
386
387 stream.write('</testsuite>\n')
388
389 stream.close()
390
391 def escape_xml(self, s):
392 '''Escape a XML string.'''
393 s = s.replace('&', '&amp;')
394 s = s.replace('<', '&lt;')
395 s = s.replace('>', '&gt;')
396 s = s.replace('"', '&quot;')
397 s = s.replace("'", '&apos;')
398 return s
399