From 0c682073f6f199aec4489655a9998d48b82abb37 Mon Sep 17 00:00:00 2001 From: Fischer Moseley <42497969+fischermoseley@users.noreply.github.com> Date: Tue, 10 Feb 2026 20:54:29 -0700 Subject: [PATCH] uart: use wiring.Component instead of plain Elaborateable --- src/manta/ethernet/bridge.py | 126 +++++++++++++++--------------- src/manta/uart/cobs_decode.py | 43 +++++----- src/manta/uart/cobs_encode.py | 61 +++++++-------- src/manta/uart/receiver.py | 71 ++++++++--------- src/manta/uart/stream_packer.py | 48 ++++++------ src/manta/uart/stream_unpacker.py | 53 +++++++------ src/manta/uart/transmitter.py | 69 ++++++++-------- src/manta/utils.py | 22 +++++- test/test_uart_bridge_sim.py | 72 ++++++++++++----- 9 files changed, 309 insertions(+), 256 deletions(-) diff --git a/src/manta/ethernet/bridge.py b/src/manta/ethernet/bridge.py index dd3e856..6922e5a 100644 --- a/src/manta/ethernet/bridge.py +++ b/src/manta/ethernet/bridge.py @@ -1,20 +1,16 @@ from amaranth import * +from amaranth.lib import wiring +from amaranth.lib.wiring import In, Out from manta.utils import * -class EthernetBridge(Elaboratable): +class EthernetBridge(wiring.Component): + sink: In(StreamSignature(32)) + source: Out(StreamSignature(32)) + def __init__(self): - self.data_i = Signal(32) - self.valid_i = Signal() - self.last_i = Signal() - self.ready_o = Signal() - - self.data_o = Signal(32) - self.valid_o = Signal() - self.last_o = Signal() - self.ready_i = Signal() - + super().__init__() self.bus_o = Signal(InternalBus()) self.bus_i = Signal(InternalBus()) @@ -26,62 +22,62 @@ class EthernetBridge(Elaboratable): with m.FSM(init="IDLE"): with m.State("IDLE"): - m.d.sync += self.ready_o.eq(1) - m.d.sync += self.valid_o.eq(0) + m.d.sync += self.sink.ready.eq(1) + m.d.sync += self.source.valid.eq(0) # TODO: not necessary, but makes debugging way easier - m.d.sync += self.last_o.eq(0) - m.d.sync += self.data_o.eq(0) + m.d.sync += self.source.last.eq(0) + m.d.sync += self.source.data.eq(0) - with m.If(self.valid_i & self.ready_o): + with m.If(self.sink.valid & self.sink.ready): # First 32 bits was presented, which contains message type (first 3 bits) # as well as sequence number (next 13 bits). The remaining 16 bits are unused. # Send NACK if message type or sequence number is incorrect with m.If( - (self.data_i[:3] > max(MessageTypes)) - | (self.data_i[3:16] != seq_num_expected) + (self.sink.data[:3] > max(MessageTypes)) + | (self.sink.data[3:16] != seq_num_expected) ): # Wait to NACK if this isn't the last beat in message - with m.If(~self.last_i): + with m.If(~self.sink.last): m.next = "NACK_WAIT_FOR_LAST" # Otherwise, NACK immediately with m.Else(): - m.d.sync += self.data_o.eq( + m.d.sync += self.source.data.eq( EthernetMessageHeader.concat_signals( MessageTypes.NACK, seq_num_expected, ) ) - m.d.sync += self.valid_o.eq(1) - m.d.sync += self.last_o.eq(1) - m.d.sync += self.ready_o.eq(0) + m.d.sync += self.source.valid.eq(1) + m.d.sync += self.source.last.eq(1) + m.d.sync += self.sink.ready.eq(0) m.next = "NACK_WAIT_FOR_READY" - with m.Elif(self.data_i[:3] == MessageTypes.READ_REQUEST): + with m.Elif(self.sink.data[:3] == MessageTypes.READ_REQUEST): m.d.sync += seq_num_expected.eq(seq_num_expected + 1) - m.d.sync += read_len.eq(self.data_i[16:23] - 1) + m.d.sync += read_len.eq(self.sink.data[16:23] - 1) - m.d.sync += self.data_o.eq( + m.d.sync += self.source.data.eq( EthernetMessageHeader.concat_signals( MessageTypes.READ_RESPONSE, seq_num_expected, ) ) - m.d.sync += self.valid_o.eq(1) + m.d.sync += self.source.valid.eq(1) m.next = "READ_WAIT_FOR_ADDR" - with m.Elif(self.data_i[:3] == MessageTypes.WRITE_REQUEST): + with m.Elif(self.sink.data[:3] == MessageTypes.WRITE_REQUEST): m.next = "WRITE_WAIT_FOR_ADDR" with m.State("READ_WAIT_FOR_ADDR"): - m.d.sync += self.valid_o.eq(0) - m.d.sync += self.data_o.eq(0) + m.d.sync += self.source.valid.eq(0) + m.d.sync += self.source.data.eq(0) - with m.If(self.valid_i): + with m.If(self.sink.valid): # we have the length and the address to read from, let's go! - m.d.sync += self.bus_o.addr.eq(self.data_i) + m.d.sync += self.bus_o.addr.eq(self.sink.data) m.d.sync += self.bus_o.data.eq(0) m.d.sync += self.bus_o.rw.eq(0) m.d.sync += self.bus_o.valid.eq(1) @@ -94,7 +90,7 @@ class EthernetBridge(Elaboratable): m.next = "READ" with m.State("READ"): - m.d.sync += self.ready_o.eq(0) + m.d.sync += self.sink.ready.eq(0) # Clock out read requests to the bus with m.If(read_len > 0): @@ -111,47 +107,47 @@ class EthernetBridge(Elaboratable): # Clock out any read data from the bus with m.If(self.bus_i.valid): - m.d.sync += self.data_o.eq(self.bus_i.data) - m.d.sync += self.valid_o.eq(1) - m.d.sync += self.last_o.eq(self.bus_i.last) + m.d.sync += self.source.data.eq(self.bus_i.data) + m.d.sync += self.source.valid.eq(1) + m.d.sync += self.source.last.eq(self.bus_i.last) - with m.If(self.last_o): - m.d.sync += self.data_o.eq(0) - m.d.sync += self.valid_o.eq(0) - m.d.sync += self.last_o.eq(0) + with m.If(self.source.last): + m.d.sync += self.source.data.eq(0) + m.d.sync += self.source.valid.eq(0) + m.d.sync += self.source.last.eq(0) m.next = "IDLE" # TODO: could save a cycle by checking valid_i to see if there's more work to do with m.State("WRITE_WAIT_FOR_ADDR"): - with m.If(self.valid_i): - m.d.sync += self.bus_o.addr.eq(self.data_i) + with m.If(self.sink.valid): + m.d.sync += self.bus_o.addr.eq(self.sink.data) m.next = "WRITE_FIRST" # Don't want to increment address on the first write, # and I'm lazy so I'm making a new state to keep track of that with m.State("WRITE_FIRST"): - with m.If(self.valid_i): - m.d.sync += self.bus_o.data.eq(self.data_i) + with m.If(self.sink.valid): + m.d.sync += self.bus_o.data.eq(self.sink.data) m.d.sync += self.bus_o.rw.eq(1) m.d.sync += self.bus_o.valid.eq(1) - m.d.sync += self.bus_o.last.eq(self.last_i) + m.d.sync += self.bus_o.last.eq(self.sink.last) - with m.If(self.last_i): - m.d.sync += self.ready_o.eq(0) + with m.If(self.sink.last): + m.d.sync += self.sink.ready.eq(0) m.next = "WRITE_WAIT_FOR_LAST" with m.Else(): m.next = "WRITE" with m.State("WRITE"): - with m.If(self.valid_i): + with m.If(self.sink.valid): m.d.sync += self.bus_o.addr.eq(self.bus_o.addr + 1) - m.d.sync += self.bus_o.data.eq(self.data_i) + m.d.sync += self.bus_o.data.eq(self.sink.data) m.d.sync += self.bus_o.rw.eq(1) m.d.sync += self.bus_o.valid.eq(1) - m.d.sync += self.bus_o.last.eq(self.last_i) + m.d.sync += self.bus_o.last.eq(self.sink.last) - with m.If(self.last_i): - m.d.sync += self.ready_o.eq(0) + with m.If(self.sink.last): + m.d.sync += self.sink.ready.eq(0) m.next = "WRITE_WAIT_FOR_LAST" with m.Else(): @@ -165,38 +161,38 @@ class EthernetBridge(Elaboratable): with m.If(self.bus_i.last): m.d.sync += seq_num_expected.eq(seq_num_expected + 1) - m.d.sync += self.data_o.eq( + m.d.sync += self.source.data.eq( EthernetMessageHeader.concat_signals( MessageTypes.WRITE_RESPONSE, seq_num_expected, ) ) - m.d.sync += self.valid_o.eq(1) - m.d.sync += self.last_o.eq(1) + m.d.sync += self.source.valid.eq(1) + m.d.sync += self.source.last.eq(1) m.next = "IDLE" # TODO: could save a cycle by checking valid_i to see if there's more work to do with m.State("NACK_WAIT_FOR_LAST"): - with m.If(self.last_i): - m.d.sync += self.data_o.eq( + with m.If(self.sink.last): + m.d.sync += self.source.data.eq( EthernetMessageHeader.concat_signals( MessageTypes.NACK, seq_num_expected, ) ) - m.d.sync += self.valid_o.eq(1) - m.d.sync += self.last_o.eq(1) - m.d.sync += self.ready_o.eq(0) + m.d.sync += self.source.valid.eq(1) + m.d.sync += self.source.last.eq(1) + m.d.sync += self.sink.ready.eq(0) m.next = "NACK_WAIT_FOR_READY" with m.State("NACK_WAIT_FOR_READY"): - with m.If(self.ready_i): - m.d.sync += self.valid_o.eq(0) + with m.If(self.source.ready): + m.d.sync += self.source.valid.eq(0) # TODO: remove these next two lines, they're not necessary # although they are nice for debug... - m.d.sync += self.data_o.eq(0) - m.d.sync += self.last_o.eq(0) - m.d.sync += self.ready_o.eq(1) + m.d.sync += self.source.data.eq(0) + m.d.sync += self.source.last.eq(0) + m.d.sync += self.sink.ready.eq(1) m.next = "IDLE" diff --git a/src/manta/uart/cobs_decode.py b/src/manta/uart/cobs_decode.py index 9d9b5e0..fffdb47 100644 --- a/src/manta/uart/cobs_decode.py +++ b/src/manta/uart/cobs_decode.py @@ -1,57 +1,54 @@ from amaranth import * +from amaranth.lib import wiring +from amaranth.lib.wiring import In, Out + +from manta.utils import * -class COBSDecode(Elaboratable): - def __init__(self): - # Stream-like data input - self.data_i = Signal(8) - self.valid_i = Signal() - - # Stream-like data output - self.data_o = Signal(8) - self.valid_o = Signal() - self.last_o = Signal() +class COBSDecode(wiring.Component): + sink: In(StreamSignature(8, has_last=False, has_ready=False)) + source: Out(StreamSignature(8)) def elaborate(self, platform): m = Module() counter = Signal(8) - m.d.sync += self.data_o.eq(0) - m.d.sync += self.valid_o.eq(0) - m.d.sync += self.last_o.eq(0) + m.d.sync += self.source.data.eq(0) + m.d.sync += self.source.valid.eq(0) + m.d.sync += self.source.last.eq(0) # State Machine: with m.FSM(): # TODO: determine if wait for packet logic should stay # with m.State("WAIT_FOR_PACKET_START"): - # with m.If((self.data_i == 0) & (self.valid_i)): + # with m.If((self.sink.data == 0) & (self.sink.valid)): # m.next = "START_OF_PACKET" with m.State("START_OF_PACKET"): - with m.If(self.valid_i): - m.d.sync += counter.eq(self.data_i - 1) + with m.If(self.sink.valid): + m.d.sync += counter.eq(self.sink.data - 1) m.next = "DECODING" # with m.Else(): # m.next = "START_OF_PACKET" with m.State("DECODING"): - with m.If(self.valid_i): + with m.If(self.sink.valid): with m.If(counter > 0): m.d.sync += counter.eq(counter - 1) - m.d.sync += self.data_o.eq(self.data_i) - m.d.sync += self.valid_o.eq(1) + m.d.sync += self.source.data.eq(self.sink.data) + m.d.sync += self.source.valid.eq(1) m.next = "DECODING" with m.Else(): - with m.If(self.data_i == 0): - m.d.sync += self.last_o.eq(1) + with m.If(self.sink.data == 0): + m.d.sync += self.source.last.eq(1) m.next = "START_OF_PACKET" with m.Else(): - m.d.sync += counter.eq(self.data_i - 1) - m.d.sync += self.valid_o.eq(1) + m.d.sync += counter.eq(self.sink.data - 1) + m.d.sync += self.source.valid.eq(1) m.next = "DECODING" return m diff --git a/src/manta/uart/cobs_encode.py b/src/manta/uart/cobs_encode.py index 69fc94a..638dfac 100644 --- a/src/manta/uart/cobs_encode.py +++ b/src/manta/uart/cobs_encode.py @@ -1,23 +1,14 @@ from amaranth import * +from amaranth.lib import wiring from amaranth.lib.memory import Memory +from amaranth.lib.wiring import In, Out + +from manta.utils import * -class COBSEncode(Elaboratable): - def __init__(self): - # Top-Level IO - # Stream-like data input - self.data_i = Signal(8) - self.valid_i = Signal() - self.ready_o = Signal() - self.last_i = Signal() - - # Stream-like data output - self.data_o = Signal(8) - self.valid_o = Signal() - self.ready_i = Signal() - - # Define memory - self.memory = Memory(shape=8, depth=256, init=[0] * 256) +class COBSEncode(wiring.Component): + sink: In(StreamSignature(8)) + source: Out(StreamSignature(8, has_last=False)) def elaborate(self, platform): m = Module() @@ -27,13 +18,13 @@ class COBSEncode(Elaboratable): tail_pointer = Signal(range(256)) # Add memory and read/write ports - m.submodules.memory = self.memory - rd_port = self.memory.read_port() - wr_port = self.memory.write_port() + m.submodules.memory = memory = Memory(shape=8, depth=256, init=[0] * 256) + rd_port = memory.read_port() + wr_port = memory.write_port() # Reset top-level IO - m.d.sync += self.data_o.eq(0) - m.d.sync += self.valid_o.eq(0) + m.d.sync += self.source.data.eq(0) + m.d.sync += self.source.valid.eq(0) # Generate rd_port_addr_prev rd_port_addr_prev = Signal().like(rd_port.addr) @@ -42,7 +33,7 @@ class COBSEncode(Elaboratable): # State Machine: with m.FSM() as fsm: with m.State("IDLE"): - with m.If(self.last_i): + with m.If(self.sink.last): m.d.sync += head_pointer.eq(0) m.d.sync += tail_pointer.eq(0) m.d.sync += rd_port.addr.eq(0) @@ -59,8 +50,10 @@ class COBSEncode(Elaboratable): m.d.sync += head_pointer.eq(rd_port_addr_prev) m.d.sync += rd_port.addr.eq(tail_pointer) - m.d.sync += self.data_o.eq(rd_port_addr_prev - tail_pointer + 1) - m.d.sync += self.valid_o.eq(1) + m.d.sync += self.source.data.eq( + rd_port_addr_prev - tail_pointer + 1 + ) + m.d.sync += self.source.valid.eq(1) m.next = "CLOCK_OUT_BYTES_STALL" @@ -77,16 +70,16 @@ class COBSEncode(Elaboratable): m.d.sync += rd_port.addr.eq(rd_port.addr + 1) # Watch prev_addr - with m.If(rd_port_addr_prev <= head_pointer): - m.d.sync += self.data_o.eq(rd_port.data) - m.d.sync += self.valid_o.eq(1) + with m.If(rd_port_addr_prev < head_pointer): + m.d.sync += self.source.data.eq(rd_port.data) + m.d.sync += self.source.valid.eq(1) m.next = "CLOCK_OUT_BYTES" with m.If(rd_port_addr_prev == head_pointer): # Reached end of message with m.If(head_pointer == wr_port.addr): - m.d.sync += self.data_o.eq(0) - m.d.sync += self.valid_o.eq(1) + m.d.sync += self.source.data.eq(0) + m.d.sync += self.source.valid.eq(1) m.next = "IDLE" @@ -94,7 +87,9 @@ class COBSEncode(Elaboratable): m.d.sync += tail_pointer.eq(head_pointer + 1) m.d.sync += head_pointer.eq(head_pointer + 1) m.d.sync += rd_port.addr.eq(head_pointer + 1) - m.d.sync += self.valid_o.eq(0) # i have no idea why this works + m.d.sync += self.source.valid.eq( + 0 + ) # i have no idea why this works m.next = "SEARCH_FOR_ZERO_STALL" @@ -102,10 +97,10 @@ class COBSEncode(Elaboratable): m.next = "SEARCH_FOR_ZERO" # Fill memory from input stream - m.d.comb += wr_port.en.eq((fsm.ongoing("IDLE")) & (self.valid_i)) - m.d.comb += wr_port.data.eq(self.data_i) + m.d.comb += wr_port.en.eq((fsm.ongoing("IDLE")) & (self.sink.valid)) + m.d.comb += wr_port.data.eq(self.sink.data) m.d.sync += wr_port.addr.eq(wr_port.addr + wr_port.en) - m.d.comb += self.ready_o.eq(fsm.ongoing("IDLE")) + m.d.comb += self.sink.ready.eq(fsm.ongoing("IDLE")) return m diff --git a/src/manta/uart/receiver.py b/src/manta/uart/receiver.py index faadb48..82acc13 100644 --- a/src/manta/uart/receiver.py +++ b/src/manta/uart/receiver.py @@ -1,64 +1,65 @@ from amaranth import * +from amaranth.lib import wiring +from amaranth.lib.wiring import In, Out + +from manta.utils import * -class UARTReceiver(Elaboratable): +class UARTReceiver(wiring.Component): """ A module for receiving bytes on a 8N1 UART at a configurable baudrate. Outputs bytes as a stream. """ + rx: In(1) + source: Out(StreamSignature(8, has_last=False, has_ready=False)) + def __init__(self, clocks_per_baud): + super().__init__() self._clocks_per_baud = clocks_per_baud - # Top-Level Ports - self.rx = Signal() - self.data_o = Signal(8) - self.valid_o = Signal() - - # Internal Signals - self._busy = Signal() - self._bit_index = Signal(range(10)) - self._baud_counter = Signal(range(2 * clocks_per_baud)) - - self._rx_d = Signal() - self._rx_q = Signal() - self._rx_q_prev = Signal() - def elaborate(self, platform): m = Module() + busy = Signal() + bit_index = Signal(range(10)) + baud_counter = Signal(range(2 * self._clocks_per_baud)) + + rx_d = Signal() + rx_q = Signal() + rx_q_prev = Signal() + # Two Flip-Flop Synchronizer m.d.sync += [ - self._rx_d.eq(self.rx), - self._rx_q.eq(self._rx_d), - self._rx_q_prev.eq(self._rx_q), + rx_d.eq(self.rx), + rx_q.eq(rx_d), + rx_q_prev.eq(rx_q), ] - m.d.sync += self.valid_o.eq(0) + m.d.sync += self.source.valid.eq(0) - with m.If(~self._busy): - with m.If((~self._rx_q) & (self._rx_q_prev)): - m.d.sync += self._busy.eq(1) - m.d.sync += self._bit_index.eq(8) - m.d.sync += self._baud_counter.eq( + with m.If(~busy): + with m.If((~rx_q) & (rx_q_prev)): + m.d.sync += busy.eq(1) + m.d.sync += bit_index.eq(8) + m.d.sync += baud_counter.eq( self._clocks_per_baud + (self._clocks_per_baud // 2) - 2 ) with m.Else(): - with m.If(self._baud_counter == 0): - with m.If(self._bit_index == 0): - m.d.sync += self.valid_o.eq(1) - m.d.sync += self._busy.eq(0) - m.d.sync += self._bit_index.eq(0) - m.d.sync += self._baud_counter.eq(0) + with m.If(baud_counter == 0): + with m.If(bit_index == 0): + m.d.sync += self.source.valid.eq(1) + m.d.sync += busy.eq(0) + m.d.sync += bit_index.eq(0) + m.d.sync += baud_counter.eq(0) with m.Else(): - # m.d.sync += self.data_o.eq(Cat(self._rx_q, self.data_o[0:7])) - m.d.sync += self.data_o.eq(Cat(self.data_o[1:8], self._rx_q)) - m.d.sync += self._bit_index.eq(self._bit_index - 1) - m.d.sync += self._baud_counter.eq(self._clocks_per_baud - 1) + m.d.sync += self.source.data.eq(Cat(self.source.data[1:8], rx_q)) + m.d.sync += bit_index.eq(bit_index - 1) + m.d.sync += baud_counter.eq(self._clocks_per_baud - 1) with m.Else(): - m.d.sync += self._baud_counter.eq(self._baud_counter - 1) + m.d.sync += baud_counter.eq(baud_counter - 1) return m diff --git a/src/manta/uart/stream_packer.py b/src/manta/uart/stream_packer.py index 8586d9f..89a8840 100644 --- a/src/manta/uart/stream_packer.py +++ b/src/manta/uart/stream_packer.py @@ -1,40 +1,44 @@ from amaranth import * +from amaranth.lib import wiring +from amaranth.lib.wiring import In, Out + +from manta.utils import * -class StreamPacker(Elaboratable): - def __init__(self): - self.data_i = Signal(8) - self.valid_i = Signal() - self.ready_o = Signal(init=1) - self.last_i = Signal() - - self.data_o = Signal(32) - self.valid_o = Signal() - self.ready_i = Signal() - self.last_o = Signal() +class StreamPacker(wiring.Component): + sink: In(StreamSignature(8)) + source: Out(StreamSignature(32)) def elaborate(self, platform): m = Module() + # Defining an internal idle signal and combinationally assigning it to + # self.sink.ready allows us to specify an initial value for the signal + # without having to modify the members of StreamSignature + idle = Signal(init=1) + m.d.comb += self.sink.ready.eq(idle) + count = Signal(range(4)) - with m.If(self.ready_o): - with m.If(self.valid_i): - m.d.sync += self.data_o.eq(Cat(self.data_o[8:], self.data_i)) + with m.If(idle): + with m.If(self.sink.valid): + m.d.sync += self.source.data.eq( + Cat(self.source.data[8:], self.sink.data) + ) m.d.sync += count.eq(count + 1) with m.If(count == 3): - m.d.sync += self.ready_o.eq(0) - m.d.sync += self.valid_o.eq(1) - m.d.sync += self.last_o.eq(self.last_i) + m.d.sync += idle.eq(0) + m.d.sync += self.source.valid.eq(1) + m.d.sync += self.source.last.eq(self.sink.last) with m.Else(): - with m.If(self.valid_o & self.ready_i): - m.d.sync += self.ready_o.eq(1) - m.d.sync += self.valid_o.eq(0) + with m.If(self.source.valid & self.source.ready): + m.d.sync += idle.eq(1) + m.d.sync += self.source.valid.eq(0) # TODO: not necessary, but makes debugging much easier! - m.d.sync += self.data_o.eq(0) - m.d.sync += self.last_o.eq(0) + m.d.sync += self.source.data.eq(0) + m.d.sync += self.source.last.eq(0) return m diff --git a/src/manta/uart/stream_unpacker.py b/src/manta/uart/stream_unpacker.py index 10c1c7e..dbd626c 100644 --- a/src/manta/uart/stream_unpacker.py +++ b/src/manta/uart/stream_unpacker.py @@ -1,54 +1,57 @@ from amaranth import * +from amaranth.lib import wiring +from amaranth.lib.wiring import In, Out + +from manta.utils import * -class StreamUnpacker(Elaboratable): - def __init__(self): - self.data_i = Signal(32) - self.valid_i = Signal() - self.ready_o = Signal(init=1) - self.last_i = Signal() - - self.data_o = Signal(8) - self.valid_o = Signal() - self.ready_i = Signal() - self.last_o = Signal() +class StreamUnpacker(wiring.Component): + sink: In(StreamSignature(32)) + source: Out(StreamSignature(8)) def elaborate(self, platform): m = Module() # Turn a stream of 32-bit numbers into a stream of 8-bit numbers + + # Defining an internal idle signal and combinationally assigning it to + # self.sink.ready allows us to specify an initial value for the signal + # without having to modify the members of StreamSignature + idle = Signal(init=1) + m.d.comb += self.sink.ready.eq(idle) + buf = Signal(24) last = Signal() count = Signal(range(3)) - with m.If(self.ready_o): - with m.If(self.valid_i): - m.d.sync += buf.eq(self.data_i[8:]) - m.d.sync += last.eq(self.last_i) - m.d.sync += self.ready_o.eq(0) + with m.If(idle): + with m.If(self.sink.valid): + m.d.sync += buf.eq(self.sink.data[8:]) + m.d.sync += last.eq(self.sink.last) + m.d.sync += idle.eq(0) - m.d.sync += self.data_o.eq(self.data_i[:7]) - m.d.sync += self.valid_o.eq(1) + m.d.sync += self.source.data.eq(self.sink.data[:7]) + m.d.sync += self.source.valid.eq(1) m.d.sync += count.eq(0) # Have some data in the buffer with m.Else(): - with m.If(self.valid_o & self.ready_i): + with m.If(self.source.valid & self.source.ready): # if done, clean up and signal ready for next word with m.If(count == 3): - m.d.sync += self.valid_o.eq(0) - m.d.sync += self.ready_o.eq(1) + m.d.sync += self.source.valid.eq(0) + m.d.sync += idle.eq(1) # TODO: not necessary, but makes debugging much easier! - m.d.sync += self.data_o.eq(0) - m.d.sync += self.last_o.eq(0) + m.d.sync += self.source.data.eq(0) + m.d.sync += self.source.last.eq(0) # if not done, clock out next byte with m.Else(): - m.d.sync += self.data_o.eq(buf[8:]) + m.d.sync += self.source.data.eq(buf[8:]) m.d.sync += buf.eq(buf >> 8) m.d.sync += count.eq(count + 1) - m.d.sync += self.last_o.eq((last) & (count == 2)) + m.d.sync += self.source.last.eq((last) & (count == 2)) return m diff --git a/src/manta/uart/transmitter.py b/src/manta/uart/transmitter.py index 55a779b..b5dc8d3 100644 --- a/src/manta/uart/transmitter.py +++ b/src/manta/uart/transmitter.py @@ -1,57 +1,64 @@ from amaranth import * +from amaranth.lib import wiring +from amaranth.lib.wiring import In, Out + +from manta.utils import * -class UARTTransmitter(Elaboratable): +class UARTTransmitter(wiring.Component): """ A module for transmitting bytes on a 8N1 UART at a configurable baudrate. Accepts bytes as a stream. """ + sink: In(StreamSignature(8, has_last=False)) + tx: Out(1, init=1) + def __init__(self, clocks_per_baud): + super().__init__() self._clocks_per_baud = clocks_per_baud - # Top-Level Ports - self.data_i = Signal(8) - self.start_i = Signal() - self.done_o = Signal(init=1) - - self.tx = Signal(init=1) - - # Internal Signals - self._baud_counter = Signal(range(self._clocks_per_baud)) - self._buffer = Signal(9) - self._bit_index = Signal(4) - def elaborate(self, platform): m = Module() - with m.If((self.start_i) & (self.done_o)): - m.d.sync += self._baud_counter.eq(self._clocks_per_baud - 1) - m.d.sync += self._buffer.eq(Cat(self.data_i, 1)) - m.d.sync += self._bit_index.eq(0) - m.d.sync += self.done_o.eq(0) - m.d.sync += self.tx.eq(0) + # Defining an internal idle signal and combinationally assigning it to + # self.sink.ready allows us to specify an initial value for the signal + # without having to modify the members of StreamSignature + idle = Signal(init=1) + m.d.comb += self.sink.ready.eq(idle) - with m.Elif(~self.done_o): - m.d.sync += self._baud_counter.eq(self._baud_counter - 1) - m.d.sync += self.done_o.eq((self._baud_counter == 1) & (self._bit_index == 9)) + baud_counter = Signal(range(self._clocks_per_baud)) + buffer = Signal(9) + bit_index = Signal(4) + + with m.If(idle): + with m.If(self.sink.valid): + m.d.sync += baud_counter.eq(self._clocks_per_baud - 1) + m.d.sync += buffer.eq(Cat(self.sink.data, 1)) + m.d.sync += bit_index.eq(0) + m.d.sync += idle.eq(0) + m.d.sync += self.tx.eq(0) + + with m.Else(): + m.d.sync += baud_counter.eq(baud_counter - 1) + m.d.sync += idle.eq((baud_counter == 1) & (bit_index == 9)) # A baud period has elapsed - with m.If(self._baud_counter == 0): - m.d.sync += self._baud_counter.eq(self._clocks_per_baud - 1) + with m.If(baud_counter == 0): + m.d.sync += baud_counter.eq(self._clocks_per_baud - 1) # Clock out another bit if there are any left - with m.If(self._bit_index < 9): - m.d.sync += self.tx.eq(self._buffer.bit_select(self._bit_index, 1)) - m.d.sync += self._bit_index.eq(self._bit_index + 1) + with m.If(bit_index < 9): + m.d.sync += self.tx.eq(buffer.bit_select(bit_index, 1)) + m.d.sync += bit_index.eq(bit_index + 1) # Byte has been sent, send out next one or go to idle with m.Else(): - with m.If(self.start_i): - m.d.sync += self._buffer.eq(Cat(self.data_i, 1)) - m.d.sync += self._bit_index.eq(0) + with m.If(self.sink.valid): + m.d.sync += buffer.eq(Cat(self.sink.data, 1)) + m.d.sync += bit_index.eq(0) m.d.sync += self.tx.eq(0) with m.Else(): - m.d.sync += self.done_o.eq(1) + m.d.sync += idle.eq(1) return m diff --git a/src/manta/utils.py b/src/manta/utils.py index a1359b8..d2ed07c 100644 --- a/src/manta/utils.py +++ b/src/manta/utils.py @@ -4,9 +4,10 @@ from pathlib import Path from random import sample from amaranth import Cat, Const, Elaboratable, Signal, unsigned -from amaranth.lib import data +from amaranth.lib import data, wiring from amaranth.lib.data import Struct from amaranth.lib.enum import IntEnum +from amaranth.lib.wiring import In, Out from amaranth.sim import Simulator @@ -104,6 +105,25 @@ class InternalBus(data.StructLayout): ) +class StreamSignature(wiring.Signature): + def __init__(self, data_shape, has_last=True, has_ready=True): + sig = { + "data": Out(data_shape), + "valid": Out(1), + } + + if has_last: + sig["last"] = Out(1) + + if has_ready: + sig["ready"] = In(1) + + super().__init__(sig) + + def __eq__(self, other): + return self.members == other.members + + class MessageTypes(IntEnum, shape=unsigned(3)): READ_REQUEST = 0 WRITE_REQUEST = 1 diff --git a/test/test_uart_bridge_sim.py b/test/test_uart_bridge_sim.py index b76bdd7..5a4105b 100644 --- a/test/test_uart_bridge_sim.py +++ b/test/test_uart_bridge_sim.py @@ -1,5 +1,7 @@ from amaranth import * +from cobs import cobs +from manta import * from manta.ethernet.bridge import EthernetBridge from manta.uart.cobs_decode import COBSDecode from manta.uart.cobs_encode import COBSEncode @@ -32,27 +34,15 @@ class UARTHardware(Elaboratable): m.submodules.cobs_encode = cobs_encode = COBSEncode() m.submodules.uart_tx = uart_tx = UARTTransmitter(self._clocks_per_baud) + wiring.connect(m, uart_rx.source, cobs_decode.sink) + wiring.connect(m, cobs_decode.source, stream_packer.sink) + wiring.connect(m, stream_packer.source, bridge.sink) + wiring.connect(m, bridge.source, stream_unpacker.sink) + wiring.connect(m, stream_unpacker.source, cobs_encode.sink) + wiring.connect(m, cobs_encode.source, uart_tx.sink) + m.d.comb += [ uart_rx.rx.eq(self.rx), - cobs_decode.data_i.eq(uart_rx.data_o), - cobs_decode.valid_i.eq(uart_rx.valid_o), - stream_packer.data_i.eq(cobs_decode.data_o), - stream_packer.valid_i.eq(cobs_decode.valid_o), - stream_packer.last_i.eq(cobs_decode.last_o), - bridge.data_i.eq(stream_packer.data_o), - stream_packer.ready_i.eq(bridge.ready_o), - bridge.valid_i.eq(stream_packer.valid_o), - bridge.last_i.eq(stream_packer.last_o), - stream_unpacker.data_i.eq(bridge.data_o), - bridge.ready_i.eq(stream_unpacker.ready_o), - stream_unpacker.valid_i.eq(bridge.valid_o), - stream_unpacker.last_i.eq(bridge.last_o), - cobs_encode.data_i.eq(stream_unpacker.data_o), - cobs_encode.valid_i.eq(stream_unpacker.valid_o), - # not quite sure what the rest of these signals will be... - uart_tx.data_i.eq(cobs_encode.data_o), - uart_tx.start_i.eq(cobs_encode.valid_o), - cobs_encode.ready_i.eq(uart_tx.done_o), self.tx.eq(uart_tx.tx), self.bus_o.eq(bridge.bus_o), bridge.bus_i.eq(self.bus_i), @@ -61,7 +51,31 @@ class UARTHardware(Elaboratable): return m -uart_hw = UARTHardware() +class UARTHardwarePlusMemoryCore(Elaboratable): + def __init__(self): + self.rx = Signal() + self.tx = Signal() + + self._clocks_per_baud = 10 + + def elaborate(self, platform): + m = Module() + m.submodules.uart = uart = UARTHardware() + m.submodules.mem_core = mem_core = MemoryCore("bidirectional", 32, 1024) + mem_core.base_addr = 0 + + m.d.comb += uart.bus_i.eq(mem_core.bus_o) + m.d.comb += mem_core.bus_i.eq(uart.bus_o) + + m.d.comb += [ + self.tx.eq(uart.tx), + uart.rx.eq(self.rx), + ] + + return m + + +uart_hw = UARTHardwarePlusMemoryCore() async def send_byte(ctx, module, data): @@ -77,5 +91,21 @@ async def send_byte(ctx, module, data): @simulate(uart_hw) async def test_read_request(ctx): - await send_byte(ctx, uart_hw, 0) + addr = 0x5678_9ABC + header = EthernetMessageHeader.from_params( + MessageTypes.READ_REQUEST, seq_num=0x0, length=1 + ) + request = bytestring_from_ints([header.as_bits(), addr], byteorder="little") + encoded = cobs.encode(request) + encoded = encoded + int(0).to_bytes(1) + + ctx.set(uart_hw.rx, 1) + await ctx.tick() + await ctx.tick() + await ctx.tick() + + for byte in encoded: + await send_byte(ctx, uart_hw, int(byte)) + print(hex(int(byte))) + await ctx.tick()