Source code for system.webrtc.webrtcd

#!/usr/bin/env python3

import argparse
import asyncio
import json
import uuid
import logging
from dataclasses import dataclass, field
from typing import Any, TYPE_CHECKING

# aiortc and its dependencies have lots of internal warnings :(
import warnings
warnings.filterwarnings("ignore", category=DeprecationWarning)

import capnp
from aiohttp import web
if TYPE_CHECKING:
  from aiortc.rtcdatachannel import RTCDataChannel

from openpilot.system.webrtc.schema import generate_field
from cereal import messaging, log


[docs] class CerealOutgoingMessageProxy: def __init__(self, sm: messaging.SubMaster): self.sm = sm self.channels: list['RTCDataChannel'] = []
[docs] def add_channel(self, channel: 'RTCDataChannel'): self.channels.append(channel)
[docs] def to_json(self, msg_content: Any): if isinstance(msg_content, capnp._DynamicStructReader): msg_dict = msg_content.to_dict() elif isinstance(msg_content, capnp._DynamicListReader): msg_dict = [self.to_json(msg) for msg in msg_content] elif isinstance(msg_content, bytes): msg_dict = msg_content.decode() else: msg_dict = msg_content return msg_dict
[docs] def update(self): # this is blocking in async context... self.sm.update(0) for service, updated in self.sm.updated.items(): if not updated: continue msg_dict = self.to_json(self.sm[service]) mono_time, valid = self.sm.logMonoTime[service], self.sm.valid[service] outgoing_msg = {"type": service, "logMonoTime": mono_time, "valid": valid, "data": msg_dict} encoded_msg = json.dumps(outgoing_msg).encode() for channel in self.channels: channel.send(encoded_msg)
[docs] class CerealIncomingMessageProxy: def __init__(self, pm: messaging.PubMaster): self.pm = pm
[docs] def send(self, message: bytes): msg_json = json.loads(message) msg_type, msg_data = msg_json["type"], msg_json["data"] size = None if not isinstance(msg_data, dict): size = len(msg_data) msg = messaging.new_message(msg_type, size=size) setattr(msg, msg_type, msg_data) self.pm.send(msg_type, msg)
[docs] class CerealProxyRunner: def __init__(self, proxy: CerealOutgoingMessageProxy): self.proxy = proxy self.is_running = False self.task = None self.logger = logging.getLogger("webrtcd")
[docs] def start(self): assert self.task is None self.task = asyncio.create_task(self.run())
[docs] def stop(self): if self.task is None or self.task.done(): return self.task.cancel() self.task = None
[docs] async def run(self): from aiortc.exceptions import InvalidStateError while True: try: self.proxy.update() except InvalidStateError: self.logger.warning("Cereal outgoing proxy invalid state (connection closed)") break except Exception as ex: self.logger.error("Cereal outgoing proxy failure: %s", ex) await asyncio.sleep(0.01)
[docs] class DynamicPubMaster(messaging.PubMaster): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.lock = asyncio.Lock()
[docs] async def add_services_if_needed(self, services): async with self.lock: for service in services: if service not in self.sock: self.sock[service] = messaging.pub_sock(service)
[docs] class StreamSession: shared_pub_master = DynamicPubMaster([]) def __init__(self, sdp: str, cameras: list[str], incoming_services: list[str], outgoing_services: list[str], debug_mode: bool = False): from aiortc.mediastreams import VideoStreamTrack, AudioStreamTrack from aiortc.contrib.media import MediaBlackhole from openpilot.system.webrtc.device.video import LiveStreamVideoStreamTrack from openpilot.system.webrtc.device.audio import AudioInputStreamTrack, AudioOutputSpeaker from teleoprtc import WebRTCAnswerBuilder from teleoprtc.info import parse_info_from_offer config = parse_info_from_offer(sdp) builder = WebRTCAnswerBuilder(sdp) assert len(cameras) == config.n_expected_camera_tracks, "Incoming stream has misconfigured number of video tracks" for cam in cameras: track = LiveStreamVideoStreamTrack(cam) if not debug_mode else VideoStreamTrack() builder.add_video_stream(cam, track) if config.expected_audio_track: track = AudioInputStreamTrack() if not debug_mode else AudioStreamTrack() builder.add_audio_stream(track) if config.incoming_audio_track: self.audio_output_cls = AudioOutputSpeaker if not debug_mode else MediaBlackhole builder.offer_to_receive_audio_stream() self.stream = builder.stream() self.identifier = str(uuid.uuid4()) self.incoming_bridge: CerealIncomingMessageProxy | None = None self.incoming_bridge_services = incoming_services self.outgoing_bridge: CerealOutgoingMessageProxy | None = None self.outgoing_bridge_runner: CerealProxyRunner | None = None if len(incoming_services) > 0: self.incoming_bridge = CerealIncomingMessageProxy(self.shared_pub_master) if len(outgoing_services) > 0: self.outgoing_bridge = CerealOutgoingMessageProxy(messaging.SubMaster(outgoing_services)) self.outgoing_bridge_runner = CerealProxyRunner(self.outgoing_bridge) self.audio_output: AudioOutputSpeaker | MediaBlackhole | None = None self.run_task: asyncio.Task | None = None self.logger = logging.getLogger("webrtcd") self.logger.info("New stream session (%s), cameras %s, audio in %s out %s, incoming services %s, outgoing services %s", self.identifier, cameras, config.incoming_audio_track, config.expected_audio_track, incoming_services, outgoing_services)
[docs] def start(self): self.run_task = asyncio.create_task(self.run())
[docs] def stop(self): if self.run_task.done(): return self.run_task.cancel() self.run_task = None asyncio.run(self.post_run_cleanup())
[docs] async def get_answer(self): return await self.stream.start()
[docs] async def message_handler(self, message: bytes): assert self.incoming_bridge is not None try: self.incoming_bridge.send(message) except Exception as ex: self.logger.error("Cereal incoming proxy failure: %s", ex)
[docs] async def run(self): try: await self.stream.wait_for_connection() if self.stream.has_messaging_channel(): if self.incoming_bridge is not None: await self.shared_pub_master.add_services_if_needed(self.incoming_bridge_services) self.stream.set_message_handler(self.message_handler) if self.outgoing_bridge_runner is not None: channel = self.stream.get_messaging_channel() self.outgoing_bridge_runner.proxy.add_channel(channel) self.outgoing_bridge_runner.start() if self.stream.has_incoming_audio_track(): track = self.stream.get_incoming_audio_track(buffered=False) self.audio_output = self.audio_output_cls() self.audio_output.addTrack(track) self.audio_output.start() self.logger.info("Stream session (%s) connected", self.identifier) await self.stream.wait_for_disconnection() await self.post_run_cleanup() self.logger.info("Stream session (%s) ended", self.identifier) except Exception as ex: self.logger.error("Stream session failure: %s", ex)
[docs] async def post_run_cleanup(self): await self.stream.stop() if self.outgoing_bridge is not None: self.outgoing_bridge_runner.stop() if self.audio_output: self.audio_output.stop()
[docs] @dataclass class StreamRequestBody: sdp: str cameras: list[str] bridge_services_in: list[str] = field(default_factory=list) bridge_services_out: list[str] = field(default_factory=list)
[docs] async def get_stream(request: 'web.Request'): stream_dict, debug_mode = request.app['streams'], request.app['debug'] raw_body = await request.json() body = StreamRequestBody(**raw_body) session = StreamSession(body.sdp, body.cameras, body.bridge_services_in, body.bridge_services_out, debug_mode) answer = await session.get_answer() session.start() stream_dict[session.identifier] = session return web.json_response({"sdp": answer.sdp, "type": answer.type})
[docs] async def get_schema(request: 'web.Request'): services = request.query["services"].split(",") services = [s for s in services if s] assert all(s in log.Event.schema.fields and not s.endswith("DEPRECATED") for s in services), "Invalid service name" schema_dict = {s: generate_field(log.Event.schema.fields[s]) for s in services} return web.json_response(schema_dict)
[docs] async def on_shutdown(app: 'web.Application'): for session in app['streams'].values(): session.stop() del app['streams']
[docs] def webrtcd_thread(host: str, port: int, debug: bool): logging.basicConfig(level=logging.CRITICAL, handlers=[logging.StreamHandler()]) logging_level = logging.DEBUG if debug else logging.INFO logging.getLogger("WebRTCStream").setLevel(logging_level) logging.getLogger("webrtcd").setLevel(logging_level) app = web.Application() app['streams'] = dict() app['debug'] = debug app.on_shutdown.append(on_shutdown) app.router.add_post("/stream", get_stream) app.router.add_get("/schema", get_schema) web.run_app(app, host=host, port=port)
[docs] def main(): parser = argparse.ArgumentParser(description="WebRTC daemon") parser.add_argument("--host", type=str, default="0.0.0.0", help="Host to listen on") parser.add_argument("--port", type=int, default=5001, help="Port to listen on") parser.add_argument("--debug", action="store_true", help="Enable debug mode") args = parser.parse_args() webrtcd_thread(args.host, args.port, args.debug)
if __name__=="__main__": main()