freedreno/a2xx: ir2: fix lowering of instructions after float lowering
[mesa.git] / src / gallium / drivers / freedreno / a2xx / ir2_nir.c
index 36f3a679ff6df3596f270f9a34b593bf0eac518a..bb3ca9c9c0ccd5bb6b3a61dabd4f25ad0bef7544 100644 (file)
@@ -39,7 +39,9 @@ static const nir_shader_compiler_options options = {
        /* .fdot_replicates = true, it is replicated, but it makes things worse */
        .lower_all_io_to_temps = true,
        .vertex_id_zero_based = true, /* its not implemented anyway */
-       .lower_bitshift = true,
+       .lower_bitops = true,
+       .lower_rotate = true,
+       .lower_vector_cmp = true,
 };
 
 const nir_shader_compiler_options *
@@ -120,7 +122,7 @@ ir2_optimize_nir(nir_shader *s, bool lower)
        ir2_optimize_loop(s);
 
        OPT_V(s, nir_remove_dead_variables, nir_var_function_temp);
-       OPT_V(s, nir_move_load_const);
+       OPT_V(s, nir_opt_sink, nir_move_const_undef);
 
        /* TODO we dont want to get shaders writing to depth for depth textures */
        if (s->info.stage == MESA_SHADER_FRAGMENT) {
@@ -285,11 +287,10 @@ instr_create_alu(struct ir2_context *ctx, nir_op opcode, unsigned ncomp)
                [0 ... nir_num_opcodes - 1] = {-1, -1},
 
                [nir_op_mov] = {MAXs, MAXv},
+               [nir_op_fneg] = {MAXs, MAXv},
+               [nir_op_fabs] = {MAXs, MAXv},
+               [nir_op_fsat] = {MAXs, MAXv},
                [nir_op_fsign] = {-1, CNDGTEv},
-               [nir_op_fnot] = {SETEs, SETEv},
-               [nir_op_for] = {MAXs, MAXv},
-               [nir_op_fand] = {MINs, MINv},
-               [nir_op_fxor] = {-1, SETNEv},
                [nir_op_fadd] = {ADDs, ADDv},
                [nir_op_fsub] = {ADDs, ADDv},
                [nir_op_fmul] = {MULs, MULv},
@@ -431,6 +432,15 @@ emit_alu(struct ir2_context *ctx, nir_alu_instr * alu)
 
        /* workarounds for NIR ops that don't map directly to a2xx ops */
        switch (alu->op) {
+       case nir_op_fneg:
+               instr->src[0].negate = 1;
+               break;
+       case nir_op_fabs:
+               instr->src[0].abs = 1;
+               break;
+       case nir_op_fsat:
+               instr->alu.saturate = 1;
+               break;
        case nir_op_slt:
                tmp = instr->src[0];
                instr->src[0] = instr->src[1];
@@ -719,7 +729,7 @@ emit_tex(struct ir2_context *ctx, nir_tex_instr * tex)
 
        instr = ir2_instr_create_fetch(ctx, &tex->dest, TEX_FETCH);
        instr->src[0] = src_coord;
-       instr->src[0].swizzle = is_cube ? IR2_SWIZZLE_XYW : 0;
+       instr->src[0].swizzle = is_cube ? IR2_SWIZZLE_YXW : 0;
        instr->fetch.tex.is_cube = is_cube;
        instr->fetch.tex.is_rect = is_rect;
        instr->fetch.tex.samp_id = tex->sampler_index;
@@ -1052,6 +1062,29 @@ static void cleanup_binning(struct ir2_context *ctx)
        ir2_optimize_nir(ctx->nir, false);
 }
 
+static bool
+ir2_alu_to_scalar_filter_cb(const nir_instr *instr, const void *data)
+{
+       if (instr->type != nir_instr_type_alu)
+               return false;
+
+       nir_alu_instr *alu = nir_instr_as_alu(instr);
+       switch (alu->op) {
+       case nir_op_frsq:
+       case nir_op_frcp:
+       case nir_op_flog2:
+       case nir_op_fexp2:
+       case nir_op_fsqrt:
+       case nir_op_fcos:
+       case nir_op_fsin:
+               return true;
+       default:
+               break;
+       }
+
+       return false;
+}
+
 void
 ir2_nir_compile(struct ir2_context *ctx, bool binning)
 {
@@ -1064,19 +1097,17 @@ ir2_nir_compile(struct ir2_context *ctx, bool binning)
        if (binning)
                cleanup_binning(ctx);
 
-       /* postprocess */
-       OPT_V(ctx->nir, nir_opt_algebraic_late);
-
-       OPT_V(ctx->nir, nir_lower_to_source_mods, nir_lower_all_source_mods);
        OPT_V(ctx->nir, nir_copy_prop);
        OPT_V(ctx->nir, nir_opt_dce);
-       OPT_V(ctx->nir, nir_opt_move_comparisons);
+       OPT_V(ctx->nir, nir_opt_move, nir_move_comparisons);
 
-       OPT_V(ctx->nir, nir_lower_bool_to_float);
        OPT_V(ctx->nir, nir_lower_int_to_float);
+       OPT_V(ctx->nir, nir_lower_bool_to_float);
+       while(OPT(ctx->nir, nir_opt_algebraic));
+       OPT_V(ctx->nir, nir_opt_algebraic_late);
+       OPT_V(ctx->nir, nir_lower_to_source_mods, nir_lower_all_source_mods);
 
-       /* lower to scalar instructions that can only be scalar on a2xx */
-       OPT_V(ctx->nir, ir2_nir_lower_scalar);
+       OPT_V(ctx->nir, nir_lower_alu_to_scalar, ir2_alu_to_scalar_filter_cb, NULL);
 
        OPT_V(ctx->nir, nir_lower_locals_to_regs);