Source code for pydle.client

## client.py
# Basic IRC client implementation.
import asyncio
import logging
from asyncio import new_event_loop, gather, get_event_loop, sleep

from . import connection, protocol

__all__ = ['Error', 'AlreadyInChannel', 'NotInChannel', 'BasicClient', 'ClientPool']
DEFAULT_NICKNAME = '<unregistered>'


class Error(Exception):
    """ Base class for all pydle errors. """
    pass


class NotInChannel(Error):
    def __init__(self, channel):
        super().__init__('Not in channel: {}'.format(channel))
        self.channel = channel


class AlreadyInChannel(Error):
    def __init__(self, channel):
        super().__init__('Already in channel: {}'.format(channel))
        self.channel = channel


[docs]class BasicClient: """ Base IRC client class. This class on its own is not complete: in order to be able to run properly, _has_message, _parse_message and _create_message have to be overloaded. """ PING_TIMEOUT = 300 RECONNECT_ON_ERROR = True RECONNECT_MAX_ATTEMPTS = 3 RECONNECT_DELAYED = True RECONNECT_DELAYS = [5, 5, 10, 30, 120, 600] def __init__(self, nickname, fallback_nicknames=[], username=None, realname=None, eventloop=None, **kwargs): """ Create a client. """ self._nicknames = [nickname] + fallback_nicknames self.username = username or nickname.lower() self.realname = realname or nickname if eventloop: self.eventloop = eventloop else: self.eventloop = get_event_loop() self.own_eventloop = not eventloop self._reset_connection_attributes() self._reset_attributes() if kwargs: self.logger.warning('Unused arguments: %s', ', '.join(kwargs.keys())) def _reset_attributes(self): """ Reset attributes. """ # Record-keeping. self.channels = {} self.users = {} # Low-level data stuff. self._receive_buffer = b'' self._pending = {} self._handler_top_level = False self._ping_checker_handle = None # Misc. self.logger = logging.getLogger(__name__) # Public connection attributes. self.nickname = DEFAULT_NICKNAME self.network = None def _reset_connection_attributes(self): """ Reset connection attributes. """ self.connection = None self.encoding = None self._autojoin_channels = [] self._reconnect_attempts = 0 ## Connection.
[docs] def run(self, *args, **kwargs): """ Connect and run bot in event loop. """ self.eventloop.run_until_complete(self.connect(*args, **kwargs)) try: self.eventloop.run_forever() finally: self.eventloop.stop()
[docs] async def connect(self, hostname=None, port=None, reconnect=False, **kwargs): """ Connect to IRC server. """ if (not hostname or not port) and not reconnect: raise ValueError('Have to specify hostname and port if not reconnecting.') # Disconnect from current connection. if self.connected: await self.disconnect(expected=True) # Reset attributes and connect. if not reconnect: self._reset_connection_attributes() await self._connect(hostname=hostname, port=port, reconnect=reconnect, **kwargs) # Set logger name. if self.server_tag: self.logger = logging.getLogger(self.__class__.__name__ + ':' + self.server_tag) self.eventloop.create_task(self.handle_forever())
[docs] async def disconnect(self, expected=True): """ Disconnect from server. """ if self.connected: # Unschedule ping checker. if self._ping_checker_handle: self._ping_checker_handle.cancel() # Schedule disconnect. await self._disconnect(expected)
async def _disconnect(self, expected): # Shutdown connection. await self.connection.disconnect() # Reset any attributes. self._reset_attributes() # Callback. await self.on_disconnect(expected) # Shut down event loop. if expected and self.own_eventloop: self.connection.stop() async def _connect(self, hostname, port, reconnect=False, channels=[], encoding=protocol.DEFAULT_ENCODING, source_address=None): """ Connect to IRC host. """ # Create connection if we can't reuse it. if not reconnect or not self.connection: self._autojoin_channels = channels self.connection = connection.Connection(hostname, port, source_address=source_address, eventloop=self.eventloop) self.encoding = encoding # Connect. await self.connection.connect() def _reconnect_delay(self): """ Calculate reconnection delay. """ if self.RECONNECT_ON_ERROR and self.RECONNECT_DELAYED: if self._reconnect_attempts >= len(self.RECONNECT_DELAYS): return self.RECONNECT_DELAYS[-1] else: return self.RECONNECT_DELAYS[self._reconnect_attempts] else: return 0 async def _perform_ping_timeout(self, delay: int): """ Handle timeout gracefully. Args: delay (int): delay before raising the timeout (in seconds) """ # pause for delay seconds await sleep(delay) # then continue error = TimeoutError( 'Ping timeout: no data received from server in {timeout} seconds.'.format( timeout=self.PING_TIMEOUT)) await self.on_data_error(error) ## Internal database management. def _create_channel(self, channel): self.channels[channel] = { 'users': set(), } def _destroy_channel(self, channel): # Copy set to prevent a runtime error when destroying the user. for user in set(self.channels[channel]['users']): self._destroy_user(user, channel) del self.channels[channel] def _create_user(self, nickname): # Servers are NOT users. if not nickname or '.' in nickname: return self.users[nickname] = { 'nickname': nickname, 'username': None, 'realname': None, 'hostname': None } def _sync_user(self, nick, metadata): # Create user in database. if nick not in self.users: self._create_user(nick) if nick not in self.users: return self.users[nick].update(metadata) def _rename_user(self, user, new): if user in self.users: self.users[new] = self.users[user] self.users[new]['nickname'] = new del self.users[user] else: self._create_user(new) if new not in self.users: return for ch in self.channels.values(): # Rename user in channel list. if user in ch['users']: ch['users'].discard(user) ch['users'].add(new) def _destroy_user(self, nickname, channel=None): if channel: channels = [self.channels[channel]] else: channels = self.channels.values() for ch in channels: # Remove from nicklist. ch['users'].discard(nickname) # If we're not in any common channels with the user anymore, we have no reliable way to keep their info up-to-date. # Remove the user. if not channel or not any(nickname in ch['users'] for ch in self.channels.values()): del self.users[nickname] def _parse_user(self, data): """ Parse user and return nickname, metadata tuple. """ raise NotImplementedError() def _format_user_mask(self, nickname): user = self.users.get(nickname, {"nickname": nickname, "username": "*", "hostname": "*"}) return self._format_host_mask(user['nickname'], user['username'] or '*', user['hostname'] or '*') def _format_host_mask(self, nick, user, host): return '{n}!{u}@{h}'.format(n=nick, u=user, h=host) ## IRC helpers.
[docs] def is_channel(self, chan): """ Check if given argument is a channel name or not. """ return True
[docs] def in_channel(self, channel): """ Check if we are currently in the given channel. """ return channel in self.channels.keys()
[docs] def is_same_nick(self, left, right): """ Check if given nicknames are equal. """ return left == right
[docs] def is_same_channel(self, left, right): """ Check if given channel names are equal. """ return left == right
## IRC attributes. @property def connected(self): """ Whether or not we are connected. """ return self.connection and self.connection.connected @property def server_tag(self): if self.connected and self.connection.hostname: if self.network: tag = self.network.lower() else: tag = self.connection.hostname.lower() # Remove hostname prefix. if tag.startswith('irc.'): tag = tag[4:] # Check if host is either an FQDN or IPv4. if '.' in tag: # Attempt to cut off TLD. host, suffix = tag.rsplit('.', 1) # Make sure we aren't cutting off the last octet of an IPv4. try: int(suffix) except ValueError: tag = host return tag else: return None ## IRC API.
[docs] async def raw(self, message): """ Send raw command. """ await self._send(message)
[docs] async def rawmsg(self, command, *args, **kwargs): """ Send raw message. """ message = str(self._create_message(command, *args, **kwargs)) await self._send(message)
## Overloadable callbacks.
[docs] async def on_connect(self): """ Callback called when the client has connected successfully. """ # Reset reconnect attempts. self._reconnect_attempts = 0
async def on_disconnect(self, expected): if not expected: # Unexpected disconnect. Reconnect? if self.RECONNECT_ON_ERROR and ( self.RECONNECT_MAX_ATTEMPTS is None or self._reconnect_attempts < self.RECONNECT_MAX_ATTEMPTS): # Calculate reconnect delay. delay = self._reconnect_delay() self._reconnect_attempts += 1 if delay > 0: self.logger.error( 'Unexpected disconnect. Attempting to reconnect within %s seconds.', delay) else: self.logger.error('Unexpected disconnect. Attempting to reconnect.') # Wait and reconnect. await sleep(delay) await self.connect(reconnect=True) else: self.logger.error('Unexpected disconnect. Giving up.') ## Message dispatch. def _has_message(self): """ Whether or not we have messages available for processing. """ raise NotImplementedError() def _create_message(self, command, *params, **kwargs): raise NotImplementedError() def _parse_message(self): raise NotImplementedError() async def _send(self, input): if not isinstance(input, (bytes, str)): input = str(input) if isinstance(input, str): input = input.encode(self.encoding) self.logger.debug('>> %s', input.decode(self.encoding)) await self.connection.send(input)
[docs] async def handle_forever(self): """ Handle data forever. """ while self.connected: data = await self.connection.recv() if not data: if self.connected: await self.disconnect(expected=False) break await self.on_data(data)
## Raw message handlers. async def on_data(self, data): """ Handle received data. """ self._receive_buffer += data # Schedule new timeout event. if self._ping_checker_handle: self._ping_checker_handle.cancel() # create a task for the ping checker self._ping_checker_handle = self.eventloop.create_task( self._perform_ping_timeout(self.PING_TIMEOUT)) while self._has_message(): message = self._parse_message() self.eventloop.create_task(self.on_raw(message)) async def on_data_error(self, exception): """ Handle error. """ self.logger.error('Encountered error on socket.', exc_info=(type(exception), exception, None)) await self.disconnect(expected=False)
[docs] async def on_raw(self, message): """ Handle a single message. """ self.logger.debug('<< %s', message._raw) if not message._valid: self.logger.warning('Encountered strictly invalid IRC message from server: %s', message._raw) if isinstance(message.command, int): cmd = str(message.command).zfill(3) else: cmd = message.command # Invoke dispatcher, if we have one. method = 'on_raw_' + cmd.lower() try: # Set _top_level so __getattr__() can decide whether to return on_unknown or _ignored for unknown handlers. # The reason for this is that features can always call super().on_raw_* safely and thus don't need to care for other features, # while unknown messages for which no handlers exist at all are still logged. self._handler_top_level = True handler = getattr(self, method) self._handler_top_level = False await handler(message) except: self.logger.exception('Failed to execute %s handler.', method)
[docs] async def on_unknown(self, message): """ Unknown command. """ self.logger.warning('Unknown command: [%s] %s %s', message.source, message.command, message.params)
async def _ignored(self, message): """ Ignore message. """ pass def __getattr__(self, attr): """ Return on_unknown or _ignored for unknown handlers, depending on the invocation type. """ # Is this a raw handler? if attr.startswith('on_raw_'): # Are we in on_raw() trying to find any message handler? if self._handler_top_level: # In that case, return the method that logs and possibly acts on unknown messages. return self.on_unknown # Are we in an existing handler calling super()? else: # Just ignore it, then. return self._ignored # This isn't a handler, just raise an error. raise AttributeError(attr)
[docs]class ClientPool: """ A pool of clients that are ran and handled in parallel. """ def __init__(self, clients=None, eventloop=None): self.eventloop = eventloop if eventloop else new_event_loop() self.clients = set(clients or []) self.connect_args = {}
[docs] def connect(self, client: BasicClient, *args, **kwargs): """ Add client to pool. """ self.clients.add(client) self.connect_args[client] = (args, kwargs) # hack the clients event loop to use the pools own event loop client.eventloop = self.eventloop
# necessary to run multiple clients in the same thread via the pool
[docs] def disconnect(self, client): """ Remove client from pool. """ self.clients.remove(client) del self.connect_args[client] client.disconnect()
def __contains__(self, item): return item in self.clients ## High-level.
[docs] def handle_forever(self): """ Main loop of the pool: handle clients forever, until the event loop is stopped. """ # container for all the client connection coros connection_list = [] for client in self.clients: args, kwargs = self.connect_args[client] connection_list.append(client.connect(*args, **kwargs)) # single future for executing the connections connections = gather(*connection_list, loop=self.eventloop) # run the connections self.eventloop.run_until_complete(connections) # run the clients self.eventloop.run_forever() for client in self.clients: client.disconnect()