uart: rewrite COBS encoder to allow backpressure

This commit is contained in:
Fischer Moseley 2026-02-20 09:25:17 -07:00
parent 858c9554dc
commit df4e1b2e0c
2 changed files with 159 additions and 78 deletions

View File

@ -1,5 +1,6 @@
from amaranth import *
from amaranth.lib import wiring
from amaranth.lib.fifo import SyncFIFO
from amaranth.lib.memory import Memory
from amaranth.lib.wiring import In, Out
@ -8,99 +9,69 @@ from manta.utils import *
class COBSEncode(wiring.Component):
sink: In(StreamSignature(8))
source: Out(StreamSignature(8, has_last=False))
source: Out(StreamSignature(8))
def elaborate(self, platform):
m = Module()
# Internal Signals
head_pointer = Signal(range(256))
tail_pointer = Signal(range(256))
m.submodules.fifo = fifo = SyncFIFO(width=8, depth=256)
# Add memory and read/write ports
m.submodules.memory = memory = Memory(shape=8, depth=256, init=[0] * 256)
rd_port = memory.read_port()
wr_port = memory.write_port()
fsm_data = Signal(8)
was_last = Signal(1)
fifo_written_to_last_cycle = Signal(1)
m.d.comb += fifo_written_to_last_cycle.eq(
(self.sink.ready) & (self.sink.valid) & (self.sink.data != 0)
)
# Reset top-level IO
m.d.sync += self.source.data.eq(0)
m.d.sync += self.source.valid.eq(0)
# Generate rd_port_addr_prev
rd_port_addr_prev = Signal().like(rd_port.addr)
m.d.sync += rd_port_addr_prev.eq(rd_port.addr)
# State Machine:
with m.FSM() as fsm:
with m.State("IDLE"):
with m.If(self.sink.last):
m.d.sync += head_pointer.eq(0)
m.d.sync += tail_pointer.eq(0)
m.d.sync += rd_port.addr.eq(0)
m.next = "SEARCH_FOR_ZERO"
with m.State("COUNT_BYTES"):
with m.If(self.sink.valid & self.sink.ready):
# End of packet or zero found, clock out length
with m.If((self.sink.last) | (self.sink.data == 0)):
m.d.sync += fsm_data.eq(
fifo.r_level + fifo_written_to_last_cycle + 1
)
with m.State("SEARCH_FOR_ZERO"):
# Drive read addr until length is reached
with m.If(rd_port.addr < wr_port.addr):
m.d.sync += rd_port.addr.eq(rd_port.addr + 1)
m.d.sync += was_last.eq(self.sink.last)
m.next = "WAIT_FOR_LENGTH"
# Watch prev_addr and data
with m.If((rd_port_addr_prev == wr_port.addr) | (rd_port.data == 0)):
# Either reached the end of the input buffer or found a zero
with m.State("WAIT_FOR_LENGTH"):
with m.If(self.source.valid & self.source.ready):
m.next = "SEND_BYTES"
m.d.sync += head_pointer.eq(rd_port_addr_prev)
m.d.sync += rd_port.addr.eq(tail_pointer)
m.d.sync += self.source.data.eq(
rd_port_addr_prev - tail_pointer + 1
)
m.d.sync += self.source.valid.eq(1)
with m.State("SEND_BYTES"):
# Wait until the FIFO will be empty on next cycle
with m.If(
(fifo.r_level == 1) & (self.source.ready) & (self.source.valid)
):
m.next = "COUNT_BYTES"
m.next = "CLOCK_OUT_BYTES_STALL"
with m.If(was_last):
m.d.sync += fsm_data.eq(0)
m.next = "SEND_DELIMITER"
with m.Else():
m.next = "SEARCH_FOR_ZERO"
with m.State("SEND_DELIMITER"):
with m.If(self.source.valid & self.source.ready):
m.next = "COUNT_BYTES"
with m.State("CLOCK_OUT_BYTES_STALL"):
m.d.sync += rd_port.addr.eq(rd_port.addr + 1)
m.next = "CLOCK_OUT_BYTES"
# Wire FIFO input to sink
m.d.comb += fifo.w_data.eq(self.sink.data)
m.d.comb += fifo.w_en.eq(
(self.sink.ready) & (self.sink.valid) & (self.sink.data != 0)
)
m.d.comb += self.sink.ready.eq(fifo.w_rdy & fsm.ongoing("COUNT_BYTES"))
with m.State("CLOCK_OUT_BYTES"):
# Drive rd_port.addr
with m.If(rd_port.addr < head_pointer):
m.d.sync += rd_port.addr.eq(rd_port.addr + 1)
# Wire FIFO output to source, allow FSM to preempt FIFO
with m.If(fsm.ongoing("WAIT_FOR_LENGTH") | fsm.ongoing("SEND_DELIMITER")):
m.d.comb += self.source.data.eq(fsm_data)
m.d.comb += self.source.valid.eq(1)
m.d.comb += fifo.r_en.eq(0)
# Watch prev_addr
with m.If(rd_port_addr_prev < head_pointer):
m.d.sync += self.source.data.eq(rd_port.data)
m.d.sync += self.source.valid.eq(1)
m.next = "CLOCK_OUT_BYTES"
with m.Else():
m.d.comb += self.source.data.eq(fifo.r_data)
m.d.comb += self.source.valid.eq(fifo.r_rdy & fsm.ongoing("SEND_BYTES"))
m.d.comb += fifo.r_en.eq(self.source.valid & self.source.ready)
with m.If(rd_port_addr_prev == head_pointer):
# Reached end of message
with m.If(head_pointer == wr_port.addr):
m.d.sync += self.source.data.eq(0)
m.d.sync += self.source.valid.eq(1)
m.next = "IDLE"
with m.Else(): # this section is a beautiful!
m.d.sync += tail_pointer.eq(head_pointer + 1)
m.d.sync += head_pointer.eq(head_pointer + 1)
m.d.sync += rd_port.addr.eq(head_pointer + 1)
m.d.sync += self.source.valid.eq(
0
) # i have no idea why this works
m.next = "SEARCH_FOR_ZERO_STALL"
with m.State("SEARCH_FOR_ZERO_STALL"):
m.next = "SEARCH_FOR_ZERO"
# Fill memory from input stream
m.d.comb += wr_port.en.eq((fsm.ongoing("IDLE")) & (self.sink.valid))
m.d.comb += wr_port.data.eq(self.sink.data)
m.d.sync += wr_port.addr.eq(wr_port.addr + wr_port.en)
m.d.comb += self.sink.ready.eq(fsm.ongoing("IDLE"))
m.d.comb += self.source.last.eq(fsm.ongoing("SEND_DELIMITER"))
return m

110
test/test_cobs_encode.py Normal file
View File

@ -0,0 +1,110 @@
import random
from cobs import cobs
from manta.uart.cobs_encode import COBSEncode
from manta.utils import *
ce = COBSEncode()
@simulate(ce)
async def test_cobs_encode(ctx):
await ctx.tick()
await ctx.tick()
await ctx.tick()
await ctx.tick()
print("")
ctx.set(ce.source.ready, 1)
await encode_and_compare(ctx, [0x11, 0x22, 0x33, 0x44])
# await encode_and_compare(ctx, [0x11, 0x22, 0x33, 0x44, 0x55, 0x66, 0x77, 0x88, 0x99])
await ctx.tick()
await ctx.tick()
await ctx.tick()
await encode_and_compare(ctx, [0x11, 0x22, 0x00, 0x33])
await ctx.tick()
await ctx.tick()
await ctx.tick()
@simulate(ce)
async def test_cobs_encode_random(ctx):
ctx.set(ce.source.ready, 1)
num_tests = random.randint(1, 5)
for _ in range(num_tests):
length = random.randint(1, 254)
input_data = [random.randint(0, 255) for _ in range(length)]
result = await encode_and_compare(ctx, input_data, quiet=True)
await ctx.tick()
await ctx.tick()
await ctx.tick()
if not result:
await encode_and_compare(ctx, input_data, quiet=False)
await ctx.tick()
await ctx.tick()
await ctx.tick()
async def encode(ctx, data):
tx_done = False
tx_index = 0
rx_done = False
rx_buf = []
while not (tx_done and rx_done):
# Feed data to encoder
if not tx_done:
ctx.set(ce.sink.valid, 1)
ctx.set(ce.sink.data, data[tx_index])
ctx.set(ce.sink.last, tx_index == len(data) - 1)
if ctx.get(ce.sink.valid) and ctx.get(ce.sink.ready):
if tx_index == len(data) - 1:
tx_done = True
else:
tx_index += 1
else:
ctx.set(ce.sink.data, 0)
ctx.set(ce.sink.valid, 0)
ctx.set(ce.sink.last, 0)
# Randomly set source.ready
# ctx.set(ce.source.ready, random.randint(0, 1))
ctx.set(ce.source.ready, 1)
# Pull output data from buffer
if ctx.get(ce.source.valid) and ctx.get(ce.source.ready):
rx_buf += [ctx.get(ce.source.data)]
if ctx.get(ce.source.last):
rx_done = True
await ctx.tick()
await ctx.tick()
await ctx.tick()
await ctx.tick()
return rx_buf
async def encode_and_compare(ctx, data, quiet=False):
expected = cobs.encode(bytes(data)) + b"\0"
actual = await encode(ctx, data)
matched = bytes(actual) == expected
if not quiet:
print(f" input: {[hex(d) for d in data]}")
print(f"expected: {[hex(d) for d in expected]}")
print(f" actual: {[hex(d) for d in actual]}")
print(f" result: {'PASS' if matched else 'FAIL'}")
print("")
return matched