1
0
mirror of https://git.yoctoproject.org/poky synced 2026-06-01 13:09:50 +00:00

bitbake: asyncrpc: Abstract sockets

Rewrites the asyncrpc client and server code to make it possible to have
other transport backends that are not stream based (e.g. websockets
which are message based). The connection handling classes are now shared
between both the client and server to make it easier to implement new
transport mechanisms

(Bitbake rev: 2aaeae53696e4c2f13a169830c3b7089cbad6eca)

Signed-off-by: Joshua Watt <JPEWhacker@gmail.com>
Signed-off-by: Richard Purdie <richard.purdie@linuxfoundation.org>
This commit is contained in:
Joshua Watt
2023-11-03 08:26:19 -06:00
committed by Richard Purdie
parent f97b686884
commit 8f8501ed40
10 changed files with 396 additions and 362 deletions
+7 -25
View File
@@ -4,30 +4,12 @@
# SPDX-License-Identifier: GPL-2.0-only
#
import itertools
import json
# The Python async server defaults to a 64K receive buffer, so we hardcode our
# maximum chunk size. It would be better if the client and server reported to
# each other what the maximum chunk sizes were, but that will slow down the
# connection setup with a round trip delay so I'd rather not do that unless it
# is necessary
DEFAULT_MAX_CHUNK = 32 * 1024
def chunkify(msg, max_chunk):
if len(msg) < max_chunk - 1:
yield ''.join((msg, "\n"))
else:
yield ''.join((json.dumps({
'chunk-stream': None
}), "\n"))
args = [iter(msg)] * (max_chunk - 1)
for m in map(''.join, itertools.zip_longest(*args, fillvalue='')):
yield ''.join(itertools.chain(m, "\n"))
yield "\n"
from .client import AsyncClient, Client
from .serv import AsyncServer, AsyncServerConnection, ClientError, ServerError
from .serv import AsyncServer, AsyncServerConnection
from .connection import DEFAULT_MAX_CHUNK
from .exceptions import (
ClientError,
ServerError,
ConnectionClosedError,
)
+23 -55
View File
@@ -10,13 +10,13 @@ import json
import os
import socket
import sys
from . import chunkify, DEFAULT_MAX_CHUNK
from .connection import StreamConnection, DEFAULT_MAX_CHUNK
from .exceptions import ConnectionClosedError
class AsyncClient(object):
def __init__(self, proto_name, proto_version, logger, timeout=30):
self.reader = None
self.writer = None
self.socket = None
self.max_chunk = DEFAULT_MAX_CHUNK
self.proto_name = proto_name
self.proto_version = proto_version
@@ -25,7 +25,8 @@ class AsyncClient(object):
async def connect_tcp(self, address, port):
async def connect_sock():
return await asyncio.open_connection(address, port)
reader, writer = await asyncio.open_connection(address, port)
return StreamConnection(reader, writer, self.timeout, self.max_chunk)
self._connect_sock = connect_sock
@@ -40,27 +41,27 @@ class AsyncClient(object):
sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM, 0)
sock.connect(os.path.basename(path))
finally:
os.chdir(cwd)
return await asyncio.open_unix_connection(sock=sock)
os.chdir(cwd)
reader, writer = await asyncio.open_unix_connection(sock=sock)
return StreamConnection(reader, writer, self.timeout, self.max_chunk)
self._connect_sock = connect_sock
async def setup_connection(self):
s = '%s %s\n\n' % (self.proto_name, self.proto_version)
self.writer.write(s.encode("utf-8"))
await self.writer.drain()
# Send headers
await self.socket.send("%s %s" % (self.proto_name, self.proto_version))
# End of headers
await self.socket.send("")
async def connect(self):
if self.reader is None or self.writer is None:
(self.reader, self.writer) = await self._connect_sock()
if self.socket is None:
self.socket = await self._connect_sock()
await self.setup_connection()
async def close(self):
self.reader = None
if self.writer is not None:
self.writer.close()
self.writer = None
if self.socket is not None:
await self.socket.close()
self.socket = None
async def _send_wrapper(self, proc):
count = 0
@@ -71,6 +72,7 @@ class AsyncClient(object):
except (
OSError,
ConnectionError,
ConnectionClosedError,
json.JSONDecodeError,
UnicodeDecodeError,
) as e:
@@ -82,49 +84,15 @@ class AsyncClient(object):
await self.close()
count += 1
async def send_message(self, msg):
async def get_line():
try:
line = await asyncio.wait_for(self.reader.readline(), self.timeout)
except asyncio.TimeoutError:
raise ConnectionError("Timed out waiting for server")
if not line:
raise ConnectionError("Connection closed")
line = line.decode("utf-8")
if not line.endswith("\n"):
raise ConnectionError("Bad message %r" % (line))
return line
async def invoke(self, msg):
async def proc():
for c in chunkify(json.dumps(msg), self.max_chunk):
self.writer.write(c.encode("utf-8"))
await self.writer.drain()
l = await get_line()
m = json.loads(l)
if m and "chunk-stream" in m:
lines = []
while True:
l = (await get_line()).rstrip("\n")
if not l:
break
lines.append(l)
m = json.loads("".join(lines))
return m
await self.socket.send_message(msg)
return await self.socket.recv_message()
return await self._send_wrapper(proc)
async def ping(self):
return await self.send_message(
{'ping': {}}
)
return await self.invoke({"ping": {}})
class Client(object):
@@ -142,7 +110,7 @@ class Client(object):
# required (but harmless) with it.
asyncio.set_event_loop(self.loop)
self._add_methods('connect_tcp', 'ping')
self._add_methods("connect_tcp", "ping")
@abc.abstractmethod
def _get_async_client(self):
+95
View File
@@ -0,0 +1,95 @@
#
# Copyright BitBake Contributors
#
# SPDX-License-Identifier: GPL-2.0-only
#
import asyncio
import itertools
import json
from .exceptions import ClientError, ConnectionClosedError
# The Python async server defaults to a 64K receive buffer, so we hardcode our
# maximum chunk size. It would be better if the client and server reported to
# each other what the maximum chunk sizes were, but that will slow down the
# connection setup with a round trip delay so I'd rather not do that unless it
# is necessary
DEFAULT_MAX_CHUNK = 32 * 1024
def chunkify(msg, max_chunk):
if len(msg) < max_chunk - 1:
yield "".join((msg, "\n"))
else:
yield "".join((json.dumps({"chunk-stream": None}), "\n"))
args = [iter(msg)] * (max_chunk - 1)
for m in map("".join, itertools.zip_longest(*args, fillvalue="")):
yield "".join(itertools.chain(m, "\n"))
yield "\n"
class StreamConnection(object):
def __init__(self, reader, writer, timeout, max_chunk=DEFAULT_MAX_CHUNK):
self.reader = reader
self.writer = writer
self.timeout = timeout
self.max_chunk = max_chunk
@property
def address(self):
return self.writer.get_extra_info("peername")
async def send_message(self, msg):
for c in chunkify(json.dumps(msg), self.max_chunk):
self.writer.write(c.encode("utf-8"))
await self.writer.drain()
async def recv_message(self):
l = await self.recv()
m = json.loads(l)
if not m:
return m
if "chunk-stream" in m:
lines = []
while True:
l = await self.recv()
if not l:
break
lines.append(l)
m = json.loads("".join(lines))
return m
async def send(self, msg):
self.writer.write(("%s\n" % msg).encode("utf-8"))
await self.writer.drain()
async def recv(self):
if self.timeout < 0:
line = await self.reader.readline()
else:
try:
line = await asyncio.wait_for(self.reader.readline(), self.timeout)
except asyncio.TimeoutError:
raise ConnectionError("Timed out waiting for data")
if not line:
raise ConnectionClosedError("Connection closed")
line = line.decode("utf-8")
if not line.endswith("\n"):
raise ConnectionError("Bad message %r" % (line))
return line.rstrip()
async def close(self):
self.reader = None
if self.writer is not None:
self.writer.close()
self.writer = None
+17
View File
@@ -0,0 +1,17 @@
#
# Copyright BitBake Contributors
#
# SPDX-License-Identifier: GPL-2.0-only
#
class ClientError(Exception):
pass
class ServerError(Exception):
pass
class ConnectionClosedError(Exception):
pass
+165 -157
View File
@@ -12,241 +12,248 @@ import signal
import socket
import sys
import multiprocessing
from . import chunkify, DEFAULT_MAX_CHUNK
class ClientError(Exception):
pass
class ServerError(Exception):
pass
from .connection import StreamConnection
from .exceptions import ClientError, ServerError, ConnectionClosedError
class AsyncServerConnection(object):
def __init__(self, reader, writer, proto_name, logger):
self.reader = reader
self.writer = writer
# If a handler returns this object (e.g. `return self.NO_RESPONSE`), no
# return message will be automatically be sent back to the client
NO_RESPONSE = object()
def __init__(self, socket, proto_name, logger):
self.socket = socket
self.proto_name = proto_name
self.max_chunk = DEFAULT_MAX_CHUNK
self.handlers = {
'chunk-stream': self.handle_chunk,
'ping': self.handle_ping,
"ping": self.handle_ping,
}
self.logger = logger
async def close(self):
await self.socket.close()
async def process_requests(self):
try:
self.addr = self.writer.get_extra_info('peername')
self.logger.debug('Client %r connected' % (self.addr,))
self.logger.info("Client %r connected" % (self.socket.address,))
# Read protocol and version
client_protocol = await self.reader.readline()
client_protocol = await self.socket.recv()
if not client_protocol:
return
(client_proto_name, client_proto_version) = client_protocol.decode('utf-8').rstrip().split()
(client_proto_name, client_proto_version) = client_protocol.split()
if client_proto_name != self.proto_name:
self.logger.debug('Rejecting invalid protocol %s' % (self.proto_name))
self.logger.debug("Rejecting invalid protocol %s" % (self.proto_name))
return
self.proto_version = tuple(int(v) for v in client_proto_version.split('.'))
self.proto_version = tuple(int(v) for v in client_proto_version.split("."))
if not self.validate_proto_version():
self.logger.debug('Rejecting invalid protocol version %s' % (client_proto_version))
self.logger.debug(
"Rejecting invalid protocol version %s" % (client_proto_version)
)
return
# Read headers. Currently, no headers are implemented, so look for
# an empty line to signal the end of the headers
while True:
line = await self.reader.readline()
if not line:
return
line = line.decode('utf-8').rstrip()
if not line:
header = await self.socket.recv()
if not header:
break
# Handle messages
while True:
d = await self.read_message()
d = await self.socket.recv_message()
if d is None:
break
await self.dispatch_message(d)
await self.writer.drain()
except ClientError as e:
response = await self.dispatch_message(d)
if response is not self.NO_RESPONSE:
await self.socket.send_message(response)
except ConnectionClosedError as e:
self.logger.info(str(e))
except (ClientError, ConnectionError) as e:
self.logger.error(str(e))
finally:
self.writer.close()
await self.close()
async def dispatch_message(self, msg):
for k in self.handlers.keys():
if k in msg:
self.logger.debug('Handling %s' % k)
await self.handlers[k](msg[k])
return
self.logger.debug("Handling %s" % k)
return await self.handlers[k](msg[k])
raise ClientError("Unrecognized command %r" % msg)
def write_message(self, msg):
for c in chunkify(json.dumps(msg), self.max_chunk):
self.writer.write(c.encode('utf-8'))
async def read_message(self):
l = await self.reader.readline()
if not l:
return None
try:
message = l.decode('utf-8')
if not message.endswith('\n'):
return None
return json.loads(message)
except (json.JSONDecodeError, UnicodeDecodeError) as e:
self.logger.error('Bad message from client: %r' % message)
raise e
async def handle_chunk(self, request):
lines = []
try:
while True:
l = await self.reader.readline()
l = l.rstrip(b"\n").decode("utf-8")
if not l:
break
lines.append(l)
msg = json.loads(''.join(lines))
except (json.JSONDecodeError, UnicodeDecodeError) as e:
self.logger.error('Bad message from client: %r' % lines)
raise e
if 'chunk-stream' in msg:
raise ClientError("Nested chunks are not allowed")
await self.dispatch_message(msg)
async def handle_ping(self, request):
response = {'alive': True}
self.write_message(response)
return {"alive": True}
class StreamServer(object):
def __init__(self, handler, logger):
self.handler = handler
self.logger = logger
self.closed = False
async def handle_stream_client(self, reader, writer):
# writer.transport.set_write_buffer_limits(0)
socket = StreamConnection(reader, writer, -1)
if self.closed:
await socket.close()
return
await self.handler(socket)
async def stop(self):
self.closed = True
class TCPStreamServer(StreamServer):
def __init__(self, host, port, handler, logger):
super().__init__(handler, logger)
self.host = host
self.port = port
def start(self, loop):
self.server = loop.run_until_complete(
asyncio.start_server(self.handle_stream_client, self.host, self.port)
)
for s in self.server.sockets:
self.logger.debug("Listening on %r" % (s.getsockname(),))
# Newer python does this automatically. Do it manually here for
# maximum compatibility
s.setsockopt(socket.SOL_TCP, socket.TCP_NODELAY, 1)
s.setsockopt(socket.SOL_TCP, socket.TCP_QUICKACK, 1)
# Enable keep alives. This prevents broken client connections
# from persisting on the server for long periods of time.
s.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1)
s.setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPIDLE, 30)
s.setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPINTVL, 15)
s.setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPCNT, 4)
name = self.server.sockets[0].getsockname()
if self.server.sockets[0].family == socket.AF_INET6:
self.address = "[%s]:%d" % (name[0], name[1])
else:
self.address = "%s:%d" % (name[0], name[1])
return [self.server.wait_closed()]
async def stop(self):
await super().stop()
self.server.close()
def cleanup(self):
pass
class UnixStreamServer(StreamServer):
def __init__(self, path, handler, logger):
super().__init__(handler, logger)
self.path = path
def start(self, loop):
cwd = os.getcwd()
try:
# Work around path length limits in AF_UNIX
os.chdir(os.path.dirname(self.path))
self.server = loop.run_until_complete(
asyncio.start_unix_server(
self.handle_stream_client, os.path.basename(self.path)
)
)
finally:
os.chdir(cwd)
self.logger.debug("Listening on %r" % self.path)
self.address = "unix://%s" % os.path.abspath(self.path)
return [self.server.wait_closed()]
async def stop(self):
await super().stop()
self.server.close()
def cleanup(self):
os.unlink(self.path)
class AsyncServer(object):
def __init__(self, logger):
self._cleanup_socket = None
self.logger = logger
self.start = None
self.address = None
self.loop = None
self.run_tasks = []
def start_tcp_server(self, host, port):
def start_tcp():
self.server = self.loop.run_until_complete(
asyncio.start_server(self.handle_client, host, port)
)
for s in self.server.sockets:
self.logger.debug('Listening on %r' % (s.getsockname(),))
# Newer python does this automatically. Do it manually here for
# maximum compatibility
s.setsockopt(socket.SOL_TCP, socket.TCP_NODELAY, 1)
s.setsockopt(socket.SOL_TCP, socket.TCP_QUICKACK, 1)
# Enable keep alives. This prevents broken client connections
# from persisting on the server for long periods of time.
s.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1)
s.setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPIDLE, 30)
s.setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPINTVL, 15)
s.setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPCNT, 4)
name = self.server.sockets[0].getsockname()
if self.server.sockets[0].family == socket.AF_INET6:
self.address = "[%s]:%d" % (name[0], name[1])
else:
self.address = "%s:%d" % (name[0], name[1])
self.start = start_tcp
self.server = TCPStreamServer(host, port, self._client_handler, self.logger)
def start_unix_server(self, path):
def cleanup():
os.unlink(path)
self.server = UnixStreamServer(path, self._client_handler, self.logger)
def start_unix():
cwd = os.getcwd()
try:
# Work around path length limits in AF_UNIX
os.chdir(os.path.dirname(path))
self.server = self.loop.run_until_complete(
asyncio.start_unix_server(self.handle_client, os.path.basename(path))
)
finally:
os.chdir(cwd)
self.logger.debug('Listening on %r' % path)
self._cleanup_socket = cleanup
self.address = "unix://%s" % os.path.abspath(path)
self.start = start_unix
@abc.abstractmethod
def accept_client(self, reader, writer):
pass
async def handle_client(self, reader, writer):
# writer.transport.set_write_buffer_limits(0)
async def _client_handler(self, socket):
try:
client = self.accept_client(reader, writer)
client = self.accept_client(socket)
await client.process_requests()
except Exception as e:
import traceback
self.logger.error('Error from client: %s' % str(e), exc_info=True)
traceback.print_exc()
writer.close()
self.logger.debug('Client disconnected')
def run_loop_forever(self):
try:
self.loop.run_forever()
except KeyboardInterrupt:
pass
self.logger.error("Error from client: %s" % str(e), exc_info=True)
traceback.print_exc()
await socket.close()
self.logger.debug("Client disconnected")
@abc.abstractmethod
def accept_client(self, socket):
pass
async def stop(self):
self.logger.debug("Stopping server")
await self.server.stop()
def start(self):
tasks = self.server.start(self.loop)
self.address = self.server.address
return tasks
def signal_handler(self):
self.logger.debug("Got exit signal")
self.loop.stop()
self.loop.create_task(self.stop())
def _serve_forever(self):
def _serve_forever(self, tasks):
try:
self.loop.add_signal_handler(signal.SIGTERM, self.signal_handler)
self.loop.add_signal_handler(signal.SIGINT, self.signal_handler)
self.loop.add_signal_handler(signal.SIGQUIT, self.signal_handler)
signal.pthread_sigmask(signal.SIG_UNBLOCK, [signal.SIGTERM])
self.run_loop_forever()
self.server.close()
self.loop.run_until_complete(asyncio.gather(*tasks))
self.loop.run_until_complete(self.server.wait_closed())
self.logger.debug('Server shutting down')
self.logger.debug("Server shutting down")
finally:
if self._cleanup_socket is not None:
self._cleanup_socket()
self.server.cleanup()
def serve_forever(self):
"""
Serve requests in the current process
"""
self._create_loop()
tasks = self.start()
self._serve_forever(tasks)
self.loop.close()
def _create_loop(self):
# Create loop and override any loop that may have existed in
# a parent process. It is possible that the usecases of
# serve_forever might be constrained enough to allow using
# get_event_loop here, but better safe than sorry for now.
self.loop = asyncio.new_event_loop()
asyncio.set_event_loop(self.loop)
self.start()
self._serve_forever()
def serve_as_process(self, *, prefunc=None, args=()):
"""
Serve requests in a child process
"""
def run(queue):
# Create loop and override any loop that may have existed
# in a parent process. Without doing this and instead
@@ -259,18 +266,19 @@ class AsyncServer(object):
# more general, though, as any potential use of asyncio in
# Cooker could create a loop that needs to replaced in this
# new process.
self.loop = asyncio.new_event_loop()
asyncio.set_event_loop(self.loop)
self._create_loop()
try:
self.start()
self.address = None
tasks = self.start()
finally:
# Always put the server address to wake up the parent task
queue.put(self.address)
queue.close()
if prefunc is not None:
prefunc(self, *args)
self._serve_forever()
self._serve_forever(tasks)
if sys.version_info >= (3, 6):
self.loop.run_until_complete(self.loop.shutdown_asyncgens())