ethernet: fix host-side UDP socket leak

This commit is contained in:
Fischer Moseley 2026-01-14 14:51:41 -07:00
parent 82f289aa74
commit 6aea352fba
1 changed files with 46 additions and 46 deletions

View File

@ -574,70 +574,70 @@ class EthernetInterface(Elaboratable):
return m
def _read_request(self, addr, length):
sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
sock.bind((self._host_ip_addr, self._udp_port))
with socket.socket(socket.AF_INET, socket.SOCK_DGRAM) as sock:
sock.bind((self._host_ip_addr, self._udp_port))
for _ in range(self._max_retries):
header = EthernetMessageHeader.from_params(
MessageTypes.READ_REQUEST, self._seq_num, length
)
request = bytestring_from_ints([header.as_bits(), addr])
for _ in range(self._max_retries):
header = EthernetMessageHeader.from_params(
MessageTypes.READ_REQUEST, self._seq_num, length
)
request = bytestring_from_ints([header.as_bits(), addr])
sock.sendto(request, (self._fpga_ip_addr, self._udp_port))
data, (ip_addr, port) = sock.recvfrom(4 + (length * 4))
response = ints_from_bytestring(data)
sock.sendto(request, (self._fpga_ip_addr, self._udp_port))
data, (ip_addr, port) = sock.recvfrom(4 + (length * 4))
response = ints_from_bytestring(data)
if ip_addr != self._fpga_ip_addr:
raise ValueError("Non-Manta traffic detected on this UDP port!")
if ip_addr != self._fpga_ip_addr:
raise ValueError("Non-Manta traffic detected on this UDP port!")
header = EthernetMessageHeader.from_bits(response[0])
read_data = response[1:]
header = EthernetMessageHeader.from_bits(response[0])
read_data = response[1:]
if header.msg_type == MessageTypes.READ_RESPONSE:
assert len(read_data) == length
self._seq_num += 1
return read_data
if header.msg_type == MessageTypes.READ_RESPONSE:
assert len(read_data) == length
self._seq_num += 1
return read_data
elif header.msg_type == MessageTypes.NACK:
self._seq_num = header.seq_num
elif header.msg_type == MessageTypes.NACK:
self._seq_num = header.seq_num
else:
print(MessageTypes(header.msg_type).name)
raise ValueError("Unexpected message format received!")
else:
print(MessageTypes(header.msg_type).name)
raise ValueError("Unexpected message format received!")
raise ValueError("Maximum number of retries exceeded!")
raise ValueError("Maximum number of retries exceeded!")
def _write_request(self, addr, datas):
sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
sock.bind((self._host_ip_addr, self._udp_port))
with socket.socket(socket.AF_INET, socket.SOCK_DGRAM) as sock:
sock.bind((self._host_ip_addr, self._udp_port))
for _ in range(self._max_retries):
header = EthernetMessageHeader.from_params(
MessageTypes.WRITE_REQUEST, self._seq_num
)
request = bytestring_from_ints([header.as_bits(), addr, *datas])
for _ in range(self._max_retries):
header = EthernetMessageHeader.from_params(
MessageTypes.WRITE_REQUEST, self._seq_num
)
request = bytestring_from_ints([header.as_bits(), addr, *datas])
sock.sendto(request, (self._fpga_ip_addr, self._udp_port))
data, (ip_addr, port) = sock.recvfrom(4)
response = ints_from_bytestring(data)
sock.sendto(request, (self._fpga_ip_addr, self._udp_port))
data, (ip_addr, port) = sock.recvfrom(4)
response = ints_from_bytestring(data)
assert port == self._udp_port
assert port == self._udp_port
if ip_addr != self._fpga_ip_addr:
raise ValueError("Non-Manta traffic detected on this UDP port!")
if ip_addr != self._fpga_ip_addr:
raise ValueError("Non-Manta traffic detected on this UDP port!")
header = EthernetMessageHeader.from_bits(response[0])
if header.msg_type == MessageTypes.WRITE_RESPONSE:
self._seq_num += 1
return
header = EthernetMessageHeader.from_bits(response[0])
if header.msg_type == MessageTypes.WRITE_RESPONSE:
self._seq_num += 1
return
elif header.msg_type == MessageTypes.NACK:
self._seq_num = header.seq_num
elif header.msg_type == MessageTypes.NACK:
self._seq_num = header.seq_num
else:
raise ValueError("Unexpected message format received!")
else:
raise ValueError("Unexpected message format received!")
raise ValueError("Maximum number of retries exceeded!")
raise ValueError("Maximum number of retries exceeded!")
def read_block(self, base_addr, length):
data = []