Search code examples
llvmllvm-ir

How to create a new Instruction to replace the previous one in llvm pass?


Given such an example named TestCase1.ll:

; RUN: opt -load-pass-plugin=%dylibdir/libLocalOpts.so \
; RUN:     -p=algebraic-identity,strength-reduction,multi-inst-opt \
; RUN:     -S %s -o %basename_t
; RUN: FileCheck --match-full-lines %s --input-file=%basename_t

; #include <stdio.h>

; void foo(int a) {
;   int r0 = a + 0;
;   int r1 = r0 * 16;
;   int r2 = r1 * r0;
;   int r3 = r2 / a;
;   int r4 = r2 / 10;
;   int r5 = 54 * r3;
;   int r6 = r4 / 128;
;   int r7 = r5 / 54;
;   int r8 = r4 / 1;
;   int r9 = r7 - 0;
;   printf("%d,%d,%d,%d,%d,%d,%d,%d,%d,%d\n", r0, r1, r2, r3, r4, r5, r6, r7, r8,
;          r9);
; }

@.str = private unnamed_addr constant [31 x i8] c"%d,%d,%d,%d,%d,%d,%d,%d,%d,%d\0A\00", align 1

define dso_local void @foo(i32 noundef %0) {
; CHECK-LABEL: define dso_local void @foo(i32 noundef %0) {
; @todo(CSCD70) Please complete the CHECK directives.
  %2 = add nsw i32 %0, 0
  %3 = mul nsw i32 %2, 16
  %4 = mul nsw i32 %3, %2
  %5 = sdiv i32 %4, %0
  %6 = sdiv i32 %4, 10
  %7 = mul nsw i32 54, %5
  %8 = sdiv i32 %6, 128
  %9 = sdiv i32 %7, 54
  %10 = sdiv i32 %6, 1
  %11 = sub nsw i32 %9, 0
  %12 = call i32 (ptr, ...) @printf(ptr noundef @.str, i32 noundef %2, i32 noundef %3, i32 noundef %4, i32 noundef %5, i32 noundef %6, i32 noundef %7, i32 noundef %8, i32 noundef %9, i32 noundef %10, i32 noundef %11)
  ret void
}

declare i32 @printf(ptr noundef, ...) #1

I would like to optimize the mul and sdiv instructions by replacing them with shift instructions for faster execution. In simple terms, I want to optimize x * 4 to x << 2.

I write an llvm Pass like this:

int getShift(int64_t num){
  if(num < 0 || (num & (num - 1)) != 0)
    return -1;

  int cnt = 0;
  while(num > 1){
    cnt++;
    num >>= 1;
  }
  return cnt;
}

PreservedAnalyses StrengthReductionPass::run([[maybe_unused]] Function &F,
                                             FunctionAnalysisManager &) {

  /// @todo(CSCD70) Please complete this method.
  for(auto &BB : F)
    for(auto &I : BB) {
      if(I.getNumOperands() == 2) {
        Value* op0 = I.getOperand(0);
        Value* op1 = I.getOperand(1);

        int64_t v0 = -1, v1 = -2;
        if(isa<ConstantInt>(op0))
          v0 = dyn_cast<ConstantInt>(op0) -> getSExtValue();
        if(isa<ConstantInt>(op1))
          v1 = dyn_cast<ConstantInt>(op1) -> getSExtValue();

        switch (I.getOpcode()) {
          case Instruction::Mul:
            if(isa<ConstantInt>(op0) && getShift(v0) != -1) {
              I.replaceAllUsesWith(BinaryOperator::CreateShl(op1, ConstantInt::get(I.getType(), getShift(v0))));
            }
            else if(isa<ConstantInt>(op1) && getShift(v1) != -1) {
              I.replaceAllUsesWith(BinaryOperator::CreateShl(op0, ConstantInt::get(I.getType(), getShift(v1))));
            }
            break;

          case Instruction::SDiv:
            if(isa<ConstantInt>(op1) && getShift(v1) != -1) {
              I.replaceAllUsesWith(BinaryOperator::CreateAShr(op0, ConstantInt::get(I.getType(), getShift(v1))));
            }
            break;
          
          default:
            break;
        }
      }
    }
  
  return PreservedAnalyses::none();
}

And it gives me an error:

Instruction does not dominate all uses!
  <badref> = shl i32 %0, 4
  %4 = mul nsw i32 <badref>, %0
Instruction does not dominate all uses!
  <badref> = shl i32 %0, 4
  %12 = call i32 (ptr, ...) @printf(ptr noundef @.str, i32 noundef %0, i32 noundef <badref>, i32 noundef %4, i32 noundef %5, i32 noundef %6, i32 noundef %7, i32 noundef <badref>, i32 noundef %9, i32 noundef %6, i32 noundef %9)
Instruction does not dominate all uses!
  <badref> = ashr i32 128, -1
  %12 = call i32 (ptr, ...) @printf(ptr noundef @.str, i32 noundef %0, i32 noundef <badref>, i32 noundef %4, i32 noundef %5, i32 noundef %6, i32 noundef %7, i32 noundef <badref>, i32 noundef %9, i32 noundef %6, i32 noundef %9)
LLVM ERROR: Broken module found, compilation aborted!

However, I succeed after changing the I.replaceAllUsesWith() into the following two approaches:

I.replaceAllUsesWith(BinaryOperator::Create(Instruction::Shl, op1, ConstantInt::get(I.getType(), getShift(v0)),"shl",&I));
auto *NewInst = BinaryOperator::CreateShl(op1, ConstantInt::get(I.getType(), getShift(v0)));
NewInst -> insertBefore(&I);
I.replaceAllUsesWith(NewInst);

I am really confused with what the replaceAllUsesWith() does. Will it individually do some changes in the code list visually?

I think the key is that the BinaryOperator::CreateShl just create an Instruction in memory and didn't insert it in the code list, so if I use it later, it cannot be traced. Am I right? I really want to figure out the performance of replaceAllUsesWith() in detail.


Solution

  • Forget the IR builder, it's just an unnecessary layer of indirection in this case. You need about three lines of code.

    If you look at the BinaryOperator documentation, you'll see the the very first function mentioned is called Create and takes an argument called Instruction * insertBefore. That's the one you want. Call that to create a new BinaryOperator and insert it into the program immediately before the instruction you want to replace. That's the first line you need.

    Your second line is to update the usages of the old instruction to reference the new one. Calling old->replaceAllUsesWith(…) will do that, IIRC.

    Now that your new instruction is used by the rest of the code and the old one is disused, it's time to get rid of the old one and you're done. Three lines.

    A comment: I've found it best to split the code in two. One part iterates over the instructions in a function and identifies the ones to change, and a separate part deletes and creates instructions. That way I don't have to change a list while iterating over it. My brain is very small, so I prefer two simple problems over one conglomerate.

    If you want to make life easy for yourself, add a call to verifyFunction() afterwards, though. That tends to uncover any stupid mistakes quickly and easily.