"""
A remote debugging module for the python debugger pdb
"""
import pdb
import socket
import sys
import os
import selectors
import fcntl
import logging
from typing import Any, Optional, Iterator, TextIO, Mapping
from io import RawIOBase
from contextlib import contextmanager
from pytb.config import current_config as pytb_config
@contextmanager
def _run_mainsafe() -> Iterator[None]:
"""
this contextmanager backs up the ``__main__`` module's ``__dict__``
before entering the context and makes sure the original state is restored
before exiting from the context.
This enables :meth:`_runscript` and :meth:`_runmodule` to be called from
``__main__`` which would otherwise not work as those methods clear the original
``__dict__``
"""
main_backup = globals().copy()
try:
yield
finally:
globals().clear()
globals().update(main_backup)
class Rdb(pdb.Pdb):
"""
:param host: Host interface to bind the remote socket to.
If None, the key `bind_to` from the current :class:`pytb.config.Config` s
`[rdb]` section is used
:param port: Port to listen for incoming connections
If None, the key `port` from the current :class:`pytb.config.Config` s
`[rdb]` section is used
:param patch_stdio: redirect this process' stdin, stdout and stderr to the remote
debugging client. If None, the key `patch_stdio` from the current
:class:`pytb.config.Config` s `[rdb]` section is used
:param **kwargs: passed to the parent Pdb class, except ``stdin`` and
``stdout`` are always overwritten by the remote socket
"""
# pylint: disable=protected-access
_std_streams = ["stdin", "stdout", "stderr"]
"""
streams to redirect to the socket on connection
"""
_session = None
"""
A global session of the debugger. It is used to keep alive the session between multiple calls to
set_trace() when the session originally was continued by the user
"""
def __init__(
self,
host: Optional[str] = None,
port: Optional[int] = None,
patch_stdio: Optional[bool] = None,
**kwargs: Any,
):
Rdb._session = self
_logger = logging.getLogger(
f"{self.__class__.__module__}.{self.__class__.__name__}"
)
# load the parameters from the config
config = pytb_config["rdb"]
host = config.get("bind_to") if host is None else host
port = int(config.get("port")) if port is None else port
patch_stdio = (
config.getboolean("patch_stdio") if patch_stdio is None else patch_stdio
)
listen_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
listen_socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, True)
listen_socket.bind((host, port))
_logger.info(
f"Started Remote debug session on {host}:{port}. Waiting for connection..."
)
listen_socket.listen(1)
connection, address = listen_socket.accept()
_logger.info(f"new connection from {address[0]}:{address[1]}")
self.connection_file = connection.makefile("rw")
kwargs["stdin"] = self.connection_file
kwargs["stdout"] = self.connection_file
super().__init__(**kwargs)
self.stdio_patched = patch_stdio
if patch_stdio:
self.original_streams: Mapping[str, TextIO] = {}
for stream in Rdb._std_streams:
self.original_streams[stream] = getattr(sys, stream)
setattr(sys, stream, self.connection_file)
self.prompt = f"RDB@{socket.gethostname()}:{port} >>> "
def _flush_outputs(self) -> None:
"""
Flush all currently installed stdio streams forwarded to the socket or not
"""
for stream in self._std_streams:
getattr(sys, stream).flush()
def _cleanup(self) -> None:
"""
Quit from the debugger. The remote connection is closed
and the stdio streams are restored to their original state
"""
self._flush_outputs()
if self.stdio_patched:
for stream in Rdb._std_streams:
setattr(sys, stream, self.original_streams[stream])
self.stdin = sys.stdin
self.stdout = sys.stdout
self.connection_file.close()
Rdb._session = None
def do_continue(self, arg: Any) -> Any:
self._flush_outputs()
return super().do_continue(arg)
do_c = do_cont = do_continue
def do_EOF(self, arg: Any) -> Any:
self._cleanup()
return super().do_EOF(arg)
def do_quit(self, arg: Any) -> Any:
self._cleanup()
return super().do_quit(arg)
do_q = do_exit = do_quit
def _runscript(self, filename: str) -> None:
with _run_mainsafe():
super()._runscript(filename)
def _runmodule(self, module_name: str) -> None:
with _run_mainsafe():
super()._runmodule(module_name)
[docs]def set_trace(
*args: Any,
host: Optional[str] = None,
port: Optional[int] = None,
patch_stdio: Optional[bool] = None,
**kwargs: Any,
) -> None:
"""
Opens a remote PDB on the specified host and port if no session is running.
If a session is already running (was started previously and a client is still connected)
the session is reused instead.
:param patch_stdio: When true, redirects stdout, stderr and stdin to the remote socket.
"""
# pylint: disable=protected-access
if Rdb._session is None:
if host is None:
host = os.environ.get("REMOTE_PDB_HOST", None)
if port is None:
env_port = os.environ.get("REMOTE_PDB_PORT", None)
if env_port is not None:
port = int(env_port)
Rdb._session = Rdb(host=host, port=port, patch_stdio=patch_stdio)
Rdb._session.set_trace(*args, **kwargs)
_previous_breakpoint_hook = None # pylint: disable=invalid-name
[docs]def install_hook() -> None:
"""
Installs the remote debugger as standard debugging method and calls
it when using the builtin `breakpoint()`
"""
_previous_breakpoint_hook = sys.breakpointhook
sys.breakpointhook = set_trace
[docs]def uninstall_hook() -> None:
"""
Restore the original state of sys.breakpointhook.
If :meth:`install_hook` was never called before, this is a noop
"""
if _previous_breakpoint_hook is not None:
sys.breakpointhook = _previous_breakpoint_hook
[docs]class RdbClient:
"""
A simple ``netcat`` like socket client that can be used as a convenience
wrapper to connect to a remote debugger session.
If `host` or `port` are unspecified, they are laoded from the current
:class:`pytb.config.Config` s `[rdb]` section
"""
# pylint: disable=too-few-public-methods
_selector = selectors.DefaultSelector()
def __init__(self, host: Optional[str] = None, port: Optional[int] = None):
# load the parameters from the config
config = pytb_config["rdb"]
host = config.get("host") if host is None else host
port = int(config.get("port")) if port is None else port
self.socket = socket.create_connection((host, port))
self.socket_closed = False
self.socket.setblocking(False)
self.stdin = sys.stdin
self.stdout = sys.stdout.buffer.raw # type: ignore
# make stdin non-blocking to multiplex reading with the socket
orig_fl = fcntl.fcntl(self.stdin, fcntl.F_GETFL)
fcntl.fcntl(self.stdin, fcntl.F_SETFL, orig_fl | os.O_NONBLOCK)
# install the mulitplexing selector on the socket and stdin (stdout still is blocking)
RdbClient._selector.register(
self.socket, selectors.EVENT_READ | selectors.EVENT_WRITE, self._handle_io
)
RdbClient._selector.register(self.stdin, selectors.EVENT_READ, self._handle_io)
RdbClient._selector.register(
self.stdout, selectors.EVENT_WRITE, self._handle_io
)
# create an empty string buffer
self.socketbuf = bytearray()
self.stdoutbuf = bytearray()
# loop until the socket is closed and the stdout buffer is empty
while not self.socket_closed or self.stdoutbuf:
# wait for I/O
events = RdbClient._selector.select()
for key, mask in events:
callback = key.data
callback(key.fileobj, mask)
def _handle_io(self, stream: RawIOBase, mask: int) -> None:
if stream is self.socket:
if mask & selectors.EVENT_READ:
data_read = self.socket.recv(1024)
if not data_read:
self.socket_closed = True
self.stdoutbuf += data_read
elif mask & selectors.EVENT_WRITE:
if self.socketbuf:
sent = self.socket.send(self.socketbuf)
self.socketbuf = self.socketbuf[sent:]
elif stream is self.stdin:
self.socketbuf += self.stdin.read().encode(sys.stdout.encoding)
elif stream is self.stdout:
sent = self.stdout.write(self.stdoutbuf)
self.stdoutbuf = self.stdoutbuf[sent:]