Search code examples
chisel

How to pass a operator as a parameter


I'm trying to pass an operator to a module so the module can be built generically. I pass a two-input operator parameter and then use it in a reduction operation. If I replace the passed parameter with a concrete operator this works OK.

What's the correct way to pass a Chisel/UInt/Data operator as a module parameter?

  val io = IO(new Bundle {
    val a = Vec(n, Flipped(Decoupled(UInt(width.W))))
    val z = Decoupled(UInt(width.W))
  })
  val a_int = for (n <- 0 until n) yield DCInput(io.a(n))
  val z_int = Wire(Decoupled(UInt(width.W)))

  val all_valid = a_int.map(_.valid).reduce(_ & _)
  z_int.bits := a_int.map(_.bits).reduce(_ op _)
...

Solution

  • Here's a fancy Scala way of doing it

    import chisel3._
    import chisel3.tester._
    import chiseltest.ChiselScalatestTester
    import org.scalatest.{FreeSpec, Matchers}
    
    class ChiselFuncParam(mathFunc: UInt => UInt => UInt) extends Module {
      val io = IO(new Bundle {
        val a = Input(UInt(8.W))
        val b = Input(UInt(8.W))
        val out = Output(UInt(8.W))
      })
    
      io.out := mathFunc(io.a)(io.b)
    }
    
    class CFPTest extends FreeSpec with ChiselScalatestTester with Matchers {
      def add(a: UInt)(b: UInt): UInt = a + b
      def sub(a: UInt)(b: UInt): UInt = a - b
    
      "add works" in {
        test(new ChiselFuncParam(add)) { c =>
          c.io.a.poke(9.U)
          c.io.b.poke(5.U)
          c.io.out.expect(14.U)
        }
      }
      "sub works" in {
        test(new ChiselFuncParam(sub)) { c =>
          c.io.a.poke(9.U)
          c.io.b.poke(2.U)
          c.io.out.expect(7.U)
        }
      }
    }
    

    Although it might be clearer to just pass in a string form of the operator and then use simple Scala ifs to control the appropriate code generation. Something like

    class MathOp(code: String) extends Module {
      val io = IO(new Bundle {
        val a = Input(UInt(8.W))
        val b = Input(UInt(8.W))
        val out = Output(UInt(8.W))
      })
    
      io.out := (code match {
        case "+" => io.a + io.b
        case "-" => io.a - io.b
        // ...
      })
    }