llvmpipe: Specialize arithmetic operations.
[mesa.git] / src / gallium / drivers / llvmpipe / lp_bld_arit.c
1 /**************************************************************************
2 *
3 * Copyright 2009 VMware, Inc.
4 * All Rights Reserved.
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a
7 * copy of this software and associated documentation files (the
8 * "Software"), to deal in the Software without restriction, including
9 * without limitation the rights to use, copy, modify, merge, publish,
10 * distribute, sub license, and/or sell copies of the Software, and to
11 * permit persons to whom the Software is furnished to do so, subject to
12 * the following conditions:
13 *
14 * The above copyright notice and this permission notice (including the
15 * next paragraph) shall be included in all copies or substantial portions
16 * of the Software.
17 *
18 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
19 * OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
20 * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NON-INFRINGEMENT.
21 * IN NO EVENT SHALL VMWARE AND/OR ITS SUPPLIERS BE LIABLE FOR
22 * ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
23 * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
24 * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
25 *
26 **************************************************************************/
27
28
29 /**
30 * @file
31 * Helper
32 *
33 * LLVM IR doesn't support all basic arithmetic operations we care about (most
34 * notably min/max and saturated operations), and it is often necessary to
35 * resort machine-specific intrinsics directly. The functions here hide all
36 * these implementation details from the other modules.
37 *
38 * We also do simple expressions simplification here. Reasons are:
39 * - it is very easy given we have all necessary information readily available
40 * - LLVM optimization passes fail to simplify several vector expressions
41 * - We often know value constraints which the optimization passes have no way
42 * of knowing, such as when source arguments are known to be in [0, 1] range.
43 *
44 * @author Jose Fonseca <jfonseca@vmware.com>
45 */
46
47
48 #include "pipe/p_state.h"
49
50 #include "lp_bld_arit.h"
51
52
53 LLVMTypeRef
54 lp_build_elem_type(union lp_type type)
55 {
56 if (type.floating) {
57 assert(type.sign);
58 switch(type.width) {
59 case 32:
60 return LLVMFloatType();
61 break;
62 case 64:
63 return LLVMDoubleType();
64 break;
65 default:
66 assert(0);
67 return LLVMFloatType();
68 }
69 }
70 else {
71 return LLVMIntType(type.width);
72 }
73 }
74
75
76 LLVMTypeRef
77 lp_build_vec_type(union lp_type type)
78 {
79 LLVMTypeRef elem_type = lp_build_elem_type(type);
80 return LLVMVectorType(elem_type, type.length);
81 }
82
83
84 /**
85 * This function is a mirrot of lp_build_elem_type() above.
86 *
87 * XXX: I'm not sure if it wouldn't be easier/efficient to just recreate the
88 * type and check for identity.
89 */
90 boolean
91 lp_check_elem_type(union lp_type type, LLVMTypeRef elem_type)
92 {
93 LLVMTypeKind elem_kind;
94
95 assert(elem_type);
96 if(!elem_type)
97 return FALSE;
98
99 elem_kind = LLVMGetTypeKind(elem_type);
100
101 if (type.floating) {
102 switch(type.width) {
103 case 32:
104 if(elem_kind != LLVMFloatTypeKind)
105 return FALSE;
106 break;
107 case 64:
108 if(elem_kind != LLVMDoubleTypeKind)
109 return FALSE;
110 break;
111 default:
112 assert(0);
113 return FALSE;
114 }
115 }
116 else {
117 if(elem_kind != LLVMIntegerTypeKind)
118 return FALSE;
119
120 if(LLVMGetIntTypeWidth(elem_type) != type.width)
121 return FALSE;
122 }
123
124 return TRUE;
125 }
126
127
128 boolean
129 lp_check_vec_type(union lp_type type, LLVMTypeRef vec_type)
130 {
131 LLVMTypeRef elem_type;
132
133 assert(vec_type);
134 if(!vec_type)
135 return FALSE;
136
137 if(LLVMGetTypeKind(vec_type) != LLVMVectorTypeKind)
138 return FALSE;
139
140 if(LLVMGetVectorSize(vec_type) != type.length)
141 return FALSE;
142
143 elem_type = LLVMGetElementType(vec_type);
144
145 return lp_check_elem_type(type, elem_type);
146 }
147
148
149 boolean
150 lp_check_value(union lp_type type, LLVMValueRef val)
151 {
152 LLVMTypeRef vec_type;
153
154 assert(val);
155 if(!val)
156 return FALSE;
157
158 vec_type = LLVMTypeOf(val);
159
160 return lp_check_vec_type(type, vec_type);
161 }
162
163
164 LLVMValueRef
165 lp_build_undef(union lp_type type)
166 {
167 LLVMTypeRef vec_type = lp_build_vec_type(type);
168 return LLVMGetUndef(vec_type);
169 }
170
171
172 LLVMValueRef
173 lp_build_zero(union lp_type type)
174 {
175 LLVMTypeRef vec_type = lp_build_vec_type(type);
176 return LLVMConstNull(vec_type);
177 }
178
179
180 LLVMValueRef
181 lp_build_one(union lp_type type)
182 {
183 LLVMTypeRef elem_type;
184 LLVMValueRef elems[LP_MAX_VECTOR_LENGTH];
185 unsigned i;
186
187 assert(type.length < LP_MAX_VECTOR_LENGTH);
188
189 elem_type = lp_build_elem_type(type);
190
191 if(type.floating)
192 elems[0] = LLVMConstReal(elem_type, 1.0);
193 else if(type.fixed)
194 elems[0] = LLVMConstInt(elem_type, 1LL << (type.width/2), 0);
195 else if(!type.norm)
196 elems[0] = LLVMConstInt(elem_type, 1, 0);
197 else {
198 /* special case' -- 1.0 for normalized types is more easily attained if
199 * we start with a vector consisting of all bits set */
200 LLVMTypeRef vec_type = LLVMVectorType(elem_type, type.length);
201 LLVMValueRef vec = LLVMConstAllOnes(vec_type);
202
203 if(type.sign)
204 vec = LLVMConstLShr(vec, LLVMConstInt(LLVMInt32Type(), 1, 0));
205
206 return vec;
207 }
208
209 for(i = 1; i < type.length; ++i)
210 elems[i] = elems[0];
211
212 return LLVMConstVector(elems, type.length);
213 }
214
215
216 LLVMValueRef
217 lp_build_const_aos(union lp_type type,
218 double r, double g, double b, double a,
219 const unsigned char *swizzle)
220 {
221 const unsigned char default_swizzle[4] = {0, 1, 2, 3};
222 LLVMTypeRef elem_type;
223 LLVMValueRef elems[LP_MAX_VECTOR_LENGTH];
224 unsigned i;
225
226 assert(type.length % 4 == 0);
227 assert(type.length < LP_MAX_VECTOR_LENGTH);
228
229 elem_type = lp_build_elem_type(type);
230
231 if(swizzle == NULL)
232 swizzle = default_swizzle;
233
234 if(type.floating) {
235 elems[swizzle[0]] = LLVMConstReal(elem_type, r);
236 elems[swizzle[1]] = LLVMConstReal(elem_type, g);
237 elems[swizzle[2]] = LLVMConstReal(elem_type, b);
238 elems[swizzle[3]] = LLVMConstReal(elem_type, a);
239 }
240 else {
241 unsigned shift;
242 long long llscale;
243 double dscale;
244
245 if(type.fixed)
246 shift = type.width/2;
247 else if(type.norm)
248 shift = type.sign ? type.width - 1 : type.width;
249 else
250 shift = 0;
251
252 llscale = (long long)1 << shift;
253 dscale = (double)llscale;
254 assert((long long)dscale == llscale);
255
256 elems[swizzle[0]] = LLVMConstInt(elem_type, r*dscale + 0.5, 0);
257 elems[swizzle[1]] = LLVMConstInt(elem_type, g*dscale + 0.5, 0);
258 elems[swizzle[2]] = LLVMConstInt(elem_type, b*dscale + 0.5, 0);
259 elems[swizzle[3]] = LLVMConstInt(elem_type, a*dscale + 0.5, 0);
260 }
261
262 for(i = 4; i < type.length; ++i)
263 elems[i] = elems[i % 4];
264
265 return LLVMConstVector(elems, type.length);
266 }
267
268
269 static LLVMValueRef
270 lp_build_intrinsic_binary(LLVMBuilderRef builder,
271 const char *name,
272 LLVMValueRef a,
273 LLVMValueRef b)
274 {
275 LLVMModuleRef module = LLVMGetGlobalParent(LLVMGetBasicBlockParent(LLVMGetInsertBlock(builder)));
276 LLVMValueRef function;
277 LLVMValueRef args[2];
278
279 function = LLVMGetNamedFunction(module, name);
280 if(!function) {
281 LLVMTypeRef type = LLVMTypeOf(a);
282 LLVMTypeRef arg_types[2];
283 arg_types[0] = type;
284 arg_types[1] = type;
285 function = LLVMAddFunction(module, name, LLVMFunctionType(type, arg_types, 2, 0));
286 LLVMSetFunctionCallConv(function, LLVMCCallConv);
287 LLVMSetLinkage(function, LLVMExternalLinkage);
288 }
289 assert(LLVMIsDeclaration(function));
290
291 args[0] = a;
292 args[1] = b;
293
294 return LLVMBuildCall(builder, function, args, 2, "");
295 }
296
297
298 static LLVMValueRef
299 lp_build_min_simple(struct lp_build_context *bld,
300 LLVMValueRef a,
301 LLVMValueRef b)
302 {
303 const union lp_type type = bld->type;
304 const char *intrinsic = NULL;
305 LLVMValueRef cond;
306
307 /* TODO: optimize the constant case */
308
309 #if defined(PIPE_ARCH_X86) || defined(PIPE_ARCH_X86_64)
310 if(type.width * type.length == 128) {
311 if(type.floating)
312 if(type.width == 32)
313 intrinsic = "llvm.x86.sse.min.ps";
314 if(type.width == 64)
315 intrinsic = "llvm.x86.sse2.min.pd";
316 else {
317 if(type.width == 8 && !type.sign)
318 intrinsic = "llvm.x86.sse2.pminu.b";
319 if(type.width == 16 && type.sign)
320 intrinsic = "llvm.x86.sse2.pmins.w";
321 }
322 }
323 #endif
324
325 if(intrinsic)
326 return lp_build_intrinsic_binary(bld->builder, intrinsic, a, b);
327
328 if(type.floating)
329 cond = LLVMBuildFCmp(bld->builder, LLVMRealULT, a, b, "");
330 else
331 cond = LLVMBuildICmp(bld->builder, type.sign ? LLVMIntSLT : LLVMIntULT, a, b, "");
332 return LLVMBuildSelect(bld->builder, cond, a, b, "");
333 }
334
335
336 static LLVMValueRef
337 lp_build_max_simple(struct lp_build_context *bld,
338 LLVMValueRef a,
339 LLVMValueRef b)
340 {
341 const union lp_type type = bld->type;
342 const char *intrinsic = NULL;
343 LLVMValueRef cond;
344
345 /* TODO: optimize the constant case */
346
347 #if defined(PIPE_ARCH_X86) || defined(PIPE_ARCH_X86_64)
348 if(type.width * type.length == 128) {
349 if(type.floating)
350 if(type.width == 32)
351 intrinsic = "llvm.x86.sse.max.ps";
352 if(type.width == 64)
353 intrinsic = "llvm.x86.sse2.max.pd";
354 else {
355 if(type.width == 8 && !type.sign)
356 intrinsic = "llvm.x86.sse2.pmaxu.b";
357 if(type.width == 16 && type.sign)
358 intrinsic = "llvm.x86.sse2.pmaxs.w";
359 }
360 }
361 #endif
362
363 if(intrinsic)
364 return lp_build_intrinsic_binary(bld->builder, intrinsic, a, b);
365
366 if(type.floating)
367 cond = LLVMBuildFCmp(bld->builder, LLVMRealULT, a, b, "");
368 else
369 cond = LLVMBuildICmp(bld->builder, type.sign ? LLVMIntSLT : LLVMIntULT, a, b, "");
370 return LLVMBuildSelect(bld->builder, cond, b, a, "");
371 }
372
373
374 LLVMValueRef
375 lp_build_comp(struct lp_build_context *bld,
376 LLVMValueRef a)
377 {
378 const union lp_type type = bld->type;
379
380 if(a == bld->one)
381 return bld->zero;
382 if(a == bld->zero)
383 return bld->one;
384
385 if(type.norm && !type.floating && !type.fixed && !type.sign) {
386 if(LLVMIsConstant(a))
387 return LLVMConstNot(a);
388 else
389 return LLVMBuildNot(bld->builder, a, "");
390 }
391
392 if(LLVMIsConstant(a))
393 return LLVMConstSub(bld->one, a);
394 else
395 return LLVMBuildSub(bld->builder, bld->one, a, "");
396 }
397
398
399 LLVMValueRef
400 lp_build_add(struct lp_build_context *bld,
401 LLVMValueRef a,
402 LLVMValueRef b)
403 {
404 const union lp_type type = bld->type;
405 LLVMValueRef res;
406
407 if(a == bld->zero)
408 return b;
409 if(b == bld->zero)
410 return a;
411 if(a == bld->undef || b == bld->undef)
412 return bld->undef;
413
414 if(bld->type.norm) {
415 const char *intrinsic = NULL;
416
417 if(a == bld->one || b == bld->one)
418 return bld->one;
419
420 #if defined(PIPE_ARCH_X86) || defined(PIPE_ARCH_X86_64)
421 if(type.width * type.length == 128 &&
422 !type.floating && !type.fixed) {
423 if(type.width == 8)
424 intrinsic = type.sign ? "llvm.x86.sse2.adds.b" : "llvm.x86.sse2.addus.b";
425 if(type.width == 16)
426 intrinsic = type.sign ? "llvm.x86.sse2.adds.w" : "llvm.x86.sse2.addus.w";
427 }
428 #endif
429
430 if(intrinsic)
431 return lp_build_intrinsic_binary(bld->builder, intrinsic, a, b);
432 }
433
434 if(LLVMIsConstant(a) && LLVMIsConstant(b))
435 res = LLVMConstAdd(a, b);
436 else
437 res = LLVMBuildAdd(bld->builder, a, b, "");
438
439 if(bld->type.norm && (bld->type.floating || bld->type.fixed))
440 res = lp_build_min_simple(bld, res, bld->one);
441
442 return res;
443 }
444
445
446 LLVMValueRef
447 lp_build_sub(struct lp_build_context *bld,
448 LLVMValueRef a,
449 LLVMValueRef b)
450 {
451 const union lp_type type = bld->type;
452 LLVMValueRef res;
453
454 if(b == bld->zero)
455 return a;
456 if(a == bld->undef || b == bld->undef)
457 return bld->undef;
458 if(a == b)
459 return bld->zero;
460
461 if(bld->type.norm) {
462 const char *intrinsic = NULL;
463
464 if(b == bld->one)
465 return bld->zero;
466
467 #if defined(PIPE_ARCH_X86) || defined(PIPE_ARCH_X86_64)
468 if(type.width * type.length == 128 &&
469 !type.floating && !type.fixed) {
470 if(type.width == 8)
471 intrinsic = type.sign ? "llvm.x86.sse2.subs.b" : "llvm.x86.sse2.subus.b";
472 if(type.width == 16)
473 intrinsic = type.sign ? "llvm.x86.sse2.subs.w" : "llvm.x86.sse2.subus.w";
474 }
475 #endif
476
477 if(intrinsic)
478 return lp_build_intrinsic_binary(bld->builder, intrinsic, a, b);
479 }
480
481 if(LLVMIsConstant(a) && LLVMIsConstant(b))
482 res = LLVMConstSub(a, b);
483 else
484 res = LLVMBuildSub(bld->builder, a, b, "");
485
486 if(bld->type.norm && (bld->type.floating || bld->type.fixed))
487 res = lp_build_max_simple(bld, res, bld->zero);
488
489 return res;
490 }
491
492
493 LLVMValueRef
494 lp_build_mul(struct lp_build_context *bld,
495 LLVMValueRef a,
496 LLVMValueRef b)
497 {
498 if(a == bld->zero)
499 return bld->zero;
500 if(a == bld->one)
501 return b;
502 if(b == bld->zero)
503 return bld->zero;
504 if(b == bld->one)
505 return a;
506 if(a == bld->undef || b == bld->undef)
507 return bld->undef;
508
509 if(LLVMIsConstant(a) && LLVMIsConstant(b))
510 return LLVMConstMul(a, b);
511
512 return LLVMBuildMul(bld->builder, a, b, "");
513 }
514
515
516 LLVMValueRef
517 lp_build_min(struct lp_build_context *bld,
518 LLVMValueRef a,
519 LLVMValueRef b)
520 {
521 if(a == bld->undef || b == bld->undef)
522 return bld->undef;
523
524 if(bld->type.norm) {
525 if(a == bld->zero || b == bld->zero)
526 return bld->zero;
527 if(a == bld->one)
528 return b;
529 if(b == bld->one)
530 return a;
531 }
532
533 return lp_build_min_simple(bld, a, b);
534 }
535
536
537 LLVMValueRef
538 lp_build_max(struct lp_build_context *bld,
539 LLVMValueRef a,
540 LLVMValueRef b)
541 {
542 if(a == bld->undef || b == bld->undef)
543 return bld->undef;
544
545 if(bld->type.norm) {
546 if(a == bld->one || b == bld->one)
547 return bld->one;
548 if(a == bld->zero)
549 return b;
550 if(b == bld->zero)
551 return a;
552 }
553
554 return lp_build_max_simple(bld, a, b);
555 }