uart: fix COBS encoder bug where 254th byte is zero

This commit is contained in:
Fischer Moseley 2026-02-25 09:20:22 -07:00
parent 29f8603728
commit e7306e51c8
2 changed files with 57 additions and 38 deletions

View File

@ -18,6 +18,7 @@ class COBSEncode(wiring.Component):
fsm_data = Signal(8) fsm_data = Signal(8)
was_last = Signal(1) was_last = Signal(1)
was_zero = Signal(1)
fifo_written_to_last_cycle = Signal(1) fifo_written_to_last_cycle = Signal(1)
m.d.comb += fifo_written_to_last_cycle.eq( m.d.comb += fifo_written_to_last_cycle.eq(
(self.sink.ready) & (self.sink.valid) & (self.sink.data != 0) (self.sink.ready) & (self.sink.valid) & (self.sink.data != 0)
@ -26,12 +27,18 @@ class COBSEncode(wiring.Component):
with m.FSM() as fsm: with m.FSM() as fsm:
with m.State("COUNT_BYTES"): with m.State("COUNT_BYTES"):
with m.If(self.sink.valid & self.sink.ready): with m.If(self.sink.valid & self.sink.ready):
# End of packet or zero found, clock out length # Send the length byte, as either a zero has been found, the end of the packet
# has been reached, or 254 bytes have been read in.
with m.If( with m.If(
(self.sink.last) | (self.sink.data == 0) | (fifo.r_level == 253) (self.sink.last) | (self.sink.data == 0) | (fifo.r_level == 253)
): ):
# Handle edge case where 254th byte is a zero
with m.If(fifo.r_level == 253): with m.If(fifo.r_level == 253):
m.d.sync += fsm_data.eq(255) with m.If(self.sink.data == 0):
m.d.sync += fsm_data.eq(254)
with m.Else():
m.d.sync += fsm_data.eq(255)
with m.Else(): with m.Else():
m.d.sync += fsm_data.eq( m.d.sync += fsm_data.eq(
@ -39,22 +46,34 @@ class COBSEncode(wiring.Component):
) )
m.d.sync += was_last.eq(self.sink.last) m.d.sync += was_last.eq(self.sink.last)
m.next = "WAIT_FOR_LENGTH" m.d.sync += was_zero.eq(self.sink.data == 0)
m.next = "SEND_LENGTH"
with m.State("WAIT_FOR_LENGTH"): with m.State("SEND_LENGTH"):
with m.If(self.source.valid & self.source.ready): with m.If(self.source.valid & self.source.ready):
m.next = "SEND_BYTES" m.next = "SEND_BYTES"
with m.State("SEND_BYTES"): with m.State("SEND_BYTES"):
# Wait until the FIFO will be empty on next cycle # Wait until the FIFO will be empty on next cycle
with m.If( with m.If(
(fifo.r_level == 1) & (self.source.ready) & (self.source.valid) ((fifo.r_level == 1) & (self.source.ready) & (self.source.valid))
| (fifo.r_level == 0)
): ):
m.next = "COUNT_BYTES" m.next = "COUNT_BYTES"
with m.If(was_last): with m.If(was_last):
m.d.sync += fsm_data.eq(0) with m.If(was_zero):
m.next = "SEND_DELIMITER" m.d.sync += fsm_data.eq(1)
m.next = "SEND_ONE"
with m.Else():
m.d.sync += fsm_data.eq(0)
m.next = "SEND_DELIMITER"
with m.State("SEND_ONE"):
with m.If(self.source.valid & self.source.ready):
m.d.sync += fsm_data.eq(0)
m.next = "SEND_DELIMITER"
with m.State("SEND_DELIMITER"): with m.State("SEND_DELIMITER"):
with m.If(self.source.valid & self.source.ready): with m.If(self.source.valid & self.source.ready):
@ -68,7 +87,11 @@ class COBSEncode(wiring.Component):
m.d.comb += self.sink.ready.eq(fifo.w_rdy & fsm.ongoing("COUNT_BYTES")) m.d.comb += self.sink.ready.eq(fifo.w_rdy & fsm.ongoing("COUNT_BYTES"))
# Wire FIFO output to source, allow FSM to preempt FIFO # Wire FIFO output to source, allow FSM to preempt FIFO
with m.If(fsm.ongoing("WAIT_FOR_LENGTH") | fsm.ongoing("SEND_DELIMITER")): with m.If(
fsm.ongoing("SEND_LENGTH")
| fsm.ongoing("SEND_ONE")
| fsm.ongoing("SEND_DELIMITER")
):
m.d.comb += self.source.data.eq(fsm_data) m.d.comb += self.source.data.eq(fsm_data)
m.d.comb += self.source.valid.eq(1) m.d.comb += self.source.valid.eq(1)
m.d.comb += fifo.r_en.eq(0) m.d.comb += fifo.r_en.eq(0)

View File

@ -10,24 +10,32 @@ ce = COBSEncode()
@simulate(ce) @simulate(ce)
async def test_cobs_encode(ctx): async def test_cobs_encode(ctx):
await ctx.tick() await ctx.tick().repeat(5)
await ctx.tick()
await ctx.tick()
await ctx.tick()
print("")
ctx.set(ce.source.ready, 1) ctx.set(ce.source.ready, 1)
await encode_and_compare(ctx, [0x11, 0x22, 0x33, 0x44]) # Test cases taken from Wikipedia:
await ctx.tick() # https://en.wikipedia.org/wiki/Consistent_Overhead_Byte_Stuffing
await ctx.tick() await encode_and_compare(ctx, [0x00])
await ctx.tick() await encode_and_compare(ctx, [0x00, 0x00])
await encode_and_compare(ctx, [0x00, 0x11, 0x00])
await encode_and_compare(ctx, [0x11, 0x22, 0x00, 0x33]) await encode_and_compare(ctx, [0x11, 0x22, 0x00, 0x33])
await ctx.tick() await encode_and_compare(ctx, [0x11, 0x22, 0x33, 0x44])
await ctx.tick() await encode_and_compare(ctx, [0x11, 0x00, 0x00, 0x00])
await ctx.tick() await encode_and_compare(ctx, [i for i in range(1, 255)])
await encode_and_compare(ctx, [0x00] + [i for i in range(1, 255)])
await encode_and_compare(ctx, [i for i in range(256)])
await encode_and_compare(ctx, [i for i in range(2, 256)] + [0x00])
await encode_and_compare(ctx, [i for i in range(3, 256)] + [0x00, 0x01])
# Selected edge and corner cases
await encode_and_compare(ctx, [0x00] * 253)
await encode_and_compare(ctx, [0x00] * 254)
await encode_and_compare(ctx, [0x00] * 255)
await encode_and_compare(ctx, ([0x11] * 253) + [0])
await encode_and_compare(ctx, ([0x11] * 253) + [0] + ([0x11] * 5))
await encode_and_compare(ctx, ([0x11] * 254) + [0])
await encode_and_compare(ctx, ([0x11] * 255) + [0])
@simulate(ce) @simulate(ce)
@ -38,16 +46,7 @@ async def test_cobs_encode_random(ctx):
for _ in range(num_tests): for _ in range(num_tests):
length = random.randint(1, 2000) length = random.randint(1, 2000)
input_data = [random.randint(0, 255) for _ in range(length)] input_data = [random.randint(0, 255) for _ in range(length)]
result = await encode_and_compare(ctx, input_data, quiet=True) await encode_and_compare(ctx, input_data)
await ctx.tick()
await ctx.tick()
await ctx.tick()
if not result:
await encode_and_compare(ctx, input_data, quiet=False)
await ctx.tick()
await ctx.tick()
await ctx.tick()
async def encode(ctx, data): async def encode(ctx, data):
@ -88,19 +87,16 @@ async def encode(ctx, data):
await ctx.tick() await ctx.tick()
await ctx.tick() await ctx.tick().repeat(5)
await ctx.tick()
await ctx.tick()
return rx_buf return rx_buf
async def encode_and_compare(ctx, data, quiet=False): async def encode_and_compare(ctx, data, only_print_on_fail=True):
expected = cobs.encode(bytes(data)) + b"\0" expected = cobs.encode(bytes(data)) + b"\0"
actual = await encode(ctx, data) actual = await encode(ctx, data)
matched = bytes(actual) == expected matched = bytes(actual) == expected
if not quiet: if (not only_print_on_fail) or (only_print_on_fail and not matched):
print(f" input: {[hex(d) for d in data]}") print(f" input: {[hex(d) for d in data]}")
print(f"expected: {[hex(d) for d in expected]}") print(f"expected: {[hex(d) for d in expected]}")
print(f" actual: {[hex(d) for d in actual]}") print(f" actual: {[hex(d) for d in actual]}")