ethernet: use EthernetMessageHeader class

This commit is contained in:
Fischer Moseley 2026-01-13 02:58:45 -07:00
parent c5d23aba76
commit f888303ca2
3 changed files with 100 additions and 77 deletions

View File

@ -94,8 +94,6 @@ class EthernetInterface(Elaboratable):
self._seq_num = 0
self._max_retries = 3
self._max_read_len = 126
self._max_write_len = 126
def _check_config(self):
# Make sure UDP port is an integer in the range 0-65535
@ -583,54 +581,36 @@ class EthernetInterface(Elaboratable):
return m
@staticmethod
def _read_request_bytes(seq_num, addr, length):
message = [
(length << 16) | (seq_num << 3) | MessageTypes.READ_REQUEST,
addr,
]
return b"".join([i.to_bytes(4, "little") for i in message])
@staticmethod
def _write_request_bytes(seq_num, addr, datas):
message = [
(seq_num << 3) | MessageTypes.WRITE_REQUEST,
addr,
*datas,
]
return b"".join([i.to_bytes(4, "little") for i in message])
def _read_request(self, addr, length):
sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
sock.bind((self._host_ip_addr, self._udp_port))
retry_count = 0
while retry_count < self._max_retries:
request = self._read_request_bytes(self._seq_num, addr, length)
for _ in range(self._max_retries):
header = EthernetMessageHeader.from_params(
MessageTypes.READ_REQUEST, self._seq_num, length
)
request = bytestring_from_ints([header.as_bits(), addr])
sock.sendto(request, (self._fpga_ip_addr, self._udp_port))
data, (ip_addr, port) = sock.recvfrom(4 + (length * 4))
response = ints_from_bytestring(data)
if ip_addr != self._fpga_ip_addr:
raise ValueError("Non-Manta traffic detected on this UDP port!")
data = [
int.from_bytes(data[i : i + 4], "little")
for i in range(0, len(data), 4)
]
header = EthernetMessageHeader.from_bits(response[0])
read_data = response[1:]
response_type = MessageTypes(part_select(data[0], 29, 31))
if response_type == MessageTypes.READ_RESPONSE:
assert len(data) - 1 == length
if header.msg_type == MessageTypes.READ_RESPONSE:
assert len(read_data) == length
self._seq_num += 1
return data[1:]
return read_data
elif response_type == MessageTypes.NACK:
self._seq_num = part_select(data[0], 16, 28)
retry_count += 1
elif header.msg_type == MessageTypes.NACK:
self._seq_num = header.seq_num
else:
print(MessageTypes(header.msg_type).name)
raise ValueError("Unexpected message format received!")
raise ValueError("Maximum number of retries exceeded!")
@ -639,30 +619,28 @@ class EthernetInterface(Elaboratable):
sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
sock.bind((self._host_ip_addr, self._udp_port))
retry_count = 0
while retry_count < self._max_retries:
request = self._write_request_bytes(self._seq_num, addr, datas)
for _ in range(self._max_retries):
header = EthernetMessageHeader.from_params(
MessageTypes.WRITE_REQUEST, self._seq_num
)
request = bytestring_from_ints([header.as_bits(), addr, *datas])
sock.sendto(request, (self._fpga_ip_addr, self._udp_port))
data, (ip_addr, port) = sock.recvfrom(4)
response = ints_from_bytestring(data)
assert port == self._udp_port
if ip_addr != self._fpga_ip_addr:
raise ValueError("Non-Manta traffic detected on this UDP port!")
data = [
int.from_bytes(data[i : i + 4], "little")
for i in range(0, len(data), 4)
]
response_type = MessageTypes(part_select(data[0], 29, 31))
if response_type == MessageTypes.WRITE_RESPONSE:
header = EthernetMessageHeader.from_bits(response[0])
if header.msg_type == MessageTypes.WRITE_RESPONSE:
self._seq_num += 1
return
elif response_type == MessageTypes.NACK:
self._seq_num = part_select(data[0], 16, 28)
retry_count += 1
elif header.msg_type == MessageTypes.NACK:
self._seq_num = header.seq_num
else:
raise ValueError("Unexpected message format received!")
@ -674,7 +652,7 @@ class EthernetInterface(Elaboratable):
offset = 0
while offset < length:
chunk_size = min(self._max_read_len, length - offset)
chunk_size = min(EthernetMessageHeader.MAX_READ_LENGTH, length - offset)
data += self._read_request(base_addr + offset, chunk_size)
offset += chunk_size
@ -682,10 +660,12 @@ class EthernetInterface(Elaboratable):
return data
def write_block(self, base_addr, data):
data_chunks = split_into_chunks(data, self._max_write_len)
data_chunks = split_into_chunks(data, EthernetMessageHeader.MAX_WRITE_LENGTH)
for i, chunk in enumerate(data_chunks):
self._write_request(base_addr + (i * self._max_write_len), chunk)
self._write_request(
base_addr + (i * EthernetMessageHeader.MAX_WRITE_LENGTH), chunk
)
def read(self, addrs):
"""

View File

@ -49,10 +49,9 @@ class EthernetBridge(Elaboratable):
# Otherwise, NACK immediately
with m.Else():
m.d.sync += self.data_o.eq(
Cat(
C(0, unsigned(16)),
seq_num_expected,
EthernetMessageHeader.concat_signals(
MessageTypes.NACK,
seq_num_expected,
)
)
m.d.sync += self.valid_o.eq(1)
@ -65,10 +64,9 @@ class EthernetBridge(Elaboratable):
m.d.sync += read_len.eq(self.data_i[16:23] - 1)
m.d.sync += self.data_o.eq(
Cat(
C(0, unsigned(16)),
seq_num_expected,
EthernetMessageHeader.concat_signals(
MessageTypes.READ_RESPONSE,
seq_num_expected,
)
)
m.d.sync += self.valid_o.eq(1)
@ -167,12 +165,10 @@ class EthernetBridge(Elaboratable):
with m.If(self.bus_i.last):
m.d.sync += seq_num_expected.eq(seq_num_expected + 1)
m.d.sync += self.data_o.eq(
Cat(
C(0, unsigned(16)),
seq_num_expected,
EthernetMessageHeader.concat_signals(
MessageTypes.WRITE_RESPONSE,
seq_num_expected,
)
)
m.d.sync += self.valid_o.eq(1)
@ -182,7 +178,10 @@ class EthernetBridge(Elaboratable):
with m.State("NACK_WAIT_FOR_LAST"):
with m.If(self.last_i):
m.d.sync += self.data_o.eq(
Cat(C(0, unsigned(16)), seq_num_expected, MessageTypes.NACK)
EthernetMessageHeader.concat_signals(
MessageTypes.NACK,
seq_num_expected,
)
)
m.d.sync += self.valid_o.eq(1)
m.d.sync += self.last_o.eq(1)

View File

@ -3,8 +3,9 @@ from abc import ABC, abstractmethod
from pathlib import Path
from random import sample
from amaranth import Elaboratable, unsigned
from amaranth import Cat, Const, Elaboratable, Signal, unsigned
from amaranth.lib import data
from amaranth.lib.data import Struct
from amaranth.lib.enum import IntEnum
from amaranth.sim import Simulator
@ -111,6 +112,47 @@ class MessageTypes(IntEnum, shape=unsigned(3)):
NACK = 4
class EthernetMessageHeader(Struct):
msg_type: MessageTypes
seq_num: 13
length: 7 = 0
zero_padding: 9 = 0
MAX_READ_LENGTH = 126
MAX_WRITE_LENGTH = 126
@classmethod
def from_params(cls, msg_type, seq_num, length=0):
return cls.const(
init={"msg_type": msg_type, "seq_num": seq_num, "length": length}
)
@classmethod
def concat_signals(
cls, msg_type: MessageTypes, seq_num: Signal, length: Signal = None
):
# Make sure each signal is the right width!
widths = cls.from_bits(0).shape().members
if Const(msg_type).shape().width != MessageTypes.as_shape().width:
raise TypeError
if seq_num.shape().width != widths["seq_num"]:
raise TypeError
zp_width = widths["zero_padding"]
len_width = widths["length"]
if length is None:
return Cat(msg_type, seq_num, Const(0, len_width), Const(0, zp_width))
else:
if length.shape().width != len_width:
raise TypeError
return Cat(msg_type, seq_num, length, Const(0, zp_width))
def warn(message):
"""
Prints a warning to the user's terminal. Originally the warn() method
@ -121,20 +163,6 @@ def warn(message):
print("Warning: " + message)
def part_select(value, start, end):
# Ensure the start bit is less than or equal to the end bit
if start > end:
raise ValueError(
"Start bit position must be less than or equal to end bit position."
)
# Create a mask to isolate the bits from `start` to `end`
mask = (1 << (end - start + 1)) - 1
# Shift the number to the right by `start` bits and apply the mask
return (value >> start) & mask
def parse_sequences(numbers):
"""
Takes a list of integers and identifies runs of sequential numbers
@ -203,6 +231,22 @@ def check_value_fits_in_bits(value, n_bits):
raise ValueError("Signed integer too large.")
def ints_from_bytestring(bytes, byteorder="little"):
"""
Takes a list of ints, interprets them as 32-bit integers, and returns a
bytestring of the constituent bytes joined together.
"""
return [int.from_bytes(chunk, byteorder) for chunk in split_into_chunks(bytes, 4)]
def bytestring_from_ints(ints, byteorder="little"):
"""
Takes a list of ints, interprets them as 32-bit integers, and returns a
bytestring of the constituent bytes joined together.
"""
return b"".join(i.to_bytes(4, byteorder) for i in ints)
def split_into_chunks(data, chunk_size):
"""
Split a list into a list of lists, where each sublist has length `chunk_size`.