1
0
mirror of https://git.yoctoproject.org/poky synced 2026-06-02 01:19:52 +00:00

bitbake: hashserv: Add websocket connection implementation

Adds support to the hash equivalence client and server to communicate
over websockets. Since websockets are message orientated instead of
stream orientated, and new connection class is needed to handle them.

Note that websocket support does require the 3rd party websockets python
module be installed on the host, but it should not be required unless
websockets are actually being used.

(Bitbake rev: 56dd2fdbfb6350a9eef43a12aa529c8637887a7e)

Signed-off-by: Joshua Watt <JPEWhacker@gmail.com>
Signed-off-by: Richard Purdie <richard.purdie@linuxfoundation.org>
This commit is contained in:
Joshua Watt
2023-11-03 08:26:20 -06:00
committed by Richard Purdie
parent 8f8501ed40
commit 2484bd8931
6 changed files with 137 additions and 2 deletions
+10 -1
View File
@@ -10,7 +10,7 @@ import json
import os import os
import socket import socket
import sys import sys
from .connection import StreamConnection, DEFAULT_MAX_CHUNK from .connection import StreamConnection, WebsocketConnection, DEFAULT_MAX_CHUNK
from .exceptions import ConnectionClosedError from .exceptions import ConnectionClosedError
@@ -47,6 +47,15 @@ class AsyncClient(object):
self._connect_sock = connect_sock self._connect_sock = connect_sock
async def connect_websocket(self, uri):
import websockets
async def connect_sock():
websocket = await websockets.connect(uri, ping_interval=None)
return WebsocketConnection(websocket, self.timeout)
self._connect_sock = connect_sock
async def setup_connection(self): async def setup_connection(self):
# Send headers # Send headers
await self.socket.send("%s %s" % (self.proto_name, self.proto_version)) await self.socket.send("%s %s" % (self.proto_name, self.proto_version))
+44
View File
@@ -93,3 +93,47 @@ class StreamConnection(object):
if self.writer is not None: if self.writer is not None:
self.writer.close() self.writer.close()
self.writer = None self.writer = None
class WebsocketConnection(object):
def __init__(self, socket, timeout):
self.socket = socket
self.timeout = timeout
@property
def address(self):
return ":".join(str(s) for s in self.socket.remote_address)
async def send_message(self, msg):
await self.send(json.dumps(msg))
async def recv_message(self):
m = await self.recv()
return json.loads(m)
async def send(self, msg):
import websockets.exceptions
try:
await self.socket.send(msg)
except websockets.exceptions.ConnectionClosed:
raise ConnectionClosedError("Connection closed")
async def recv(self):
import websockets.exceptions
try:
if self.timeout < 0:
return await self.socket.recv()
try:
return await asyncio.wait_for(self.socket.recv(), self.timeout)
except asyncio.TimeoutError:
raise ConnectionError("Timed out waiting for data")
except websockets.exceptions.ConnectionClosed:
raise ConnectionClosedError("Connection closed")
async def close(self):
if self.socket is not None:
await self.socket.close()
self.socket = None
+52 -1
View File
@@ -12,7 +12,7 @@ import signal
import socket import socket
import sys import sys
import multiprocessing import multiprocessing
from .connection import StreamConnection from .connection import StreamConnection, WebsocketConnection
from .exceptions import ClientError, ServerError, ConnectionClosedError from .exceptions import ClientError, ServerError, ConnectionClosedError
@@ -178,6 +178,54 @@ class UnixStreamServer(StreamServer):
os.unlink(self.path) os.unlink(self.path)
class WebsocketsServer(object):
def __init__(self, host, port, handler, logger):
self.host = host
self.port = port
self.handler = handler
self.logger = logger
def start(self, loop):
import websockets.server
self.server = loop.run_until_complete(
websockets.server.serve(
self.client_handler,
self.host,
self.port,
ping_interval=None,
)
)
for s in self.server.sockets:
self.logger.debug("Listening on %r" % (s.getsockname(),))
# Enable keep alives. This prevents broken client connections
# from persisting on the server for long periods of time.
s.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1)
s.setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPIDLE, 30)
s.setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPINTVL, 15)
s.setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPCNT, 4)
name = self.server.sockets[0].getsockname()
if self.server.sockets[0].family == socket.AF_INET6:
self.address = "ws://[%s]:%d" % (name[0], name[1])
else:
self.address = "ws://%s:%d" % (name[0], name[1])
return [self.server.wait_closed()]
async def stop(self):
self.server.close()
def cleanup(self):
pass
async def client_handler(self, websocket):
socket = WebsocketConnection(websocket, -1)
await self.handler(socket)
class AsyncServer(object): class AsyncServer(object):
def __init__(self, logger): def __init__(self, logger):
self.logger = logger self.logger = logger
@@ -190,6 +238,9 @@ class AsyncServer(object):
def start_unix_server(self, path): def start_unix_server(self, path):
self.server = UnixStreamServer(path, self._client_handler, self.logger) self.server = UnixStreamServer(path, self._client_handler, self.logger)
def start_websocket_server(self, host, port):
self.server = WebsocketsServer(host, port, self._client_handler, self.logger)
async def _client_handler(self, socket): async def _client_handler(self, socket):
try: try:
client = self.accept_client(socket) client = self.accept_client(socket)
+13
View File
@@ -9,11 +9,15 @@ import re
import sqlite3 import sqlite3
import itertools import itertools
import json import json
from urllib.parse import urlparse
UNIX_PREFIX = "unix://" UNIX_PREFIX = "unix://"
WS_PREFIX = "ws://"
WSS_PREFIX = "wss://"
ADDR_TYPE_UNIX = 0 ADDR_TYPE_UNIX = 0
ADDR_TYPE_TCP = 1 ADDR_TYPE_TCP = 1
ADDR_TYPE_WS = 2
UNIHASH_TABLE_DEFINITION = ( UNIHASH_TABLE_DEFINITION = (
("method", "TEXT NOT NULL", "UNIQUE"), ("method", "TEXT NOT NULL", "UNIQUE"),
@@ -84,6 +88,8 @@ def setup_database(database, sync=True):
def parse_address(addr): def parse_address(addr):
if addr.startswith(UNIX_PREFIX): if addr.startswith(UNIX_PREFIX):
return (ADDR_TYPE_UNIX, (addr[len(UNIX_PREFIX):],)) return (ADDR_TYPE_UNIX, (addr[len(UNIX_PREFIX):],))
elif addr.startswith(WS_PREFIX) or addr.startswith(WSS_PREFIX):
return (ADDR_TYPE_WS, (addr,))
else: else:
m = re.match(r'\[(?P<host>[^\]]*)\]:(?P<port>\d+)$', addr) m = re.match(r'\[(?P<host>[^\]]*)\]:(?P<port>\d+)$', addr)
if m is not None: if m is not None:
@@ -103,6 +109,9 @@ def create_server(addr, dbname, *, sync=True, upstream=None, read_only=False):
(typ, a) = parse_address(addr) (typ, a) = parse_address(addr)
if typ == ADDR_TYPE_UNIX: if typ == ADDR_TYPE_UNIX:
s.start_unix_server(*a) s.start_unix_server(*a)
elif typ == ADDR_TYPE_WS:
url = urlparse(a[0])
s.start_websocket_server(url.hostname, url.port)
else: else:
s.start_tcp_server(*a) s.start_tcp_server(*a)
@@ -116,6 +125,8 @@ def create_client(addr):
(typ, a) = parse_address(addr) (typ, a) = parse_address(addr)
if typ == ADDR_TYPE_UNIX: if typ == ADDR_TYPE_UNIX:
c.connect_unix(*a) c.connect_unix(*a)
elif typ == ADDR_TYPE_WS:
c.connect_websocket(*a)
else: else:
c.connect_tcp(*a) c.connect_tcp(*a)
@@ -128,6 +139,8 @@ async def create_async_client(addr):
(typ, a) = parse_address(addr) (typ, a) = parse_address(addr)
if typ == ADDR_TYPE_UNIX: if typ == ADDR_TYPE_UNIX:
await c.connect_unix(*a) await c.connect_unix(*a)
elif typ == ADDR_TYPE_WS:
await c.connect_websocket(*a)
else: else:
await c.connect_tcp(*a) await c.connect_tcp(*a)
+1
View File
@@ -115,6 +115,7 @@ class Client(bb.asyncrpc.Client):
super().__init__() super().__init__()
self._add_methods( self._add_methods(
"connect_tcp", "connect_tcp",
"connect_websocket",
"get_unihash", "get_unihash",
"report_unihash", "report_unihash",
"report_unihash_equiv", "report_unihash_equiv",
+17
View File
@@ -483,3 +483,20 @@ class TestHashEquivalenceTCPServer(HashEquivalenceTestSetup, HashEquivalenceComm
# If IPv6 is enabled, it should be safe to use localhost directly, in general # If IPv6 is enabled, it should be safe to use localhost directly, in general
# case it is more reliable to resolve the IP address explicitly. # case it is more reliable to resolve the IP address explicitly.
return socket.gethostbyname("localhost") + ":0" return socket.gethostbyname("localhost") + ":0"
class TestHashEquivalenceWebsocketServer(HashEquivalenceTestSetup, HashEquivalenceCommonTests, unittest.TestCase):
def setUp(self):
try:
import websockets
except ImportError as e:
self.skipTest(str(e))
super().setUp()
def get_server_addr(self, server_idx):
# Some hosts cause asyncio module to misbehave, when IPv6 is not enabled.
# If IPv6 is enabled, it should be safe to use localhost directly, in general
# case it is more reliable to resolve the IP address explicitly.
host = socket.gethostbyname("localhost")
return "ws://%s:0" % host