One way to accomplish this is to use the predefined combinators to define the structure of the binary format -- not the combination logic. Then put the combination logic inside a function passed to xmap
.
import scodec._
import bits._
import codecs._
import shapeless._
case class Header(totalLength: Int, otherJunk: Int)
object Header {
implicit val codec: Codec[Header] = {
type Struct = Int :: Boolean :: BitVector :: Boolean :: BitVector :: Int :: HNil
val struct: Codec[Struct] = ("identifier" | uint2) :: ("moreFlag" | bool) :: ("upper length bits" | codecs.bits(5)) :: ("moreFlag2" | bool) :: ("lower length bits" | codecs.bits(7)) :: ("other junk" | uint8)
def to(header: Header): Struct = {
val lengthBits = uint(12).encodeValid(header.totalLength)
val more = !(lengthBits startsWith bin"00000")
0 :: more :: lengthBits.take(5) :: false :: lengthBits.drop(5) :: header.otherJunk :: HNil
}
def from(struct: Struct): Header = struct match {
case id :: moreFlag :: upperLengthBits :: moreFlag2 :: lowerLengthBits :: otherJunk :: HNil =>
val length =
if (moreFlag) uint(12).decodeValidValue(upperLengthBits ++ lowerLengthBits)
else uint(7).decodeValidValue(lowerLengthBits)
Header(length, otherJunk)
}
struct.xmap[Header](from, to)
}
}
Note that the upper and lower length bits are encoded as bits(5)
and bits(7)
and then manually combined in the from
function and decoded. The decoding is safe, despite using the unsafe decodeValidValue
, because the uint(12)
codec is total on 12-bit inputs -- it is impossible for it to return a left.
However, the to
function uses encodeValid
, which is clearly unsafe -- e.g., Int.MaxValue
will cause an exception. We can fix this with a slightly more complicated version, where the to
function returns an Err \/ Struct
, and we call widen
instead of xmap
:
object Header {
implicit val codec: Codec[Header] = {
type Struct = Int :: Boolean :: BitVector :: Boolean :: BitVector :: Int :: HNil
val struct: Codec[Struct] = ("identifier" | uint2) :: ("moreFlag" | bool) :: ("upper length bits" | codecs.bits(5)) :: ("moreFlag2" | bool) :: ("lower length bits" | codecs.bits(7)) :: ("other junk" | uint8)
def to(header: Header): Err \/ Struct = {
uint(12).encode(header.totalLength) map { lengthBits =>
val more = !(lengthBits startsWith bin"00000")
0 :: more :: lengthBits.take(5) :: false :: lengthBits.drop(5) :: header.otherJunk :: HNil
}
}
def from(struct: Struct): Header = struct match {
case id :: moreFlag :: upperLengthBits :: moreFlag2 :: lowerLengthBits :: otherJunk :: HNil =>
val length =
if (moreFlag) uint(12).decodeValidValue(upperLengthBits ++ lowerLengthBits)
else uint(7).decodeValidValue(lowerLengthBits)
Header(length, otherJunk)
}
struct.widen[Header](from, to)
}
}