2 * Copyright © 2014 Intel Corporation
4 * Permission is hereby granted, free of charge, to any person obtaining a
5 * copy of this software and associated documentation files (the "Software"),
6 * to deal in the Software without restriction, including without limitation
7 * the rights to use, copy, modify, merge, publish, distribute, sublicense,
8 * and/or sell copies of the Software, and to permit persons to whom the
9 * Software is furnished to do so, subject to the following conditions:
11 * The above copyright notice and this permission notice (including the next
12 * paragraph) shall be included in all copies or substantial portions of the
15 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
18 * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
20 * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
24 * Jason Ekstrand (jason@jlekstrand.net)
29 #include "nir_search.h"
30 #include "nir_builder.h"
31 #include "util/half_float.h"
36 unsigned variables_seen
;
37 nir_alu_src variables
[NIR_SEARCH_MAX_VARIABLES
];
41 match_expression(const nir_search_expression
*expr
, nir_alu_instr
*instr
,
42 unsigned num_components
, const uint8_t *swizzle
,
43 struct match_state
*state
);
45 static const uint8_t identity_swizzle
[NIR_MAX_VEC_COMPONENTS
] = { 0, 1, 2, 3 };
48 * Check if a source produces a value of the given type.
50 * Used for satisfying 'a@type' constraints.
53 src_is_type(nir_src src
, nir_alu_type type
)
55 assert(type
!= nir_type_invalid
);
60 if (src
.ssa
->parent_instr
->type
== nir_instr_type_alu
) {
61 nir_alu_instr
*src_alu
= nir_instr_as_alu(src
.ssa
->parent_instr
);
62 nir_alu_type output_type
= nir_op_infos
[src_alu
->op
].output_type
;
64 if (type
== nir_type_bool
) {
65 switch (src_alu
->op
) {
69 return src_is_type(src_alu
->src
[0].src
, nir_type_bool
) &&
70 src_is_type(src_alu
->src
[1].src
, nir_type_bool
);
72 return src_is_type(src_alu
->src
[0].src
, nir_type_bool
);
78 return nir_alu_type_get_base_type(output_type
) == type
;
79 } else if (src
.ssa
->parent_instr
->type
== nir_instr_type_intrinsic
) {
80 nir_intrinsic_instr
*intr
= nir_instr_as_intrinsic(src
.ssa
->parent_instr
);
82 if (type
== nir_type_bool
) {
83 return intr
->intrinsic
== nir_intrinsic_load_front_face
||
84 intr
->intrinsic
== nir_intrinsic_load_helper_invocation
;
93 nir_op_matches_search_op(nir_op nop
, uint16_t sop
)
95 if (sop
<= nir_last_opcode
)
98 #define MATCH_FCONV_CASE(op) \
99 case nir_search_op_##op: \
100 return nop == nir_op_##op##16 || \
101 nop == nir_op_##op##32 || \
102 nop == nir_op_##op##64;
104 #define MATCH_ICONV_CASE(op) \
105 case nir_search_op_##op: \
106 return nop == nir_op_##op##8 || \
107 nop == nir_op_##op##16 || \
108 nop == nir_op_##op##32 || \
109 nop == nir_op_##op##64;
111 #define MATCH_BCONV_CASE(op) \
112 case nir_search_op_##op: \
113 return nop == nir_op_##op##1 || \
114 nop == nir_op_##op##32;
117 MATCH_FCONV_CASE(i2f
)
118 MATCH_FCONV_CASE(u2f
)
119 MATCH_FCONV_CASE(f2f
)
120 MATCH_ICONV_CASE(f2u
)
121 MATCH_ICONV_CASE(f2i
)
122 MATCH_ICONV_CASE(u2u
)
123 MATCH_ICONV_CASE(i2i
)
124 MATCH_FCONV_CASE(b2f
)
125 MATCH_ICONV_CASE(b2i
)
126 MATCH_BCONV_CASE(i2b
)
127 MATCH_BCONV_CASE(f2b
)
129 unreachable("Invalid nir_search_op");
132 #undef MATCH_FCONV_CASE
133 #undef MATCH_ICONV_CASE
137 nir_op_for_search_op(uint16_t sop
, unsigned bit_size
)
139 if (sop
<= nir_last_opcode
)
142 #define RET_FCONV_CASE(op) \
143 case nir_search_op_##op: \
144 switch (bit_size) { \
145 case 16: return nir_op_##op##16; \
146 case 32: return nir_op_##op##32; \
147 case 64: return nir_op_##op##64; \
148 default: unreachable("Invalid bit size"); \
151 #define RET_ICONV_CASE(op) \
152 case nir_search_op_##op: \
153 switch (bit_size) { \
154 case 8: return nir_op_##op##8; \
155 case 16: return nir_op_##op##16; \
156 case 32: return nir_op_##op##32; \
157 case 64: return nir_op_##op##64; \
158 default: unreachable("Invalid bit size"); \
161 #define RET_BCONV_CASE(op) \
162 case nir_search_op_##op: \
163 switch (bit_size) { \
164 case 1: return nir_op_##op##1; \
165 case 32: return nir_op_##op##32; \
166 default: unreachable("Invalid bit size"); \
182 unreachable("Invalid nir_search_op");
185 #undef RET_FCONV_CASE
186 #undef RET_ICONV_CASE
190 match_value(const nir_search_value
*value
, nir_alu_instr
*instr
, unsigned src
,
191 unsigned num_components
, const uint8_t *swizzle
,
192 struct match_state
*state
)
194 uint8_t new_swizzle
[NIR_MAX_VEC_COMPONENTS
];
196 /* Searching only works on SSA values because, if it's not SSA, we can't
197 * know if the value changed between one instance of that value in the
198 * expression and another. Also, the replace operation will place reads of
199 * that value right before the last instruction in the expression we're
200 * replacing so those reads will happen after the original reads and may
201 * not be valid if they're register reads.
203 assert(instr
->src
[src
].src
.is_ssa
);
205 /* If the source is an explicitly sized source, then we need to reset
206 * both the number of components and the swizzle.
208 if (nir_op_infos
[instr
->op
].input_sizes
[src
] != 0) {
209 num_components
= nir_op_infos
[instr
->op
].input_sizes
[src
];
210 swizzle
= identity_swizzle
;
213 for (unsigned i
= 0; i
< num_components
; ++i
)
214 new_swizzle
[i
] = instr
->src
[src
].swizzle
[swizzle
[i
]];
216 /* If the value has a specific bit size and it doesn't match, bail */
217 if (value
->bit_size
> 0 &&
218 nir_src_bit_size(instr
->src
[src
].src
) != value
->bit_size
)
221 switch (value
->type
) {
222 case nir_search_value_expression
:
223 if (instr
->src
[src
].src
.ssa
->parent_instr
->type
!= nir_instr_type_alu
)
226 return match_expression(nir_search_value_as_expression(value
),
227 nir_instr_as_alu(instr
->src
[src
].src
.ssa
->parent_instr
),
228 num_components
, new_swizzle
, state
);
230 case nir_search_value_variable
: {
231 nir_search_variable
*var
= nir_search_value_as_variable(value
);
232 assert(var
->variable
< NIR_SEARCH_MAX_VARIABLES
);
234 if (state
->variables_seen
& (1 << var
->variable
)) {
235 if (state
->variables
[var
->variable
].src
.ssa
!= instr
->src
[src
].src
.ssa
)
238 assert(!instr
->src
[src
].abs
&& !instr
->src
[src
].negate
);
240 for (unsigned i
= 0; i
< num_components
; ++i
) {
241 if (state
->variables
[var
->variable
].swizzle
[i
] != new_swizzle
[i
])
247 if (var
->is_constant
&&
248 instr
->src
[src
].src
.ssa
->parent_instr
->type
!= nir_instr_type_load_const
)
251 if (var
->cond
&& !var
->cond(instr
, src
, num_components
, new_swizzle
))
254 if (var
->type
!= nir_type_invalid
&&
255 !src_is_type(instr
->src
[src
].src
, var
->type
))
258 state
->variables_seen
|= (1 << var
->variable
);
259 state
->variables
[var
->variable
].src
= instr
->src
[src
].src
;
260 state
->variables
[var
->variable
].abs
= false;
261 state
->variables
[var
->variable
].negate
= false;
263 for (unsigned i
= 0; i
< NIR_MAX_VEC_COMPONENTS
; ++i
) {
264 if (i
< num_components
)
265 state
->variables
[var
->variable
].swizzle
[i
] = new_swizzle
[i
];
267 state
->variables
[var
->variable
].swizzle
[i
] = 0;
274 case nir_search_value_constant
: {
275 nir_search_constant
*const_val
= nir_search_value_as_constant(value
);
277 if (!nir_src_is_const(instr
->src
[src
].src
))
280 switch (const_val
->type
) {
282 for (unsigned i
= 0; i
< num_components
; ++i
) {
283 double val
= nir_src_comp_as_float(instr
->src
[src
].src
,
285 if (val
!= const_val
->data
.d
)
292 case nir_type_bool
: {
293 unsigned bit_size
= nir_src_bit_size(instr
->src
[src
].src
);
294 uint64_t mask
= bit_size
== 64 ? UINT64_MAX
: (1ull << bit_size
) - 1;
295 for (unsigned i
= 0; i
< num_components
; ++i
) {
296 uint64_t val
= nir_src_comp_as_uint(instr
->src
[src
].src
,
298 if ((val
& mask
) != (const_val
->data
.u
& mask
))
305 unreachable("Invalid alu source type");
310 unreachable("Invalid search value type");
315 match_expression(const nir_search_expression
*expr
, nir_alu_instr
*instr
,
316 unsigned num_components
, const uint8_t *swizzle
,
317 struct match_state
*state
)
319 if (expr
->cond
&& !expr
->cond(instr
))
322 if (!nir_op_matches_search_op(instr
->op
, expr
->opcode
))
325 assert(instr
->dest
.dest
.is_ssa
);
327 if (expr
->value
.bit_size
> 0 &&
328 instr
->dest
.dest
.ssa
.bit_size
!= expr
->value
.bit_size
)
331 state
->inexact_match
= expr
->inexact
|| state
->inexact_match
;
332 state
->has_exact_alu
= instr
->exact
|| state
->has_exact_alu
;
333 if (state
->inexact_match
&& state
->has_exact_alu
)
336 assert(!instr
->dest
.saturate
);
337 assert(nir_op_infos
[instr
->op
].num_inputs
> 0);
339 /* If we have an explicitly sized destination, we can only handle the
340 * identity swizzle. While dot(vec3(a, b, c).zxy) is a valid
341 * expression, we don't have the information right now to propagate that
342 * swizzle through. We can only properly propagate swizzles if the
343 * instruction is vectorized.
345 if (nir_op_infos
[instr
->op
].output_size
!= 0) {
346 for (unsigned i
= 0; i
< num_components
; i
++) {
352 /* Stash off the current variables_seen bitmask. This way we can
353 * restore it prior to matching in the commutative case below.
355 unsigned variables_seen_stash
= state
->variables_seen
;
358 for (unsigned i
= 0; i
< nir_op_infos
[instr
->op
].num_inputs
; i
++) {
359 if (!match_value(expr
->srcs
[i
], instr
, i
, num_components
,
369 if (nir_op_infos
[instr
->op
].algebraic_properties
& NIR_OP_IS_COMMUTATIVE
) {
370 assert(nir_op_infos
[instr
->op
].num_inputs
== 2);
372 /* Restore the variables_seen bitmask. If we don't do this, then we
373 * could end up with an erroneous failure due to variables found in the
374 * first match attempt above not matching those in the second.
376 state
->variables_seen
= variables_seen_stash
;
378 if (!match_value(expr
->srcs
[0], instr
, 1, num_components
,
382 return match_value(expr
->srcs
[1], instr
, 0, num_components
,
390 replace_bitsize(const nir_search_value
*value
, unsigned search_bitsize
,
391 struct match_state
*state
)
393 if (value
->bit_size
> 0)
394 return value
->bit_size
;
395 if (value
->bit_size
< 0)
396 return nir_src_bit_size(state
->variables
[-value
->bit_size
- 1].src
);
397 return search_bitsize
;
401 construct_value(nir_builder
*build
,
402 const nir_search_value
*value
,
403 unsigned num_components
, unsigned search_bitsize
,
404 struct match_state
*state
,
407 switch (value
->type
) {
408 case nir_search_value_expression
: {
409 const nir_search_expression
*expr
= nir_search_value_as_expression(value
);
410 unsigned dst_bit_size
= replace_bitsize(value
, search_bitsize
, state
);
411 nir_op op
= nir_op_for_search_op(expr
->opcode
, dst_bit_size
);
413 if (nir_op_infos
[op
].output_size
!= 0)
414 num_components
= nir_op_infos
[op
].output_size
;
416 nir_alu_instr
*alu
= nir_alu_instr_create(build
->shader
, op
);
417 nir_ssa_dest_init(&alu
->instr
, &alu
->dest
.dest
, num_components
,
419 alu
->dest
.write_mask
= (1 << num_components
) - 1;
420 alu
->dest
.saturate
= false;
422 /* We have no way of knowing what values in a given search expression
423 * map to a particular replacement value. Therefore, if the
424 * expression we are replacing has any exact values, the entire
425 * replacement should be exact.
427 alu
->exact
= state
->has_exact_alu
;
429 for (unsigned i
= 0; i
< nir_op_infos
[op
].num_inputs
; i
++) {
430 /* If the source is an explicitly sized source, then we need to reset
431 * the number of components to match.
433 if (nir_op_infos
[alu
->op
].input_sizes
[i
] != 0)
434 num_components
= nir_op_infos
[alu
->op
].input_sizes
[i
];
436 alu
->src
[i
] = construct_value(build
, expr
->srcs
[i
],
437 num_components
, search_bitsize
,
441 nir_builder_instr_insert(build
, &alu
->instr
);
444 val
.src
= nir_src_for_ssa(&alu
->dest
.dest
.ssa
);
447 memcpy(val
.swizzle
, identity_swizzle
, sizeof val
.swizzle
);
452 case nir_search_value_variable
: {
453 const nir_search_variable
*var
= nir_search_value_as_variable(value
);
454 assert(state
->variables_seen
& (1 << var
->variable
));
456 nir_alu_src val
= { NIR_SRC_INIT
};
457 nir_alu_src_copy(&val
, &state
->variables
[var
->variable
],
458 (void *)build
->shader
);
459 assert(!var
->is_constant
);
464 case nir_search_value_constant
: {
465 const nir_search_constant
*c
= nir_search_value_as_constant(value
);
466 unsigned bit_size
= replace_bitsize(value
, search_bitsize
, state
);
471 cval
= nir_imm_floatN_t(build
, c
->data
.d
, bit_size
);
476 cval
= nir_imm_intN_t(build
, c
->data
.i
, bit_size
);
480 cval
= nir_imm_boolN_t(build
, c
->data
.u
, bit_size
);
484 unreachable("Invalid alu source type");
488 val
.src
= nir_src_for_ssa(cval
);
491 memset(val
.swizzle
, 0, sizeof val
.swizzle
);
497 unreachable("Invalid search value type");
502 nir_replace_instr(nir_builder
*build
, nir_alu_instr
*instr
,
503 const nir_search_expression
*search
,
504 const nir_search_value
*replace
)
506 uint8_t swizzle
[NIR_MAX_VEC_COMPONENTS
] = { 0 };
508 for (unsigned i
= 0; i
< instr
->dest
.dest
.ssa
.num_components
; ++i
)
511 assert(instr
->dest
.dest
.is_ssa
);
513 struct match_state state
;
514 state
.inexact_match
= false;
515 state
.has_exact_alu
= false;
516 state
.variables_seen
= 0;
518 if (!match_expression(search
, instr
, instr
->dest
.dest
.ssa
.num_components
,
522 build
->cursor
= nir_before_instr(&instr
->instr
);
524 nir_alu_src val
= construct_value(build
, replace
,
525 instr
->dest
.dest
.ssa
.num_components
,
526 instr
->dest
.dest
.ssa
.bit_size
,
527 &state
, &instr
->instr
);
529 /* Inserting a mov may be unnecessary. However, it's much easier to
530 * simply let copy propagation clean this up than to try to go through
531 * and rewrite swizzles ourselves.
533 nir_ssa_def
*ssa_val
=
534 nir_imov_alu(build
, val
, instr
->dest
.dest
.ssa
.num_components
);
535 nir_ssa_def_rewrite_uses(&instr
->dest
.dest
.ssa
, nir_src_for_ssa(ssa_val
));
537 /* We know this one has no more uses because we just rewrote them all,
538 * so we can remove it. The rest of the matched expression, however, we
539 * don't know so much about. We'll just let dead code clean them up.
541 nir_instr_remove(&instr
->instr
);