diff --git a/src/manta/uart/cobs_encode.py b/src/manta/uart/cobs_encode.py index 638dfac..4b3f25f 100644 --- a/src/manta/uart/cobs_encode.py +++ b/src/manta/uart/cobs_encode.py @@ -1,5 +1,6 @@ from amaranth import * from amaranth.lib import wiring +from amaranth.lib.fifo import SyncFIFO from amaranth.lib.memory import Memory from amaranth.lib.wiring import In, Out @@ -8,99 +9,69 @@ from manta.utils import * class COBSEncode(wiring.Component): sink: In(StreamSignature(8)) - source: Out(StreamSignature(8, has_last=False)) + source: Out(StreamSignature(8)) def elaborate(self, platform): m = Module() - # Internal Signals - head_pointer = Signal(range(256)) - tail_pointer = Signal(range(256)) + m.submodules.fifo = fifo = SyncFIFO(width=8, depth=256) - # Add memory and read/write ports - m.submodules.memory = memory = Memory(shape=8, depth=256, init=[0] * 256) - rd_port = memory.read_port() - wr_port = memory.write_port() + fsm_data = Signal(8) + was_last = Signal(1) + fifo_written_to_last_cycle = Signal(1) + m.d.comb += fifo_written_to_last_cycle.eq( + (self.sink.ready) & (self.sink.valid) & (self.sink.data != 0) + ) - # Reset top-level IO - m.d.sync += self.source.data.eq(0) - m.d.sync += self.source.valid.eq(0) - - # Generate rd_port_addr_prev - rd_port_addr_prev = Signal().like(rd_port.addr) - m.d.sync += rd_port_addr_prev.eq(rd_port.addr) - - # State Machine: with m.FSM() as fsm: - with m.State("IDLE"): - with m.If(self.sink.last): - m.d.sync += head_pointer.eq(0) - m.d.sync += tail_pointer.eq(0) - m.d.sync += rd_port.addr.eq(0) - m.next = "SEARCH_FOR_ZERO" + with m.State("COUNT_BYTES"): + with m.If(self.sink.valid & self.sink.ready): + # End of packet or zero found, clock out length + with m.If((self.sink.last) | (self.sink.data == 0)): + m.d.sync += fsm_data.eq( + fifo.r_level + fifo_written_to_last_cycle + 1 + ) - with m.State("SEARCH_FOR_ZERO"): - # Drive read addr until length is reached - with m.If(rd_port.addr < wr_port.addr): - m.d.sync += rd_port.addr.eq(rd_port.addr + 1) + m.d.sync += was_last.eq(self.sink.last) + m.next = "WAIT_FOR_LENGTH" - # Watch prev_addr and data - with m.If((rd_port_addr_prev == wr_port.addr) | (rd_port.data == 0)): - # Either reached the end of the input buffer or found a zero + with m.State("WAIT_FOR_LENGTH"): + with m.If(self.source.valid & self.source.ready): + m.next = "SEND_BYTES" - m.d.sync += head_pointer.eq(rd_port_addr_prev) - m.d.sync += rd_port.addr.eq(tail_pointer) - m.d.sync += self.source.data.eq( - rd_port_addr_prev - tail_pointer + 1 - ) - m.d.sync += self.source.valid.eq(1) + with m.State("SEND_BYTES"): + # Wait until the FIFO will be empty on next cycle + with m.If( + (fifo.r_level == 1) & (self.source.ready) & (self.source.valid) + ): + m.next = "COUNT_BYTES" - m.next = "CLOCK_OUT_BYTES_STALL" + with m.If(was_last): + m.d.sync += fsm_data.eq(0) + m.next = "SEND_DELIMITER" - with m.Else(): - m.next = "SEARCH_FOR_ZERO" + with m.State("SEND_DELIMITER"): + with m.If(self.source.valid & self.source.ready): + m.next = "COUNT_BYTES" - with m.State("CLOCK_OUT_BYTES_STALL"): - m.d.sync += rd_port.addr.eq(rd_port.addr + 1) - m.next = "CLOCK_OUT_BYTES" + # Wire FIFO input to sink + m.d.comb += fifo.w_data.eq(self.sink.data) + m.d.comb += fifo.w_en.eq( + (self.sink.ready) & (self.sink.valid) & (self.sink.data != 0) + ) + m.d.comb += self.sink.ready.eq(fifo.w_rdy & fsm.ongoing("COUNT_BYTES")) - with m.State("CLOCK_OUT_BYTES"): - # Drive rd_port.addr - with m.If(rd_port.addr < head_pointer): - m.d.sync += rd_port.addr.eq(rd_port.addr + 1) + # Wire FIFO output to source, allow FSM to preempt FIFO + with m.If(fsm.ongoing("WAIT_FOR_LENGTH") | fsm.ongoing("SEND_DELIMITER")): + m.d.comb += self.source.data.eq(fsm_data) + m.d.comb += self.source.valid.eq(1) + m.d.comb += fifo.r_en.eq(0) - # Watch prev_addr - with m.If(rd_port_addr_prev < head_pointer): - m.d.sync += self.source.data.eq(rd_port.data) - m.d.sync += self.source.valid.eq(1) - m.next = "CLOCK_OUT_BYTES" + with m.Else(): + m.d.comb += self.source.data.eq(fifo.r_data) + m.d.comb += self.source.valid.eq(fifo.r_rdy & fsm.ongoing("SEND_BYTES")) + m.d.comb += fifo.r_en.eq(self.source.valid & self.source.ready) - with m.If(rd_port_addr_prev == head_pointer): - # Reached end of message - with m.If(head_pointer == wr_port.addr): - m.d.sync += self.source.data.eq(0) - m.d.sync += self.source.valid.eq(1) - - m.next = "IDLE" - - with m.Else(): # this section is a beautiful! - m.d.sync += tail_pointer.eq(head_pointer + 1) - m.d.sync += head_pointer.eq(head_pointer + 1) - m.d.sync += rd_port.addr.eq(head_pointer + 1) - m.d.sync += self.source.valid.eq( - 0 - ) # i have no idea why this works - - m.next = "SEARCH_FOR_ZERO_STALL" - - with m.State("SEARCH_FOR_ZERO_STALL"): - m.next = "SEARCH_FOR_ZERO" - - # Fill memory from input stream - m.d.comb += wr_port.en.eq((fsm.ongoing("IDLE")) & (self.sink.valid)) - m.d.comb += wr_port.data.eq(self.sink.data) - m.d.sync += wr_port.addr.eq(wr_port.addr + wr_port.en) - - m.d.comb += self.sink.ready.eq(fsm.ongoing("IDLE")) + m.d.comb += self.source.last.eq(fsm.ongoing("SEND_DELIMITER")) return m diff --git a/test/test_cobs_encode.py b/test/test_cobs_encode.py new file mode 100644 index 0000000..54e5802 --- /dev/null +++ b/test/test_cobs_encode.py @@ -0,0 +1,110 @@ +import random + +from cobs import cobs + +from manta.uart.cobs_encode import COBSEncode +from manta.utils import * + +ce = COBSEncode() + + +@simulate(ce) +async def test_cobs_encode(ctx): + await ctx.tick() + await ctx.tick() + await ctx.tick() + await ctx.tick() + + print("") + + ctx.set(ce.source.ready, 1) + + await encode_and_compare(ctx, [0x11, 0x22, 0x33, 0x44]) + # await encode_and_compare(ctx, [0x11, 0x22, 0x33, 0x44, 0x55, 0x66, 0x77, 0x88, 0x99]) + await ctx.tick() + await ctx.tick() + await ctx.tick() + + await encode_and_compare(ctx, [0x11, 0x22, 0x00, 0x33]) + await ctx.tick() + await ctx.tick() + await ctx.tick() + + +@simulate(ce) +async def test_cobs_encode_random(ctx): + ctx.set(ce.source.ready, 1) + + num_tests = random.randint(1, 5) + for _ in range(num_tests): + length = random.randint(1, 254) + input_data = [random.randint(0, 255) for _ in range(length)] + result = await encode_and_compare(ctx, input_data, quiet=True) + await ctx.tick() + await ctx.tick() + await ctx.tick() + + if not result: + await encode_and_compare(ctx, input_data, quiet=False) + await ctx.tick() + await ctx.tick() + await ctx.tick() + + +async def encode(ctx, data): + tx_done = False + tx_index = 0 + rx_done = False + rx_buf = [] + + while not (tx_done and rx_done): + # Feed data to encoder + if not tx_done: + ctx.set(ce.sink.valid, 1) + ctx.set(ce.sink.data, data[tx_index]) + ctx.set(ce.sink.last, tx_index == len(data) - 1) + + if ctx.get(ce.sink.valid) and ctx.get(ce.sink.ready): + if tx_index == len(data) - 1: + tx_done = True + + else: + tx_index += 1 + else: + ctx.set(ce.sink.data, 0) + ctx.set(ce.sink.valid, 0) + ctx.set(ce.sink.last, 0) + + # Randomly set source.ready + # ctx.set(ce.source.ready, random.randint(0, 1)) + ctx.set(ce.source.ready, 1) + + # Pull output data from buffer + if ctx.get(ce.source.valid) and ctx.get(ce.source.ready): + rx_buf += [ctx.get(ce.source.data)] + + if ctx.get(ce.source.last): + rx_done = True + + await ctx.tick() + + await ctx.tick() + await ctx.tick() + await ctx.tick() + + return rx_buf + + +async def encode_and_compare(ctx, data, quiet=False): + expected = cobs.encode(bytes(data)) + b"\0" + actual = await encode(ctx, data) + matched = bytes(actual) == expected + + if not quiet: + print(f" input: {[hex(d) for d in data]}") + print(f"expected: {[hex(d) for d in expected]}") + print(f" actual: {[hex(d) for d in actual]}") + print(f" result: {'PASS' if matched else 'FAIL'}") + print("") + + return matched