Search code examples
scalachisel

Chisel: How to change module parameters from command line?


I have many modules with multiple parameters. Take as a toy example a modified version of the GCD in the template:

class GCD (len: Int = 16, validHigh: Boolean = true) extends Module {
  val io = IO(new Bundle {
    val value1        = Input(UInt(len.W))
    val value2        = Input(UInt(len.W))
    val loadingValues = Input(Bool())
    val outputGCD     = Output(UInt(len.W))
    val outputValid   = Output(Bool())
  })

  val x  = Reg(UInt())
  val y  = Reg(UInt())

  when(x > y) { x := x - y }
    .otherwise { y := y - x }

  when(io.loadingValues) {
    x := io.value1
    y := io.value2
  }

  io.outputGCD := x
  if (validHigh) {
    io.outputValid := (y === 0.U)
  } else {
    io.outputValid := (y =/= 0.U)
  }
}

To test or synthesize many different designs, I want to change the values from the command line when I call the tester or the generator apps. Preferably, like this:

[generation or test command] --len 12 --validHigh false

but this or something similar would also be okay

[generation or test command] --param "len=12" --param "validHigh=false"

After some trial and error, I came up with a solution that looks like this:

gcd.scala

package gcd

import firrtl._
import chisel3._

case class GCDConfig(
  len: Int = 16,
  validHigh: Boolean = true
)

class GCD (val conf: GCDConfig = GCDConfig()) extends Module {
  val io = IO(new Bundle {
    val value1        = Input(UInt(conf.len.W))
    val value2        = Input(UInt(conf.len.W))
    val loadingValues = Input(Bool())
    val outputGCD     = Output(UInt(conf.len.W))
    val outputValid   = Output(Bool())
  })

  val x  = Reg(UInt())
  val y  = Reg(UInt())

  when(x > y) { x := x - y }
    .otherwise { y := y - x }

  when(io.loadingValues) {
    x := io.value1
    y := io.value2
  }

  io.outputGCD := x
  if (conf.validHigh) {
    io.outputValid := y === 0.U
  } else {
    io.outputValid := y =/= 0.U
  }
}

trait HasParams {
  self: ExecutionOptionsManager =>

  var params: Map[String, String] = Map()

  parser.note("Design Parameters")

  parser.opt[Map[String, String]]('p', "params")
    .valueName("k1=v1,k2=v2")
    .foreach { v => params = v }
    .text("Parameters of Design")
}

object GCD {
  def apply(params: Map[String, String]): GCD = {
    new GCD(params2conf(params))
  }

  def params2conf(params: Map[String, String]): GCDConfig = {
    var conf = new GCDConfig
    for ((k, v) <- params) {
      (k, v) match {
        case ("len", _) => conf = conf.copy(len = v.toInt)
        case ("validHigh", _) => conf = conf.copy(validHigh = v.toBoolean)
        case _ =>
      }
    }
    conf
  }
}

object GCDGen extends App {
  val optionsManager = new ExecutionOptionsManager("gcdgen")
  with HasChiselExecutionOptions with HasFirrtlOptions with HasParams
  optionsManager.parse(args) match {
    case true => 
      chisel3.Driver.execute(optionsManager, () => GCD(optionsManager.params))
    case _ =>
      ChiselExecutionFailure("could not parse results")
  }
}

and for tests

GCDSpec.scala

package gcd

import chisel3._
import firrtl._
import chisel3.tester._
import org.scalatest.FreeSpec
import chisel3.experimental.BundleLiterals._
import chiseltest.internal._
import chiseltest.experimental.TestOptionBuilder._

object GCDTest extends App {
  val optionsManager = new ExecutionOptionsManager("gcdtest") with HasParams
  optionsManager.parse(args) match {
    case true => 
      //println(optionsManager.commonOptions.programArgs)
      (new GCDSpec(optionsManager.params)).execute()
    case _ =>
      ChiselExecutionFailure("could not parse results")
  }
}

class GCDSpec(params: Map[String, String] = Map()) extends FreeSpec with ChiselScalatestTester {

  "Gcd should calculate proper greatest common denominator" in {
    test(GCD(params)) { dut =>
      dut.io.value1.poke(95.U)
      dut.io.value2.poke(10.U)
      dut.io.loadingValues.poke(true.B)
      dut.clock.step(1)
      dut.io.loadingValues.poke(false.B)
      while (dut.io.outputValid.peek().litToBoolean != dut.conf.validHigh) {
        dut.clock.step(1)
      }
      dut.io.outputGCD.expect(5.U)
    }
  }
}

This way, I can generate different designs and test them with

sbt 'runMain gcd.GCDGen --params "len=12,validHigh=false"'
sbt 'test:runMain gcd.GCDTest --params "len=12,validHigh=false"'

But there are a couple of problems or annoyances with this solution:

  1. It uses deprecated features (ExecutionOptionsManager and HasFirrtlOptions). I'm not sure if this solution is portable to the new FirrtlStage Infrastructure.
  2. There is a lot of boilerplate involved. It becomes tedious to write new case classes and params2conf functions for every module and rewrite both when a parameter is added or removed.
  3. Using conf.x instead of x all the time. But I guess, this is unavoidable because there is nothing like python's kwargs in Scala.

Is there a better way or one that is at least not deprecated?


Solution

  • Based on http://blog.echo.sh/2013/11/04/exploring-scala-macros-map-to-case-class-conversion.html, I was able to find another way of removing the params2conf boilerplate using scala macros. I also extended Chick's answer with verilog generation since that was also part of the original question. A full repository of my solution can be found on github.

    Basically there are three four files:

    The macro that converts a map to a case class:

    package mappable
    
    import scala.language.experimental.macros
    import scala.reflect.macros.whitebox.Context
    
    trait Mappable[T] {
      def toMap(t: T): Map[String, String]
      def fromMap(map: Map[String, String]): T
    }
    
    object Mappable {
      implicit def materializeMappable[T]: Mappable[T] = macro materializeMappableImpl[T]
    
      def materializeMappableImpl[T: c.WeakTypeTag](c: Context): c.Expr[Mappable[T]] = {
        import c.universe._
        val tpe = weakTypeOf[T]
        val companion = tpe.typeSymbol.companion
    
        val fields = tpe.decls.collectFirst {
          case m: MethodSymbol if m.isPrimaryConstructor => m
        }.get.paramLists.head
    
        val (toMapParams, fromMapParams) = fields.map { field =>
          val name = field.name.toTermName
          val decoded = name.decodedName.toString
          val returnType = tpe.decl(name).typeSignature
    
          val fromMapLine = returnType match {
            case NullaryMethodType(res) if res =:= typeOf[Int] => q"map($decoded).toInt"
            case NullaryMethodType(res) if res =:= typeOf[String] => q"map($decoded)"
            case NullaryMethodType(res) if res =:= typeOf[Boolean] => q"map($decoded).toBoolean"
            case _ => q""
          }
    
          (q"$decoded -> t.$name.toString", fromMapLine)
        }.unzip
    
        c.Expr[Mappable[T]] { q"""
          new Mappable[$tpe] {
            def toMap(t: $tpe): Map[String, String] = Map(..$toMapParams)
            def fromMap(map: Map[String, String]): $tpe = $companion(..$fromMapParams)
          }
        """ }
      }
    }
    

    Library like tools:

    package cliparams
    
    import chisel3.stage.{ChiselStage, ChiselGeneratorAnnotation, ChiselCli}
    import firrtl.AnnotationSeq
    import firrtl.annotations.{Annotation, NoTargetAnnotation}
    import firrtl.options.{HasShellOptions, Shell, ShellOption, Stage, Unserializable, StageMain}
    import firrtl.stage.FirrtlCli
    
    import mappable._
    
    trait SomeAnnotaion {
      this: Annotation =>
    }
    
    case class ParameterAnnotation(map: Map[String, String])
        extends SomeAnnotaion
        with NoTargetAnnotation
        with Unserializable
    
    object ParameterAnnotation extends HasShellOptions {
      val options = Seq(
        new ShellOption[Map[String, String]](
          longOption = "params",
          toAnnotationSeq = (a: Map[String, String]) => Seq(ParameterAnnotation(a)),
          helpText = """a comma separated, space free list of additional paramters, e.g. --param-string "k1=7,k2=dog" """
        )
      )
    }
    
    trait ParameterCli {
      this: Shell =>
    
      Seq(ParameterAnnotation).foreach(_.addOptions(parser))
    }
    
    class GenericParameterCliStage[P: Mappable](thunk: (P, AnnotationSeq) => Unit, default: P) extends Stage {
    
      def mapify(p: P) = implicitly[Mappable[P]].toMap(p)
      def materialize(map: Map[String, String]) = implicitly[Mappable[P]].fromMap(map)
    
      val shell: Shell = new Shell("chiseltest") with ParameterCli with ChiselCli with FirrtlCli
    
      def run(annotations: AnnotationSeq): AnnotationSeq = {
        val params = annotations
          .collectFirst {case ParameterAnnotation(map) => materialize(mapify(default) ++ map.toSeq)}
          .getOrElse(default)
    
        thunk(params, annotations)
        annotations
      }
    }
    

    The GCD source file

    // See README.md for license details.
    
    package gcd
    
    import firrtl._
    import chisel3._
    import chisel3.stage.{ChiselStage, ChiselGeneratorAnnotation}
    import firrtl.options.{StageMain}
    
    // Both have to be imported
    import mappable._
    import cliparams._
    
    case class GCDConfig(
      len: Int = 16,
      validHigh: Boolean = true
    )
    
    /**
      * Compute GCD using subtraction method.
      * Subtracts the smaller from the larger until register y is zero.
      * value in register x is then the GCD
      */
    class GCD (val conf: GCDConfig = GCDConfig()) extends Module {
      val io = IO(new Bundle {
        val value1        = Input(UInt(conf.len.W))
        val value2        = Input(UInt(conf.len.W))
        val loadingValues = Input(Bool())
        val outputGCD     = Output(UInt(conf.len.W))
        val outputValid   = Output(Bool())
      })
    
      val x  = Reg(UInt())
      val y  = Reg(UInt())
    
      when(x > y) { x := x - y }
        .otherwise { y := y - x }
    
      when(io.loadingValues) {
        x := io.value1
        y := io.value2
      }
    
      io.outputGCD := x
      if (conf.validHigh) {
        io.outputValid := y === 0.U
      } else {
        io.outputValid := y =/= 0.U
      }
    }
    
    class GCDGenStage extends GenericParameterCliStage[GCDConfig]((params, annotations) => {
      (new chisel3.stage.ChiselStage).execute(
        Array("-X", "verilog"),
        Seq(ChiselGeneratorAnnotation(() => new GCD(params))))}, GCDConfig())
    
    object GCDGen extends StageMain(new GCDGenStage)
    

    and the tests

    // See README.md for license details.
    
    package gcd
    
    import chisel3._
    import firrtl._
    import chisel3.tester._
    import org.scalatest.FreeSpec
    import chisel3.experimental.BundleLiterals._
    import chiseltest.internal._
    import chiseltest.experimental.TestOptionBuilder._
    import firrtl.options.{StageMain}
    
    import mappable._
    import cliparams._
    
    class GCDSpec(params: GCDConfig, annotations: AnnotationSeq = Seq()) extends FreeSpec with ChiselScalatestTester {
    
      "Gcd should calculate proper greatest common denominator" in {
        test(new GCD(params)) { dut =>
          dut.io.value1.poke(95.U)
          dut.io.value2.poke(10.U)
          dut.io.loadingValues.poke(true.B)
          dut.clock.step(1)
          dut.io.loadingValues.poke(false.B)
          while (dut.io.outputValid.peek().litToBoolean != dut.conf.validHigh) {
            dut.clock.step(1)
          }
          dut.io.outputGCD.expect(5.U)
        }
      }
    }
    
    class GCDTestStage extends GenericParameterCliStage[GCDConfig]((params, annotations) => {
      (new GCDSpec(params, annotations)).execute()}, GCDConfig())
    
    object GCDTest extends StageMain(new GCDTestStage)
    

    Both, generation and tests can be parameterized via CLI as in the OQ:

    sbt 'runMain gcd.GCDGen --params "len=12,validHigh=false"'
    sbt 'test:runMain gcd.GCDTest --params "len=12,validHigh=false"'