uart: rewrite COBS encoder to allow backpressure
This commit is contained in:
parent
858c9554dc
commit
df4e1b2e0c
|
|
@ -1,5 +1,6 @@
|
|||
from amaranth import *
|
||||
from amaranth.lib import wiring
|
||||
from amaranth.lib.fifo import SyncFIFO
|
||||
from amaranth.lib.memory import Memory
|
||||
from amaranth.lib.wiring import In, Out
|
||||
|
||||
|
|
@ -8,99 +9,69 @@ from manta.utils import *
|
|||
|
||||
class COBSEncode(wiring.Component):
|
||||
sink: In(StreamSignature(8))
|
||||
source: Out(StreamSignature(8, has_last=False))
|
||||
source: Out(StreamSignature(8))
|
||||
|
||||
def elaborate(self, platform):
|
||||
m = Module()
|
||||
|
||||
# Internal Signals
|
||||
head_pointer = Signal(range(256))
|
||||
tail_pointer = Signal(range(256))
|
||||
m.submodules.fifo = fifo = SyncFIFO(width=8, depth=256)
|
||||
|
||||
# Add memory and read/write ports
|
||||
m.submodules.memory = memory = Memory(shape=8, depth=256, init=[0] * 256)
|
||||
rd_port = memory.read_port()
|
||||
wr_port = memory.write_port()
|
||||
fsm_data = Signal(8)
|
||||
was_last = Signal(1)
|
||||
fifo_written_to_last_cycle = Signal(1)
|
||||
m.d.comb += fifo_written_to_last_cycle.eq(
|
||||
(self.sink.ready) & (self.sink.valid) & (self.sink.data != 0)
|
||||
)
|
||||
|
||||
# Reset top-level IO
|
||||
m.d.sync += self.source.data.eq(0)
|
||||
m.d.sync += self.source.valid.eq(0)
|
||||
|
||||
# Generate rd_port_addr_prev
|
||||
rd_port_addr_prev = Signal().like(rd_port.addr)
|
||||
m.d.sync += rd_port_addr_prev.eq(rd_port.addr)
|
||||
|
||||
# State Machine:
|
||||
with m.FSM() as fsm:
|
||||
with m.State("IDLE"):
|
||||
with m.If(self.sink.last):
|
||||
m.d.sync += head_pointer.eq(0)
|
||||
m.d.sync += tail_pointer.eq(0)
|
||||
m.d.sync += rd_port.addr.eq(0)
|
||||
m.next = "SEARCH_FOR_ZERO"
|
||||
with m.State("COUNT_BYTES"):
|
||||
with m.If(self.sink.valid & self.sink.ready):
|
||||
# End of packet or zero found, clock out length
|
||||
with m.If((self.sink.last) | (self.sink.data == 0)):
|
||||
m.d.sync += fsm_data.eq(
|
||||
fifo.r_level + fifo_written_to_last_cycle + 1
|
||||
)
|
||||
|
||||
with m.State("SEARCH_FOR_ZERO"):
|
||||
# Drive read addr until length is reached
|
||||
with m.If(rd_port.addr < wr_port.addr):
|
||||
m.d.sync += rd_port.addr.eq(rd_port.addr + 1)
|
||||
m.d.sync += was_last.eq(self.sink.last)
|
||||
m.next = "WAIT_FOR_LENGTH"
|
||||
|
||||
# Watch prev_addr and data
|
||||
with m.If((rd_port_addr_prev == wr_port.addr) | (rd_port.data == 0)):
|
||||
# Either reached the end of the input buffer or found a zero
|
||||
with m.State("WAIT_FOR_LENGTH"):
|
||||
with m.If(self.source.valid & self.source.ready):
|
||||
m.next = "SEND_BYTES"
|
||||
|
||||
m.d.sync += head_pointer.eq(rd_port_addr_prev)
|
||||
m.d.sync += rd_port.addr.eq(tail_pointer)
|
||||
m.d.sync += self.source.data.eq(
|
||||
rd_port_addr_prev - tail_pointer + 1
|
||||
)
|
||||
m.d.sync += self.source.valid.eq(1)
|
||||
with m.State("SEND_BYTES"):
|
||||
# Wait until the FIFO will be empty on next cycle
|
||||
with m.If(
|
||||
(fifo.r_level == 1) & (self.source.ready) & (self.source.valid)
|
||||
):
|
||||
m.next = "COUNT_BYTES"
|
||||
|
||||
m.next = "CLOCK_OUT_BYTES_STALL"
|
||||
with m.If(was_last):
|
||||
m.d.sync += fsm_data.eq(0)
|
||||
m.next = "SEND_DELIMITER"
|
||||
|
||||
with m.Else():
|
||||
m.next = "SEARCH_FOR_ZERO"
|
||||
with m.State("SEND_DELIMITER"):
|
||||
with m.If(self.source.valid & self.source.ready):
|
||||
m.next = "COUNT_BYTES"
|
||||
|
||||
with m.State("CLOCK_OUT_BYTES_STALL"):
|
||||
m.d.sync += rd_port.addr.eq(rd_port.addr + 1)
|
||||
m.next = "CLOCK_OUT_BYTES"
|
||||
# Wire FIFO input to sink
|
||||
m.d.comb += fifo.w_data.eq(self.sink.data)
|
||||
m.d.comb += fifo.w_en.eq(
|
||||
(self.sink.ready) & (self.sink.valid) & (self.sink.data != 0)
|
||||
)
|
||||
m.d.comb += self.sink.ready.eq(fifo.w_rdy & fsm.ongoing("COUNT_BYTES"))
|
||||
|
||||
with m.State("CLOCK_OUT_BYTES"):
|
||||
# Drive rd_port.addr
|
||||
with m.If(rd_port.addr < head_pointer):
|
||||
m.d.sync += rd_port.addr.eq(rd_port.addr + 1)
|
||||
# Wire FIFO output to source, allow FSM to preempt FIFO
|
||||
with m.If(fsm.ongoing("WAIT_FOR_LENGTH") | fsm.ongoing("SEND_DELIMITER")):
|
||||
m.d.comb += self.source.data.eq(fsm_data)
|
||||
m.d.comb += self.source.valid.eq(1)
|
||||
m.d.comb += fifo.r_en.eq(0)
|
||||
|
||||
# Watch prev_addr
|
||||
with m.If(rd_port_addr_prev < head_pointer):
|
||||
m.d.sync += self.source.data.eq(rd_port.data)
|
||||
m.d.sync += self.source.valid.eq(1)
|
||||
m.next = "CLOCK_OUT_BYTES"
|
||||
with m.Else():
|
||||
m.d.comb += self.source.data.eq(fifo.r_data)
|
||||
m.d.comb += self.source.valid.eq(fifo.r_rdy & fsm.ongoing("SEND_BYTES"))
|
||||
m.d.comb += fifo.r_en.eq(self.source.valid & self.source.ready)
|
||||
|
||||
with m.If(rd_port_addr_prev == head_pointer):
|
||||
# Reached end of message
|
||||
with m.If(head_pointer == wr_port.addr):
|
||||
m.d.sync += self.source.data.eq(0)
|
||||
m.d.sync += self.source.valid.eq(1)
|
||||
|
||||
m.next = "IDLE"
|
||||
|
||||
with m.Else(): # this section is a beautiful!
|
||||
m.d.sync += tail_pointer.eq(head_pointer + 1)
|
||||
m.d.sync += head_pointer.eq(head_pointer + 1)
|
||||
m.d.sync += rd_port.addr.eq(head_pointer + 1)
|
||||
m.d.sync += self.source.valid.eq(
|
||||
0
|
||||
) # i have no idea why this works
|
||||
|
||||
m.next = "SEARCH_FOR_ZERO_STALL"
|
||||
|
||||
with m.State("SEARCH_FOR_ZERO_STALL"):
|
||||
m.next = "SEARCH_FOR_ZERO"
|
||||
|
||||
# Fill memory from input stream
|
||||
m.d.comb += wr_port.en.eq((fsm.ongoing("IDLE")) & (self.sink.valid))
|
||||
m.d.comb += wr_port.data.eq(self.sink.data)
|
||||
m.d.sync += wr_port.addr.eq(wr_port.addr + wr_port.en)
|
||||
|
||||
m.d.comb += self.sink.ready.eq(fsm.ongoing("IDLE"))
|
||||
m.d.comb += self.source.last.eq(fsm.ongoing("SEND_DELIMITER"))
|
||||
|
||||
return m
|
||||
|
|
|
|||
|
|
@ -0,0 +1,110 @@
|
|||
import random
|
||||
|
||||
from cobs import cobs
|
||||
|
||||
from manta.uart.cobs_encode import COBSEncode
|
||||
from manta.utils import *
|
||||
|
||||
ce = COBSEncode()
|
||||
|
||||
|
||||
@simulate(ce)
|
||||
async def test_cobs_encode(ctx):
|
||||
await ctx.tick()
|
||||
await ctx.tick()
|
||||
await ctx.tick()
|
||||
await ctx.tick()
|
||||
|
||||
print("")
|
||||
|
||||
ctx.set(ce.source.ready, 1)
|
||||
|
||||
await encode_and_compare(ctx, [0x11, 0x22, 0x33, 0x44])
|
||||
# await encode_and_compare(ctx, [0x11, 0x22, 0x33, 0x44, 0x55, 0x66, 0x77, 0x88, 0x99])
|
||||
await ctx.tick()
|
||||
await ctx.tick()
|
||||
await ctx.tick()
|
||||
|
||||
await encode_and_compare(ctx, [0x11, 0x22, 0x00, 0x33])
|
||||
await ctx.tick()
|
||||
await ctx.tick()
|
||||
await ctx.tick()
|
||||
|
||||
|
||||
@simulate(ce)
|
||||
async def test_cobs_encode_random(ctx):
|
||||
ctx.set(ce.source.ready, 1)
|
||||
|
||||
num_tests = random.randint(1, 5)
|
||||
for _ in range(num_tests):
|
||||
length = random.randint(1, 254)
|
||||
input_data = [random.randint(0, 255) for _ in range(length)]
|
||||
result = await encode_and_compare(ctx, input_data, quiet=True)
|
||||
await ctx.tick()
|
||||
await ctx.tick()
|
||||
await ctx.tick()
|
||||
|
||||
if not result:
|
||||
await encode_and_compare(ctx, input_data, quiet=False)
|
||||
await ctx.tick()
|
||||
await ctx.tick()
|
||||
await ctx.tick()
|
||||
|
||||
|
||||
async def encode(ctx, data):
|
||||
tx_done = False
|
||||
tx_index = 0
|
||||
rx_done = False
|
||||
rx_buf = []
|
||||
|
||||
while not (tx_done and rx_done):
|
||||
# Feed data to encoder
|
||||
if not tx_done:
|
||||
ctx.set(ce.sink.valid, 1)
|
||||
ctx.set(ce.sink.data, data[tx_index])
|
||||
ctx.set(ce.sink.last, tx_index == len(data) - 1)
|
||||
|
||||
if ctx.get(ce.sink.valid) and ctx.get(ce.sink.ready):
|
||||
if tx_index == len(data) - 1:
|
||||
tx_done = True
|
||||
|
||||
else:
|
||||
tx_index += 1
|
||||
else:
|
||||
ctx.set(ce.sink.data, 0)
|
||||
ctx.set(ce.sink.valid, 0)
|
||||
ctx.set(ce.sink.last, 0)
|
||||
|
||||
# Randomly set source.ready
|
||||
# ctx.set(ce.source.ready, random.randint(0, 1))
|
||||
ctx.set(ce.source.ready, 1)
|
||||
|
||||
# Pull output data from buffer
|
||||
if ctx.get(ce.source.valid) and ctx.get(ce.source.ready):
|
||||
rx_buf += [ctx.get(ce.source.data)]
|
||||
|
||||
if ctx.get(ce.source.last):
|
||||
rx_done = True
|
||||
|
||||
await ctx.tick()
|
||||
|
||||
await ctx.tick()
|
||||
await ctx.tick()
|
||||
await ctx.tick()
|
||||
|
||||
return rx_buf
|
||||
|
||||
|
||||
async def encode_and_compare(ctx, data, quiet=False):
|
||||
expected = cobs.encode(bytes(data)) + b"\0"
|
||||
actual = await encode(ctx, data)
|
||||
matched = bytes(actual) == expected
|
||||
|
||||
if not quiet:
|
||||
print(f" input: {[hex(d) for d in data]}")
|
||||
print(f"expected: {[hex(d) for d in expected]}")
|
||||
print(f" actual: {[hex(d) for d in actual]}")
|
||||
print(f" result: {'PASS' if matched else 'FAIL'}")
|
||||
print("")
|
||||
|
||||
return matched
|
||||
Loading…
Reference in New Issue