freedreno/ir3: start dealing with half-precision
authorRob Clark <robdclark@gmail.com>
Fri, 2 Mar 2018 15:21:55 +0000 (10:21 -0500)
committerRob Clark <robdclark@gmail.com>
Mon, 5 Mar 2018 13:05:33 +0000 (08:05 -0500)
Some instructions, assume src and/or dst is half-precision based on a
type field (ie. f32/s32/u32 are full precision but others are half
precision).  So add some code to sanity check the src/dst registers to
catch mixups.

Also propagate half-precision flag for SSA sources.  The instruction
consuming a SSA value needs to be of the same type as the one producing
it.

This is probably not complete half-precision support, but a useful first
step.  We do still need to add support for nir alu instructions for
converting between half/full precision.

Signed-off-by: Rob Clark <robdclark@gmail.com>
src/gallium/drivers/freedreno/ir3/ir3.c
src/gallium/drivers/freedreno/ir3/ir3.h
src/gallium/drivers/freedreno/ir3/ir3_compiler_nir.c

index dd5fb2fbbe5ae483d5cb69f7601233883cf80d03..d1a73ddc727c494532ea2027114d7a6c009de02c 100644 (file)
@@ -68,10 +68,17 @@ void ir3_destroy(struct ir3 *shader)
 
 #define iassert(cond) do { \
        if (!(cond)) { \
-               assert(cond); \
+               debug_assert(cond); \
                return -1; \
        } } while (0)
 
+#define iassert_type(reg, full) do { \
+       if ((full)) { \
+               iassert(!((reg)->flags & IR3_REG_HALF)); \
+       } else { \
+               iassert((reg)->flags & IR3_REG_HALF); \
+       } } while (0);
+
 static uint32_t reg(struct ir3_register *reg, struct ir3_info *info,
                uint32_t repeat, uint32_t valid_flags)
 {
@@ -142,11 +149,6 @@ static int emit_cat0(struct ir3_instruction *instr, void *ptr,
        return 0;
 }
 
-static uint32_t type_flags(type_t type)
-{
-       return (type_size(type) == 32) ? 0 : IR3_REG_HALF;
-}
-
 static int emit_cat1(struct ir3_instruction *instr, void *ptr,
                struct ir3_info *info)
 {
@@ -155,9 +157,9 @@ static int emit_cat1(struct ir3_instruction *instr, void *ptr,
        instr_cat1_t *cat1 = ptr;
 
        iassert(instr->regs_count == 2);
-       iassert(!((dst->flags ^ type_flags(instr->cat1.dst_type)) & IR3_REG_HALF));
-       iassert((src->flags & IR3_REG_IMMED) ||
-                       !((src->flags ^ type_flags(instr->cat1.src_type)) & IR3_REG_HALF));
+       iassert_type(dst, type_size(instr->cat1.dst_type) == 32);
+       if (!(src->flags & IR3_REG_IMMED))
+               iassert_type(src, type_size(instr->cat1.src_type) == 32);
 
        if (src->flags & IR3_REG_IMMED) {
                cat1->iim_val = src->iim_val;
@@ -425,7 +427,7 @@ static int emit_cat5(struct ir3_instruction *instr, void *ptr,
        struct ir3_register *src3 = instr->regs[3];
        instr_cat5_t *cat5 = ptr;
 
-       iassert(!((dst->flags ^ type_flags(instr->cat5.type)) & IR3_REG_HALF));
+       iassert_type(dst, type_size(instr->cat5.type) == 32)
 
        assume(src1 || !src2);
        assume(src2 || !src3);
@@ -477,6 +479,7 @@ static int emit_cat6(struct ir3_instruction *instr, void *ptr,
 {
        struct ir3_register *dst, *src1, *src2;
        instr_cat6_t *cat6 = ptr;
+       bool type_full = type_size(instr->cat6.type) == 32;
 
        cat6->type     = instr->cat6.type;
        cat6->opc      = instr->opc;
@@ -485,6 +488,36 @@ static int emit_cat6(struct ir3_instruction *instr, void *ptr,
        cat6->g        = !!(instr->flags & IR3_INSTR_G);
        cat6->opc_cat  = 6;
 
+       switch (instr->opc) {
+       case OPC_RESINFO:
+       case OPC_RESFMT:
+               iassert_type(instr->regs[0], type_full); /* dst */
+               iassert_type(instr->regs[1], type_full); /* src1 */
+               break;
+       case OPC_L2G:
+       case OPC_G2L:
+               iassert_type(instr->regs[0], true);      /* dst */
+               iassert_type(instr->regs[1], true);      /* src1 */
+               break;
+       case OPC_STG:
+       case OPC_STL:
+       case OPC_STP:
+       case OPC_STI:
+       case OPC_STLW:
+       case OPC_STIB:
+               /* no dst, so regs[0] is dummy */
+               iassert_type(instr->regs[1], true);      /* dst */
+               iassert_type(instr->regs[2], type_full); /* src1 */
+               iassert_type(instr->regs[3], true);      /* src2 */
+               break;
+       default:
+               iassert_type(instr->regs[0], type_full); /* dst */
+               iassert_type(instr->regs[1], true);      /* src1 */
+               if (instr->regs_count > 2)
+                       iassert_type(instr->regs[2], true);  /* src1 */
+               break;
+       }
+
        /* the "dst" for a store instruction is (from the perspective
         * of data flow in the shader, ie. register use/def, etc) in
         * fact a register that is read by the instruction, rather
@@ -628,7 +661,7 @@ static int emit_cat6(struct ir3_instruction *instr, void *ptr,
 
                cat6->src_off = false;
 
-               cat6b->src1 = reg(src1, info, instr->repeat, IR3_REG_IMMED);
+               cat6b->src1 = reg(src1, info, instr->repeat, IR3_REG_IMMED | IR3_REG_HALF);
                cat6b->src1_im = !!(src1->flags & IR3_REG_IMMED);
                if (src2) {
                        cat6b->src2 = reg(src2, info, instr->repeat, IR3_REG_IMMED);
index dd13e3238000440cbc0bad6885fae0b551a3e8a3..eaac2b7dfaca1be6d4c24d51f9efe14e9f10f003 100644 (file)
@@ -1015,18 +1015,28 @@ void ir3_legalize(struct ir3 *ir, bool *has_samp, bool *has_ssbo, int *max_bary)
 /* ************************************************************************* */
 /* instruction helpers */
 
+/* creates SSA src of correct type (ie. half vs full precision) */
+static inline struct ir3_register * __ssa_src(struct ir3_instruction *instr,
+               struct ir3_instruction *src, unsigned flags)
+{
+       struct ir3_register *reg;
+       if (src->regs[0]->flags & IR3_REG_HALF)
+               flags |= IR3_REG_HALF;
+       reg = ir3_reg_create(instr, 0, IR3_REG_SSA | flags);
+       reg->instr = src;
+       return reg;
+}
+
 static inline struct ir3_instruction *
 ir3_MOV(struct ir3_block *block, struct ir3_instruction *src, type_t type)
 {
        struct ir3_instruction *instr = ir3_instr_create(block, OPC_MOV);
        ir3_reg_create(instr, 0, 0);   /* dst */
        if (src->regs[0]->flags & IR3_REG_ARRAY) {
-               struct ir3_register *src_reg =
-                       ir3_reg_create(instr, 0, IR3_REG_ARRAY);
+               struct ir3_register *src_reg = __ssa_src(instr, src, IR3_REG_ARRAY);
                src_reg->array = src->regs[0]->array;
-               src_reg->instr = src;
        } else {
-               ir3_reg_create(instr, 0, IR3_REG_SSA)->instr = src;
+               __ssa_src(instr, src, 0);
        }
        debug_assert(!(src->regs[0]->flags & IR3_REG_RELATIV));
        instr->cat1.src_type = type;
@@ -1040,7 +1050,7 @@ ir3_COV(struct ir3_block *block, struct ir3_instruction *src,
 {
        struct ir3_instruction *instr = ir3_instr_create(block, OPC_MOV);
        ir3_reg_create(instr, 0, 0);   /* dst */
-       ir3_reg_create(instr, 0, IR3_REG_SSA)->instr = src;
+       __ssa_src(instr, src, 0);
        instr->cat1.src_type = src_type;
        instr->cat1.dst_type = dst_type;
        debug_assert(!(src->regs[0]->flags & IR3_REG_ARRAY));
@@ -1070,7 +1080,7 @@ ir3_##name(struct ir3_block *block,                                      \
        struct ir3_instruction *instr =                                      \
                ir3_instr_create(block, OPC_##name);                             \
        ir3_reg_create(instr, 0, 0);   /* dst */                             \
-       ir3_reg_create(instr, 0, IR3_REG_SSA | aflags)->instr = a;           \
+       __ssa_src(instr, a, aflags);                                         \
        return instr;                                                        \
 }
 
@@ -1083,8 +1093,8 @@ ir3_##name(struct ir3_block *block,                                      \
        struct ir3_instruction *instr =                                      \
                ir3_instr_create(block, OPC_##name);                             \
        ir3_reg_create(instr, 0, 0);   /* dst */                             \
-       ir3_reg_create(instr, 0, IR3_REG_SSA | aflags)->instr = a;           \
-       ir3_reg_create(instr, 0, IR3_REG_SSA | bflags)->instr = b;           \
+       __ssa_src(instr, a, aflags);                                         \
+       __ssa_src(instr, b, bflags);                                         \
        return instr;                                                        \
 }
 
@@ -1098,9 +1108,9 @@ ir3_##name(struct ir3_block *block,                                      \
        struct ir3_instruction *instr =                                      \
                ir3_instr_create(block, OPC_##name);                             \
        ir3_reg_create(instr, 0, 0);   /* dst */                             \
-       ir3_reg_create(instr, 0, IR3_REG_SSA | aflags)->instr = a;           \
-       ir3_reg_create(instr, 0, IR3_REG_SSA | bflags)->instr = b;           \
-       ir3_reg_create(instr, 0, IR3_REG_SSA | cflags)->instr = c;           \
+       __ssa_src(instr, a, aflags);                                         \
+       __ssa_src(instr, b, bflags);                                         \
+       __ssa_src(instr, c, cflags);                                         \
        return instr;                                                        \
 }
 
@@ -1115,10 +1125,10 @@ ir3_##name(struct ir3_block *block,                                      \
        struct ir3_instruction *instr =                                      \
                ir3_instr_create2(block, OPC_##name, 5);                         \
        ir3_reg_create(instr, 0, 0);   /* dst */                             \
-       ir3_reg_create(instr, 0, IR3_REG_SSA | aflags)->instr = a;           \
-       ir3_reg_create(instr, 0, IR3_REG_SSA | bflags)->instr = b;           \
-       ir3_reg_create(instr, 0, IR3_REG_SSA | cflags)->instr = c;           \
-       ir3_reg_create(instr, 0, IR3_REG_SSA | dflags)->instr = d;           \
+       __ssa_src(instr, a, aflags);                                         \
+       __ssa_src(instr, b, bflags);                                         \
+       __ssa_src(instr, c, cflags);                                         \
+       __ssa_src(instr, d, dflags);                                         \
        return instr;                                                        \
 }
 
@@ -1133,10 +1143,10 @@ ir3_##name##_##f(struct ir3_block *block,                                \
        struct ir3_instruction *instr =                                      \
                ir3_instr_create2(block, OPC_##name, 5);                         \
        ir3_reg_create(instr, 0, 0);   /* dst */                             \
-       ir3_reg_create(instr, 0, IR3_REG_SSA | aflags)->instr = a;           \
-       ir3_reg_create(instr, 0, IR3_REG_SSA | bflags)->instr = b;           \
-       ir3_reg_create(instr, 0, IR3_REG_SSA | cflags)->instr = c;           \
-       ir3_reg_create(instr, 0, IR3_REG_SSA | dflags)->instr = d;           \
+       __ssa_src(instr, a, aflags);                                         \
+       __ssa_src(instr, b, bflags);                                         \
+       __ssa_src(instr, c, cflags);                                         \
+       __ssa_src(instr, d, dflags);                                         \
        instr->flags |= IR3_INSTR_##f;                                       \
        return instr;                                                        \
 }
index 8644bc1921890c5c652a7702d746a86006356aa4..1d6403b5260f18c29f23cdcc5acf9a4d2b22cf2e 100644 (file)
@@ -472,6 +472,14 @@ get_src(struct ir3_context *ctx, nir_src *src)
 static void
 put_dst(struct ir3_context *ctx, nir_dest *dst)
 {
+       unsigned bit_size = nir_dest_bit_size(*dst);
+
+       if (bit_size < 32) {
+               for (unsigned i = 0; i < ctx->last_dst_n; i++) {
+                       ctx->last_dst[i]->regs[0]->flags |= IR3_REG_HALF;
+               }
+       }
+
        if (!dst->is_ssa) {
                nir_register *reg = dst->reg.reg;
                struct ir3_array *arr = get_array(ctx, reg);