diff --git a/src/manta/uart/cobs_decode.py b/src/manta/uart/cobs_decode.py new file mode 100644 index 0000000..f8028a8 --- /dev/null +++ b/src/manta/uart/cobs_decode.py @@ -0,0 +1,56 @@ +from amaranth import * + + +class COBSDecode(Elaboratable): + def __init__(self): + # Stream-like data input + self.data_i = Signal(8) + self.valid_i = Signal() + + # Stream-like data output + self.data_o = Signal(8) + self.valid_o = Signal() + self.last_o = Signal() + + def elaborate(self, platform): + m = Module() + + counter = Signal(8) + + m.d.sync += self.data_o.eq(0) + m.d.sync += self.valid_o.eq(0) + m.d.sync += self.last_o.eq(0) + + # State Machine: + with m.FSM(): + with m.State("WAIT_FOR_PACKET_START"): + with m.If((self.data_i == 0) & (self.valid_i)): + m.next = "START_OF_PACKET" + + with m.State("START_OF_PACKET"): + with m.If(self.valid_i): + m.d.sync += counter.eq(self.data_i - 1) + m.next = "DECODING" + + with m.Else(): + m.next = "START_OF_PACKET" + + with m.State("DECODING"): + with m.If(self.valid_i): + with m.If(counter > 0): + m.d.sync += counter.eq(counter - 1) + m.d.sync += self.data_o.eq(self.data_i) + m.d.sync += self.valid_o.eq(1) + m.next = "DECODING" + + with m.Else(): + with m.If(self.data_i == 0): + m.d.sync += self.last_o.eq(1) + m.next = "START_OF_PACKET" + + with m.Else(): + m.d.sync += counter.eq(self.data_i - 1) + m.d.sync += self.valid_o.eq(1) + m.next = "DECODING" + + return m diff --git a/src/manta/uart/cobs_encode.py b/src/manta/uart/cobs_encode.py new file mode 100644 index 0000000..3adf3a5 --- /dev/null +++ b/src/manta/uart/cobs_encode.py @@ -0,0 +1,111 @@ +from amaranth import * +from amaranth.lib.memory import Memory + + +class COBSEncode(Elaboratable): + def __init__(self): + # Top-Level IO + self.start = Signal() + self.done = Signal() + + # Stream-like data input + self.data_i = Signal(8) + self.valid_i = Signal() + self.ready_o = Signal() + + # Stream-like data output + self.data_o = Signal(8) + self.valid_o = Signal() + self.ready_i = Signal() + + # Define memory + self.memory = Memory(shape=8, depth=256, init=[0] * 256) + + def elaborate(self, platform): + m = Module() + + # Internal Signals + head_pointer = Signal(range(256)) + tail_pointer = Signal(range(256)) + + # Add memory and read/write ports + m.submodules.memory = self.memory + rd_port = self.memory.read_port() + wr_port = self.memory.write_port() + + # Reset top-level IO + m.d.sync += self.data_o.eq(0) + m.d.sync += self.valid_o.eq(0) + + # Generate rd_port_addr_prev + rd_port_addr_prev = Signal().like(rd_port.addr) + m.d.sync += rd_port_addr_prev.eq(rd_port.addr) + + # State Machine: + with m.FSM() as fsm: + with m.State("IDLE"): + with m.If(self.start): + m.d.sync += head_pointer.eq(0) + m.d.sync += tail_pointer.eq(0) + m.d.sync += rd_port.addr.eq(0) + m.next = "SEARCH_FOR_ZERO" + + with m.State("SEARCH_FOR_ZERO"): + # Drive read addr until length is reached + with m.If(rd_port.addr < wr_port.addr): + m.d.sync += rd_port.addr.eq(rd_port.addr + 1) + + # Watch prev_addr and data + with m.If((rd_port_addr_prev == wr_port.addr) | (rd_port.data == 0)): + # Either reached the end of the input buffer or found a zero + + m.d.sync += head_pointer.eq(rd_port_addr_prev) + m.d.sync += rd_port.addr.eq(tail_pointer) + m.d.sync += self.data_o.eq(rd_port_addr_prev - tail_pointer + 1) + m.d.sync += self.valid_o.eq(1) + + m.next = "CLOCK_OUT_BYTES_STALL" + + with m.Else(): + m.next = "SEARCH_FOR_ZERO" + + with m.State("CLOCK_OUT_BYTES_STALL"): + m.d.sync += rd_port.addr.eq(rd_port.addr + 1) + m.next = "CLOCK_OUT_BYTES" + + with m.State("CLOCK_OUT_BYTES"): + # Drive rd_port.addr + with m.If(rd_port.addr < head_pointer): + m.d.sync += rd_port.addr.eq(rd_port.addr + 1) + + # Watch prev_addr + with m.If(rd_port_addr_prev <= head_pointer): + m.d.sync += self.data_o.eq(rd_port.data) + m.d.sync += self.valid_o.eq(1) + m.next = "CLOCK_OUT_BYTES" + + with m.If(rd_port_addr_prev == head_pointer): + # Reached end of message + with m.If(head_pointer == wr_port.addr): + m.d.sync += self.data_o.eq(0) + m.d.sync += self.valid_o.eq(1) + + m.next = "IDLE" + + with m.Else(): # this section is a beautiful! + m.d.sync += tail_pointer.eq(head_pointer + 1) + m.d.sync += head_pointer.eq(head_pointer + 1) + m.d.sync += rd_port.addr.eq(head_pointer + 1) + m.d.sync += self.valid_o.eq(0) # i have no idea why this works + + m.next = "SEARCH_FOR_ZERO_STALL" + + with m.State("SEARCH_FOR_ZERO_STALL"): + m.next = "SEARCH_FOR_ZERO" + + # Fill memory from input stream + m.d.comb += wr_port.en.eq((fsm.ongoing("IDLE")) & (self.valid_i)) + m.d.comb += wr_port.data.eq(self.data_i) + m.d.sync += wr_port.addr.eq(wr_port.addr + wr_port.en) + + return m diff --git a/src/manta/uart/stream_packer.py b/src/manta/uart/stream_packer.py new file mode 100644 index 0000000..c8bb7e8 --- /dev/null +++ b/src/manta/uart/stream_packer.py @@ -0,0 +1,17 @@ +from amaranth import * + + +class StreamPacker(Elaboratable): + def __init__(self): + self.data_i = Signal(8) + self.valid_i = Signal() + self.last_i = Signal() + + self.data_o = Signal(32) + self.valid_o = Signal() + self.ready_i = Signal() + self.last_o = Signal() + + def elaborate(self, platform): + m = Module() + return m diff --git a/src/manta/uart/stream_unpacker.py b/src/manta/uart/stream_unpacker.py new file mode 100644 index 0000000..912d13b --- /dev/null +++ b/src/manta/uart/stream_unpacker.py @@ -0,0 +1,18 @@ +from amaranth import * + + +class StreamUnpacker(Elaboratable): + def __init__(self): + self.data_i = Signal(32) + self.valid_i = Signal() + self.ready_o = Signal() + self.last_i = Signal() + + self.data_o = Signal(8) + self.valid_o = Signal() + self.ready_i = Signal() + self.last_o = Signal() + + def elaborate(self, platform): + m = Module() + return m diff --git a/test/test_uart_bridge_sim.py b/test/test_uart_bridge_sim.py new file mode 100644 index 0000000..b76bdd7 --- /dev/null +++ b/test/test_uart_bridge_sim.py @@ -0,0 +1,81 @@ +from amaranth import * + +from manta.ethernet.bridge import EthernetBridge +from manta.uart.cobs_decode import COBSDecode +from manta.uart.cobs_encode import COBSEncode +from manta.uart.receiver import UARTReceiver +from manta.uart.stream_packer import StreamPacker +from manta.uart.stream_unpacker import StreamUnpacker +from manta.uart.transmitter import UARTTransmitter +from manta.utils import * + +# uart_rx -> COBS decode -> pack_stream -> bridge -> unpack_stream -> COBS encode -> uart_tx + + +class UARTHardware(Elaboratable): + def __init__(self): + self.rx = Signal() + self.tx = Signal() + + self.bus_o = Signal(InternalBus()) + self.bus_i = Signal(InternalBus()) + self._clocks_per_baud = 10 + + def elaborate(self, platform): + m = Module() + + m.submodules.uart_rx = uart_rx = UARTReceiver(self._clocks_per_baud) + m.submodules.cobs_decode = cobs_decode = COBSDecode() + m.submodules.stream_packer = stream_packer = StreamPacker() + m.submodules.bridge = bridge = EthernetBridge() + m.submodules.stream_unpacker = stream_unpacker = StreamUnpacker() + m.submodules.cobs_encode = cobs_encode = COBSEncode() + m.submodules.uart_tx = uart_tx = UARTTransmitter(self._clocks_per_baud) + + m.d.comb += [ + uart_rx.rx.eq(self.rx), + cobs_decode.data_i.eq(uart_rx.data_o), + cobs_decode.valid_i.eq(uart_rx.valid_o), + stream_packer.data_i.eq(cobs_decode.data_o), + stream_packer.valid_i.eq(cobs_decode.valid_o), + stream_packer.last_i.eq(cobs_decode.last_o), + bridge.data_i.eq(stream_packer.data_o), + stream_packer.ready_i.eq(bridge.ready_o), + bridge.valid_i.eq(stream_packer.valid_o), + bridge.last_i.eq(stream_packer.last_o), + stream_unpacker.data_i.eq(bridge.data_o), + bridge.ready_i.eq(stream_unpacker.ready_o), + stream_unpacker.valid_i.eq(bridge.valid_o), + stream_unpacker.last_i.eq(bridge.last_o), + cobs_encode.data_i.eq(stream_unpacker.data_o), + cobs_encode.valid_i.eq(stream_unpacker.valid_o), + # not quite sure what the rest of these signals will be... + uart_tx.data_i.eq(cobs_encode.data_o), + uart_tx.start_i.eq(cobs_encode.valid_o), + cobs_encode.ready_i.eq(uart_tx.done_o), + self.tx.eq(uart_tx.tx), + self.bus_o.eq(bridge.bus_o), + bridge.bus_i.eq(self.bus_i), + ] + + return m + + +uart_hw = UARTHardware() + + +async def send_byte(ctx, module, data): + # 8N1 serial, LSB sent first + data_bits = "0" + f"{data:08b}"[::-1] + "1" + data_bits = [int(bit) for bit in data_bits] + + for i in range(10 * uart_hw._clocks_per_baud): + bit_index = i // uart_hw._clocks_per_baud + ctx.set(module.rx, data_bits[bit_index]) + await ctx.tick() + + +@simulate(uart_hw) +async def test_read_request(ctx): + await send_byte(ctx, uart_hw, 0) + await ctx.tick()