Source code for lsst.sims.maf.db.dbObj

import numpy
from lsst.utils import getPackageDir
import os
from sqlalchemy.orm import scoped_session, sessionmaker
from sqlalchemy.engine import reflection, url
from sqlalchemy import (create_engine, MetaData, event, inspect)
import warnings
from io import BytesIO
str_cast = str


__all__ = ['DBObject']


def valueOfPi():
    """
    A function to return the value of pi.  This is needed for adding PI()
    to sqlite databases
    """
    return numpy.pi


def declareTrigFunctions(conn, connection_rec, connection_proxy):
    """
    A database event listener
    which will define the math functions necessary for evaluating the
    Haversine function in sqlite databases (where they are not otherwise
    defined)

    see:    http://docs.sqlalchemy.org/en/latest/core/events.html
    """

    conn.create_function("COS", 1, numpy.cos)
    conn.create_function("SIN", 1, numpy.sin)
    conn.create_function("ASIN", 1, numpy.arcsin)
    conn.create_function("SQRT", 1, numpy.sqrt)
    conn.create_function("POWER", 2, numpy.power)
    conn.create_function("PI", 0, valueOfPi)


class ChunkIterator(object):
    """Iterator for query chunks"""
    def __init__(self, dbobj, query, chunk_size, arbitrarySQL=False):
        self.dbobj = dbobj
        self.exec_query = dbobj.connection.session.execute(query)
        self.chunk_size = chunk_size

        # arbitrarySQL exists in case a CatalogDBObject calls
        # get_arbitrary_chunk_iterator; in that case, we need to
        # be able to tell this object to call _postprocess_arbitrary_results,
        # rather than _postprocess_results
        self.arbitrarySQL = arbitrarySQL

    def __iter__(self):
        return self

    def __next__(self):
        if self.chunk_size is None and not self.exec_query.closed:
            chunk = self.exec_query.fetchall()
            return self._postprocess_results(chunk)
        elif self.chunk_size is not None:
            chunk = self.exec_query.fetchmany(self.chunk_size)
            return self._postprocess_results(chunk)
        else:
            raise StopIteration

    def _postprocess_results(self, chunk):
        if len(chunk) == 0:
            raise StopIteration
        if self.arbitrarySQL:
            return self.dbobj._postprocess_arbitrary_results(chunk)
        else:
            return self.dbobj._postprocess_results(chunk)


class DBConnection(object):
    """
    This is a class that will hold the engine, session, and metadata for a
    DBObject.  This will allow multiple DBObjects to share the same
    sqlalchemy connection, when appropriate.
    """

    def __init__(self, database=None, driver=None, host=None, port=None, verbose=False):
        """
        @param [in] database is the name of the database file being connected to

        @param [in] driver is the dialect of the database (e.g. 'sqlite', 'mssql', etc.)

        @param [in] host is the URL of the remote host, if appropriate

        @param [in] port is the port on the remote host to connect to, if appropriate

        @param [in] verbose is a boolean controlling sqlalchemy's verbosity
        """

        self._database = database
        self._driver = driver
        self._host = host
        self._port = port
        self._verbose = verbose

        self._validate_conn_params()
        self._connect_to_engine()

    def __del__(self):
        try:
            del self._metadata
        except AttributeError:
            pass

        try:
            del self._engine
        except AttributeError:
            pass

        try:
            del self._session
        except AttributeError:
            pass

    def _connect_to_engine(self):

        # Remove dbAuth things. Assume we are only connecting to a local database.
        # Line to use when we update sqlalchemy 
        #dbUrl = url.URL.create(self._driver,
        #                       database=self._database)
        # Remove this line when sqlalchemy updated:
        dbUrl = url.URL(self._driver, database=self._database)

        self._engine = create_engine(dbUrl, echo=self._verbose)

        if self._engine.dialect.name == 'sqlite':
            event.listen(self._engine, 'checkout', declareTrigFunctions)

        self._session = scoped_session(sessionmaker(autoflush=True,
                                                    bind=self._engine))
        self._metadata = MetaData(bind=self._engine)

    def _validate_conn_params(self):
        """Validate connection parameters

        - Check if user passed dbAddress instead of an database. Convert and warn.
        - Check that required connection paramters are present
        - Replace default host/port if driver is 'sqlite'
        """

        if self._database is None:
            raise AttributeError("Cannot instantiate DBConnection; database is 'None'")

        if '//' in self._database:
            warnings.warn("Database name '%s' is invalid but looks like a dbAddress. "
                          "Attempting to convert to database, driver, host, "
                          "and port parameters. Any usernames and passwords are ignored and must "
                          "be in the db-auth.paf policy file. " % (self.database), FutureWarning)

            dbUrl = url.make_url(self._database)
            dialect = dbUrl.get_dialect()
            self._driver = dialect.name + '+' + dialect.driver if dialect.driver else dialect.name
            for key, value in dbUrl.translate_connect_args().items():
                if value is not None:
                    setattr(self, '_'+key, value)

        errMessage = "Please supply a 'driver' kwarg to the constructor or in class definition. "
        errMessage += "'driver' is formatted as dialect+driver, such as 'sqlite' or 'mssql+pymssql'."
        if not hasattr(self, '_driver'):
            raise AttributeError("%s has no attribute 'driver'. " % (self.__class__.__name__) + errMessage)
        elif self._driver is None:
            raise AttributeError("%s.driver is None. " % (self.__class__.__name__) + errMessage)

        errMessage = "Please supply a 'database' kwarg to the constructor or in class definition. "
        errMessage += " 'database' is the database name or the filename path if driver is 'sqlite'. "
        if not hasattr(self, '_database'):
            raise AttributeError("%s has no attribute 'database'. " % (self.__class__.__name__) + errMessage)
        elif self._database is None:
            raise AttributeError("%s.database is None. " % (self.__class__.__name__) + errMessage)

        if 'sqlite' in self._driver:
            # When passed sqlite database, override default host/port
            self._host = None
            self._port = None

    def __eq__(self, other):
        return (str(self._database) == str(other._database)) and \
               (str(self._driver) == str(other._driver)) and \
               (str(self._host) == str(other._host)) and \
               (str(self._port) == str(other._port))

    @property
    def engine(self):
        return self._engine

    @property
    def session(self):
        return self._session

    @property
    def metadata(self):
        return self._metadata

    @property
    def database(self):
        return self._database

    @property
    def driver(self):
        return self._driver

    @property
    def host(self):
        return self._host

    @property
    def port(self):
        return self._port

    @property
    def verbose(self):
        return self._verbose


[docs]class DBObject(object): def __init__(self, database=None, driver=None, host=None, port=None, verbose=False, connection=None, cache_connection=True): """ Initialize DBObject. @param [in] database is the name of the database file being connected to @param [in] driver is the dialect of the database (e.g. 'sqlite', 'mssql', etc.) @param [in] host is the URL of the remote host, if appropriate @param [in] port is the port on the remote host to connect to, if appropriate @param [in] verbose is a boolean controlling sqlalchemy's verbosity (default False) @param [in] connection is an optional instance of DBConnection, in the event that this DBObject can share a database connection with another DBObject. This is only necessary or even possible in a few specialized cases and should be used carefully. @param [in] cache_connection is a boolean. If True, DBObject will use a cache of DBConnections (if available) to get the connection to this database. """ self.dtype = None # this is a cache for the query, so that any one query does not have to guess dtype multiple times if connection is None: # Explicit constructor to DBObject preferred kwargDict = dict(database=database, driver=driver, host=host, port=port, verbose=verbose) for key, value in kwargDict.items(): if value is not None or not hasattr(self, key): setattr(self, key, value) self.connection = self._get_connection(self.database, self.driver, self.host, self.port, use_cache=cache_connection) else: self.connection = connection self.database = connection.database self.driver = connection.driver self.host = connection.host self.port = connection.port self.verbose = connection.verbose def _get_connection(self, database, driver, host, port, use_cache=True): """ Search self._connection_cache (if it exists; it won't for DBObject, but will for CatalogDBObject) for a DBConnection matching the specified parameters. If it exists, return it. If not, open a connection to the specified database, add it to the cache, and return the connection. Parameters ---------- database is the name of the database file being connected to driver is the dialect of the database (e.g. 'sqlite', 'mssql', etc.) host is the URL of the remote host, if appropriate port is the port on the remote host to connect to, if appropriate use_cache is a boolean specifying whether or not we try to use the cache of database connections (you don't want to if opening many connections in many threads). """ if use_cache and hasattr(self, '_connection_cache'): for conn in self._connection_cache: if str(conn.database) == str(database): if str(conn.driver) == str(driver): if str(conn.host) == str(host): if str(conn.port) == str(port): return conn conn = DBConnection(database=database, driver=driver, host=host, port=port) if use_cache and hasattr(self, '_connection_cache'): self._connection_cache.append(conn) return conn
[docs] def get_table_names(self): """Return a list of the names of the tables (and views) in the database""" return [str(xx) for xx in inspect(self.connection.engine).get_table_names()] + \ [str(xx) for xx in inspect(self.connection.engine).get_view_names()]
[docs] def get_column_names(self, tableName=None): """ Return a list of the names of the columns in the specified table. If no table is specified, return a dict of lists. The dict will be keyed to the table names. The lists will be of the column names in that table """ tableNameList = self.get_table_names() if tableName is not None: if tableName not in tableNameList: return [] return [str_cast(xx['name']) for xx in inspect(self.connection.engine).get_columns(tableName)] else: columnDict = {} for name in tableNameList: columnList = [str_cast(xx['name']) for xx in inspect(self.connection.engine).get_columns(name)] columnDict[name] = columnList return columnDict
def _final_pass(self, results): """ Make final modifications to a set of data before returning it to the user **Parameters** * results : a structured array constructed from the result set from a query **Returns** * results : a potentially modified structured array. The default is to do nothing. """ return results def _convert_results_to_numpy_recarray_dbobj(self, results): if self.dtype is None: """ Determine the dtype from the data. Store it in a global variable so we do not have to repeat on every chunk. """ dataString = '' # We are going to detect the dtype by reading in a single row # of data with np.genfromtxt. To do this, we must pass the # row as a string delimited by a specified character. Here we # select a character that does not occur anywhere in the data. delimit_char_list = [',', ';', '|', ':', '/', '\\'] delimit_char = None for cc in delimit_char_list: is_valid = True for xx in results[0]: if cc in str(xx): is_valid = False break if is_valid: delimit_char = cc break if delimit_char is None: raise RuntimeError("DBObject could not detect the dtype of your return rows\n" "Please specify a dtype with the 'dtype' kwarg.") for xx in results[0]: if dataString != '': dataString += delimit_char dataString += str(xx) names = [str_cast(ww) for ww in results[0].keys()] dataArr = numpy.genfromtxt(BytesIO(dataString.encode()), dtype=None, names=names, delimiter=delimit_char, encoding='utf-8') dt_list = [] for name in dataArr.dtype.names: type_name = str(dataArr.dtype[name]) sub_list = [name] if type_name.startswith('S') or type_name.startswith('|S'): sub_list.append(str_cast) sub_list.append(int(type_name.replace('S', '').replace('|', ''))) else: sub_list.append(dataArr.dtype[name]) dt_list.append(tuple(sub_list)) self.dtype = numpy.dtype(dt_list) if len(results) == 0: return numpy.recarray((0,), dtype=self.dtype) retresults = numpy.rec.fromrecords([tuple(xx) for xx in results], dtype=self.dtype) return retresults def _postprocess_results(self, results): """ This wrapper exists so that a ChunkIterator built from a DBObject can have the same API as a ChunkIterator built from a CatalogDBObject """ return self._postprocess_arbitrary_results(results) def _postprocess_arbitrary_results(self, results): if not isinstance(results, numpy.recarray): retresults = self._convert_results_to_numpy_recarray_dbobj(results) else: retresults = results return self._final_pass(retresults)
[docs] def execute_arbitrary(self, query, dtype=None): """ Executes an arbitrary query. Returns a recarray of the results. dtype will be the dtype of the output recarray. If it is None, then the code will guess the datatype and assign generic names to the columns """ is_string = isinstance(query, str) if not is_string: raise RuntimeError("DBObject execute must be called with a string query") unacceptableCommands = ["delete", "drop", "insert", "update"] for badCommand in unacceptableCommands: if query.lower().find(badCommand.lower()) >= 0: raise RuntimeError("query made to DBObject execute contained %s " % badCommand) self.dtype = dtype retresults = self._postprocess_arbitrary_results(self.connection.session.execute(query).fetchall()) return retresults
[docs] def get_arbitrary_chunk_iterator(self, query, chunk_size=None, dtype=None): """ This wrapper exists so that CatalogDBObjects can refer to get_arbitrary_chunk_iterator and DBObjects can refer to get_chunk_iterator """ return self.get_chunk_iterator(query, chunk_size=chunk_size, dtype=dtype)
[docs] def get_chunk_iterator(self, query, chunk_size=None, dtype=None): """ Take an arbitrary, user-specified query and return a ChunkIterator that executes that query dtype will tell the ChunkIterator what datatype to expect for this query. This information gets passed to _postprocess_results. If 'None', then _postprocess_results will just guess the datatype and return generic names for the columns. """ self.dtype = dtype return ChunkIterator(self, query, chunk_size, arbitrarySQL=True)