##
# Copyright (c) 2009-2010 ATMAIL. All rights reserved
# See http://atmail.com/license.php for license agreement
##

"""
Generic SQL database access object.
"""

__all__ = [
    "AbstractSQLDatabase",
    "db_prefix"
]

import os
from configobj import ConfigObj
import MySQLdb
import time

try:
    import sqlite3 as sqlite
except ImportError:
    from pysqlite2 import dbapi2 as sqlite

from twistedcaldav.log import Logger

log = Logger()

db_prefix = ".db."

class AbstractSQLDatabase(object):
    """
    A generic SQL database.
    """

    def __init__(self, dbpath, persistent, autocommit=False):
        """
        
        @param dbpath: the path where the db file is stored.
        @type dbpath: str
        @param persistent: C{True} if the data in the DB must be perserved during upgrades,
            C{False} if the DB data can be re-created from an external source.
        @type persistent: bool
        @param autocommit: C{True} if auto-commit mode is desired, C{False} otherwise
        @type autocommit: bool
        """
        self.dbpath = dbpath
        self.persistent = persistent
        self.autocommit = autocommit
	self.reconnect = None

    def __repr__(self):
        return "<%s %r>" % (self.__class__.__name__, self.dbpath)

    def _db_version(self):
        """
        @return: the schema version assigned to this index.
        """
        raise NotImplementedError
        
    def _db_type(self):
        """
        @return: the collection type assigned to this index.
        """
        raise NotImplementedError
        
    def _db_reconnect(self):
	self.reconnect = 1
	return 1

    def _db(self):
        """
        Access the underlying database.
        @return: a db2 connection object for this index's underlying data store.
        """
        if not hasattr(self, "_db_connection"):
            db_filename = self.dbpath
            try:
		if self.dbpath.find(".db.accounts") != -1:
		    self.mysql = 1
		else:
		    self.mysql = 0
		if self.mysql:
		    dbpath = ConfigObj("atmail.ini");
		    config = ConfigObj(dbpath['production']['dbconfig.path'] + "/dbconfig.ini");
		    self.config = config
		    self._db_connection = MySQLdb.connect(passwd=self.config['production']['database.params.password'], db=self.config['production']['database.params.dbname'], host=self.config['production']['database.params.host'], user=self.config['production']['database.params.username'])
		else:
		    if self.autocommit:
	            	self._db_connection = sqlite.connect(db_filename, isolation_level=None)
        	    else:
			self._db_connection = sqlite.connect(db_filename)
            except:
                log.err("Unable to open database: %s" % (self,))
                raise

            #
            # Set up the schema
            #
            q = self._db_connection.cursor()
            try:
                # Create CALDAV table if needed

                if self._test_schema_table(q):
		    if self.mysql:
  	                q.execute("""select `VALUE` from CalDav where `KEY` = 'SCHEMA_VERSION' and EXTRA = %s""", db_filename)
                    else:
			q.execute("""select `VALUE` from CALDAV where `KEY` = 'SCHEMA_VERSION'""")

                    version = q.fetchone()
			
                    if version is not None: version = version[0]

		    if self.mysql:
                        q.execute("""select `VALUE` from CalDav where `KEY` = 'TYPE' and EXTRA = %s""", db_filename)
                    else:
			q.execute("""select `VALUE` from CALDAV where `KEY` = 'TYPE'""")
                    dbtype = q.fetchone()

                    if dbtype is not None: dbtype = dbtype[0]

                    if (version != self._db_version()) or (dbtype != self._db_type()):

                        # Clean-up first
                        q.close()
                        q = None
                        self._db_connection.close()
                        del(self._db_connection)

                        if version != self._db_version():
                            log.err("Database %s has different schema (v.%s vs. v.%s)"
                                    % (db_filename, version, self._db_version()))
                            
                            # Upgrade the DB
                            return self._db_upgrade(version)

                        if dbtype != self._db_type():
                            log.err("Database %s has different type (%s vs. %s)"
                                    % (db_filename, dbtype, self._db_type()))

                            # Delete this index and start over
                            os.remove(db_filename)
                            return self._db()

                else:
                    self._db_init(db_filename, q)

                self._db_connection.commit()
            finally:
                if q is not None: q.close()

	# test the mysql connection to make sure
	# it doesn't 'go away'
	if self.mysql:
	    if self._db_connection:
		try:
		    self._db_connection.ping()
		except MySQLdb.OperationalError, message: # loss of connection!
		    del self._db_connection
		    self._db_connection = None
	    if self.reconnect is not None:
	    	self.reconnect = None
	    	del self._db_connection
	    	self._db_connection = None
	    if self._db_connection is None:
		cont = 1
		while cont > 0:
		    try:
			self._db_connection = MySQLdb.connect(passwd=self.config['production']['database.params.password'], db=self.config['production']['database.params.dbname'], host=self.config['production']['database.params.host'], user=self.config['production']['database.params.username'])
			cont = 0
		    except MySQLdb.OperationalError, message: # loss of connection!
			log.err("Having trouble with database connection....")
			time.sleep(0.5)
			
        return self._db_connection

    def _test_schema_table(self, q):
        try:
	    if self.mysql:
		q.execute("""select (1) from CalDav where EXTRA = %s""", self.dbpath)
            	returnvar = q.fetchone()
	    else:
	        q.execute("""select (1) from SQLITE_MASTER where TYPE = 'table' and NAME = 'CALDAV'""")
	        returnvar = q.fetchone()
        except:
		returnvar = 0
	
        return returnvar

    def _db_init(self, db_filename, q):
        """
        Initialise the underlying database tables.
        @param db_filename: the file name of the index database.
        @param q:           a database cursor to use.
        """
        log.msg("Initializing database %s" % (db_filename,))

        # We need an exclusive lock here as we are making a big change to the database and we don't
        # want other processes to get stomped on or stomp on us.
	if not self.mysql:
            old_isolation = self._db_connection.isolation_level
            self._db_connection.isolation_level = None
	if self.mysql:
	    q.execute("begin")
	else:
	    q.execute("begin exclusive transaction")
        
        # We re-check whether the schema table is present again AFTER we've got an exclusive
        # lock as some other server process may have snuck in and already created it
        # before we got the lock, or whilst we were waiting for it.
        if not self._test_schema_table(q):
            self._db_init_schema_table(q)
            self._db_init_data_tables(q)
            self._db_recreate(False)

        q.execute("commit")
	
	if not self.mysql:
	    self._db_connection.isolation_level = old_isolation

    def _db_init_schema_table(self, q):
        """
        Initialise the underlying database tables.
        @param db_filename: the file name of the index database.
        @param q:           a database cursor to use.
        """

        #
        # CALDAV table keeps track of our schema version and type
        #
	if not self.mysql:
            q.execute(
                """
                create table IF NOT EXISTS CALDAV (
                    `KEY` varchar(255) primary key, `VALUE` text, `EXTRA` text
                )
                """
            )

	# above table will already be created by installer

	if self.mysql:
	    q.execute("""select (1) from CalDav where `KEY` = 'SCHEMA_VERSION' and `VALUE` =  %s and `EXTRA` = %s""", [self._db_version(), self.dbpath])
	    if q.fetchone() == None:
		q.execute("""insert IGNORE into CalDav (`KEY`, `VALUE`, `EXTRA`) values ('SCHEMA_VERSION', %s, %s)""", [self._db_version(), self.dbpath])
	    q.execute("""select (1) from CalDav where `KEY` = 'TYPE' and `VALUE` = %s and `EXTRA` = %s""", [self._db_type(), self.dbpath])
	    if q.fetchone() == None:
		q.execute("""insert IGNORE into CalDav (`KEY`, `VALUE`, `EXTRA`) values ('TYPE', %s, %s)""", [self._db_type(), self.dbpath])
	else:
	    q.execute("""insert into CALDAV (KEY, VALUE) values ('SCHEMA_VERSION', :1)""", [self._db_version()])
	    q.execute("""insert into CALDAV (KEY, VALUE) values ('TYPE', :1)""", [self._db_type()])

    def _db_init_data_tables(self, q):
        """
        Initialise the underlying database tables.
        @param db_filename: the file name of the index database.
        @param q:           a database cursor to use.
        """
        raise NotImplementedError

    def _db_recreate(self, do_commit=True):
        """
        Recreate the database tables.
        """

        # Always commit at the end of this method as we have an open transaction from previous methods.
        if do_commit:
            self._db_commit()

    def _db_upgrade(self, old_version):
        """
        Upgrade the database tables.
        """

        if self.persistent:
	    if self.mysql:
		dbpath = ConfigObj("atmail.ini");
	    	config = ConfigObj(dbpath['production']['dbconfig.path'] + "/dbconfig.ini");
           	self._db_connection = MySQLdb.connect(passwd=config['production']['database.params.password'], db=config['production']['database.params.dbname'], host=config['production']['database.params.host'], user=config['production']['database.params.username'])
            else:
		self._db_connection = sqlite.connect(self.dbpath, isolation_level=None)

	    q = self._db_connection.cursor()
            self._db_upgrade_data_tables(q, old_version)
            self._db_upgrade_schema(q)
            self._db_close()
            return self._db()
        else:
            # Non-persistent DB's by default can be removed and re-created. However, for simple
            # DB upgrades they SHOULD override this method and handle those for better performance.
            os.remove(self.dbpath)
            return self._db()
    
    def _db_upgrade_data_tables(self, q, old_version):
        """
        Upgrade the data from an older version of the DB.
        """
        # Persistent DB's MUST override this method and do a proper upgrade. Their data
        # cannot be thrown away.
        raise NotImplementedError("Persistent databases MUST support an upgrade method.")

    def _db_upgrade_schema(self, q):
        """
        Upgrade the stored schema version to the current one.
        """
	if self.mysql:
	    q.execute("""insert or replace into CalDav (`KEY`, `VALUE`, `EXTRA`) values ('SCHEMA_VERSION', %s, %s)""", [self._db_version(), db_filename])
        else:
	    q.execute("""insert or replace into CALDAV (KEY, VALUE) values ('SCHEMA_VERSION', :1)""", [self._db_version()])

    def _db_close(self):
        if hasattr(self, "_db_connection"):
            self._db_connection.close()
            del self._db_connection

    def _db_values_for_sql(self, sql, *query_params):
        """
        Execute an SQL query and obtain the resulting values.
        @param sql: the SQL query to execute.
        @param query_params: parameters to C{sql}.
        @return: an interable of values in the first column of each row
            resulting from executing C{sql} with C{query_params}.
        @raise AssertionError: if the query yields multiple columns.
        """

        return (row[0] for row in self._db_execute(sql, *query_params))

    def _db_value_for_sql(self, sql, *query_params):
        """
        Execute an SQL query and obtain a single value.
        @param sql: the SQL query to execute.
        @param query_params: parameters to C{sql}.
        @return: the value resulting from the executing C{sql} with
            C{query_params}.
        @raise AssertionError: if the query yields multiple rows or columns.
        """
        value = None
        for row in self._db_values_for_sql(sql, *query_params):
            assert value is None, "Multiple values in DB for %s %s" % (sql, query_params)
            value = row
        return value

    def _db_execute(self, sql, *query_params):
        """
        Execute an SQL query and obtain the resulting values.
        @param sql: the SQL query to execute.
        @param query_params: parameters to C{sql}.
        @return: an interable of tuples for each row resulting from executing
            C{sql} with C{query_params}.
        """
        q = self._db().cursor()
        try:
	    q.execute(sql, query_params)
            return q.fetchall()
	except:
	    log.err("Exception while executing SQL on DB %s: %r %r" % (self, sql, query_params))
	    raise
	finally:
	    q.close()

    def _db_commit  (self): self._db_connection.commit()
    def _db_rollback(self): self._db_connection.rollback()
