aco: don't move memory accesses to before control barriers
[mesa.git] / src / amd / compiler / aco_print_ir.cpp
index 5ced1d2d7bb8c64e4c14aed1955af9f5f29eeeca..8b8b5d0f306bc39f549317678a613342725a4985 100644 (file)
@@ -7,30 +7,52 @@
 namespace aco {
 
 static const char *reduce_ops[] = {
+   [iadd8] = "iadd8",
+   [iadd16] = "iadd16",
    [iadd32] = "iadd32",
    [iadd64] = "iadd64",
+   [imul8] = "imul8",
+   [imul16] = "imul16",
    [imul32] = "imul32",
    [imul64] = "imul64",
+   [fadd16] = "fadd16",
    [fadd32] = "fadd32",
    [fadd64] = "fadd64",
+   [fmul16] = "fmul16",
    [fmul32] = "fmul32",
    [fmul64] = "fmul64",
+   [imin8] = "imin8",
+   [imin16] = "imin16",
    [imin32] = "imin32",
    [imin64] = "imin64",
+   [imax8] = "imax8",
+   [imax16] = "imax16",
    [imax32] = "imax32",
    [imax64] = "imax64",
+   [umin8] = "umin8",
+   [umin16] = "umin16",
    [umin32] = "umin32",
    [umin64] = "umin64",
+   [umax8] = "umax8",
+   [umax16] = "umax16",
    [umax32] = "umax32",
    [umax64] = "umax64",
+   [fmin16] = "fmin16",
    [fmin32] = "fmin32",
    [fmin64] = "fmin64",
+   [fmax16] = "fmax16",
    [fmax32] = "fmax32",
    [fmax64] = "fmax64",
+   [iand8] = "iand8",
+   [iand16] = "iand16",
    [iand32] = "iand32",
    [iand64] = "iand64",
+   [ior8] = "ior8",
+   [ior16] = "ior16",
    [ior32] = "ior32",
    [ior64] = "ior64",
+   [ixor8] = "ixor8",
+   [ixor16] = "ixor16",
    [ixor32] = "ixor32",
    [ixor64] = "ixor64",
 };
@@ -53,12 +75,18 @@ static void print_reg_class(const RegClass rc, FILE *output)
       case RegClass::v6: fprintf(output, " v6: "); return;
       case RegClass::v7: fprintf(output, " v7: "); return;
       case RegClass::v8: fprintf(output, " v8: "); return;
+      case RegClass::v1b: fprintf(output, " v1b: "); return;
+      case RegClass::v2b: fprintf(output, " v2b: "); return;
+      case RegClass::v3b: fprintf(output, " v3b: "); return;
+      case RegClass::v4b: fprintf(output, " v4b: "); return;
+      case RegClass::v6b: fprintf(output, " v6b: "); return;
+      case RegClass::v8b: fprintf(output, " v8b: "); return;
       case RegClass::v1_linear: fprintf(output, " v1: "); return;
       case RegClass::v2_linear: fprintf(output, " v2: "); return;
    }
 }
 
-void print_physReg(unsigned reg, unsigned size, FILE *output)
+void print_physReg(PhysReg reg, unsigned bytes, FILE *output)
 {
    if (reg == 124) {
       fprintf(output, ":m0");
@@ -70,12 +98,15 @@ void print_physReg(unsigned reg, unsigned size, FILE *output)
       fprintf(output, ":exec");
    } else {
       bool is_vgpr = reg / 256;
-      reg = reg % 256;
-      fprintf(output, ":%c[%d", is_vgpr ? 'v' : 's', reg);
+      unsigned r = reg % 256;
+      unsigned size = DIV_ROUND_UP(bytes, 4);
+      fprintf(output, ":%c[%d", is_vgpr ? 'v' : 's', r);
       if (size > 1)
-         fprintf(output, "-%d]", reg + size -1);
+         fprintf(output, "-%d]", r + size -1);
       else
          fprintf(output, "]");
+      if (reg.byte() || bytes % 4)
+         fprintf(output, "[%d:%d]", reg.byte()*8, (reg.byte()+bytes) * 8);
    }
 }
 
@@ -122,55 +153,121 @@ static void print_constant(uint8_t reg, FILE *output)
 
 static void print_operand(const Operand *operand, FILE *output)
 {
-   if (operand->isLiteral()) {
-      fprintf(output, "0x%x", operand->constantValue());
+   if (operand->isLiteral() || (operand->isConstant() && operand->bytes() == 1)) {
+      if (operand->bytes() == 1)
+         fprintf(output, "0x%.2x", operand->constantValue());
+      else if (operand->bytes() == 2)
+         fprintf(output, "0x%.4x", operand->constantValue());
+      else
+         fprintf(output, "0x%x", operand->constantValue());
    } else if (operand->isConstant()) {
-      print_constant(operand->physReg().reg, output);
+      print_constant(operand->physReg().reg(), output);
    } else if (operand->isUndefined()) {
       print_reg_class(operand->regClass(), output);
       fprintf(output, "undef");
    } else {
+      if (operand->isLateKill())
+         fprintf(output, "(latekill)");
+
       fprintf(output, "%%%d", operand->tempId());
 
       if (operand->isFixed())
-         print_physReg(operand->physReg(), operand->size(), output);
+         print_physReg(operand->physReg(), operand->bytes(), output);
    }
 }
 
 static void print_definition(const Definition *definition, FILE *output)
 {
    print_reg_class(definition->regClass(), output);
+   if (definition->isPrecise())
+      fprintf(output, "(precise)");
+   if (definition->isNUW())
+      fprintf(output, "(nuw)");
    fprintf(output, "%%%d", definition->tempId());
 
    if (definition->isFixed())
-      print_physReg(definition->physReg(), definition->size(), output);
+      print_physReg(definition->physReg(), definition->bytes(), output);
 }
 
-static void print_barrier_reorder(bool can_reorder, barrier_interaction barrier, FILE *output)
+static void print_storage(storage_class storage, FILE *output)
 {
-   if (can_reorder)
-      fprintf(output, " reorder");
+   fprintf(output, " storage:");
+   int printed = 0;
+   if (storage & storage_buffer)
+      printed += fprintf(output, "%sbuffer", printed ? "," : "");
+   if (storage & storage_atomic_counter)
+      printed += fprintf(output, "%satomic_counter", printed ? "," : "");
+   if (storage & storage_image)
+      printed += fprintf(output, "%simage", printed ? "," : "");
+   if (storage & storage_shared)
+      printed += fprintf(output, "%sshared", printed ? "," : "");
+   if (storage & storage_vmem_output)
+      printed += fprintf(output, "%svmem_output", printed ? "," : "");
+   if (storage & storage_scratch)
+      printed += fprintf(output, "%sscratch", printed ? "," : "");
+   if (storage & storage_vgpr_spill)
+      printed += fprintf(output, "%svgpr_spill", printed ? "," : "");
+}
+
+static void print_semantics(memory_semantics sem, FILE *output)
+{
+   fprintf(output, " semantics:");
+   int printed = 0;
+   if (sem & semantic_acquire)
+      printed += fprintf(output, "%sacquire", printed ? "," : "");
+   if (sem & semantic_release)
+      printed += fprintf(output, "%srelease", printed ? "," : "");
+   if (sem & semantic_volatile)
+      printed += fprintf(output, "%svolatile", printed ? "," : "");
+   if (sem & semantic_private)
+      printed += fprintf(output, "%sprivate", printed ? "," : "");
+   if (sem & semantic_can_reorder)
+      printed += fprintf(output, "%sreorder", printed ? "," : "");
+   if (sem & semantic_atomic)
+      printed += fprintf(output, "%satomic", printed ? "," : "");
+   if (sem & semantic_rmw)
+      printed += fprintf(output, "%srmw", printed ? "," : "");
+}
+
+static void print_scope(sync_scope scope, FILE *output, const char *prefix="scope")
+{
+   fprintf(output, " %s:", prefix);
+   switch (scope) {
+   case scope_invocation:
+      fprintf(output, "invocation");
+      break;
+   case scope_subgroup:
+      fprintf(output, "subgroup");
+      break;
+   case scope_workgroup:
+      fprintf(output, "workgroup");
+      break;
+   case scope_queuefamily:
+      fprintf(output, "queuefamily");
+      break;
+   case scope_device:
+      fprintf(output, "device");
+      break;
+   }
+}
 
-   if (barrier & barrier_buffer)
-      fprintf(output, " buffer");
-   if (barrier & barrier_image)
-      fprintf(output, " image");
-   if (barrier & barrier_atomic)
-      fprintf(output, " atomic");
-   if (barrier & barrier_shared)
-      fprintf(output, " shared");
+static void print_sync(memory_sync_info sync, FILE *output)
+{
+   print_storage(sync.storage, output);
+   print_semantics(sync.semantics, output);
+   print_scope(sync.scope, output);
 }
 
-static void print_instr_format_specific(struct Instruction *instr, FILE *output)
+static void print_instr_format_specific(const Instruction *instr, FILE *output)
 {
    switch (instr->format) {
    case Format::SOPK: {
-      SOPK_instruction* sopk = static_cast<SOPK_instruction*>(instr);
+      const SOPK_instruction* sopk = static_cast<const SOPK_instruction*>(instr);
       fprintf(output, " imm:%d", sopk->imm & 0x8000 ? (sopk->imm - 65536) : sopk->imm);
       break;
    }
    case Format::SOPP: {
-      SOPP_instruction* sopp = static_cast<SOPP_instruction*>(instr);
+      const SOPP_instruction* sopp = static_cast<const SOPP_instruction*>(instr);
       uint16_t imm = sopp->imm;
       switch (instr->opcode) {
       case aco_opcode::s_waitcnt: {
@@ -192,6 +289,41 @@ static void print_instr_format_specific(struct Instruction *instr, FILE *output)
       case aco_opcode::s_set_gpr_idx_off: {
          break;
       }
+      case aco_opcode::s_sendmsg: {
+         unsigned id = imm & sendmsg_id_mask;
+         switch (id) {
+         case sendmsg_none:
+            fprintf(output, " sendmsg(MSG_NONE)");
+            break;
+         case _sendmsg_gs:
+            fprintf(output, " sendmsg(gs%s%s, %u)",
+                    imm & 0x10 ? ", cut" : "", imm & 0x20 ? ", emit" : "", imm >> 8);
+            break;
+         case _sendmsg_gs_done:
+            fprintf(output, " sendmsg(gs_done%s%s, %u)",
+                    imm & 0x10 ? ", cut" : "", imm & 0x20 ? ", emit" : "", imm >> 8);
+            break;
+         case sendmsg_save_wave:
+            fprintf(output, " sendmsg(save_wave)");
+            break;
+         case sendmsg_stall_wave_gen:
+            fprintf(output, " sendmsg(stall_wave_gen)");
+            break;
+         case sendmsg_halt_waves:
+            fprintf(output, " sendmsg(halt_waves)");
+            break;
+         case sendmsg_ordered_ps_done:
+            fprintf(output, " sendmsg(ordered_ps_done)");
+            break;
+         case sendmsg_early_prim_dealloc:
+            fprintf(output, " sendmsg(early_prim_dealloc)");
+            break;
+         case sendmsg_gs_alloc_req:
+            fprintf(output, " sendmsg(gs_alloc_req)");
+            break;
+         }
+         break;
+      }
       default: {
          if (imm)
             fprintf(output, " imm:%u", imm);
@@ -203,39 +335,42 @@ static void print_instr_format_specific(struct Instruction *instr, FILE *output)
       break;
    }
    case Format::SMEM: {
-      SMEM_instruction* smem = static_cast<SMEM_instruction*>(instr);
+      const SMEM_instruction* smem = static_cast<const SMEM_instruction*>(instr);
       if (smem->glc)
          fprintf(output, " glc");
       if (smem->dlc)
          fprintf(output, " dlc");
       if (smem->nv)
          fprintf(output, " nv");
-      print_barrier_reorder(smem->can_reorder, smem->barrier, output);
+      print_sync(smem->sync, output);
       break;
    }
    case Format::VINTRP: {
-      Interp_instruction* vintrp = static_cast<Interp_instruction*>(instr);
+      const Interp_instruction* vintrp = static_cast<const Interp_instruction*>(instr);
       fprintf(output, " attr%d.%c", vintrp->attribute, "xyzw"[vintrp->component]);
       break;
    }
    case Format::DS: {
-      DS_instruction* ds = static_cast<DS_instruction*>(instr);
+      const DS_instruction* ds = static_cast<const DS_instruction*>(instr);
       if (ds->offset0)
          fprintf(output, " offset0:%u", ds->offset0);
       if (ds->offset1)
          fprintf(output, " offset1:%u", ds->offset1);
       if (ds->gds)
          fprintf(output, " gds");
+      print_sync(ds->sync, output);
       break;
    }
    case Format::MUBUF: {
-      MUBUF_instruction* mubuf = static_cast<MUBUF_instruction*>(instr);
+      const MUBUF_instruction* mubuf = static_cast<const MUBUF_instruction*>(instr);
       if (mubuf->offset)
          fprintf(output, " offset:%u", mubuf->offset);
       if (mubuf->offen)
          fprintf(output, " offen");
       if (mubuf->idxen)
          fprintf(output, " idxen");
+      if (mubuf->addr64)
+         fprintf(output, " addr64");
       if (mubuf->glc)
          fprintf(output, " glc");
       if (mubuf->dlc)
@@ -248,11 +383,11 @@ static void print_instr_format_specific(struct Instruction *instr, FILE *output)
          fprintf(output, " lds");
       if (mubuf->disable_wqm)
          fprintf(output, " disable_wqm");
-      print_barrier_reorder(mubuf->can_reorder, mubuf->barrier, output);
+      print_sync(mubuf->sync, output);
       break;
    }
    case Format::MIMG: {
-      MIMG_instruction* mimg = static_cast<MIMG_instruction*>(instr);
+      const MIMG_instruction* mimg = static_cast<const MIMG_instruction*>(instr);
       unsigned identity_dmask = !instr->definitions.empty() ?
                                 (1 << instr->definitions[0].size()) - 1 :
                                 0xf;
@@ -308,11 +443,11 @@ static void print_instr_format_specific(struct Instruction *instr, FILE *output)
          fprintf(output, " d16");
       if (mimg->disable_wqm)
          fprintf(output, " disable_wqm");
-      print_barrier_reorder(mimg->can_reorder, mimg->barrier, output);
+      print_sync(mimg->sync, output);
       break;
    }
    case Format::EXP: {
-      Export_instruction* exp = static_cast<Export_instruction*>(instr);
+      const Export_instruction* exp = static_cast<const Export_instruction*>(instr);
       unsigned identity_mask = exp->compressed ? 0x5 : 0xf;
       if ((exp->enabled_mask & identity_mask) != identity_mask)
          fprintf(output, " en:%c%c%c%c",
@@ -340,7 +475,7 @@ static void print_instr_format_specific(struct Instruction *instr, FILE *output)
       break;
    }
    case Format::PSEUDO_BRANCH: {
-      Pseudo_branch_instruction* branch = static_cast<Pseudo_branch_instruction*>(instr);
+      const Pseudo_branch_instruction* branch = static_cast<const Pseudo_branch_instruction*>(instr);
       /* Note: BB0 cannot be a branch target */
       if (branch->target[0] != 0)
          fprintf(output, " BB%d", branch->target[0]);
@@ -349,16 +484,22 @@ static void print_instr_format_specific(struct Instruction *instr, FILE *output)
       break;
    }
    case Format::PSEUDO_REDUCTION: {
-      Pseudo_reduction_instruction* reduce = static_cast<Pseudo_reduction_instruction*>(instr);
+      const Pseudo_reduction_instruction* reduce = static_cast<const Pseudo_reduction_instruction*>(instr);
       fprintf(output, " op:%s", reduce_ops[reduce->reduce_op]);
       if (reduce->cluster_size)
          fprintf(output, " cluster_size:%u", reduce->cluster_size);
       break;
    }
+   case Format::PSEUDO_BARRIER: {
+      const Pseudo_barrier_instruction* barrier = static_cast<const Pseudo_barrier_instruction*>(instr);
+      print_sync(barrier->sync, output);
+      print_scope(barrier->exec_scope, output, "exec_scope");
+      break;
+   }
    case Format::FLAT:
    case Format::GLOBAL:
    case Format::SCRATCH: {
-      FLAT_instruction* flat = static_cast<FLAT_instruction*>(instr);
+      const FLAT_instruction* flat = static_cast<const FLAT_instruction*>(instr);
       if (flat->offset)
          fprintf(output, " offset:%u", flat->offset);
       if (flat->glc)
@@ -373,10 +514,11 @@ static void print_instr_format_specific(struct Instruction *instr, FILE *output)
          fprintf(output, " nv");
       if (flat->disable_wqm)
          fprintf(output, " disable_wqm");
+      print_sync(flat->sync, output);
       break;
    }
    case Format::MTBUF: {
-      MTBUF_instruction* mtbuf = static_cast<MTBUF_instruction*>(instr);
+      const MTBUF_instruction* mtbuf = static_cast<const MTBUF_instruction*>(instr);
       fprintf(output, " dfmt:");
       switch (mtbuf->dfmt) {
       case V_008F0C_BUF_DATA_FORMAT_8: fprintf(output, "8"); break;
@@ -422,7 +564,12 @@ static void print_instr_format_specific(struct Instruction *instr, FILE *output)
          fprintf(output, " tfe");
       if (mtbuf->disable_wqm)
          fprintf(output, " disable_wqm");
-      print_barrier_reorder(mtbuf->can_reorder, mtbuf->barrier, output);
+      print_sync(mtbuf->sync, output);
+      break;
+   }
+   case Format::VOP3P: {
+      if (static_cast<const VOP3P_instruction*>(instr)->clamp)
+         fprintf(output, " clamp");
       break;
    }
    default: {
@@ -430,7 +577,7 @@ static void print_instr_format_specific(struct Instruction *instr, FILE *output)
    }
    }
    if (instr->isVOP3()) {
-      VOP3A_instruction* vop3 = static_cast<VOP3A_instruction*>(instr);
+      const VOP3A_instruction* vop3 = static_cast<const VOP3A_instruction*>(instr);
       switch (vop3->omod) {
       case 1:
          fprintf(output, " *2");
@@ -444,8 +591,10 @@ static void print_instr_format_specific(struct Instruction *instr, FILE *output)
       }
       if (vop3->clamp)
          fprintf(output, " clamp");
+      if (vop3->opsel & (1 << 3))
+         fprintf(output, " opsel_hi");
    } else if (instr->isDPP()) {
-      DPP_instruction* dpp = static_cast<DPP_instruction*>(instr);
+      const DPP_instruction* dpp = static_cast<const DPP_instruction*>(instr);
       if (dpp->dpp_ctrl <= 0xff) {
          fprintf(output, " quad_perm:[%d,%d,%d,%d]",
                  dpp->dpp_ctrl & 0x3, (dpp->dpp_ctrl >> 2) & 0x3,
@@ -482,11 +631,42 @@ static void print_instr_format_specific(struct Instruction *instr, FILE *output)
       if (dpp->bound_ctrl)
          fprintf(output, " bound_ctrl:1");
    } else if ((int)instr->format & (int)Format::SDWA) {
-      fprintf(output, " (printing unimplemented)");
+      const SDWA_instruction* sdwa = static_cast<const SDWA_instruction*>(instr);
+      switch (sdwa->omod) {
+      case 1:
+         fprintf(output, " *2");
+         break;
+      case 2:
+         fprintf(output, " *4");
+         break;
+      case 3:
+         fprintf(output, " *0.5");
+         break;
+      }
+      if (sdwa->clamp)
+         fprintf(output, " clamp");
+      switch (sdwa->dst_sel & sdwa_asuint) {
+      case sdwa_udword:
+         break;
+      case sdwa_ubyte0:
+      case sdwa_ubyte1:
+      case sdwa_ubyte2:
+      case sdwa_ubyte3:
+         fprintf(output, " dst_sel:%sbyte%u", sdwa->dst_sel & sdwa_sext ? "s" : "u",
+                 sdwa->dst_sel & sdwa_bytenum);
+         break;
+      case sdwa_uword0:
+      case sdwa_uword1:
+         fprintf(output, " dst_sel:%sword%u", sdwa->dst_sel & sdwa_sext ? "s" : "u",
+                 sdwa->dst_sel & sdwa_wordnum);
+         break;
+      }
+      if (sdwa->dst_preserve)
+         fprintf(output, " dst_preserve");
    }
 }
 
-void aco_print_instr(struct Instruction *instr, FILE *output)
+void aco_print_instr(const Instruction *instr, FILE *output)
 {
    if (!instr->definitions.empty()) {
       for (unsigned i = 0; i < instr->definitions.size(); ++i) {
@@ -500,23 +680,38 @@ void aco_print_instr(struct Instruction *instr, FILE *output)
    if (instr->operands.size()) {
       bool abs[instr->operands.size()];
       bool neg[instr->operands.size()];
+      bool opsel[instr->operands.size()];
+      uint8_t sel[instr->operands.size()];
       if ((int)instr->format & (int)Format::VOP3A) {
-         VOP3A_instruction* vop3 = static_cast<VOP3A_instruction*>(instr);
+         const VOP3A_instruction* vop3 = static_cast<const VOP3A_instruction*>(instr);
          for (unsigned i = 0; i < instr->operands.size(); ++i) {
             abs[i] = vop3->abs[i];
             neg[i] = vop3->neg[i];
+            opsel[i] = vop3->opsel & (1 << i);
+            sel[i] = sdwa_udword;
          }
       } else if (instr->isDPP()) {
-         DPP_instruction* dpp = static_cast<DPP_instruction*>(instr);
-         assert(instr->operands.size() <= 2);
+         const DPP_instruction* dpp = static_cast<const DPP_instruction*>(instr);
          for (unsigned i = 0; i < instr->operands.size(); ++i) {
-            abs[i] = dpp->abs[i];
-            neg[i] = dpp->neg[i];
+            abs[i] = i < 2 ? dpp->abs[i] : false;
+            neg[i] = i < 2 ? dpp->neg[i] : false;
+            opsel[i] = false;
+            sel[i] = sdwa_udword;
+         }
+      } else if (instr->isSDWA()) {
+         const SDWA_instruction* sdwa = static_cast<const SDWA_instruction*>(instr);
+         for (unsigned i = 0; i < instr->operands.size(); ++i) {
+            abs[i] = i < 2 ? sdwa->abs[i] : false;
+            neg[i] = i < 2 ? sdwa->neg[i] : false;
+            opsel[i] = false;
+            sel[i] = i < 2 ? sdwa->sel[i] : sdwa_udword;
          }
       } else {
          for (unsigned i = 0; i < instr->operands.size(); ++i) {
             abs[i] = false;
             neg[i] = false;
+            opsel[i] = false;
+            sel[i] = sdwa_udword;
          }
       }
       for (unsigned i = 0; i < instr->operands.size(); ++i) {
@@ -529,10 +724,42 @@ void aco_print_instr(struct Instruction *instr, FILE *output)
             fprintf(output, "-");
          if (abs[i])
             fprintf(output, "|");
+         if (opsel[i])
+            fprintf(output, "hi(");
+         else if (sel[i] & sdwa_sext)
+            fprintf(output, "sext(");
          print_operand(&instr->operands[i], output);
+         if (opsel[i] || (sel[i] & sdwa_sext))
+            fprintf(output, ")");
+         if (!(sel[i] & sdwa_isra)) {
+            if (sel[i] & sdwa_udword) {
+               /* print nothing */
+            } else if (sel[i] & sdwa_isword) {
+               unsigned index = sel[i] & sdwa_wordnum;
+               fprintf(output, "[%u:%u]", index * 16, index * 16 + 15);
+            } else {
+               unsigned index = sel[i] & sdwa_bytenum;
+               fprintf(output, "[%u:%u]", index * 8, index * 8 + 7);
+            }
+         }
          if (abs[i])
             fprintf(output, "|");
-       }
+
+         if (instr->format == Format::VOP3P) {
+            const VOP3P_instruction* vop3 = static_cast<const VOP3P_instruction*>(instr);
+            if ((vop3->opsel_lo & (1 << i)) || !(vop3->opsel_hi & (1 << i))) {
+               fprintf(output, ".%c%c",
+                       vop3->opsel_lo & (1 << i) ? 'y' : 'x',
+                       vop3->opsel_hi & (1 << i) ? 'y' : 'x');
+            }
+            if (vop3->neg_lo[i] && vop3->neg_hi[i])
+               fprintf(output, "*[-1,-1]");
+            else if (vop3->neg_lo[i])
+               fprintf(output, "*[-1,1]");
+            else if (vop3->neg_hi[i])
+               fprintf(output, "*[1,-1]");
+         }
+      }
    }
    print_instr_format_specific(instr, output);
 }
@@ -569,9 +796,55 @@ static void print_block_kind(uint16_t kind, FILE *output)
       fprintf(output, "needs_lowering, ");
    if (kind & block_kind_uses_demote)
       fprintf(output, "uses_demote, ");
+   if (kind & block_kind_export_end)
+      fprintf(output, "export_end, ");
 }
 
-void aco_print_block(const struct Block* block, FILE *output)
+static void print_stage(Stage stage, FILE *output)
+{
+   fprintf(output, "ACO shader stage: ");
+
+   if (stage == compute_cs)
+      fprintf(output, "compute_cs");
+   else if (stage == fragment_fs)
+      fprintf(output, "fragment_fs");
+   else if (stage == gs_copy_vs)
+      fprintf(output, "gs_copy_vs");
+   else if (stage == vertex_ls)
+      fprintf(output, "vertex_ls");
+   else if (stage == vertex_es)
+      fprintf(output, "vertex_es");
+   else if (stage == vertex_vs)
+      fprintf(output, "vertex_vs");
+   else if (stage == tess_control_hs)
+      fprintf(output, "tess_control_hs");
+   else if (stage == vertex_tess_control_hs)
+      fprintf(output, "vertex_tess_control_hs");
+   else if (stage == tess_eval_es)
+      fprintf(output, "tess_eval_es");
+   else if (stage == tess_eval_vs)
+      fprintf(output, "tess_eval_vs");
+   else if (stage == geometry_gs)
+      fprintf(output, "geometry_gs");
+   else if (stage == vertex_geometry_gs)
+      fprintf(output, "vertex_geometry_gs");
+   else if (stage == tess_eval_geometry_gs)
+      fprintf(output, "tess_eval_geometry_gs");
+   else if (stage == ngg_vertex_gs)
+      fprintf(output, "ngg_vertex_gs");
+   else if (stage == ngg_tess_eval_gs)
+      fprintf(output, "ngg_tess_eval_gs");
+   else if (stage == ngg_vertex_geometry_gs)
+      fprintf(output, "ngg_vertex_geometry_gs");
+   else if (stage == ngg_tess_eval_geometry_gs)
+      fprintf(output, "ngg_tess_eval_geometry_gs");
+   else
+      fprintf(output, "unknown");
+
+   fprintf(output, "\n");
+}
+
+void aco_print_block(const Block* block, FILE *output)
 {
    fprintf(output, "BB%d\n", block->index);
    fprintf(output, "/* logical preds: ");
@@ -590,8 +863,10 @@ void aco_print_block(const struct Block* block, FILE *output)
    }
 }
 
-void aco_print_program(Program *program, FILE *output)
+void aco_print_program(const Program *program, FILE *output)
 {
+   print_stage(program->stage, output);
+
    for (Block const& block : program->blocks)
       aco_print_block(&block, output);