Skip to content
Snippets Groups Projects
Code owners
Assign users and groups as approvers for specific file changes. Learn more.
postgres.py 19.62 KiB
#!/usr/bin/env python3

# Copyright (C) 2012-2015  ASTRON (Netherlands Institute for Radio Astronomy)
# P.O. Box 2, 7990 AA Dwingeloo, The Netherlands
#
# This file is part of the LOFAR software suite.
# The LOFAR software suite is free software: you can redistribute it and/or
# modify it under the terms of the GNU General Public License as published
# by the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# The LOFAR software suite is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License along
# with the LOFAR software suite. If not, see <http://www.gnu.org/licenses/>.

# $Id$

'''
Module with nice postgres helper methods and classes.
'''

import logging
from threading import Thread, Lock
from queue import Queue, Empty
from datetime import  datetime
import collections
import time
import re
import select
import psycopg2
import psycopg2.extras
import psycopg2.extensions
from lofar.common.datetimeutils import totalSeconds
from lofar.common.dbcredentials import DBCredentials

logger = logging.getLogger(__name__)

def makePostgresNotificationQueries(schema, table, action, column_name='id'):
    action = action.upper()
    if action not in ('INSERT', 'UPDATE', 'DELETE'):
        raise ValueError('''trigger_type '%s' not in ('INSERT', 'UPDATE', 'DELETE')''' % action)

    change_name = '''{table}_{action}'''.format(table=table, action=action)
    if column_name != 'id':
        change_name += '_column_' + column_name
    function_name = '''NOTIFY_{change_name}'''.format(change_name=change_name)
    function_sql = '''
    CREATE OR REPLACE FUNCTION {schema}.{function_name}()
    RETURNS TRIGGER AS $$
    DECLARE payload text;
    BEGIN
    {begin_update_check}SELECT CAST({column_value} AS text) INTO payload;
    PERFORM pg_notify(CAST('{change_name}' AS text), payload);{end_update_check}
    RETURN {value};
    END;
    $$ LANGUAGE plpgsql;
    '''.format(schema=schema,
                function_name=function_name,
                table=table,
                action=action,
                column_value=('OLD' if action == 'DELETE' else 'NEW') + '.' + column_name,
                value='OLD' if action == 'DELETE' else 'NEW',
                change_name=change_name.lower(),
                begin_update_check='IF ROW(NEW.*) IS DISTINCT FROM ROW(OLD.*) THEN\n' if action == 'UPDATE' else '',
                end_update_check='\nEND IF;' if action == 'UPDATE' else '')

    trigger_name = 'T_%s' % function_name

    trigger_sql = '''
    CREATE TRIGGER {trigger_name}
    AFTER {action} ON {schema}.{table}
    FOR EACH ROW
    EXECUTE PROCEDURE {schema}.{function_name}();
    '''.format(trigger_name=trigger_name,
                function_name=function_name,
                schema=schema,
                table=table,
                action=action)

    drop_sql = '''
    DROP TRIGGER IF EXISTS {trigger_name} ON {schema}.{table} CASCADE;
    DROP FUNCTION IF EXISTS {schema}.{function_name}();
    '''.format(trigger_name=trigger_name,
               function_name=function_name,
               schema=schema,
               table=table)

    sql = drop_sql + '\n' + function_sql + '\n' + trigger_sql
    sql_lines = '\n'.join([s.strip() for s in sql.split('\n')]) + '\n'
    return sql_lines

FETCH_NONE=0
FETCH_ONE=1
FETCH_ALL=2

class PostgresDBError(Exception):
    pass

class PostgresDBConnectionError(PostgresDBError):
    pass

class PostgresDBQueryExecutionError(PostgresDBError):
    pass

class PostgresDatabaseConnection:
    def __init__(self,
                 dbcreds: DBCredentials,
                 log_queries: bool=False,
                 auto_commit_selects: bool=False,
                 num_connect_retries: int=5,
                 connect_retry_interval: float=1.0):
        self.dbcreds = dbcreds
        self._connection = None
        self._log_queries = log_queries
        self.__auto_commit_selects = auto_commit_selects
        self.__num_connect_retries = num_connect_retries
        self.__connect_retry_interval = connect_retry_interval

    def connect(self):
        if self.is_connected:
            logger.debug("already connected to database: %s", self.dbcreds.stringWithHiddenPassword())
            return

        for retry_cntr in range(self.__num_connect_retries+1):
            try:
                logger.debug("trying to connect to database using: %s", self.dbcreds.stringWithHiddenPassword())

                self._connection = psycopg2.connect(host=self.dbcreds.host,
                                                    user=self.dbcreds.user,
                                                    password=self.dbcreds.password,
                                                    database=self.dbcreds.database,
                                                    port=self.dbcreds.port,
                                                    connect_timeout=5)

                if self._connection:
                    self._cursor = self._connection.cursor(cursor_factory=psycopg2.extras.RealDictCursor)

                    logger.log(logging.INFO if self._log_queries else logging.DEBUG,
                               "connected to database: %s", self.dbcreds.stringWithHiddenPassword())

                    # see http://initd.org/psycopg/docs/connection.html#connection.notices
                    # try to set the notices attribute with a non-list collection,
                    # so we can log more than 50 messages. Is only available since 2.7, so encapsulate in try/except.
                    try:
                        self._connection.notices = collections.deque()
                    except TypeError:
                        logger.warning("Cannot overwrite self._connection.notices with a deque... only max 50 notifications available per query. (That's ok, no worries.)")

                    # we have a proper connection, so return
                    return
            except psycopg2.DatabaseError as dbe:
                error_string = str(dbe).replace('\n', ' ')
                logger.error(error_string)

                # see https://www.postgresql.org/docs/current/errcodes-appendix.html#ERRCODES-TABLE
                if (isinstance(dbe, psycopg2.OperationalError) and re.search('connection', str(dbe), re.IGNORECASE)) or \
                   (dbe.pgcode is not None and (dbe.pgcode.startswith('08') or dbe.pgcode.startswith('57P') or dbe.pgcode.startswith('53'))):
                    # try to reconnect on connection-like-errors
                    if retry_cntr == self.__num_connect_retries:
                        raise PostgresDBConnectionError("%s Error while connecting to %s. error=%s" % (self.__class__.__name__,
                                                                                                       self.dbcreds.stringWithHiddenPassword(),
                                                                                                       error_string))

                    logger.info('retrying to connect to %s in %s seconds', self.database, self.__connect_retry_interval)
                    time.sleep(self.__connect_retry_interval)
                else:
                    # non-connection-error, raise generic PostgresDBError
                    raise PostgresDBError(error_string)

    def disconnect(self):
        if self.is_connected:
            logger.debug("%s disconnecting from db: %s", self.__class__.__name__, self.database)
            self._cursor.close()
            self._cursor = None
            self._connection.close()
            self._connection = None
            logger.debug("%s disconnected from db: %s", self.__class__.__name__, self.database)

    @property
    def database(self) -> str:
        '''returns the database name'''
        return self.dbcreds.database

    @property
    def is_connected(self) -> bool:
        return self._connection is not None and self._connection.closed==0

    def reconnect(self):
        self.disconnect()
        self.connect()

    def __enter__(self):
        '''connects to the database'''
        self.connect()
        return self

    def __exit__(self, exc_type, exc_val, exc_tb):
        '''disconnects from the database'''
        self.disconnect()

    @staticmethod
    def _queryAsSingleLine(query, qargs=None):
        line = ' '.join(query.replace('\n', ' ').split())
        if qargs:
            line = line % tuple(['\'%s\'' % a if isinstance(a, str) else a for a in qargs])
        return line

    def executeQuery(self, query, qargs=None, fetch=FETCH_NONE):
        '''execute the query and reconnect upon OperationalError'''
        try:
            # make sure we're connected
            if not self.is_connected:
                self.connect()

            query_log_line = self._queryAsSingleLine(query, qargs)

            if self._log_queries:
                logger.debug('executing query: %s', query_log_line)

            start = datetime.utcnow()
            self._cursor.execute(query, qargs)
            if self._log_queries:
                elapsed = datetime.utcnow() - start
                elapsed_ms = 1000.0 * totalSeconds(elapsed)
                logger.info('executed query in %.1fms%s yielding %s rows: %s', elapsed_ms,
                                                                               ' (SLOW!)' if elapsed_ms > 250 else '', # for easy log grep'ing
                                                                               self._cursor.rowcount,
                                                                               query_log_line)

            self._log_database_notifications()

            try:
                result = []
                if fetch == FETCH_ONE:
                    result = self._cursor.fetchone()

                if fetch == FETCH_ALL:
                    result = self._cursor.fetchall()

                if self.__auto_commit_selects and re.search('select', query, re.IGNORECASE):
                    #prevent dangling in idle transaction on server
                    self.commit()

                return result
            except Exception as e:
                logger.error("error while fetching result(s) for %s: %s", query_log_line, e)

        except psycopg2.OperationalError as oe:
            error_string = str(oe).replace('\n', ' ')

            # see https://www.postgresql.org/docs/current/errcodes-appendix.html#ERRCODES-TABLE
            if oe.pgcode.startswith('08') or oe.pgcode.startswith('57P') or oe.pgcode.startswith('53'):
                # some connection error occured.
                # try to reconnect a few times and execute the query againg when the connection returns
                logger.warning(error_string)

                for retry_cntr in range(self.__num_connect_retries + 1):
                    try:
                        self.reconnect()
                    except PostgresDBConnectionError as e:
                        logger.error(e)
                        time.sleep(self.__connect_retry_interval)
                    else:
                        # hey, reconnect worked, and re-execute the query
                        # WARNING: possible stack-overflow, but that's the least of our problems compared to a lost db connection...
                        return self.executeQuery(query, qargs, fetch)
            else:
                # raise psycopg2 wrapped in our own PostgresDBQueryExecutionError
                raise PostgresDBQueryExecutionError("Could not execute query '%s' error=%s" % (query_log_line, error_string))

        except (psycopg2.IntegrityError, psycopg2.ProgrammingError, psycopg2.InternalError, psycopg2.DataError) as e:
            self._log_database_notifications()
            error_string = str(e).replace('\n', ' ')
            logger.error("Rolling back query=\'%s\' due to error: \'%s\'" % (query_log_line, error_string))
            self.rollback()
            raise PostgresDBQueryExecutionError("Could not execute query '%s' error=%s" % (query_log_line, error_string))
        except Exception as e:
            error_string = str(e).replace('\n', ' ')
            raise PostgresDBQueryExecutionError("Could not execute query '%s' error=%s" % (query_log_line, error_string))

        return []

    def _log_database_notifications(self):
        try:
            if self._log_queries and self._connection.notices:
                for notice in self._connection.notices:
                    logger.info('database log message %s', notice.strip())
                if isinstance(self._connection.notices, collections.deque):
                    self._connection.notices.clear()
                else:
                    del self._connection.notices[:]
        except Exception as e:
            logger.error(str(e))

    def commit(self):
        if self._log_queries:
            logger.info('commit')
        self._connection.commit()

    def rollback(self):
        if self._log_queries:
            logger.info('rollback')
        self._connection.rollback()


class PostgresListener(PostgresDatabaseConnection):
    ''' This class lets you listen to postgress notifications
    It execute callbacks when a notifocation occurs.
    Make your own subclass with your callbacks and subscribe them to the appriate channel.
    Example:

    class MyListener(PostgresListener):
        def __init__(self, host, database, username, password):
            super(MyListener, self).__init__(host=host, database=database, username=username, password=password)
            self.subscribe('foo', self.foo)
            self.subscribe('bar', self.bar)

        def foo(self, payload = None):
            print "Foo called with payload: ", payload

        def bar(self, payload = None):
            print "Bar called with payload: ", payload

    with MyListener(...args...) as listener:
        #either listen like below in a loop doing stuff...
        while True:
            #do stuff or wait,
            #the listener calls the callbacks meanwhile in another thread

        #... or listen like below blocking
        #while the listener calls the callbacks meanwhile in this thread
        listener.waitWhileListening()
    '''
    def __init__(self, dbcreds: DBCredentials):
        '''Create a new PostgresListener'''
        super(PostgresListener, self).__init__(dbcreds=dbcreds,
                                               auto_commit_selects=True)
        self.__listening = False
        self.__lock = Lock()
        self.__callbacks = {}
        self.__waiting = False
        self.__queue = Queue()

    def connect(self):
        super(PostgresListener, self).connect()
        self._connection.set_isolation_level(psycopg2.extensions.ISOLATION_LEVEL_AUTOCOMMIT)

    def subscribe(self, notification, callback):
        '''Subscribe to a certain postgres notification.
        Call callback method in case such a notification is received.'''
        logger.info("Subscribed %sto %s" % ('and listening ' if self.isListening() else '', notification))
        with self.__lock:
            self.executeQuery("LISTEN %s;", (psycopg2.extensions.AsIs(notification),))
            self.__callbacks[notification] = callback

    def unsubscribe(self, notification):
        '''Unubscribe from a certain postgres notification.'''
        logger.info("Unsubscribed from %s" % notification)
        with self.__lock:
            self.executeQuery("UNLISTEN %s;", (psycopg2.extensions.AsIs(notification),))
            if notification in self.__callbacks:
                del self.__callbacks[notification]

    def isListening(self):
        '''Are we listening? Has the listener been started?'''
        with self.__lock:
            return self.__listening

    def start(self):
        '''Start listening. Does nothing if already listening.
        When using the listener in a context start() and stop()
        are called upon __enter__ and __exit__

        This method return immediately.
        Listening and calling callbacks takes place on another thread.
        If you want to block processing and call the callbacks on the main thread,
        then call waitWhileListening() after start.
        '''
        if self.isListening():
            return

        self.connect()

        logger.info("Started listening to %s" % ', '.join([str(x) for x in list(self.__callbacks.keys())]))

        def eventLoop():
            while self.isListening():
                if select.select([self._connection],[],[],2) != ([],[],[]):
                    self._connection.poll()
                    while self._connection.notifies:
                        try:
                            notification = self._connection.notifies.pop(0)
                            logger.debug("Received notification on channel %s payload %s" % (notification.channel, notification.payload))

                            if self.isWaiting():
                                # put notification on Queue
                                # let waiting thread handle the callback
                                self.__queue.put((notification.channel, notification.payload))
                            else:
                                # call callback on this listener thread
                                self._callCallback(notification.channel, notification.payload)
                        except Exception as e:
                            logger.error(str(e))

        self.__thread = Thread(target=eventLoop)
        self.__thread.daemon = True
        self.__listening = True
        self.__thread.start()

    def stop(self):
        '''Stop listening. (Can be restarted)'''
        with self.__lock:
            if not self.__listening:
                return
            self.__listening = False

        self.__thread.join()
        self.__thread = None

        logger.info("Stopped listening")
        self.stopWaiting()
        self.disconnect()

    def __enter__(self):
        '''starts the listener upon contect enter'''
        self.start()
        return self

    def __exit__(self, exc_type, exc_val, exc_tb):
        '''stops the listener upon contect enter'''
        self.stop()

    def _callCallback(self, channel, payload = None):
        '''call the appropiate callback based on channel'''
        try:
            callback = None
            with self.__lock:
                if channel in self.__callbacks:
                    callback = self.__callbacks[channel]

            if callback:
                if payload:
                    callback(payload)
                else:
                    callback()
        except Exception as e:
            logger.error(str(e))

    def isWaiting(self):
        '''Are we waiting in the waitWhileListening() method?'''
        with self.__lock:
            return self.__waiting

    def stopWaiting(self):
        '''break from the blocking waitWhileListening() method'''
        with self.__lock:
            if self.__waiting:
                self.__waiting = False
                logger.info("Continuing from blocking waitWhileListening")

    def waitWhileListening(self):
        '''
        block calling thread until interrupted or
        until stopWaiting is called from another thread
        meanwhile, handle the callbacks on this thread
        '''
        logger.info("Waiting while listening to %s" % ', '.join([str(x) for x in list(self.__callbacks.keys())]))

        with self.__lock:
            self.__waiting = True

        while self.isWaiting():
            try:
                notification = self.__queue.get(True, 1)
                channel = notification[0]
                payload = notification[1]

                self._callCallback(channel, payload)
            except KeyboardInterrupt:
                # break
                break
            except Empty:
                pass

        self.stopWaiting()