ethernet: use EthernetMessageHeader class
This commit is contained in:
parent
c5d23aba76
commit
f888303ca2
|
|
@ -94,8 +94,6 @@ class EthernetInterface(Elaboratable):
|
|||
|
||||
self._seq_num = 0
|
||||
self._max_retries = 3
|
||||
self._max_read_len = 126
|
||||
self._max_write_len = 126
|
||||
|
||||
def _check_config(self):
|
||||
# Make sure UDP port is an integer in the range 0-65535
|
||||
|
|
@ -583,54 +581,36 @@ class EthernetInterface(Elaboratable):
|
|||
|
||||
return m
|
||||
|
||||
@staticmethod
|
||||
def _read_request_bytes(seq_num, addr, length):
|
||||
message = [
|
||||
(length << 16) | (seq_num << 3) | MessageTypes.READ_REQUEST,
|
||||
addr,
|
||||
]
|
||||
|
||||
return b"".join([i.to_bytes(4, "little") for i in message])
|
||||
|
||||
@staticmethod
|
||||
def _write_request_bytes(seq_num, addr, datas):
|
||||
message = [
|
||||
(seq_num << 3) | MessageTypes.WRITE_REQUEST,
|
||||
addr,
|
||||
*datas,
|
||||
]
|
||||
|
||||
return b"".join([i.to_bytes(4, "little") for i in message])
|
||||
|
||||
def _read_request(self, addr, length):
|
||||
sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
|
||||
sock.bind((self._host_ip_addr, self._udp_port))
|
||||
|
||||
retry_count = 0
|
||||
while retry_count < self._max_retries:
|
||||
request = self._read_request_bytes(self._seq_num, addr, length)
|
||||
for _ in range(self._max_retries):
|
||||
header = EthernetMessageHeader.from_params(
|
||||
MessageTypes.READ_REQUEST, self._seq_num, length
|
||||
)
|
||||
request = bytestring_from_ints([header.as_bits(), addr])
|
||||
|
||||
sock.sendto(request, (self._fpga_ip_addr, self._udp_port))
|
||||
data, (ip_addr, port) = sock.recvfrom(4 + (length * 4))
|
||||
response = ints_from_bytestring(data)
|
||||
|
||||
if ip_addr != self._fpga_ip_addr:
|
||||
raise ValueError("Non-Manta traffic detected on this UDP port!")
|
||||
|
||||
data = [
|
||||
int.from_bytes(data[i : i + 4], "little")
|
||||
for i in range(0, len(data), 4)
|
||||
]
|
||||
header = EthernetMessageHeader.from_bits(response[0])
|
||||
read_data = response[1:]
|
||||
|
||||
response_type = MessageTypes(part_select(data[0], 29, 31))
|
||||
if response_type == MessageTypes.READ_RESPONSE:
|
||||
assert len(data) - 1 == length
|
||||
if header.msg_type == MessageTypes.READ_RESPONSE:
|
||||
assert len(read_data) == length
|
||||
self._seq_num += 1
|
||||
return data[1:]
|
||||
return read_data
|
||||
|
||||
elif response_type == MessageTypes.NACK:
|
||||
self._seq_num = part_select(data[0], 16, 28)
|
||||
retry_count += 1
|
||||
elif header.msg_type == MessageTypes.NACK:
|
||||
self._seq_num = header.seq_num
|
||||
|
||||
else:
|
||||
print(MessageTypes(header.msg_type).name)
|
||||
raise ValueError("Unexpected message format received!")
|
||||
|
||||
raise ValueError("Maximum number of retries exceeded!")
|
||||
|
|
@ -639,30 +619,28 @@ class EthernetInterface(Elaboratable):
|
|||
sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
|
||||
sock.bind((self._host_ip_addr, self._udp_port))
|
||||
|
||||
retry_count = 0
|
||||
while retry_count < self._max_retries:
|
||||
request = self._write_request_bytes(self._seq_num, addr, datas)
|
||||
for _ in range(self._max_retries):
|
||||
header = EthernetMessageHeader.from_params(
|
||||
MessageTypes.WRITE_REQUEST, self._seq_num
|
||||
)
|
||||
request = bytestring_from_ints([header.as_bits(), addr, *datas])
|
||||
|
||||
sock.sendto(request, (self._fpga_ip_addr, self._udp_port))
|
||||
data, (ip_addr, port) = sock.recvfrom(4)
|
||||
response = ints_from_bytestring(data)
|
||||
|
||||
assert port == self._udp_port
|
||||
|
||||
if ip_addr != self._fpga_ip_addr:
|
||||
raise ValueError("Non-Manta traffic detected on this UDP port!")
|
||||
|
||||
data = [
|
||||
int.from_bytes(data[i : i + 4], "little")
|
||||
for i in range(0, len(data), 4)
|
||||
]
|
||||
|
||||
response_type = MessageTypes(part_select(data[0], 29, 31))
|
||||
if response_type == MessageTypes.WRITE_RESPONSE:
|
||||
header = EthernetMessageHeader.from_bits(response[0])
|
||||
if header.msg_type == MessageTypes.WRITE_RESPONSE:
|
||||
self._seq_num += 1
|
||||
return
|
||||
|
||||
elif response_type == MessageTypes.NACK:
|
||||
self._seq_num = part_select(data[0], 16, 28)
|
||||
retry_count += 1
|
||||
elif header.msg_type == MessageTypes.NACK:
|
||||
self._seq_num = header.seq_num
|
||||
|
||||
else:
|
||||
raise ValueError("Unexpected message format received!")
|
||||
|
|
@ -674,7 +652,7 @@ class EthernetInterface(Elaboratable):
|
|||
offset = 0
|
||||
|
||||
while offset < length:
|
||||
chunk_size = min(self._max_read_len, length - offset)
|
||||
chunk_size = min(EthernetMessageHeader.MAX_READ_LENGTH, length - offset)
|
||||
data += self._read_request(base_addr + offset, chunk_size)
|
||||
offset += chunk_size
|
||||
|
||||
|
|
@ -682,10 +660,12 @@ class EthernetInterface(Elaboratable):
|
|||
return data
|
||||
|
||||
def write_block(self, base_addr, data):
|
||||
data_chunks = split_into_chunks(data, self._max_write_len)
|
||||
data_chunks = split_into_chunks(data, EthernetMessageHeader.MAX_WRITE_LENGTH)
|
||||
|
||||
for i, chunk in enumerate(data_chunks):
|
||||
self._write_request(base_addr + (i * self._max_write_len), chunk)
|
||||
self._write_request(
|
||||
base_addr + (i * EthernetMessageHeader.MAX_WRITE_LENGTH), chunk
|
||||
)
|
||||
|
||||
def read(self, addrs):
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -49,10 +49,9 @@ class EthernetBridge(Elaboratable):
|
|||
# Otherwise, NACK immediately
|
||||
with m.Else():
|
||||
m.d.sync += self.data_o.eq(
|
||||
Cat(
|
||||
C(0, unsigned(16)),
|
||||
seq_num_expected,
|
||||
EthernetMessageHeader.concat_signals(
|
||||
MessageTypes.NACK,
|
||||
seq_num_expected,
|
||||
)
|
||||
)
|
||||
m.d.sync += self.valid_o.eq(1)
|
||||
|
|
@ -65,10 +64,9 @@ class EthernetBridge(Elaboratable):
|
|||
m.d.sync += read_len.eq(self.data_i[16:23] - 1)
|
||||
|
||||
m.d.sync += self.data_o.eq(
|
||||
Cat(
|
||||
C(0, unsigned(16)),
|
||||
seq_num_expected,
|
||||
EthernetMessageHeader.concat_signals(
|
||||
MessageTypes.READ_RESPONSE,
|
||||
seq_num_expected,
|
||||
)
|
||||
)
|
||||
m.d.sync += self.valid_o.eq(1)
|
||||
|
|
@ -167,12 +165,10 @@ class EthernetBridge(Elaboratable):
|
|||
|
||||
with m.If(self.bus_i.last):
|
||||
m.d.sync += seq_num_expected.eq(seq_num_expected + 1)
|
||||
|
||||
m.d.sync += self.data_o.eq(
|
||||
Cat(
|
||||
C(0, unsigned(16)),
|
||||
seq_num_expected,
|
||||
EthernetMessageHeader.concat_signals(
|
||||
MessageTypes.WRITE_RESPONSE,
|
||||
seq_num_expected,
|
||||
)
|
||||
)
|
||||
m.d.sync += self.valid_o.eq(1)
|
||||
|
|
@ -182,7 +178,10 @@ class EthernetBridge(Elaboratable):
|
|||
with m.State("NACK_WAIT_FOR_LAST"):
|
||||
with m.If(self.last_i):
|
||||
m.d.sync += self.data_o.eq(
|
||||
Cat(C(0, unsigned(16)), seq_num_expected, MessageTypes.NACK)
|
||||
EthernetMessageHeader.concat_signals(
|
||||
MessageTypes.NACK,
|
||||
seq_num_expected,
|
||||
)
|
||||
)
|
||||
m.d.sync += self.valid_o.eq(1)
|
||||
m.d.sync += self.last_o.eq(1)
|
||||
|
|
|
|||
|
|
@ -3,8 +3,9 @@ from abc import ABC, abstractmethod
|
|||
from pathlib import Path
|
||||
from random import sample
|
||||
|
||||
from amaranth import Elaboratable, unsigned
|
||||
from amaranth import Cat, Const, Elaboratable, Signal, unsigned
|
||||
from amaranth.lib import data
|
||||
from amaranth.lib.data import Struct
|
||||
from amaranth.lib.enum import IntEnum
|
||||
from amaranth.sim import Simulator
|
||||
|
||||
|
|
@ -111,6 +112,47 @@ class MessageTypes(IntEnum, shape=unsigned(3)):
|
|||
NACK = 4
|
||||
|
||||
|
||||
class EthernetMessageHeader(Struct):
|
||||
msg_type: MessageTypes
|
||||
seq_num: 13
|
||||
length: 7 = 0
|
||||
zero_padding: 9 = 0
|
||||
|
||||
MAX_READ_LENGTH = 126
|
||||
MAX_WRITE_LENGTH = 126
|
||||
|
||||
@classmethod
|
||||
def from_params(cls, msg_type, seq_num, length=0):
|
||||
return cls.const(
|
||||
init={"msg_type": msg_type, "seq_num": seq_num, "length": length}
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def concat_signals(
|
||||
cls, msg_type: MessageTypes, seq_num: Signal, length: Signal = None
|
||||
):
|
||||
# Make sure each signal is the right width!
|
||||
widths = cls.from_bits(0).shape().members
|
||||
|
||||
if Const(msg_type).shape().width != MessageTypes.as_shape().width:
|
||||
raise TypeError
|
||||
|
||||
if seq_num.shape().width != widths["seq_num"]:
|
||||
raise TypeError
|
||||
|
||||
zp_width = widths["zero_padding"]
|
||||
len_width = widths["length"]
|
||||
|
||||
if length is None:
|
||||
return Cat(msg_type, seq_num, Const(0, len_width), Const(0, zp_width))
|
||||
|
||||
else:
|
||||
if length.shape().width != len_width:
|
||||
raise TypeError
|
||||
|
||||
return Cat(msg_type, seq_num, length, Const(0, zp_width))
|
||||
|
||||
|
||||
def warn(message):
|
||||
"""
|
||||
Prints a warning to the user's terminal. Originally the warn() method
|
||||
|
|
@ -121,20 +163,6 @@ def warn(message):
|
|||
print("Warning: " + message)
|
||||
|
||||
|
||||
def part_select(value, start, end):
|
||||
# Ensure the start bit is less than or equal to the end bit
|
||||
if start > end:
|
||||
raise ValueError(
|
||||
"Start bit position must be less than or equal to end bit position."
|
||||
)
|
||||
|
||||
# Create a mask to isolate the bits from `start` to `end`
|
||||
mask = (1 << (end - start + 1)) - 1
|
||||
|
||||
# Shift the number to the right by `start` bits and apply the mask
|
||||
return (value >> start) & mask
|
||||
|
||||
|
||||
def parse_sequences(numbers):
|
||||
"""
|
||||
Takes a list of integers and identifies runs of sequential numbers
|
||||
|
|
@ -203,6 +231,22 @@ def check_value_fits_in_bits(value, n_bits):
|
|||
raise ValueError("Signed integer too large.")
|
||||
|
||||
|
||||
def ints_from_bytestring(bytes, byteorder="little"):
|
||||
"""
|
||||
Takes a list of ints, interprets them as 32-bit integers, and returns a
|
||||
bytestring of the constituent bytes joined together.
|
||||
"""
|
||||
return [int.from_bytes(chunk, byteorder) for chunk in split_into_chunks(bytes, 4)]
|
||||
|
||||
|
||||
def bytestring_from_ints(ints, byteorder="little"):
|
||||
"""
|
||||
Takes a list of ints, interprets them as 32-bit integers, and returns a
|
||||
bytestring of the constituent bytes joined together.
|
||||
"""
|
||||
return b"".join(i.to_bytes(4, byteorder) for i in ints)
|
||||
|
||||
|
||||
def split_into_chunks(data, chunk_size):
|
||||
"""
|
||||
Split a list into a list of lists, where each sublist has length `chunk_size`.
|
||||
|
|
|
|||
Loading…
Reference in New Issue