import asyncio
import contextlib
import functools
import logging
from typing import Callable, Optional, Mapping, Any, Union, Tuple
from wpull.cache import FIFOCache
from wpull.errors import NetworkError
from wpull.network.connection import Connection, SSLConnection
from wpull.network.dns import Resolver, ResolveResult
_logger = logging.getLogger(__name__)
[docs]class HostPool(object):
'''Connection pool for a host.
Attributes:
ready (Queue): Connections not in use.
busy (set): Connections in use.
'''
def __init__(self, connection_factory: Callable[[], Connection],
max_connections: int=6):
assert max_connections > 0, \
'num must be positive. got {}'.format(max_connections)
self._connection_factory = connection_factory
self.max_connections = max_connections
self.ready = set()
self.busy = set()
self._lock = asyncio.Lock()
self._condition = asyncio.Condition(lock=self._lock)
self._closed = False
[docs] def empty(self) -> bool:
'''Return whether the pool is empty.'''
return not self.ready and not self.busy
@asyncio.coroutine
[docs] def clean(self, force: bool=False):
'''Clean closed connections.
Args:
force: Clean connected and idle connections too.
Coroutine.
'''
with (yield from self._lock):
for connection in tuple(self.ready):
if force or connection.closed():
connection.close()
self.ready.remove(connection)
[docs] def close(self):
'''Forcibly close all connections.
This instance will not be usable after calling this method.
'''
for connection in self.ready:
connection.close()
for connection in self.busy:
connection.close()
self._closed = True
[docs] def count(self) -> int:
'''Return total number of connections.'''
return len(self.ready) + len(self.busy)
@asyncio.coroutine
[docs] def acquire(self) -> Connection:
'''Register and return a connection.
Coroutine.
'''
assert not self._closed
yield from self._condition.acquire()
while True:
if self.ready:
connection = self.ready.pop()
break
elif len(self.busy) < self.max_connections:
connection = self._connection_factory()
break
else:
yield from self._condition.wait()
self.busy.add(connection)
self._condition.release()
return connection
@asyncio.coroutine
[docs] def release(self, connection: Connection, reuse: bool=True):
'''Unregister a connection.
Args:
connection: Connection instance returned from :meth:`acquire`.
reuse: If True, the connection is made available for reuse.
Coroutine.
'''
yield from self._condition.acquire()
self.busy.remove(connection)
if reuse:
self.ready.add(connection)
self._condition.notify()
self._condition.release()
[docs]class ConnectionPool(object):
'''Connection pool.
Args:
max_host_count: Number of connections per host.
resolver: DNS resolver.
connection_factory: A function that accepts ``address`` and
``hostname`` arguments and returns a :class:`Connection` instance.
ssl_connection_factory: A function that returns a
:class:`SSLConnection` instance. See `connection_factory`.
max_count: Limit on number of connections
'''
def __init__(self, max_host_count: int=6,
resolver: Optional[Resolver]=None,
connection_factory:
Optional[Callable[[tuple, str], Connection]]=None,
ssl_connection_factory:
Optional[Callable[[tuple, str], SSLConnection]]=None,
max_count: int=100):
self._max_host_count = max_host_count
self._resolver = resolver or Resolver()
self._connection_factory = connection_factory or Connection
self._ssl_connection_factory = ssl_connection_factory or SSLConnection
self._max_count = max_count
self._host_pools = {}
self._host_pool_waiters = {}
self._host_pools_lock = asyncio.Lock()
self._release_tasks = set()
self._closed = False
self._happy_eyeballs_table = HappyEyeballsTable()
@property
def host_pools(self) -> Mapping[tuple, HostPool]:
return self._host_pools
@asyncio.coroutine
[docs] def acquire(self, host: str, port: int, use_ssl: bool=False,
host_key: Optional[Any]=None) \
-> Union[Connection, SSLConnection]:
'''Return an available connection.
Args:
host: A hostname or IP address.
port: Port number.
use_ssl: Whether to return a SSL connection.
host_key: If provided, it overrides the key used for per-host
connection pooling. This is useful for proxies for example.
Coroutine.
'''
assert isinstance(port, int), 'Expect int. Got {}'.format(type(port))
assert not self._closed
yield from self._process_no_wait_releases()
if use_ssl:
connection_factory = functools.partial(
self._ssl_connection_factory, hostname=host)
else:
connection_factory = functools.partial(
self._connection_factory, hostname=host)
connection_factory = functools.partial(
HappyEyeballsConnection, (host, port), connection_factory,
self._resolver, self._happy_eyeballs_table,
is_ssl=use_ssl
)
key = host_key or (host, port, use_ssl)
with (yield from self._host_pools_lock):
if key not in self._host_pools:
host_pool = self._host_pools[key] = HostPool(
connection_factory,
max_connections=self._max_host_count
)
self._host_pool_waiters[key] = 1
else:
host_pool = self._host_pools[key]
self._host_pool_waiters[key] += 1
_logger.debug('Check out %s', key)
connection = yield from host_pool.acquire()
connection.key = key
# TODO: Verify this assert is always true
# assert host_pool.count() <= host_pool.max_connections
# assert key in self._host_pools
# assert self._host_pools[key] == host_pool
with (yield from self._host_pools_lock):
self._host_pool_waiters[key] -= 1
return connection
@asyncio.coroutine
[docs] def release(self, connection: Connection):
'''Put a connection back in the pool.
Coroutine.
'''
assert not self._closed
key = connection.key
host_pool = self._host_pools[key]
_logger.debug('Check in %s', key)
yield from host_pool.release(connection)
force = self.count() > self._max_count
yield from self.clean(force=force)
[docs] def no_wait_release(self, connection: Connection):
'''Synchronous version of :meth:`release`.'''
_logger.debug('No wait check in.')
release_task = asyncio.get_event_loop().create_task(
self.release(connection)
)
self._release_tasks.add(release_task)
@asyncio.coroutine
def _process_no_wait_releases(self):
'''Process check in tasks.'''
while True:
try:
release_task = self._release_tasks.pop()
except KeyError:
return
else:
yield from release_task
@asyncio.coroutine
[docs] def session(self, host: str, port: int, use_ssl: bool=False):
'''Return a context manager that returns a connection.
Usage::
session = yield from connection_pool.session('example.com', 80)
with session as connection:
connection.write(b'blah')
connection.close()
Coroutine.
'''
connection = yield from self.acquire(host, port, use_ssl)
@contextlib.contextmanager
def context_wrapper():
try:
yield connection
finally:
self.no_wait_release(connection)
return context_wrapper()
@asyncio.coroutine
[docs] def clean(self, force: bool=False):
'''Clean all closed connections.
Args:
force: Clean connected and idle connections too.
Coroutine.
'''
assert not self._closed
with (yield from self._host_pools_lock):
for key, pool in tuple(self._host_pools.items()):
yield from pool.clean(force=force)
if not self._host_pool_waiters[key] and pool.empty():
del self._host_pools[key]
del self._host_pool_waiters[key]
[docs] def close(self):
'''Close all the connections and clean up.
This instance will not be usable after calling this method.
'''
for key, pool in tuple(self._host_pools.items()):
pool.close()
del self._host_pools[key]
del self._host_pool_waiters[key]
self._closed = True
[docs] def count(self) -> int:
'''Return number of connections.'''
counter = 0
for pool in self._host_pools.values():
counter += pool.count()
return counter
[docs]class HappyEyeballsTable(object):
def __init__(self, max_items=100, time_to_live=600):
'''Happy eyeballs connection cache table.'''
self._cache = FIFOCache(max_items=max_items, time_to_live=time_to_live)
[docs] def set_preferred(self, preferred_addr, addr_1, addr_2):
'''Set the preferred address.'''
if addr_1 > addr_2:
addr_1, addr_2 = addr_2, addr_1
self._cache[(addr_1, addr_2)] = preferred_addr
[docs] def get_preferred(self, addr_1, addr_2):
'''Return the preferred address.'''
if addr_1 > addr_2:
addr_1, addr_2 = addr_2, addr_1
return self._cache.get((addr_1, addr_2))
[docs]class HappyEyeballsConnection(object):
'''Wrapper for happy eyeballs connection.'''
def __init__(self, address, connection_factory, resolver,
happy_eyeballs_table, is_ssl=False):
self._address = address
self._connection_factory = connection_factory
self._resolver = resolver
self._happy_eyeballs_table = happy_eyeballs_table
self._primary_connection = None
self._secondary_connection = None
self._active_connection = None
self.key = None
self.proxied = False
self.tunneled = False
self.ssl = is_ssl
def __getattr__(self, item):
return getattr(self._active_connection, item)
[docs] def closed(self):
if self._active_connection:
return self._active_connection.closed()
else:
return True
[docs] def close(self):
if self._active_connection:
self._active_connection.close()
[docs] def reset(self):
if self._active_connection:
self._active_connection.reset()
@asyncio.coroutine
[docs] def connect(self):
if self._active_connection:
yield from self._active_connection.connect()
return
result = yield from self._resolver.resolve(self._address[0])
primary_host, secondary_host = self._get_preferred_host(result)
if not secondary_host:
self._primary_connection = self._active_connection = \
self._connection_factory((primary_host, self._address[1]))
yield from self._primary_connection.connect()
else:
yield from self._connect_dual_stack(
(primary_host, self._address[1]),
(secondary_host, self._address[1])
)
@asyncio.coroutine
def _connect_dual_stack(self, primary_address, secondary_address):
'''Connect using happy eyeballs.'''
self._primary_connection = self._connection_factory(primary_address)
self._secondary_connection = self._connection_factory(secondary_address)
@asyncio.coroutine
def connect_primary():
yield from self._primary_connection.connect()
return self._primary_connection
@asyncio.coroutine
def connect_secondary():
yield from self._secondary_connection.connect()
return self._secondary_connection
primary_fut = connect_primary()
secondary_fut = connect_secondary()
failed = False
for fut in asyncio.as_completed((primary_fut, secondary_fut)):
if not self._active_connection:
try:
self._active_connection = yield from fut
except NetworkError:
if not failed:
_logger.debug('Original dual stack exception', exc_info=True)
failed = True
else:
raise
else:
_logger.debug('Got first of dual stack.')
else:
@asyncio.coroutine
def cleanup():
try:
conn = yield from fut
except NetworkError:
pass
else:
conn.close()
_logger.debug('Closed abandoned connection.')
asyncio.get_event_loop().create_task(cleanup())
preferred_host = self._active_connection.host
self._happy_eyeballs_table.set_preferred(
preferred_host, primary_address[0], secondary_address[0])
def _get_preferred_host(self, result: ResolveResult) -> Tuple[str, str]:
'''Get preferred host from DNS results.'''
host_1 = result.first_ipv4.ip_address if result.first_ipv4 else None
host_2 = result.first_ipv6.ip_address if result.first_ipv6 else None
if not host_2:
return host_1, None
elif not host_1:
return host_2, None
preferred_host = self._happy_eyeballs_table.get_preferred(
host_1, host_2)
if preferred_host:
return preferred_host, None
else:
return host_1, host_2