Source code for zcached.asyncio.connection

from __future__ import annotations
from typing import Any, Type, Generic, TypeVar

import asyncio
import logging as logger

from ..connection import Connection
from ..result import Result
from ..enums import Errors

ProtocolT = TypeVar("ProtocolT", bound=asyncio.StreamReaderProtocol)


[docs] class AsyncConnection(Connection, Generic[ProtocolT]): """ An asynchronous connection object to manage communication with the server. Parameters ---------- host: The host address of the server to connect to. port: The port number of the server to connect to. connection_attempts: The maximum number of attempts to establish a connection with the server. If the maximum number of attempts is exceeded, an error will be raised. reconnect: Flag indicating whether automatic reconnection attempt should be made in case of a broken connection. timeout_limit: The maximum time in seconds to wait for a response from the server. buffer_size: The size of the buffer for receiving data from the server, in bytes. Larger values for buffer_size may allow for more data to be received in a single operation, while smaller values can be more memory-efficient but slower. loop: The event loop to run asynchronous tasks. If None, the default event loop will be used. protocol_type: The protocol type which is used to building protocol for managing the connection. Attributes ---------- connection_attempts: The maximum number of attempts to establish a connection with the server. reconnect: Flag indicating whether automatic reconnection attempt should be made in case of a broken connection. timeout_limit: The maximum time in seconds to wait for a response from the server. buffer_size: The size of the buffer for receiving data from the server, in bytes. loop: The event loop to run asynchronous tasks. """ __slots__ = ( "loop", "_protocol_type", "_reader", "_writer", "_protocol", ) def __init__( self, host: str, port: int, connection_attempts: int = 3, reconnect: bool = True, timeout_limit: int = 15, buffer_size: int = 2048, loop: asyncio.AbstractEventLoop | None = None, protocol_type: Type[ProtocolT] | None = None, ) -> None: super().__init__( host=host, port=port, connection_attempts=connection_attempts, reconnect=reconnect, timeout_limit=timeout_limit, buffer_size=buffer_size, ) self.loop: asyncio.AbstractEventLoop = loop or asyncio.get_event_loop() self._protocol_type: Type[ProtocolT] = ( # pyright: ignore protocol_type or asyncio.StreamReaderProtocol ) self._protocol: ProtocolT | None = None self._reader: asyncio.StreamReader | None = None self._writer: asyncio.StreamWriter | None = None self._lock: asyncio.Lock = asyncio.Lock() def __repr__(self) -> str: return f"<AsyncConnection(host={self.host}, port={self.port}, buffer_size={self.buffer_size}, id={self.id})>" @property def protocol(self) -> ProtocolT | None: """The protocol for managing the connection. If available.""" return self._protocol @property def protocol_type(self) -> Type[ProtocolT]: """The type of protocol used for managing the connection.""" return self._protocol_type @property def reader(self) -> asyncio.StreamReader | None: """The asyncio.StreamReader object for reading data from the server. If available.""" return self._reader @property def writer(self) -> asyncio.StreamWriter | None: """The asyncio.StreamWriter object for writing data to the server. If available.""" return self._writer @property def transport(self) -> None | asyncio.WriteTransport: """The transport object for the connection, if StreamWriter is available.""" if self._writer is not None: return self._writer.transport
[docs] async def connect(self) -> None: """Coroutine to establish a connection with the server asynchronously.""" logger.debug(f"{self.id} -> Connecting to {self.host}:{self.port}") for attempt, timeout in enumerate(self._backoff): try: self._reader, self._writer = await self.open_connection(host=self.host, port=self.port) logger.info(f"{self.id} -> Connected to the server.") self._connected = True break except Exception as exception: logger.exception(exception) if attempt + 1 >= self.connection_attempts or not self.reconnect: break logger.warning(f"{self.id} -> Connecting to the server failed. Retrying...") await asyncio.sleep(timeout)
[docs] async def open_connection( self, host: str, port: int, **kwargs: Any ) -> tuple[asyncio.StreamReader, asyncio.StreamWriter]: """ Coroutine to open a connection to the server. Parameters ---------- host: The host address of the server. port: The port number of the server. **kwargs: Additional keyword arguments to pass to the connection setup. """ logger.debug(f"{self.id} -> Creating a new connection...") reader: asyncio.StreamReader = asyncio.StreamReader(loop=self.loop) protocol = self.protocol_type(stream_reader=reader, loop=self.loop) transport, _ = await self.loop.create_connection( protocol_factory=lambda: protocol, # pyright: ignore host=host, port=port, **kwargs, ) writer: asyncio.StreamWriter = asyncio.StreamWriter( transport=transport, protocol=protocol, reader=reader, loop=self.loop ) self._protocol = protocol logger.debug(f"{self.id} -> Created a new connection.") return reader, writer
[docs] async def try_reconnect(self) -> Result[bytes]: """A method to attempt to reconnect to the server if the connection is broken.""" logger.debug(f"{self.id} -> Attempting to reconnect to the server...") # Without this, it somehow manages to establish a non-working connection. await asyncio.sleep(1) self._connected = False await self.connect() if self.is_connected() is True: return Result.fail(Errors.ConnectionReestablished.value) return Result.fail(Errors.ConnectionClosed.value)
[docs] async def send(self, data: bytes) -> Result: """ Coroutine to send a data to the server. TASK SAFE. Parameters ---------- data: Bytes to send. """ if self._writer is None: logger.error( f"{self.id} -> Missing StreamWriter object! Did you forget to connect? Aborting the send method..." ) return Result.fail(Errors.ConnectionClosed.value) if self._lock.locked(): logger.debug("Waiting for the task lock to be released...") self._pending_requests += 1 async with self._lock: try: logger.debug(f"{self.id} -> Sending data: %s.", data) self._writer.write(data) await self._writer.drain() except (ConnectionError, OSError): logger.debug(f"{self.id} -> The connection has been terminated.") if not self.reconnect: return Result.fail(Errors.ConnectionClosed.value) return await self.try_reconnect() result: Result = await self.wait_for_response() if self.reconnect and result.error == Errors.ConnectionClosed: return await self.try_reconnect() return result
[docs] async def receive(self, timeout_limit: float | None = None) -> bytes | None: """ Coroutine to receive data from the reader. NOT TASK SAFE. Parameters ---------- timeout_limit: The maximum time in seconds to wait for a response from the server. Raises ------ asyncio.TimeoutError If the timeout limit has been exceeded. """ if self._reader is None: return logger.error( f"{self.id} -> Missing StreamReader object! Did you forget to connect? Aborting the receive method..." ) if timeout_limit is None: # If there is no specified time limit, and if there is no data to receive, # the reader will wait for it as long as needed. data: bytes = await self._reader.read(self.buffer_size) else: data: bytes = await asyncio.wait_for(self._reader.read(self.buffer_size), timeout=timeout_limit) logger.debug(f"{self.id} -> Received data: %s.", data) return data
[docs] async def wait_for_response(self) -> Result: """ Coroutine to wait for a complete response from the server asynchronously. NOT TASK SAFE. """ if not self._reader: return Result.fail(Errors.ConnectionClosed.value) complete_data: bytes = bytes() try: data: bytes | None = await self.receive(timeout_limit=self.timeout_limit) if data is None: self._connected = False return Result.fail(Errors.ConnectionClosed.value) except asyncio.TimeoutError: return Result.fail(Errors.TimeoutLimit.value) complete_data += data while True: try: data = await self.receive(timeout_limit=0.1) except asyncio.TimeoutError: break # Transfer complete. if data is None or len(data) == 0: # When socket lose connection to the server it receives empty bytes. self._connected = False return Result.fail(Errors.ConnectionClosed.value) complete_data += data if self._pending_requests >= 1: self._pending_requests -= 1 # If the first byte is "-", it means that the response is an error. if complete_data.startswith(b"-"): error_message: str = complete_data.decode()[1:-2] return Result.fail(error_message) return Result.ok(complete_data)
[docs] async def close(self) -> None: """Closes the connection by closing the writer, and waiting until the writer is fully closed.""" if self._writer: self._connected = False self._writer.close() await self._writer.wait_closed() self._writer = None self._pending_requests = 0 if self._reader: self._reader = None