#include "nir.h"
#include "nir_builder.h"
+struct alu_to_scalar_data {
+ nir_instr_filter_cb cb;
+ const void *data;
+};
+
/** @file nir_lower_alu_to_scalar.c
*
* Replaces nir_alu_instr operations with more than one channel used in the
assert(alu->dest.dest.is_ssa);
assert(alu->src[0].src.is_ssa);
return alu->dest.dest.ssa.num_components > 1 ||
- alu->src[0].src.ssa->num_components > 1;
+ nir_op_infos[alu->op].input_sizes[0] > 1;
}
static void
}
static nir_ssa_def *
-lower_alu_instr_scalar(nir_builder *b, nir_instr *instr, void *_state)
+lower_alu_instr_scalar(nir_builder *b, nir_instr *instr, void *_data)
{
- BITSET_WORD *lower_set = _state;
+ struct alu_to_scalar_data *data = _data;
nir_alu_instr *alu = nir_instr_as_alu(instr);
unsigned num_src = nir_op_infos[alu->op].num_inputs;
unsigned i, chan;
b->cursor = nir_before_instr(&alu->instr);
b->exact = alu->exact;
- if (lower_set && !BITSET_TEST(lower_set, alu->op))
+ if (data->cb && !data->cb(instr, data->data))
return NULL;
#define LOWER_REDUCTION(name, chan, merge) \
case name##2: \
case name##3: \
case name##4: \
+ case name##8: \
+ case name##16: \
return lower_reduction(alu, chan, merge, b); \
switch (alu->op) {
+ case nir_op_vec16:
+ case nir_op_vec8:
case nir_op_vec4:
case nir_op_vec3:
case nir_op_vec2:
*/
return NULL;
+ case nir_op_unpack_half_2x16_flush_to_zero:
case nir_op_unpack_half_2x16: {
if (!b->shader->options->lower_unpack_half_2x16)
return NULL;
nir_ssa_def *packed = nir_ssa_for_alu_src(b, alu, 0);
- return nir_vec2(b, nir_unpack_half_2x16_split_x(b, packed),
- nir_unpack_half_2x16_split_y(b, packed));
+ if (alu->op == nir_op_unpack_half_2x16_flush_to_zero) {
+ return nir_vec2(b,
+ nir_unpack_half_2x16_split_x_flush_to_zero(b,
+ packed),
+ nir_unpack_half_2x16_split_y_flush_to_zero(b,
+ packed));
+ } else {
+ return nir_vec2(b,
+ nir_unpack_half_2x16_split_x(b, packed),
+ nir_unpack_half_2x16_split_y(b, packed));
+ }
}
case nir_op_pack_uvec2_to_uint: {
}
case nir_op_unpack_64_2x32:
+ case nir_op_unpack_64_4x16:
case nir_op_unpack_32_2x16:
return NULL;
LOWER_REDUCTION(nir_op_fdot, nir_op_fmul, nir_op_fadd);
LOWER_REDUCTION(nir_op_ball_fequal, nir_op_feq, nir_op_iand);
LOWER_REDUCTION(nir_op_ball_iequal, nir_op_ieq, nir_op_iand);
- LOWER_REDUCTION(nir_op_bany_fnequal, nir_op_fne, nir_op_ior);
+ LOWER_REDUCTION(nir_op_bany_fnequal, nir_op_fneu, nir_op_ior);
LOWER_REDUCTION(nir_op_bany_inequal, nir_op_ine, nir_op_ior);
+ LOWER_REDUCTION(nir_op_b8all_fequal, nir_op_feq8, nir_op_iand);
+ LOWER_REDUCTION(nir_op_b8all_iequal, nir_op_ieq8, nir_op_iand);
+ LOWER_REDUCTION(nir_op_b8any_fnequal, nir_op_fneu8, nir_op_ior);
+ LOWER_REDUCTION(nir_op_b8any_inequal, nir_op_ine8, nir_op_ior);
+ LOWER_REDUCTION(nir_op_b16all_fequal, nir_op_feq16, nir_op_iand);
+ LOWER_REDUCTION(nir_op_b16all_iequal, nir_op_ieq16, nir_op_iand);
+ LOWER_REDUCTION(nir_op_b16any_fnequal, nir_op_fneu16, nir_op_ior);
+ LOWER_REDUCTION(nir_op_b16any_inequal, nir_op_ine16, nir_op_ior);
LOWER_REDUCTION(nir_op_b32all_fequal, nir_op_feq32, nir_op_iand);
LOWER_REDUCTION(nir_op_b32all_iequal, nir_op_ieq32, nir_op_iand);
- LOWER_REDUCTION(nir_op_b32any_fnequal, nir_op_fne32, nir_op_ior);
+ LOWER_REDUCTION(nir_op_b32any_fnequal, nir_op_fneu32, nir_op_ior);
LOWER_REDUCTION(nir_op_b32any_inequal, nir_op_ine32, nir_op_ior);
LOWER_REDUCTION(nir_op_fall_equal, nir_op_seq, nir_op_fmin);
LOWER_REDUCTION(nir_op_fany_nequal, nir_op_sne, nir_op_fmax);
}
bool
-nir_lower_alu_to_scalar(nir_shader *shader, BITSET_WORD *lower_set)
+nir_lower_alu_to_scalar(nir_shader *shader, nir_instr_filter_cb cb, const void *_data)
{
+ struct alu_to_scalar_data data = {
+ .cb = cb,
+ .data = _data,
+ };
+
return nir_shader_lower_instructions(shader,
inst_is_vector_alu,
lower_alu_instr_scalar,
- lower_set);
+ &data);
}