2 * Copyright © 2014 Intel Corporation
4 * Permission is hereby granted, free of charge, to any person obtaining a
5 * copy of this software and associated documentation files (the "Software"),
6 * to deal in the Software without restriction, including without limitation
7 * the rights to use, copy, modify, merge, publish, distribute, sublicense,
8 * and/or sell copies of the Software, and to permit persons to whom the
9 * Software is furnished to do so, subject to the following conditions:
11 * The above copyright notice and this permission notice (including the next
12 * paragraph) shall be included in all copies or substantial portions of the
15 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
18 * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
20 * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
24 * Jason Ekstrand (jason@jlekstrand.net)
29 #include "nir_search.h"
30 #include "nir_builder.h"
31 #include "util/half_float.h"
33 /* This should be the same as nir_search_max_comm_ops in nir_algebraic.py. */
34 #define NIR_SEARCH_MAX_COMM_OPS 8
39 uint8_t comm_op_direction
;
40 unsigned variables_seen
;
41 nir_alu_src variables
[NIR_SEARCH_MAX_VARIABLES
];
45 match_expression(const nir_search_expression
*expr
, nir_alu_instr
*instr
,
46 unsigned num_components
, const uint8_t *swizzle
,
47 struct match_state
*state
);
49 static const uint8_t identity_swizzle
[NIR_MAX_VEC_COMPONENTS
] = { 0, 1, 2, 3 };
52 * Check if a source produces a value of the given type.
54 * Used for satisfying 'a@type' constraints.
57 src_is_type(nir_src src
, nir_alu_type type
)
59 assert(type
!= nir_type_invalid
);
64 if (src
.ssa
->parent_instr
->type
== nir_instr_type_alu
) {
65 nir_alu_instr
*src_alu
= nir_instr_as_alu(src
.ssa
->parent_instr
);
66 nir_alu_type output_type
= nir_op_infos
[src_alu
->op
].output_type
;
68 if (type
== nir_type_bool
) {
69 switch (src_alu
->op
) {
73 return src_is_type(src_alu
->src
[0].src
, nir_type_bool
) &&
74 src_is_type(src_alu
->src
[1].src
, nir_type_bool
);
76 return src_is_type(src_alu
->src
[0].src
, nir_type_bool
);
82 return nir_alu_type_get_base_type(output_type
) == type
;
83 } else if (src
.ssa
->parent_instr
->type
== nir_instr_type_intrinsic
) {
84 nir_intrinsic_instr
*intr
= nir_instr_as_intrinsic(src
.ssa
->parent_instr
);
86 if (type
== nir_type_bool
) {
87 return intr
->intrinsic
== nir_intrinsic_load_front_face
||
88 intr
->intrinsic
== nir_intrinsic_load_helper_invocation
;
97 nir_op_matches_search_op(nir_op nop
, uint16_t sop
)
99 if (sop
<= nir_last_opcode
)
102 #define MATCH_FCONV_CASE(op) \
103 case nir_search_op_##op: \
104 return nop == nir_op_##op##16 || \
105 nop == nir_op_##op##32 || \
106 nop == nir_op_##op##64;
108 #define MATCH_ICONV_CASE(op) \
109 case nir_search_op_##op: \
110 return nop == nir_op_##op##8 || \
111 nop == nir_op_##op##16 || \
112 nop == nir_op_##op##32 || \
113 nop == nir_op_##op##64;
115 #define MATCH_BCONV_CASE(op) \
116 case nir_search_op_##op: \
117 return nop == nir_op_##op##1 || \
118 nop == nir_op_##op##32;
121 MATCH_FCONV_CASE(i2f
)
122 MATCH_FCONV_CASE(u2f
)
123 MATCH_FCONV_CASE(f2f
)
124 MATCH_ICONV_CASE(f2u
)
125 MATCH_ICONV_CASE(f2i
)
126 MATCH_ICONV_CASE(u2u
)
127 MATCH_ICONV_CASE(i2i
)
128 MATCH_FCONV_CASE(b2f
)
129 MATCH_ICONV_CASE(b2i
)
130 MATCH_BCONV_CASE(i2b
)
131 MATCH_BCONV_CASE(f2b
)
133 unreachable("Invalid nir_search_op");
136 #undef MATCH_FCONV_CASE
137 #undef MATCH_ICONV_CASE
138 #undef MATCH_BCONV_CASE
142 nir_search_op_for_nir_op(nir_op nop
)
144 #define MATCH_FCONV_CASE(op) \
145 case nir_op_##op##16: \
146 case nir_op_##op##32: \
147 case nir_op_##op##64: \
148 return nir_search_op_##op;
150 #define MATCH_ICONV_CASE(op) \
151 case nir_op_##op##8: \
152 case nir_op_##op##16: \
153 case nir_op_##op##32: \
154 case nir_op_##op##64: \
155 return nir_search_op_##op;
157 #define MATCH_BCONV_CASE(op) \
158 case nir_op_##op##1: \
159 case nir_op_##op##32: \
160 return nir_search_op_##op;
164 MATCH_FCONV_CASE(i2f
)
165 MATCH_FCONV_CASE(u2f
)
166 MATCH_FCONV_CASE(f2f
)
167 MATCH_ICONV_CASE(f2u
)
168 MATCH_ICONV_CASE(f2i
)
169 MATCH_ICONV_CASE(u2u
)
170 MATCH_ICONV_CASE(i2i
)
171 MATCH_FCONV_CASE(b2f
)
172 MATCH_ICONV_CASE(b2i
)
173 MATCH_BCONV_CASE(i2b
)
174 MATCH_BCONV_CASE(f2b
)
179 #undef MATCH_FCONV_CASE
180 #undef MATCH_ICONV_CASE
181 #undef MATCH_BCONV_CASE
185 nir_op_for_search_op(uint16_t sop
, unsigned bit_size
)
187 if (sop
<= nir_last_opcode
)
190 #define RET_FCONV_CASE(op) \
191 case nir_search_op_##op: \
192 switch (bit_size) { \
193 case 16: return nir_op_##op##16; \
194 case 32: return nir_op_##op##32; \
195 case 64: return nir_op_##op##64; \
196 default: unreachable("Invalid bit size"); \
199 #define RET_ICONV_CASE(op) \
200 case nir_search_op_##op: \
201 switch (bit_size) { \
202 case 8: return nir_op_##op##8; \
203 case 16: return nir_op_##op##16; \
204 case 32: return nir_op_##op##32; \
205 case 64: return nir_op_##op##64; \
206 default: unreachable("Invalid bit size"); \
209 #define RET_BCONV_CASE(op) \
210 case nir_search_op_##op: \
211 switch (bit_size) { \
212 case 1: return nir_op_##op##1; \
213 case 32: return nir_op_##op##32; \
214 default: unreachable("Invalid bit size"); \
230 unreachable("Invalid nir_search_op");
233 #undef RET_FCONV_CASE
234 #undef RET_ICONV_CASE
235 #undef RET_BCONV_CASE
239 match_value(const nir_search_value
*value
, nir_alu_instr
*instr
, unsigned src
,
240 unsigned num_components
, const uint8_t *swizzle
,
241 struct match_state
*state
)
243 uint8_t new_swizzle
[NIR_MAX_VEC_COMPONENTS
];
245 /* Searching only works on SSA values because, if it's not SSA, we can't
246 * know if the value changed between one instance of that value in the
247 * expression and another. Also, the replace operation will place reads of
248 * that value right before the last instruction in the expression we're
249 * replacing so those reads will happen after the original reads and may
250 * not be valid if they're register reads.
252 assert(instr
->src
[src
].src
.is_ssa
);
254 /* If the source is an explicitly sized source, then we need to reset
255 * both the number of components and the swizzle.
257 if (nir_op_infos
[instr
->op
].input_sizes
[src
] != 0) {
258 num_components
= nir_op_infos
[instr
->op
].input_sizes
[src
];
259 swizzle
= identity_swizzle
;
262 for (unsigned i
= 0; i
< num_components
; ++i
)
263 new_swizzle
[i
] = instr
->src
[src
].swizzle
[swizzle
[i
]];
265 /* If the value has a specific bit size and it doesn't match, bail */
266 if (value
->bit_size
> 0 &&
267 nir_src_bit_size(instr
->src
[src
].src
) != value
->bit_size
)
270 switch (value
->type
) {
271 case nir_search_value_expression
:
272 if (instr
->src
[src
].src
.ssa
->parent_instr
->type
!= nir_instr_type_alu
)
275 return match_expression(nir_search_value_as_expression(value
),
276 nir_instr_as_alu(instr
->src
[src
].src
.ssa
->parent_instr
),
277 num_components
, new_swizzle
, state
);
279 case nir_search_value_variable
: {
280 nir_search_variable
*var
= nir_search_value_as_variable(value
);
281 assert(var
->variable
< NIR_SEARCH_MAX_VARIABLES
);
283 if (state
->variables_seen
& (1 << var
->variable
)) {
284 if (state
->variables
[var
->variable
].src
.ssa
!= instr
->src
[src
].src
.ssa
)
287 assert(!instr
->src
[src
].abs
&& !instr
->src
[src
].negate
);
289 for (unsigned i
= 0; i
< num_components
; ++i
) {
290 if (state
->variables
[var
->variable
].swizzle
[i
] != new_swizzle
[i
])
296 if (var
->is_constant
&&
297 instr
->src
[src
].src
.ssa
->parent_instr
->type
!= nir_instr_type_load_const
)
300 if (var
->cond
&& !var
->cond(instr
, src
, num_components
, new_swizzle
))
303 if (var
->type
!= nir_type_invalid
&&
304 !src_is_type(instr
->src
[src
].src
, var
->type
))
307 state
->variables_seen
|= (1 << var
->variable
);
308 state
->variables
[var
->variable
].src
= instr
->src
[src
].src
;
309 state
->variables
[var
->variable
].abs
= false;
310 state
->variables
[var
->variable
].negate
= false;
312 for (unsigned i
= 0; i
< NIR_MAX_VEC_COMPONENTS
; ++i
) {
313 if (i
< num_components
)
314 state
->variables
[var
->variable
].swizzle
[i
] = new_swizzle
[i
];
316 state
->variables
[var
->variable
].swizzle
[i
] = 0;
323 case nir_search_value_constant
: {
324 nir_search_constant
*const_val
= nir_search_value_as_constant(value
);
326 if (!nir_src_is_const(instr
->src
[src
].src
))
329 switch (const_val
->type
) {
330 case nir_type_float
: {
331 nir_load_const_instr
*const load
=
332 nir_instr_as_load_const(instr
->src
[src
].src
.ssa
->parent_instr
);
334 /* There are 8-bit and 1-bit integer types, but there are no 8-bit or
335 * 1-bit float types. This prevents potential assertion failures in
336 * nir_src_comp_as_float.
338 if (load
->def
.bit_size
< 16)
341 for (unsigned i
= 0; i
< num_components
; ++i
) {
342 double val
= nir_src_comp_as_float(instr
->src
[src
].src
,
344 if (val
!= const_val
->data
.d
)
352 case nir_type_bool
: {
353 unsigned bit_size
= nir_src_bit_size(instr
->src
[src
].src
);
354 uint64_t mask
= bit_size
== 64 ? UINT64_MAX
: (1ull << bit_size
) - 1;
355 for (unsigned i
= 0; i
< num_components
; ++i
) {
356 uint64_t val
= nir_src_comp_as_uint(instr
->src
[src
].src
,
358 if ((val
& mask
) != (const_val
->data
.u
& mask
))
365 unreachable("Invalid alu source type");
370 unreachable("Invalid search value type");
375 match_expression(const nir_search_expression
*expr
, nir_alu_instr
*instr
,
376 unsigned num_components
, const uint8_t *swizzle
,
377 struct match_state
*state
)
379 if (expr
->cond
&& !expr
->cond(instr
))
382 if (!nir_op_matches_search_op(instr
->op
, expr
->opcode
))
385 assert(instr
->dest
.dest
.is_ssa
);
387 if (expr
->value
.bit_size
> 0 &&
388 instr
->dest
.dest
.ssa
.bit_size
!= expr
->value
.bit_size
)
391 state
->inexact_match
= expr
->inexact
|| state
->inexact_match
;
392 state
->has_exact_alu
= instr
->exact
|| state
->has_exact_alu
;
393 if (state
->inexact_match
&& state
->has_exact_alu
)
396 assert(!instr
->dest
.saturate
);
397 assert(nir_op_infos
[instr
->op
].num_inputs
> 0);
399 /* If we have an explicitly sized destination, we can only handle the
400 * identity swizzle. While dot(vec3(a, b, c).zxy) is a valid
401 * expression, we don't have the information right now to propagate that
402 * swizzle through. We can only properly propagate swizzles if the
403 * instruction is vectorized.
405 if (nir_op_infos
[instr
->op
].output_size
!= 0) {
406 for (unsigned i
= 0; i
< num_components
; i
++) {
412 /* If this is a commutative expression and it's one of the first few, look
413 * up its direction for the current search operation. We'll use that value
414 * to possibly flip the sources for the match.
416 unsigned comm_op_flip
=
417 (expr
->comm_expr_idx
>= 0 &&
418 expr
->comm_expr_idx
< NIR_SEARCH_MAX_COMM_OPS
) ?
419 ((state
->comm_op_direction
>> expr
->comm_expr_idx
) & 1) : 0;
422 for (unsigned i
= 0; i
< nir_op_infos
[instr
->op
].num_inputs
; i
++) {
423 /* 2src_commutative instructions that have 3 sources are only commutative
424 * in the first two sources. Source 2 is always source 2.
426 if (!match_value(expr
->srcs
[i
], instr
,
427 i
< 2 ? i
^ comm_op_flip
: i
,
428 num_components
, swizzle
, state
)) {
438 replace_bitsize(const nir_search_value
*value
, unsigned search_bitsize
,
439 struct match_state
*state
)
441 if (value
->bit_size
> 0)
442 return value
->bit_size
;
443 if (value
->bit_size
< 0)
444 return nir_src_bit_size(state
->variables
[-value
->bit_size
- 1].src
);
445 return search_bitsize
;
449 construct_value(nir_builder
*build
,
450 const nir_search_value
*value
,
451 unsigned num_components
, unsigned search_bitsize
,
452 struct match_state
*state
,
455 switch (value
->type
) {
456 case nir_search_value_expression
: {
457 const nir_search_expression
*expr
= nir_search_value_as_expression(value
);
458 unsigned dst_bit_size
= replace_bitsize(value
, search_bitsize
, state
);
459 nir_op op
= nir_op_for_search_op(expr
->opcode
, dst_bit_size
);
461 if (nir_op_infos
[op
].output_size
!= 0)
462 num_components
= nir_op_infos
[op
].output_size
;
464 nir_alu_instr
*alu
= nir_alu_instr_create(build
->shader
, op
);
465 nir_ssa_dest_init(&alu
->instr
, &alu
->dest
.dest
, num_components
,
467 alu
->dest
.write_mask
= (1 << num_components
) - 1;
468 alu
->dest
.saturate
= false;
470 /* We have no way of knowing what values in a given search expression
471 * map to a particular replacement value. Therefore, if the
472 * expression we are replacing has any exact values, the entire
473 * replacement should be exact.
475 alu
->exact
= state
->has_exact_alu
;
477 for (unsigned i
= 0; i
< nir_op_infos
[op
].num_inputs
; i
++) {
478 /* If the source is an explicitly sized source, then we need to reset
479 * the number of components to match.
481 if (nir_op_infos
[alu
->op
].input_sizes
[i
] != 0)
482 num_components
= nir_op_infos
[alu
->op
].input_sizes
[i
];
484 alu
->src
[i
] = construct_value(build
, expr
->srcs
[i
],
485 num_components
, search_bitsize
,
489 nir_builder_instr_insert(build
, &alu
->instr
);
492 val
.src
= nir_src_for_ssa(&alu
->dest
.dest
.ssa
);
495 memcpy(val
.swizzle
, identity_swizzle
, sizeof val
.swizzle
);
500 case nir_search_value_variable
: {
501 const nir_search_variable
*var
= nir_search_value_as_variable(value
);
502 assert(state
->variables_seen
& (1 << var
->variable
));
504 nir_alu_src val
= { NIR_SRC_INIT
};
505 nir_alu_src_copy(&val
, &state
->variables
[var
->variable
],
506 (void *)build
->shader
);
507 assert(!var
->is_constant
);
512 case nir_search_value_constant
: {
513 const nir_search_constant
*c
= nir_search_value_as_constant(value
);
514 unsigned bit_size
= replace_bitsize(value
, search_bitsize
, state
);
519 cval
= nir_imm_floatN_t(build
, c
->data
.d
, bit_size
);
524 cval
= nir_imm_intN_t(build
, c
->data
.i
, bit_size
);
528 cval
= nir_imm_boolN_t(build
, c
->data
.u
, bit_size
);
532 unreachable("Invalid alu source type");
536 val
.src
= nir_src_for_ssa(cval
);
539 memset(val
.swizzle
, 0, sizeof val
.swizzle
);
545 unreachable("Invalid search value type");
549 MAYBE_UNUSED
static void dump_value(const nir_search_value
*val
)
552 case nir_search_value_constant
: {
553 const nir_search_constant
*sconst
= nir_search_value_as_constant(val
);
554 switch (sconst
->type
) {
556 printf("%f", sconst
->data
.d
);
559 printf("%"PRId64
, sconst
->data
.i
);
562 printf("0x%"PRIx64
, sconst
->data
.u
);
565 printf("%s", sconst
->data
.u
!= 0 ? "True" : "False");
568 unreachable("bad const type");
573 case nir_search_value_variable
: {
574 const nir_search_variable
*var
= nir_search_value_as_variable(val
);
575 if (var
->is_constant
)
577 printf("%c", var
->variable
+ 'a');
581 case nir_search_value_expression
: {
582 const nir_search_expression
*expr
= nir_search_value_as_expression(val
);
586 switch (expr
->opcode
) {
588 case nir_search_op_##n: printf(#n); break;
598 printf("%s", nir_op_infos
[expr
->opcode
].name
);
601 unsigned num_srcs
= 1;
602 if (expr
->opcode
<= nir_last_opcode
)
603 num_srcs
= nir_op_infos
[expr
->opcode
].num_inputs
;
605 for (unsigned i
= 0; i
< num_srcs
; i
++) {
607 dump_value(expr
->srcs
[i
]);
615 if (val
->bit_size
> 0)
616 printf("@%d", val
->bit_size
);
620 nir_replace_instr(nir_builder
*build
, nir_alu_instr
*instr
,
621 const nir_search_expression
*search
,
622 const nir_search_value
*replace
)
624 uint8_t swizzle
[NIR_MAX_VEC_COMPONENTS
] = { 0 };
626 for (unsigned i
= 0; i
< instr
->dest
.dest
.ssa
.num_components
; ++i
)
629 assert(instr
->dest
.dest
.is_ssa
);
631 struct match_state state
;
632 state
.inexact_match
= false;
633 state
.has_exact_alu
= false;
635 STATIC_ASSERT(sizeof(state
.comm_op_direction
) * 8 >= NIR_SEARCH_MAX_COMM_OPS
);
637 unsigned comm_expr_combinations
=
638 1 << MIN2(search
->comm_exprs
, NIR_SEARCH_MAX_COMM_OPS
);
641 for (unsigned comb
= 0; comb
< comm_expr_combinations
; comb
++) {
642 /* The bitfield of directions is just the current iteration. Hooray for
645 state
.comm_op_direction
= comb
;
646 state
.variables_seen
= 0;
648 if (match_expression(search
, instr
,
649 instr
->dest
.dest
.ssa
.num_components
,
660 dump_value(&search
->value
);
663 printf(" ssa_%d\n", instr
->dest
.dest
.ssa
.index
);
666 build
->cursor
= nir_before_instr(&instr
->instr
);
668 nir_alu_src val
= construct_value(build
, replace
,
669 instr
->dest
.dest
.ssa
.num_components
,
670 instr
->dest
.dest
.ssa
.bit_size
,
671 &state
, &instr
->instr
);
673 /* Inserting a mov may be unnecessary. However, it's much easier to
674 * simply let copy propagation clean this up than to try to go through
675 * and rewrite swizzles ourselves.
677 nir_ssa_def
*ssa_val
=
678 nir_mov_alu(build
, val
, instr
->dest
.dest
.ssa
.num_components
);
679 nir_ssa_def_rewrite_uses(&instr
->dest
.dest
.ssa
, nir_src_for_ssa(ssa_val
));
681 /* We know this one has no more uses because we just rewrote them all,
682 * so we can remove it. The rest of the matched expression, however, we
683 * don't know so much about. We'll just let dead code clean them up.
685 nir_instr_remove(&instr
->instr
);