from __future__ import annotations
import logging
from socket import socket, SOCK_STREAM, AF_INET
from threading import Lock
from time import sleep
from string import ascii_uppercase
from random import choice
from .backoff import ExponentialBackoff
from .result import Result
from .enums import Errors
[docs]
class Connection:
"""
An object to establish and manage a connection with the database server.
Parameters
----------
host:
The host address of the server to connect to.
port:
The port number of the server to connect to.
buffer_size:
The size of the buffer for receiving data from the server, in bytes.
Larger values for buff_size may allow for more data to be received in a single operation,
while smaller values can be more memory-efficient but slower.
connection_attempts:
The maximum number of attempts to establish a connection with the server.
If the maximum number of attempts is exceeded, an error will be raised.
reconnect:
Flag indicating whether automatic reconnection attempt should be made
in case of a broken connection.
timeout_limit:
The maximum time in seconds to wait for a response from the server.
Attributes
----------
socket:
The socket object for communicating with the server.
buffer_size:
The size of the buffer for receiving data from the server.
connection_attempts:
The maximum number of attempts to establish a connection with the server.
reconnect:
Flag indicating whether automatic reconnection attempt should be made
in case of a broken connection.
timeout_limit:
The maximum time in seconds to wait for a response from the server.
"""
__slots__ = (
"socket",
"buffer_size",
"connection_attempts",
"reconnect",
"timeout_limit",
"_backoff",
"_port",
"_host",
"_connected",
"_lock",
"_pending_requests",
"_id",
)
def __init__(
self,
host: str,
port: int,
connection_attempts: int = 3,
reconnect: bool = True,
timeout_limit: int = 15,
buffer_size: int = 2048,
) -> None:
self.socket: socket = socket(AF_INET, SOCK_STREAM)
self.buffer_size: int = buffer_size
self.connection_attempts: int = connection_attempts
self.reconnect: bool = reconnect
self.timeout_limit: int = timeout_limit
self._host: str = host
self._port: int = port
self._pending_requests: int = 0
self._connected: bool = False
self._backoff = ExponentialBackoff(0.5, 2, 3)
self._lock: Lock = Lock()
self._id: str = "".join([choice(ascii_uppercase) for _ in range(6)])
def __repr__(self) -> str:
return f"<Connection(host={self.host}, port={self.port}, buffer_size={self.buffer_size})>"
@property
def host(self) -> str:
"""Connection host."""
return self._host
@property
def port(self) -> int:
"""Connection port."""
return self._port
@property
def pending_requests(self) -> int:
"""The number of pending requests."""
return self._pending_requests
@property
def id(self) -> str:
"""Unique identifier for the connection."""
return f"#{self._id}-{self.port}"
[docs]
def is_locked(self) -> bool:
"""Whether the connection is locked."""
return self._lock.locked()
[docs]
def is_connected(self) -> bool:
"""
A boolean indicating whether the `connect` method was successfully invoked.
.. note::
This does not mean that the socket has a connection to the server.
The socket connection may be broken, and we do not update it constantly.
"""
return self._connected
[docs]
def connect(self) -> None:
"""
Method to connect a socket to the database server.
"""
logging.debug(f"{self.id} -> Connecting to {self.host}:{self.port}...")
for attempt, timeout in enumerate(self._backoff):
try:
self.socket.connect((self.host, self.port))
self.socket.setblocking(False)
logging.info(f"{self.id} -> Connected to the server.")
self._connected = True
break
except Exception as exception:
logging.exception(exception)
if attempt + 1 >= self.connection_attempts or not self.reconnect:
break
logging.warning(f"{self.id} -> Connecting to the server failed. Retrying...")
sleep(timeout)
[docs]
def receive(self) -> bytes | None:
"""
Method to receive the response from the server.
None if there is no data in the socket yet.
"""
try:
data: bytes = self.socket.recv(self.buffer_size)
logging.debug(f"{self.id} -> Received a response -> %s", data)
except (BlockingIOError, ConnectionAbortedError, OSError):
return None
return data
[docs]
def send(self, data: bytes) -> Result:
"""
Method to send data to the server.
THREAD SAFE.
Parameters
----------
data:
Bytes to send.
"""
if self._lock.locked():
logging.debug(f"{self.id} -> Waiting for the thread lock to become available.")
self._pending_requests += 1
with self._lock:
try:
logging.debug(f"{self.id} -> Sending data to the server -> %s", data)
self.socket.send(data)
except (BrokenPipeError, OSError):
if not self.reconnect:
return Result.fail(Errors.ConnectionClosed.value)
return self.try_reconnect()
result: Result = self.wait_for_response()
if not self.reconnect or result.error is None:
return result
if result.error == Errors.ConnectionClosed:
return self.try_reconnect()
return result
[docs]
def try_reconnect(self) -> Result[bytes]:
"""
A method to attempt to reconnect to the server if the connection is broken.
.. note::
If the connection is successfully reestablished, the method return a Result object
with a failure status and an informational message indicating that the connection
was terminated but managed to reestablish it.
"""
logging.debug(f"{self.id} -> Attempting to reconnect to the server...")
self.socket: socket = socket(AF_INET, SOCK_STREAM)
self._connected = False
self.connect()
if self.is_connected() is True:
return Result.fail(Errors.ConnectionReestablished.value)
return Result.fail(Errors.ConnectionClosed.value)
[docs]
def wait_for_response(self) -> Result:
"""
A loop to wait for the response from the server.
NOT THREAD SAFE.
"""
backoff: ExponentialBackoff = ExponentialBackoff(0.1, 1.5, 0.5)
total_bytes: bytes = bytes()
transfer_complete: bool = False
for timeout in backoff:
data: bytes | None = self.receive()
if not isinstance(data, bytes):
if len(total_bytes) > 0:
# If we already have some data, and this iteration gave us None,
# it means that the data transfer has been completed.
transfer_complete = True
else:
# We haven't received any data yet.
logging.debug(f"{self.id} -> There is no data in the socket. Timeout: {timeout}s.")
if backoff.total >= float(self.timeout_limit):
logging.error(f"{self.id} -> The waiting time limit for a response has been reached.")
return Result.fail(Errors.TimeoutLimit.value)
sleep(timeout)
continue
if transfer_complete:
if self._pending_requests >= 1:
self._pending_requests -= 1
# If the first byte is "-", it means that the response is an error.
if total_bytes.startswith(b"-"):
error_message: str = total_bytes.decode()[1:-2]
return Result.fail(error_message)
return Result.ok(total_bytes)
if len(data) == 0: # type: ignore
# When socket lose connection to the server it receives empty bytes.
return Result.fail(Errors.ConnectionClosed.value)
total_bytes += data # type: ignore
# ExponentialBackoff should be increased only when we receive None.
backoff.reset()
# This should never happen, but the type checker yells.
return Result.fail(Errors.LibraryBug.value)
[docs]
def close(self) -> None:
"""Method to close the connection."""
self._connected = False
self.socket.close()