Source code for aiospamc.connections

"""ConnectionManager classes for TCP and Unix sockets."""

from __future__ import annotations

import asyncio
import ssl
from enum import Enum, auto
from getpass import getpass
from pathlib import Path
from typing import Any, Callable, Optional, Tuple, Union

import certifi
import loguru
from loguru import logger

from .exceptions import AIOSpamcConnectionFailed, ClientTimeoutException


[docs]class Timeout: """Container object for defining timeouts."""
[docs] def __init__( self, total: float = 600, connection: Optional[float] = None, response: Optional[float] = None, ) -> None: """Timeout constructor. :param total: The total length of time in seconds to set the timeout. :param connection: The length of time in seconds to allow for a connection to live before timing out. :param response: The length of time in seconds to allow for a response from the server before timing out. """ self.total = float(total) self.connection = connection self.response = response
def __repr__(self): return ( f"{self.__class__.__qualname__}(" f"total={self.total}, " f"connection={self.connection}, " f"response={self.response}" ")" )
[docs]class ConnectionManagerBuilder: """Builder for connection managers."""
[docs] class ManagerType(Enum): """Define connection manager type during build.""" Undefined = auto() Tcp = auto() Unix = auto()
[docs] def __init__(self): """ConnectionManagerBuilder constructor.""" self._manager_type = self.ManagerType.Undefined self._tcp_builder = TcpConnectionManagerBuilder() self._unix_builder = UnixConnectionManagerBuilder() self._ssl_builder = SSLContextBuilder() self._ssl = False self._timeout = None
[docs] def build(self) -> Union[UnixConnectionManager, TcpConnectionManager]: """Builds the :class:`aiospamc.connections.ConnectionManager`. :return: An instance of :class:`aiospamc.connections.TcpConnectionManager` or :class:`aiospamc.connections.UnixConnectionManager` """ if self._manager_type is self.ManagerType.Undefined: raise ValueError( "Connection type is undefined, builder must be called with 'with_unix_socket' or 'with_tcp'" ) elif self._manager_type is self.ManagerType.Tcp: ssl_context = None if not self._ssl else self._ssl_builder.build() self._tcp_builder.set_ssl_context(ssl_context) return self._tcp_builder.set_timeout(self._timeout).build() else: return self._unix_builder.set_timeout(self._timeout).build()
[docs] def with_unix_socket(self, path: Path) -> ConnectionManagerBuilder: """Configures the builder to use a Unix socket connection. :param path: Path to the Unix socket. :return: This builder instance. """ self._manager_type = self.ManagerType.Unix self._unix_builder.set_path(path) self._tcp_host = self._tcp_port = None return self
[docs] def with_tcp(self, host: str, port: int = 783) -> ConnectionManagerBuilder: """Configures the builder to use a TCP connection. :param host: Hostname to use. :param port: Port to use. :return: This builder instance. """ self._manager_type = self.ManagerType.Tcp self._tcp_builder.set_host(host).set_port(port) self._unix_path = None return self
[docs] def add_ssl_context(self, context: ssl.SSLContext) -> ConnectionManagerBuilder: """Adds an SSL context when a TCP connection is being used. :param context: :class:`ssl.SSLContext` instance. :return: This builder instance. """ self._ssl_builder.with_context(context) self._ssl = True return self
[docs] def set_timeout(self, timeout: Timeout) -> ConnectionManagerBuilder: """Sets the timeout for the connection. :param timeout: Timeout object. :return: This builder instance. """ self._timeout = timeout return self
[docs]class ConnectionManager: """Stores connection parameters and creates connections."""
[docs] def __init__( self, connection_string: str, timeout: Optional[Timeout] = None ) -> None: """ConnectionManager constructor. :param timeout: Timeout configuration """ self._connection_string = connection_string self.timeout = timeout or Timeout() self._logger = logger.bind( connection_string=self.connection_string, timeout=self.timeout, )
@property def logger(self) -> loguru.Logger: """Return the logger object.""" return self._logger
[docs] async def request(self, data: bytes) -> bytes: """Send bytes data and receive a response. :raises: AIOSpamcConnectionFailed :raises: ClientTimeoutException :param data: Data to send. """ try: response = await asyncio.wait_for(self._send(data), self.timeout.total) except asyncio.TimeoutError: self.logger.exception("Total timeout reached") raise return response
async def _send(self, data: bytes) -> bytes: """Opens a connection, sends data to the writer, waits for the reader, then returns the response. :param data: Data to send. :return: Byte data from the response. """ reader, writer = await self._connect() writer.write(data) if writer.can_write_eof(): writer.write_eof() await writer.drain() response = await self._receive(reader) writer.close() await writer.wait_closed() return response async def _receive(self, reader: asyncio.StreamReader) -> bytes: """Takes a reader and returns the response. :param reader: asyncio reader. :return: Byte data from the response. """ try: response = await asyncio.wait_for(reader.read(), self.timeout.response) except asyncio.TimeoutError as error: self.logger.exception("Timed out receiving data") raise ClientTimeoutException from error self.logger.success("Successfully received data") return response async def _connect(self) -> Tuple[asyncio.StreamReader, asyncio.StreamWriter]: """Opens a connection from the connection manager. :return: Tuple or asyncio reader and writer. """ try: reader, writer = await asyncio.wait_for( self.open(), self.timeout.connection ) except asyncio.TimeoutError as error: self.logger.exception("Timeout when connecting") raise ClientTimeoutException from error self.logger.success("Successfully connected") return reader, writer
[docs] async def open(self) -> Tuple[asyncio.StreamReader, asyncio.StreamWriter]: """Opens a connection, returning the reader and writer objects.""" raise NotImplementedError
@property def connection_string(self) -> str: """String representation of the connection.""" return self._connection_string
[docs]class TcpConnectionManagerBuilder: """Builder for :class:`aiospamc.connections.TcpConnectionManager`"""
[docs] def __init__(self): """`TcpConnectionManagerBuilder` constructor.""" self._args = {}
[docs] def build(self) -> TcpConnectionManager: """Builds the :class:`aiospamc.connections.TcpConnectionManager`. :return: An instance of :class:`aiospamc.connections.TcpConnectionManager`. """ return TcpConnectionManager(**self._args)
[docs] def set_host(self, host: str) -> TcpConnectionManagerBuilder: """Sets the host to use. :param host: Hostname to use. :return: This builder instance. """ self._args["host"] = host return self
[docs] def set_port(self, port: int) -> TcpConnectionManagerBuilder: """Sets the port to use. :param port: Port to use. :return: This builder instance. """ self._args["port"] = port return self
[docs] def set_ssl_context(self, context: ssl.SSLContext) -> TcpConnectionManagerBuilder: """Set an SSL context. :param context: An instance of :class:`ssl.SSLContext`. :return: This builder instance. """ self._args["ssl_context"] = context return self
[docs] def set_timeout(self, timeout: Timeout) -> TcpConnectionManagerBuilder: """Sets the timeout for the connection. :param timeout: Timeout object. :return: This builder instance. """ self._args["timeout"] = timeout return self
[docs]class TcpConnectionManager(ConnectionManager): """Connection manager for TCP connections."""
[docs] def __init__( self, host: str, port: int, ssl_context: Optional[ssl.SSLContext] = None, timeout: Optional[Timeout] = None, ) -> None: """TcpConnectionManager constructor. :param host: Hostname or IP address. :param port: TCP port. :param ssl_context: SSL context. :param timeout: Timeout configuration. """ super().__init__(f"{host}:{port}", timeout) self.host = host self.port = port self.ssl_context = ssl_context
[docs] async def open(self) -> Tuple[asyncio.StreamReader, asyncio.StreamWriter]: """Opens a TCP connection. :raises: AIOSpamcConnectionFailed :return: Reader and writer for the connection. """ try: reader, writer = await asyncio.open_connection( self.host, self.port, ssl=self.ssl_context ) except (ConnectionRefusedError, OSError) as error: self.logger.exception("Exception occurred when connecting") raise AIOSpamcConnectionFailed from error return reader, writer
[docs]class UnixConnectionManagerBuilder: """Builder for :class:`aiospamc.connections.UnixConnectionManager`."""
[docs] def __init__(self): """`UnixConnectionManagerBuilder` constructor.""" self._args = {}
[docs] def build(self) -> UnixConnectionManager: """Builds a :class:`aiospamc.connections.UnixConnectionManager`. :return: An instance of :class:`aiospamc.connections.UnixConnectionManager`. """ return UnixConnectionManager(**self._args)
[docs] def set_path(self, path: Path) -> UnixConnectionManagerBuilder: """Sets the unix socket path. :param path: Path to the Unix socket. :return: This builder instance. """ self._args["path"] = path return self
[docs] def set_timeout(self, timeout: Timeout) -> UnixConnectionManagerBuilder: """Sets the timeout for the connection. :param timeout: Timeout object. :return: This builder instance. """ self._args["timeout"] = timeout return self
[docs]class UnixConnectionManager(ConnectionManager): """Connection manager for Unix pipes."""
[docs] def __init__(self, path: Path, timeout: Optional[Timeout] = None): """UnixConnectionManager constructor. :param path: Unix socket path. :param timeout: Timeout configuration """ super().__init__(str(path), timeout) self.path = path
[docs] async def open(self) -> Tuple[asyncio.StreamReader, asyncio.StreamWriter]: """Opens a unix socket path connection. :raises: AIOSpamcConnectionFailed :return: Reader and writer for the connection. """ try: reader, writer = await asyncio.open_unix_connection(self.path) except (ConnectionRefusedError, OSError) as error: self.logger.exception("Exception occurred when connecting") raise AIOSpamcConnectionFailed from error return reader, writer
[docs]class SSLContextBuilder: """SSL context builder."""
[docs] def __init__(self): """Builder contstructor. Sets up a default SSL context.""" self._context = ssl.create_default_context()
[docs] def build(self) -> ssl.SSLContext: """Builds the SSL context. :return: An instance of :class:`ssl.SSLContext`. """ return self._context
[docs] def with_context(self, context: ssl.SSLContext) -> SSLContextBuilder: """Use the SSL context. :param context: Provided SSL context. :return: The builder instance. """ self._context = context return self
[docs] def add_ca_file(self, file: Path) -> SSLContextBuilder: """Add certificate authority from a file. :param file: File of concatenated certificates. :return: The builder instance. """ self._context.load_verify_locations(cafile=file) return self
[docs] def add_ca_dir(self, dir: Path) -> SSLContextBuilder: """Add certificate authority from a directory. :param dir: Directory of certificates. :return: The builder instance. """ self._context.load_verify_locations(capath=dir) return self
[docs] def add_ca(self, path: Path) -> SSLContextBuilder: """Add a certificate authority. :param path: Directory or file of certificates. :return: The builder instance. """ if path.is_dir(): return self.add_ca_dir(path) elif path.is_file(): return self.add_ca_file(path) else: raise FileNotFoundError(path)
[docs] def add_default_ca(self) -> SSLContextBuilder: """Add default certificate authorities. :return: The builder instance. """ self._context.load_verify_locations(cafile=certifi.where()) return self
[docs] def add_client( self, file: Path, key: Optional[Path] = None, password: Optional[Callable[[], Union[str, bytes, bytearray]]] = None, ) -> SSLContextBuilder: """Add client certificate. :param file: Path to the client certificate. :param key: Path to the key. :param password: Callable that returns the password, if any. """ self._context.load_cert_chain(file, key, password) return self
[docs] def dont_verify(self) -> SSLContextBuilder: """Set the context to not verify certificates.""" self._context.check_hostname = False self._context.verify_mode = ssl.CERT_NONE return self