Search code examples
scalascodec

Records are discriminated hierarchically


I have to implement some proprietary binary format and wanted to do this with scodec. However, I cannot find a concise solution. The format is as follows: A file consists of multiple Records, where each record is prefixed with a little endian 16-bit number "t"(uint16L). Records can be classified in 4 categories, depending on the values of the first and second byte of t:

  • Normal: t.first != 0 && t.second == 0
  • Bar: t.first == 0x08 && t.second == 0xFF
  • Foo: t.first == 0x04 && t.second == 0x05
  • Invalid: t is none of the above

If t is Invalid, the program should exit, as the file is corrupted. If t is either Normal or Bar, the length of the Record follows as 32-bit little endian int. If t is Foo, another 16-bit big endian int must be parsed, before the length can be parsed as 32-bit BE int.

- Normal: ("t" | uint16L) :: ("length" | uint32L) :: [Record data discriminated by t]
- Bar: ("t" | constant(0x08FF)) ::  ("length" | uint32L) :: [Record data of Bar]
- Foo: ("t" | constant(0x0405)) :: uint16 :: ("length" | uint32) :: [Record data of foo]
- Invalid: ("t" | uint16L ) :: fail(Err(s"invalid type: $t"))

Furthermore, some values for t in "Normal" are unused and should produce an UnknownRecord (similar to the mpeg implementation here: https://github.com/scodec/scodec-protocols/blob/series/1.0.x/src/main/scala/scodec/protocols/mpeg/Descriptor.scala)

This is my current approach, but it does not feel clear and I get the feeling that I'm working more around scodec, than with it. Any ideas? Feel free to scrap my code below..

sealed trait ContainerType
object ContainerType{
    implicit class SplitInt(val self: Int) extends AnyVal{
        def first = self & 0xFF
        def second = (self >> 8) & 0xFF
    }

    case object Normal extends ContainerType
    case object Bar extends ContainerType
    case object Foo extends ContainerType
    case object Invalid extends ContainerType

    val codec: Codec[ContainerType] = {
        def to(value: Int): ContainerType = value match{
            case v if value.first != 0 && value.second == 0 => Normal
            case v if value.first == 0x08 && value.second == 0xFF => Bar
            case v if value.first == 0x04 && value.second == 0x05 => Foo
            case _ => Invalid
        }

        uint16L.xmap(to, ??) // don't have value here
        // if I use case classes and save the value I can't discriminate by it in RecordPrefix
    }        
}

sealed trait RecordPrefix{
    def t : Int,
    def length: Int
}
object RecordPrefix {

    case class Normal( override val t: Int, override val length: Int) extends RecordPrefix
    object Normal{
        val codec: Codec[Normal] = ??
    }
    case class Bar(override val t: Int, override val length: Int) extends RecordPrefix
    object Bar{
        val codec: Codec[Bar] = ??
    }
    case class Foo(override val t: Int, foobar: Int, length: Int) extends RecordPrefix
    object Foo{
        val codec: Codec[Foo] = ??
    }


    val codec: Codec[RecordPrefix] = {
        discriminated[RecordPrefix].by(ContainerType.codec)
        .typecase(Normal, Normal.codec)
        .typecase(Bar, Bar.codec)
        .typecase(Foo, Foo.codec)
        // how to handle invalid case ?
    }

}

case class Record(prefix: RecordPrefix, body: RecordBody)

sealed trait RecordBody
//.... How can I implement the codecs?

PS: This is my first question here, I hope it was clear enough. =)

Edit1: I found an implementation that does the job at least. I made tradeoff to check the Conditions again if the Record is unknown in order to get a cleaner hierarchy.

trait KnownRecord
sealed trait NormalRecord extends KnownRecord

case class BarRecord(length: Int, ..,) extends KnownRecord
object BarRecord {
    val codec: Codec[BarRecord] = {
        ("Length" | int32L) ::
        //...
    }.as[BarRecord]

}

case class FooRecord(...) extends KnownRecord
object FooRecord {
    val codec: Codec[FooRecord] = // analogue
}

case class A() extends NormalRecord
case class B() extends NormalRecord
// ...

case class UnknownRecord(rtype: Int, length: Int, data: ByteVector)
object UnknownRecord{

    val codec: Codec[UnknownRecord] = {
        ("Type" | Record.validTypeCodec) ::
        (("Length" | int32L) >>:~ { length =>
            ("Data" | bytes(length - 6)).hlist
        })
    }.as[UnknownRecord]
}

object Record{
    type Record = Either[UnknownRecord, KnownRecord]

    val validTypeCodec: Codec[Int] = {
        uint16L.consume[Int] { rtype =>
            val first = rtype & 0xFF
            val second = (rtype >> 8) & 0xFF
            rtype match {
                case i if first != 0 && second == 0 => provide(i)
                case i if first == 0x04 && second == 0x05 => provide(i)
                case i if first == 0xFF && second == 0x08 => provide(i)
                case _ => fail(Err(s"Invalid Type: $rtype!"))
            }
        } (identity)
    }

    def normalCodec(rtype: Int): Codec[NormalRecord] = {
        discriminated[NormalRecord].by(provide(rtype))
        .typecase(1, A.codec)
        .typecase(2, B.codec)
        .typecase(3, C.codec)
        .typecase(4, D.codec)
        .framing(new CodecTransformation {
            def apply[X](c: Codec[X]) = variableSizeBytes(int32L, c.complete,
                                                          sizePadding=6)
        })
    }.as[NormalRecord]


    val knownCodec: Codec[KnownRecord] = {
        val b = discriminated[KnownRecord].by(("Type" | uint16L))
            .typecase(0x0504, FooRecord.codec)
            .typecase(0x08FF, BarRecord.codec)
        (1 to 0xFF).foldLeft(b) {
            (acc, x) => acc.typecase(x, normalCodec(x))
        }
    }

    implicit val codec: Codec[Record] = {
        discriminatorFallback(UnknownRecord.codec, knownCodec)
    }

Edit2: I posted an alternate Solution as Answer below


Solution

  • I'm posting this as answer because I'm satisfied with this solution, although it is probably a question of personal Preference between my first solution (edit1 in the question) and this one. Shasticks answer provides an useful approach as well, if one wants to keep track of the discriminator value (which I would rather not).

    I hope this is helpful for others as well.

    Here is Solution 2: Instead of using a predefined codec, I decode and encode separately.

    decode chooses the right codec without decoding the type multiple times, while encode deduces the correct type value from the Recordtype (Bar/Foo -Records have constant type and NormalRecords are encoded by the DiscriminatorCodec in Record.normalCodec)

    trait KnownRecord
    sealed trait NormalRecord extends KnownRecord
    
    case class BarRecord(..,) extends KnownRecord
    object BarRecord {
        val codec: Codec[BarRecord] = {
            //...
        }.as[BarRecord]
    
    }
    
    case class FooRecord(...) extends KnownRecord
    object FooRecord {
        val codec: Codec[FooRecord] = // ...
    }
    
    case class A() extends NormalRecord
    case class B() extends NormalRecord
    // ...
    
    case class UnknownRecord(rtype: Int, length: Int, data: ByteVector)
    object UnknownRecord{
    
        val codec: Codec[UnknownRecord] = {
            ("Type" | uint16L) ::
            (("Length" | int32L) >>:~ { length =>
                ("Data" | bytes(length - 6)).hlist
            })
        }.as[UnknownRecord]
    }
    
    sealed trait ContainerType
    object ContainerType{
    
        case object FooType extends ContainerType
        case object BarType extends ContainerType
        case class NormalType(rtype: Int) extends ContainerType
        case class Invalid(rtype: Int) extends ContainerType
    
        implicit val codec: Codec[ContainerType] = {
            def from(value: Int): ContainerType = {
                val first = value & 0xFF
                val second = (value >> 8) & 0xFF
                value match {
                    case 0x0504 =>  FooType
                    case 0x08FF => BarType
                    case i if (second == 0 && first != 0) => NormalType(i)
                    case other => Invalid(other)
                }
            }
    
            def to(ct: ContainerType): Int = ct match {
                case FooType => 0x0302
                case BarType => 0x0FFF
                case NormalType(i) => i
                case Invalid(i) => i
            }
    
            uint16L.xmap(from, to)
        }
    
    }
    
    object Record{
        type Record = Either[UnknownRecord, KnownRecord]
    
        val ensureSize = new CodecTransformation {
            def apply[X](c: Codec[X]) = variableSizeBytes(int32L, c.complete,
                                                          sizePadding=6)
        }
    
        val normalCodec: Codec[NormalRecord] =
            normalCodec(uint16L).framing(ensureSize).as[NormalRecord]
    
        def normalCodec(discr: Codec[Int]) =
            discriminated[NormalRecord].by(discr)
            .typecase(1, A.codec)
            .typecase(2, B.codec)
            .typecase(3, C.codec)
            .typecase(4, D.codec)
    
        val knownCodec: Codec[KnownRecord] = {
            import ContainerType._
    
            def decodeRecord(bits: BitVector): Attempt[DecodeResult[KnownRecord]] =
            for {
                ct <- ContainerType.codec.decode(bits)
                rec <- ct.value match {
                    case FooType => FooRecord.codec.decode(ct.remainder)
                    case BarType =>
                        ensureSize(BarRecord.codec).decode(ct.remainder)
                    case NormalType(i) =>
                        ensureSize(normalCodec(provide(i))).decode(ct.remainder)
                    case Invalid(rtype) =>
                        Attempt.failure(Err(s"Invalid Type: $rtype!"))
                }
            } yield rec
    
    
            def encodeRecord(rec: KnownRecord): Attempt[BitVector] =
            rec match {
                case c: NormalRecord => normalCodec.encode(c)
    
                case fr: FooRecord => for {
                    rtype <- ContainerType.codec.encode(FooType)
                    record <- FooRecord.codec.encode(fr)
                } yield rtype ++ record
    
                case br: BarRecord => for {
                    rtype <- ContainerType.codec.encode(BarType)
                    record <- BarRecord.codec.encode(br)
                    length <- int32L.encode((record.size / 8).toInt + 6)
                } yield rtype ++ length ++ record
            }
    
            Codec(Encoder(encodeRecord _), Decoder(decodeRecord _))
        }
    
    
        implicit val codec: Codec[Record] = {
            discriminatorFallback(UnknownRecord.codec, knownCodec)
        }