2 * Copyright © 2014 Intel Corporation
4 * Permission is hereby granted, free of charge, to any person obtaining a
5 * copy of this software and associated documentation files (the "Software"),
6 * to deal in the Software without restriction, including without limitation
7 * the rights to use, copy, modify, merge, publish, distribute, sublicense,
8 * and/or sell copies of the Software, and to permit persons to whom the
9 * Software is furnished to do so, subject to the following conditions:
11 * The above copyright notice and this permission notice (including the next
12 * paragraph) shall be included in all copies or substantial portions of the
15 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
18 * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
20 * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
24 * Jason Ekstrand (jason@jlekstrand.net)
29 #include "nir_search.h"
30 #include "nir_builder.h"
31 #include "nir_worklist.h"
32 #include "util/half_float.h"
34 /* This should be the same as nir_search_max_comm_ops in nir_algebraic.py. */
35 #define NIR_SEARCH_MAX_COMM_OPS 8
40 uint8_t comm_op_direction
;
41 unsigned variables_seen
;
43 /* Used for running the automaton on newly-constructed instructions. */
44 struct util_dynarray
*states
;
45 const struct per_op_table
*pass_op_table
;
47 nir_alu_src variables
[NIR_SEARCH_MAX_VARIABLES
];
48 struct hash_table
*range_ht
;
52 match_expression(const nir_search_expression
*expr
, nir_alu_instr
*instr
,
53 unsigned num_components
, const uint8_t *swizzle
,
54 struct match_state
*state
);
56 nir_algebraic_automaton(nir_instr
*instr
, struct util_dynarray
*states
,
57 const struct per_op_table
*pass_op_table
);
59 static const uint8_t identity_swizzle
[NIR_MAX_VEC_COMPONENTS
] =
68 * Check if a source produces a value of the given type.
70 * Used for satisfying 'a@type' constraints.
73 src_is_type(nir_src src
, nir_alu_type type
)
75 assert(type
!= nir_type_invalid
);
80 if (src
.ssa
->parent_instr
->type
== nir_instr_type_alu
) {
81 nir_alu_instr
*src_alu
= nir_instr_as_alu(src
.ssa
->parent_instr
);
82 nir_alu_type output_type
= nir_op_infos
[src_alu
->op
].output_type
;
84 if (type
== nir_type_bool
) {
85 switch (src_alu
->op
) {
89 return src_is_type(src_alu
->src
[0].src
, nir_type_bool
) &&
90 src_is_type(src_alu
->src
[1].src
, nir_type_bool
);
92 return src_is_type(src_alu
->src
[0].src
, nir_type_bool
);
98 return nir_alu_type_get_base_type(output_type
) == type
;
99 } else if (src
.ssa
->parent_instr
->type
== nir_instr_type_intrinsic
) {
100 nir_intrinsic_instr
*intr
= nir_instr_as_intrinsic(src
.ssa
->parent_instr
);
102 if (type
== nir_type_bool
) {
103 return intr
->intrinsic
== nir_intrinsic_load_front_face
||
104 intr
->intrinsic
== nir_intrinsic_load_helper_invocation
;
113 nir_op_matches_search_op(nir_op nop
, uint16_t sop
)
115 if (sop
<= nir_last_opcode
)
118 #define MATCH_FCONV_CASE(op) \
119 case nir_search_op_##op: \
120 return nop == nir_op_##op##16 || \
121 nop == nir_op_##op##32 || \
122 nop == nir_op_##op##64;
124 #define MATCH_ICONV_CASE(op) \
125 case nir_search_op_##op: \
126 return nop == nir_op_##op##8 || \
127 nop == nir_op_##op##16 || \
128 nop == nir_op_##op##32 || \
129 nop == nir_op_##op##64;
131 #define MATCH_BCONV_CASE(op) \
132 case nir_search_op_##op: \
133 return nop == nir_op_##op##1 || \
134 nop == nir_op_##op##32;
137 MATCH_FCONV_CASE(i2f
)
138 MATCH_FCONV_CASE(u2f
)
139 MATCH_FCONV_CASE(f2f
)
140 MATCH_ICONV_CASE(f2u
)
141 MATCH_ICONV_CASE(f2i
)
142 MATCH_ICONV_CASE(u2u
)
143 MATCH_ICONV_CASE(i2i
)
144 MATCH_FCONV_CASE(b2f
)
145 MATCH_ICONV_CASE(b2i
)
146 MATCH_BCONV_CASE(i2b
)
147 MATCH_BCONV_CASE(f2b
)
149 unreachable("Invalid nir_search_op");
152 #undef MATCH_FCONV_CASE
153 #undef MATCH_ICONV_CASE
154 #undef MATCH_BCONV_CASE
158 nir_search_op_for_nir_op(nir_op nop
)
160 #define MATCH_FCONV_CASE(op) \
161 case nir_op_##op##16: \
162 case nir_op_##op##32: \
163 case nir_op_##op##64: \
164 return nir_search_op_##op;
166 #define MATCH_ICONV_CASE(op) \
167 case nir_op_##op##8: \
168 case nir_op_##op##16: \
169 case nir_op_##op##32: \
170 case nir_op_##op##64: \
171 return nir_search_op_##op;
173 #define MATCH_BCONV_CASE(op) \
174 case nir_op_##op##1: \
175 case nir_op_##op##32: \
176 return nir_search_op_##op;
180 MATCH_FCONV_CASE(i2f
)
181 MATCH_FCONV_CASE(u2f
)
182 MATCH_FCONV_CASE(f2f
)
183 MATCH_ICONV_CASE(f2u
)
184 MATCH_ICONV_CASE(f2i
)
185 MATCH_ICONV_CASE(u2u
)
186 MATCH_ICONV_CASE(i2i
)
187 MATCH_FCONV_CASE(b2f
)
188 MATCH_ICONV_CASE(b2i
)
189 MATCH_BCONV_CASE(i2b
)
190 MATCH_BCONV_CASE(f2b
)
195 #undef MATCH_FCONV_CASE
196 #undef MATCH_ICONV_CASE
197 #undef MATCH_BCONV_CASE
201 nir_op_for_search_op(uint16_t sop
, unsigned bit_size
)
203 if (sop
<= nir_last_opcode
)
206 #define RET_FCONV_CASE(op) \
207 case nir_search_op_##op: \
208 switch (bit_size) { \
209 case 16: return nir_op_##op##16; \
210 case 32: return nir_op_##op##32; \
211 case 64: return nir_op_##op##64; \
212 default: unreachable("Invalid bit size"); \
215 #define RET_ICONV_CASE(op) \
216 case nir_search_op_##op: \
217 switch (bit_size) { \
218 case 8: return nir_op_##op##8; \
219 case 16: return nir_op_##op##16; \
220 case 32: return nir_op_##op##32; \
221 case 64: return nir_op_##op##64; \
222 default: unreachable("Invalid bit size"); \
225 #define RET_BCONV_CASE(op) \
226 case nir_search_op_##op: \
227 switch (bit_size) { \
228 case 1: return nir_op_##op##1; \
229 case 32: return nir_op_##op##32; \
230 default: unreachable("Invalid bit size"); \
246 unreachable("Invalid nir_search_op");
249 #undef RET_FCONV_CASE
250 #undef RET_ICONV_CASE
251 #undef RET_BCONV_CASE
255 match_value(const nir_search_value
*value
, nir_alu_instr
*instr
, unsigned src
,
256 unsigned num_components
, const uint8_t *swizzle
,
257 struct match_state
*state
)
259 uint8_t new_swizzle
[NIR_MAX_VEC_COMPONENTS
];
261 /* Searching only works on SSA values because, if it's not SSA, we can't
262 * know if the value changed between one instance of that value in the
263 * expression and another. Also, the replace operation will place reads of
264 * that value right before the last instruction in the expression we're
265 * replacing so those reads will happen after the original reads and may
266 * not be valid if they're register reads.
268 assert(instr
->src
[src
].src
.is_ssa
);
270 /* If the source is an explicitly sized source, then we need to reset
271 * both the number of components and the swizzle.
273 if (nir_op_infos
[instr
->op
].input_sizes
[src
] != 0) {
274 num_components
= nir_op_infos
[instr
->op
].input_sizes
[src
];
275 swizzle
= identity_swizzle
;
278 for (unsigned i
= 0; i
< num_components
; ++i
)
279 new_swizzle
[i
] = instr
->src
[src
].swizzle
[swizzle
[i
]];
281 /* If the value has a specific bit size and it doesn't match, bail */
282 if (value
->bit_size
> 0 &&
283 nir_src_bit_size(instr
->src
[src
].src
) != value
->bit_size
)
286 switch (value
->type
) {
287 case nir_search_value_expression
:
288 if (instr
->src
[src
].src
.ssa
->parent_instr
->type
!= nir_instr_type_alu
)
291 return match_expression(nir_search_value_as_expression(value
),
292 nir_instr_as_alu(instr
->src
[src
].src
.ssa
->parent_instr
),
293 num_components
, new_swizzle
, state
);
295 case nir_search_value_variable
: {
296 nir_search_variable
*var
= nir_search_value_as_variable(value
);
297 assert(var
->variable
< NIR_SEARCH_MAX_VARIABLES
);
299 if (state
->variables_seen
& (1 << var
->variable
)) {
300 if (state
->variables
[var
->variable
].src
.ssa
!= instr
->src
[src
].src
.ssa
)
303 assert(!instr
->src
[src
].abs
&& !instr
->src
[src
].negate
);
305 for (unsigned i
= 0; i
< num_components
; ++i
) {
306 if (state
->variables
[var
->variable
].swizzle
[i
] != new_swizzle
[i
])
312 if (var
->is_constant
&&
313 instr
->src
[src
].src
.ssa
->parent_instr
->type
!= nir_instr_type_load_const
)
316 if (var
->cond
&& !var
->cond(state
->range_ht
, instr
,
317 src
, num_components
, new_swizzle
))
320 if (var
->type
!= nir_type_invalid
&&
321 !src_is_type(instr
->src
[src
].src
, var
->type
))
324 state
->variables_seen
|= (1 << var
->variable
);
325 state
->variables
[var
->variable
].src
= instr
->src
[src
].src
;
326 state
->variables
[var
->variable
].abs
= false;
327 state
->variables
[var
->variable
].negate
= false;
329 for (unsigned i
= 0; i
< NIR_MAX_VEC_COMPONENTS
; ++i
) {
330 if (i
< num_components
)
331 state
->variables
[var
->variable
].swizzle
[i
] = new_swizzle
[i
];
333 state
->variables
[var
->variable
].swizzle
[i
] = 0;
340 case nir_search_value_constant
: {
341 nir_search_constant
*const_val
= nir_search_value_as_constant(value
);
343 if (!nir_src_is_const(instr
->src
[src
].src
))
346 switch (const_val
->type
) {
347 case nir_type_float
: {
348 nir_load_const_instr
*const load
=
349 nir_instr_as_load_const(instr
->src
[src
].src
.ssa
->parent_instr
);
351 /* There are 8-bit and 1-bit integer types, but there are no 8-bit or
352 * 1-bit float types. This prevents potential assertion failures in
353 * nir_src_comp_as_float.
355 if (load
->def
.bit_size
< 16)
358 for (unsigned i
= 0; i
< num_components
; ++i
) {
359 double val
= nir_src_comp_as_float(instr
->src
[src
].src
,
361 if (val
!= const_val
->data
.d
)
369 case nir_type_bool
: {
370 unsigned bit_size
= nir_src_bit_size(instr
->src
[src
].src
);
371 uint64_t mask
= bit_size
== 64 ? UINT64_MAX
: (1ull << bit_size
) - 1;
372 for (unsigned i
= 0; i
< num_components
; ++i
) {
373 uint64_t val
= nir_src_comp_as_uint(instr
->src
[src
].src
,
375 if ((val
& mask
) != (const_val
->data
.u
& mask
))
382 unreachable("Invalid alu source type");
387 unreachable("Invalid search value type");
392 match_expression(const nir_search_expression
*expr
, nir_alu_instr
*instr
,
393 unsigned num_components
, const uint8_t *swizzle
,
394 struct match_state
*state
)
396 if (expr
->cond
&& !expr
->cond(instr
))
399 if (!nir_op_matches_search_op(instr
->op
, expr
->opcode
))
402 assert(instr
->dest
.dest
.is_ssa
);
404 if (expr
->value
.bit_size
> 0 &&
405 instr
->dest
.dest
.ssa
.bit_size
!= expr
->value
.bit_size
)
408 state
->inexact_match
= expr
->inexact
|| state
->inexact_match
;
409 state
->has_exact_alu
= instr
->exact
|| state
->has_exact_alu
;
410 if (state
->inexact_match
&& state
->has_exact_alu
)
413 assert(!instr
->dest
.saturate
);
414 assert(nir_op_infos
[instr
->op
].num_inputs
> 0);
416 /* If we have an explicitly sized destination, we can only handle the
417 * identity swizzle. While dot(vec3(a, b, c).zxy) is a valid
418 * expression, we don't have the information right now to propagate that
419 * swizzle through. We can only properly propagate swizzles if the
420 * instruction is vectorized.
422 if (nir_op_infos
[instr
->op
].output_size
!= 0) {
423 for (unsigned i
= 0; i
< num_components
; i
++) {
429 /* If this is a commutative expression and it's one of the first few, look
430 * up its direction for the current search operation. We'll use that value
431 * to possibly flip the sources for the match.
433 unsigned comm_op_flip
=
434 (expr
->comm_expr_idx
>= 0 &&
435 expr
->comm_expr_idx
< NIR_SEARCH_MAX_COMM_OPS
) ?
436 ((state
->comm_op_direction
>> expr
->comm_expr_idx
) & 1) : 0;
439 for (unsigned i
= 0; i
< nir_op_infos
[instr
->op
].num_inputs
; i
++) {
440 /* 2src_commutative instructions that have 3 sources are only commutative
441 * in the first two sources. Source 2 is always source 2.
443 if (!match_value(expr
->srcs
[i
], instr
,
444 i
< 2 ? i
^ comm_op_flip
: i
,
445 num_components
, swizzle
, state
)) {
455 replace_bitsize(const nir_search_value
*value
, unsigned search_bitsize
,
456 struct match_state
*state
)
458 if (value
->bit_size
> 0)
459 return value
->bit_size
;
460 if (value
->bit_size
< 0)
461 return nir_src_bit_size(state
->variables
[-value
->bit_size
- 1].src
);
462 return search_bitsize
;
466 construct_value(nir_builder
*build
,
467 const nir_search_value
*value
,
468 unsigned num_components
, unsigned search_bitsize
,
469 struct match_state
*state
,
472 switch (value
->type
) {
473 case nir_search_value_expression
: {
474 const nir_search_expression
*expr
= nir_search_value_as_expression(value
);
475 unsigned dst_bit_size
= replace_bitsize(value
, search_bitsize
, state
);
476 nir_op op
= nir_op_for_search_op(expr
->opcode
, dst_bit_size
);
478 if (nir_op_infos
[op
].output_size
!= 0)
479 num_components
= nir_op_infos
[op
].output_size
;
481 nir_alu_instr
*alu
= nir_alu_instr_create(build
->shader
, op
);
482 nir_ssa_dest_init(&alu
->instr
, &alu
->dest
.dest
, num_components
,
484 alu
->dest
.write_mask
= (1 << num_components
) - 1;
485 alu
->dest
.saturate
= false;
487 /* We have no way of knowing what values in a given search expression
488 * map to a particular replacement value. Therefore, if the
489 * expression we are replacing has any exact values, the entire
490 * replacement should be exact.
492 alu
->exact
= state
->has_exact_alu
|| expr
->exact
;
494 for (unsigned i
= 0; i
< nir_op_infos
[op
].num_inputs
; i
++) {
495 /* If the source is an explicitly sized source, then we need to reset
496 * the number of components to match.
498 if (nir_op_infos
[alu
->op
].input_sizes
[i
] != 0)
499 num_components
= nir_op_infos
[alu
->op
].input_sizes
[i
];
501 alu
->src
[i
] = construct_value(build
, expr
->srcs
[i
],
502 num_components
, search_bitsize
,
506 nir_builder_instr_insert(build
, &alu
->instr
);
508 assert(alu
->dest
.dest
.ssa
.index
==
509 util_dynarray_num_elements(state
->states
, uint16_t));
510 util_dynarray_append(state
->states
, uint16_t, 0);
511 nir_algebraic_automaton(&alu
->instr
, state
->states
, state
->pass_op_table
);
514 val
.src
= nir_src_for_ssa(&alu
->dest
.dest
.ssa
);
517 memcpy(val
.swizzle
, identity_swizzle
, sizeof val
.swizzle
);
522 case nir_search_value_variable
: {
523 const nir_search_variable
*var
= nir_search_value_as_variable(value
);
524 assert(state
->variables_seen
& (1 << var
->variable
));
526 nir_alu_src val
= { NIR_SRC_INIT
};
527 nir_alu_src_copy(&val
, &state
->variables
[var
->variable
],
528 (void *)build
->shader
);
529 assert(!var
->is_constant
);
531 for (unsigned i
= 0; i
< NIR_MAX_VEC_COMPONENTS
; i
++)
532 val
.swizzle
[i
] = state
->variables
[var
->variable
].swizzle
[var
->swizzle
[i
]];
537 case nir_search_value_constant
: {
538 const nir_search_constant
*c
= nir_search_value_as_constant(value
);
539 unsigned bit_size
= replace_bitsize(value
, search_bitsize
, state
);
544 cval
= nir_imm_floatN_t(build
, c
->data
.d
, bit_size
);
549 cval
= nir_imm_intN_t(build
, c
->data
.i
, bit_size
);
553 cval
= nir_imm_boolN_t(build
, c
->data
.u
, bit_size
);
557 unreachable("Invalid alu source type");
560 assert(cval
->index
==
561 util_dynarray_num_elements(state
->states
, uint16_t));
562 util_dynarray_append(state
->states
, uint16_t, 0);
563 nir_algebraic_automaton(cval
->parent_instr
, state
->states
,
564 state
->pass_op_table
);
567 val
.src
= nir_src_for_ssa(cval
);
570 memset(val
.swizzle
, 0, sizeof val
.swizzle
);
576 unreachable("Invalid search value type");
580 UNUSED
static void dump_value(const nir_search_value
*val
)
583 case nir_search_value_constant
: {
584 const nir_search_constant
*sconst
= nir_search_value_as_constant(val
);
585 switch (sconst
->type
) {
587 fprintf(stderr
, "%f", sconst
->data
.d
);
590 fprintf(stderr
, "%"PRId64
, sconst
->data
.i
);
593 fprintf(stderr
, "0x%"PRIx64
, sconst
->data
.u
);
596 fprintf(stderr
, "%s", sconst
->data
.u
!= 0 ? "True" : "False");
599 unreachable("bad const type");
604 case nir_search_value_variable
: {
605 const nir_search_variable
*var
= nir_search_value_as_variable(val
);
606 if (var
->is_constant
)
607 fprintf(stderr
, "#");
608 fprintf(stderr
, "%c", var
->variable
+ 'a');
612 case nir_search_value_expression
: {
613 const nir_search_expression
*expr
= nir_search_value_as_expression(val
);
614 fprintf(stderr
, "(");
616 fprintf(stderr
, "~");
617 switch (expr
->opcode
) {
619 case nir_search_op_##n: fprintf(stderr, #n); break;
629 fprintf(stderr
, "%s", nir_op_infos
[expr
->opcode
].name
);
632 unsigned num_srcs
= 1;
633 if (expr
->opcode
<= nir_last_opcode
)
634 num_srcs
= nir_op_infos
[expr
->opcode
].num_inputs
;
636 for (unsigned i
= 0; i
< num_srcs
; i
++) {
637 fprintf(stderr
, " ");
638 dump_value(expr
->srcs
[i
]);
641 fprintf(stderr
, ")");
646 if (val
->bit_size
> 0)
647 fprintf(stderr
, "@%d", val
->bit_size
);
651 add_uses_to_worklist(nir_instr
*instr
, nir_instr_worklist
*worklist
)
653 nir_ssa_def
*def
= nir_instr_ssa_def(instr
);
655 nir_foreach_use_safe(use_src
, def
) {
656 nir_instr_worklist_push_tail(worklist
, use_src
->parent_instr
);
661 nir_algebraic_update_automaton(nir_instr
*new_instr
,
662 nir_instr_worklist
*algebraic_worklist
,
663 struct util_dynarray
*states
,
664 const struct per_op_table
*pass_op_table
)
667 nir_instr_worklist
*automaton_worklist
= nir_instr_worklist_create();
669 /* Walk through the tree of uses of our new instruction's SSA value,
670 * recursively updating the automaton state until it stabilizes.
672 add_uses_to_worklist(new_instr
, automaton_worklist
);
675 while ((instr
= nir_instr_worklist_pop_head(automaton_worklist
))) {
676 if (nir_algebraic_automaton(instr
, states
, pass_op_table
)) {
677 nir_instr_worklist_push_tail(algebraic_worklist
, instr
);
679 add_uses_to_worklist(instr
, automaton_worklist
);
683 nir_instr_worklist_destroy(automaton_worklist
);
687 nir_replace_instr(nir_builder
*build
, nir_alu_instr
*instr
,
688 struct hash_table
*range_ht
,
689 struct util_dynarray
*states
,
690 const struct per_op_table
*pass_op_table
,
691 const nir_search_expression
*search
,
692 const nir_search_value
*replace
,
693 nir_instr_worklist
*algebraic_worklist
)
695 uint8_t swizzle
[NIR_MAX_VEC_COMPONENTS
] = { 0 };
697 for (unsigned i
= 0; i
< instr
->dest
.dest
.ssa
.num_components
; ++i
)
700 assert(instr
->dest
.dest
.is_ssa
);
702 struct match_state state
;
703 state
.inexact_match
= false;
704 state
.has_exact_alu
= false;
705 state
.range_ht
= range_ht
;
706 state
.pass_op_table
= pass_op_table
;
708 STATIC_ASSERT(sizeof(state
.comm_op_direction
) * 8 >= NIR_SEARCH_MAX_COMM_OPS
);
710 unsigned comm_expr_combinations
=
711 1 << MIN2(search
->comm_exprs
, NIR_SEARCH_MAX_COMM_OPS
);
714 for (unsigned comb
= 0; comb
< comm_expr_combinations
; comb
++) {
715 /* The bitfield of directions is just the current iteration. Hooray for
718 state
.comm_op_direction
= comb
;
719 state
.variables_seen
= 0;
721 if (match_expression(search
, instr
,
722 instr
->dest
.dest
.ssa
.num_components
,
732 fprintf(stderr
, "matched: ");
733 dump_value(&search
->value
);
734 fprintf(stderr
, " -> ");
736 fprintf(stderr
, " ssa_%d\n", instr
->dest
.dest
.ssa
.index
);
739 build
->cursor
= nir_before_instr(&instr
->instr
);
741 state
.states
= states
;
743 nir_alu_src val
= construct_value(build
, replace
,
744 instr
->dest
.dest
.ssa
.num_components
,
745 instr
->dest
.dest
.ssa
.bit_size
,
746 &state
, &instr
->instr
);
748 /* Note that NIR builder will elide the MOV if it's a no-op, which may
749 * allow more work to be done in a single pass through algebraic.
751 nir_ssa_def
*ssa_val
=
752 nir_mov_alu(build
, val
, instr
->dest
.dest
.ssa
.num_components
);
753 if (ssa_val
->index
== util_dynarray_num_elements(states
, uint16_t)) {
754 util_dynarray_append(states
, uint16_t, 0);
755 nir_algebraic_automaton(ssa_val
->parent_instr
, states
, pass_op_table
);
758 /* Rewrite the uses of the old SSA value to the new one, and recurse
759 * through the uses updating the automaton's state.
761 nir_ssa_def_rewrite_uses(&instr
->dest
.dest
.ssa
, nir_src_for_ssa(ssa_val
));
762 nir_algebraic_update_automaton(ssa_val
->parent_instr
, algebraic_worklist
,
763 states
, pass_op_table
);
765 /* Nothing uses the instr any more, so drop it out of the program. Note
766 * that the instr may be in the worklist still, so we can't free it
769 nir_instr_remove(&instr
->instr
);
775 nir_algebraic_automaton(nir_instr
*instr
, struct util_dynarray
*states
,
776 const struct per_op_table
*pass_op_table
)
778 switch (instr
->type
) {
779 case nir_instr_type_alu
: {
780 nir_alu_instr
*alu
= nir_instr_as_alu(instr
);
782 uint16_t search_op
= nir_search_op_for_nir_op(op
);
783 const struct per_op_table
*tbl
= &pass_op_table
[search_op
];
784 if (tbl
->num_filtered_states
== 0)
787 /* Calculate the index into the transition table. Note the index
788 * calculated must match the iteration order of Python's
789 * itertools.product(), which was used to emit the transition
793 for (unsigned i
= 0; i
< nir_op_infos
[op
].num_inputs
; i
++) {
794 index
*= tbl
->num_filtered_states
;
795 index
+= tbl
->filter
[*util_dynarray_element(states
, uint16_t,
796 alu
->src
[i
].src
.ssa
->index
)];
799 uint16_t *state
= util_dynarray_element(states
, uint16_t,
800 alu
->dest
.dest
.ssa
.index
);
801 if (*state
!= tbl
->table
[index
]) {
802 *state
= tbl
->table
[index
];
808 case nir_instr_type_load_const
: {
809 nir_load_const_instr
*load_const
= nir_instr_as_load_const(instr
);
810 uint16_t *state
= util_dynarray_element(states
, uint16_t,
811 load_const
->def
.index
);
812 if (*state
!= CONST_STATE
) {
813 *state
= CONST_STATE
;
825 nir_algebraic_instr(nir_builder
*build
, nir_instr
*instr
,
826 struct hash_table
*range_ht
,
827 const bool *condition_flags
,
828 const struct transform
**transforms
,
829 const uint16_t *transform_counts
,
830 struct util_dynarray
*states
,
831 const struct per_op_table
*pass_op_table
,
832 nir_instr_worklist
*worklist
)
835 if (instr
->type
!= nir_instr_type_alu
)
838 nir_alu_instr
*alu
= nir_instr_as_alu(instr
);
839 if (!alu
->dest
.dest
.is_ssa
)
842 unsigned bit_size
= alu
->dest
.dest
.ssa
.bit_size
;
843 const unsigned execution_mode
=
844 build
->shader
->info
.float_controls_execution_mode
;
845 const bool ignore_inexact
=
846 nir_is_float_control_signed_zero_inf_nan_preserve(execution_mode
, bit_size
) ||
847 nir_is_denorm_flush_to_zero(execution_mode
, bit_size
);
849 int xform_idx
= *util_dynarray_element(states
, uint16_t,
850 alu
->dest
.dest
.ssa
.index
);
851 for (uint16_t i
= 0; i
< transform_counts
[xform_idx
]; i
++) {
852 const struct transform
*xform
= &transforms
[xform_idx
][i
];
853 if (condition_flags
[xform
->condition_offset
] &&
854 !(xform
->search
->inexact
&& ignore_inexact
) &&
855 nir_replace_instr(build
, alu
, range_ht
, states
, pass_op_table
,
856 xform
->search
, xform
->replace
, worklist
)) {
857 _mesa_hash_table_clear(range_ht
, NULL
);
866 nir_algebraic_impl(nir_function_impl
*impl
,
867 const bool *condition_flags
,
868 const struct transform
**transforms
,
869 const uint16_t *transform_counts
,
870 const struct per_op_table
*pass_op_table
)
872 bool progress
= false;
875 nir_builder_init(&build
, impl
);
877 /* Note: it's important here that we're allocating a zeroed array, since
878 * state 0 is the default state, which means we don't have to visit
879 * anything other than constants and ALU instructions.
881 struct util_dynarray states
= {0};
882 if (!util_dynarray_resize(&states
, uint16_t, impl
->ssa_alloc
))
884 memset(states
.data
, 0, states
.size
);
886 struct hash_table
*range_ht
= _mesa_pointer_hash_table_create(NULL
);
888 nir_instr_worklist
*worklist
= nir_instr_worklist_create();
890 /* Walk top-to-bottom setting up the automaton state. */
891 nir_foreach_block(block
, impl
) {
892 nir_foreach_instr(instr
, block
) {
893 nir_algebraic_automaton(instr
, &states
, pass_op_table
);
897 /* Put our instrs in the worklist such that we're popping the last instr
898 * first. This will encourage us to match the biggest source patterns when
901 nir_foreach_block_reverse(block
, impl
) {
902 nir_foreach_instr_reverse(instr
, block
) {
903 nir_instr_worklist_push_tail(worklist
, instr
);
908 while ((instr
= nir_instr_worklist_pop_head(worklist
))) {
909 /* The worklist can have an instr pushed to it multiple times if it was
910 * the src of multiple instrs that also got optimized, so make sure that
911 * we don't try to re-optimize an instr we already handled.
913 if (exec_node_is_tail_sentinel(&instr
->node
))
916 progress
|= nir_algebraic_instr(&build
, instr
,
917 range_ht
, condition_flags
,
918 transforms
, transform_counts
, &states
,
919 pass_op_table
, worklist
);
922 nir_instr_worklist_destroy(worklist
);
923 ralloc_free(range_ht
);
924 util_dynarray_fini(&states
);
927 nir_metadata_preserve(impl
, nir_metadata_block_index
|
928 nir_metadata_dominance
);
931 impl
->valid_metadata
&= ~nir_metadata_not_properly_reset
;