diff --git a/src/manta/ethernet/__init__.py b/src/manta/ethernet/__init__.py index c2a5ec1..e64df6d 100644 --- a/src/manta/ethernet/__init__.py +++ b/src/manta/ethernet/__init__.py @@ -94,8 +94,6 @@ class EthernetInterface(Elaboratable): self._seq_num = 0 self._max_retries = 3 - self._max_read_len = 126 - self._max_write_len = 126 def _check_config(self): # Make sure UDP port is an integer in the range 0-65535 @@ -583,54 +581,36 @@ class EthernetInterface(Elaboratable): return m - @staticmethod - def _read_request_bytes(seq_num, addr, length): - message = [ - (length << 16) | (seq_num << 3) | MessageTypes.READ_REQUEST, - addr, - ] - - return b"".join([i.to_bytes(4, "little") for i in message]) - - @staticmethod - def _write_request_bytes(seq_num, addr, datas): - message = [ - (seq_num << 3) | MessageTypes.WRITE_REQUEST, - addr, - *datas, - ] - - return b"".join([i.to_bytes(4, "little") for i in message]) - def _read_request(self, addr, length): sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) sock.bind((self._host_ip_addr, self._udp_port)) - retry_count = 0 - while retry_count < self._max_retries: - request = self._read_request_bytes(self._seq_num, addr, length) + for _ in range(self._max_retries): + header = EthernetMessageHeader.from_params( + MessageTypes.READ_REQUEST, self._seq_num, length + ) + request = bytestring_from_ints([header.as_bits(), addr]) + sock.sendto(request, (self._fpga_ip_addr, self._udp_port)) data, (ip_addr, port) = sock.recvfrom(4 + (length * 4)) + response = ints_from_bytestring(data) if ip_addr != self._fpga_ip_addr: raise ValueError("Non-Manta traffic detected on this UDP port!") - data = [ - int.from_bytes(data[i : i + 4], "little") - for i in range(0, len(data), 4) - ] + header = EthernetMessageHeader.from_bits(response[0]) + read_data = response[1:] - response_type = MessageTypes(part_select(data[0], 29, 31)) - if response_type == MessageTypes.READ_RESPONSE: - assert len(data) - 1 == length + if header.msg_type == MessageTypes.READ_RESPONSE: + assert len(read_data) == length self._seq_num += 1 - return data[1:] + return read_data - elif response_type == MessageTypes.NACK: - self._seq_num = part_select(data[0], 16, 28) - retry_count += 1 + elif header.msg_type == MessageTypes.NACK: + self._seq_num = header.seq_num else: + print(MessageTypes(header.msg_type).name) raise ValueError("Unexpected message format received!") raise ValueError("Maximum number of retries exceeded!") @@ -639,30 +619,28 @@ class EthernetInterface(Elaboratable): sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) sock.bind((self._host_ip_addr, self._udp_port)) - retry_count = 0 - while retry_count < self._max_retries: - request = self._write_request_bytes(self._seq_num, addr, datas) + for _ in range(self._max_retries): + header = EthernetMessageHeader.from_params( + MessageTypes.WRITE_REQUEST, self._seq_num + ) + request = bytestring_from_ints([header.as_bits(), addr, *datas]) + sock.sendto(request, (self._fpga_ip_addr, self._udp_port)) data, (ip_addr, port) = sock.recvfrom(4) + response = ints_from_bytestring(data) assert port == self._udp_port if ip_addr != self._fpga_ip_addr: raise ValueError("Non-Manta traffic detected on this UDP port!") - data = [ - int.from_bytes(data[i : i + 4], "little") - for i in range(0, len(data), 4) - ] - - response_type = MessageTypes(part_select(data[0], 29, 31)) - if response_type == MessageTypes.WRITE_RESPONSE: + header = EthernetMessageHeader.from_bits(response[0]) + if header.msg_type == MessageTypes.WRITE_RESPONSE: self._seq_num += 1 return - elif response_type == MessageTypes.NACK: - self._seq_num = part_select(data[0], 16, 28) - retry_count += 1 + elif header.msg_type == MessageTypes.NACK: + self._seq_num = header.seq_num else: raise ValueError("Unexpected message format received!") @@ -674,7 +652,7 @@ class EthernetInterface(Elaboratable): offset = 0 while offset < length: - chunk_size = min(self._max_read_len, length - offset) + chunk_size = min(EthernetMessageHeader.MAX_READ_LENGTH, length - offset) data += self._read_request(base_addr + offset, chunk_size) offset += chunk_size @@ -682,10 +660,12 @@ class EthernetInterface(Elaboratable): return data def write_block(self, base_addr, data): - data_chunks = split_into_chunks(data, self._max_write_len) + data_chunks = split_into_chunks(data, EthernetMessageHeader.MAX_WRITE_LENGTH) for i, chunk in enumerate(data_chunks): - self._write_request(base_addr + (i * self._max_write_len), chunk) + self._write_request( + base_addr + (i * EthernetMessageHeader.MAX_WRITE_LENGTH), chunk + ) def read(self, addrs): """ diff --git a/src/manta/ethernet/bridge.py b/src/manta/ethernet/bridge.py index 2aab2cb..dd3e856 100644 --- a/src/manta/ethernet/bridge.py +++ b/src/manta/ethernet/bridge.py @@ -49,10 +49,9 @@ class EthernetBridge(Elaboratable): # Otherwise, NACK immediately with m.Else(): m.d.sync += self.data_o.eq( - Cat( - C(0, unsigned(16)), - seq_num_expected, + EthernetMessageHeader.concat_signals( MessageTypes.NACK, + seq_num_expected, ) ) m.d.sync += self.valid_o.eq(1) @@ -65,10 +64,9 @@ class EthernetBridge(Elaboratable): m.d.sync += read_len.eq(self.data_i[16:23] - 1) m.d.sync += self.data_o.eq( - Cat( - C(0, unsigned(16)), - seq_num_expected, + EthernetMessageHeader.concat_signals( MessageTypes.READ_RESPONSE, + seq_num_expected, ) ) m.d.sync += self.valid_o.eq(1) @@ -167,12 +165,10 @@ class EthernetBridge(Elaboratable): with m.If(self.bus_i.last): m.d.sync += seq_num_expected.eq(seq_num_expected + 1) - m.d.sync += self.data_o.eq( - Cat( - C(0, unsigned(16)), - seq_num_expected, + EthernetMessageHeader.concat_signals( MessageTypes.WRITE_RESPONSE, + seq_num_expected, ) ) m.d.sync += self.valid_o.eq(1) @@ -182,7 +178,10 @@ class EthernetBridge(Elaboratable): with m.State("NACK_WAIT_FOR_LAST"): with m.If(self.last_i): m.d.sync += self.data_o.eq( - Cat(C(0, unsigned(16)), seq_num_expected, MessageTypes.NACK) + EthernetMessageHeader.concat_signals( + MessageTypes.NACK, + seq_num_expected, + ) ) m.d.sync += self.valid_o.eq(1) m.d.sync += self.last_o.eq(1) diff --git a/src/manta/utils.py b/src/manta/utils.py index b6cb325..62a7712 100644 --- a/src/manta/utils.py +++ b/src/manta/utils.py @@ -3,8 +3,9 @@ from abc import ABC, abstractmethod from pathlib import Path from random import sample -from amaranth import Elaboratable, unsigned +from amaranth import Cat, Const, Elaboratable, Signal, unsigned from amaranth.lib import data +from amaranth.lib.data import Struct from amaranth.lib.enum import IntEnum from amaranth.sim import Simulator @@ -111,6 +112,47 @@ class MessageTypes(IntEnum, shape=unsigned(3)): NACK = 4 +class EthernetMessageHeader(Struct): + msg_type: MessageTypes + seq_num: 13 + length: 7 = 0 + zero_padding: 9 = 0 + + MAX_READ_LENGTH = 126 + MAX_WRITE_LENGTH = 126 + + @classmethod + def from_params(cls, msg_type, seq_num, length=0): + return cls.const( + init={"msg_type": msg_type, "seq_num": seq_num, "length": length} + ) + + @classmethod + def concat_signals( + cls, msg_type: MessageTypes, seq_num: Signal, length: Signal = None + ): + # Make sure each signal is the right width! + widths = cls.from_bits(0).shape().members + + if Const(msg_type).shape().width != MessageTypes.as_shape().width: + raise TypeError + + if seq_num.shape().width != widths["seq_num"]: + raise TypeError + + zp_width = widths["zero_padding"] + len_width = widths["length"] + + if length is None: + return Cat(msg_type, seq_num, Const(0, len_width), Const(0, zp_width)) + + else: + if length.shape().width != len_width: + raise TypeError + + return Cat(msg_type, seq_num, length, Const(0, zp_width)) + + def warn(message): """ Prints a warning to the user's terminal. Originally the warn() method @@ -121,20 +163,6 @@ def warn(message): print("Warning: " + message) -def part_select(value, start, end): - # Ensure the start bit is less than or equal to the end bit - if start > end: - raise ValueError( - "Start bit position must be less than or equal to end bit position." - ) - - # Create a mask to isolate the bits from `start` to `end` - mask = (1 << (end - start + 1)) - 1 - - # Shift the number to the right by `start` bits and apply the mask - return (value >> start) & mask - - def parse_sequences(numbers): """ Takes a list of integers and identifies runs of sequential numbers @@ -203,6 +231,22 @@ def check_value_fits_in_bits(value, n_bits): raise ValueError("Signed integer too large.") +def ints_from_bytestring(bytes, byteorder="little"): + """ + Takes a list of ints, interprets them as 32-bit integers, and returns a + bytestring of the constituent bytes joined together. + """ + return [int.from_bytes(chunk, byteorder) for chunk in split_into_chunks(bytes, 4)] + + +def bytestring_from_ints(ints, byteorder="little"): + """ + Takes a list of ints, interprets them as 32-bit integers, and returns a + bytestring of the constituent bytes joined together. + """ + return b"".join(i.to_bytes(4, byteorder) for i in ints) + + def split_into_chunks(data, chunk_size): """ Split a list into a list of lists, where each sublist has length `chunk_size`.