##
# Copyright (c) 2009-2010 ATMAIL. All rights reserved
# See http://atmail.com/license.php for license agreement
##

"""
SQL (sqlite) based user/group/resource directory service implementation.
"""

"""
SCHEMA:

User Database:

ROW: RECORD_TYPE, SHORT_NAME (unique), PASSWORD, NAME

Group Database:

ROW: SHORT_NAME, MEMBER_SHORT_NAME

CUAddress database:

ROW: ADDRESS (unqiue), SHORT_NAME

"""

__all__ = [
    "SQLDirectoryService",
]

from twisted.cred.credentials import UsernamePassword
from twisted.python.filepath import FilePath

from twistedcaldav.directory.directory import DirectoryService, DirectoryRecord
from twistedcaldav.directory.xmlaccountsparser import XMLAccountsParser
from twistedcaldav.sql import AbstractSQLDatabase
from twistedcaldav.sql import db_prefix

import os
import time
import hashlib
import md5crypt

class SQLDirectoryManager(AbstractSQLDatabase):
    """
    House keeping operations on the SQL DB, including loading from XML file,
    and record dumping. This can be used as a standalong DB management tool.
    """
    dbType = "DIRECTORYSERVICE"
    dbFilename = db_prefix + "accounts"
    dbFormatVersion = "3"
    cacheMD5 = 0
    cache = None
    rebuild = 1;
    groupware_enabled = 0;
    passwordmode = "PLAIN";

    def __init__(self, path):
        path = os.path.join(path, SQLDirectoryManager.dbFilename)
        super(SQLDirectoryManager, self).__init__(path, True)
	self.cacheMD5 = 0
	self.cache = None
	self.rebuild = 1;
	self.groupware_enabled = self._db_execute("""select groupwareZone from Groups where GroupName='default'""")[0][0]
	if self.groupware_enabled == "Off":
		self.groupware_enabled = 0;
	else:
		self.groupware_enabled = 1;
	self.passwordmode = self._db_execute("""select keyValue from Config where keyName = 'userPasswordEncryptionType'""")[0][0]
	print "Password mode = " + self.passwordmode
	print "Groupware enabled = " + str(self.groupware_enabled)

    def loadFromXML(self, xmlFile):
        parser = XMLAccountsParser(xmlFile)
       
        # Totally wipe existing DB and start from scratch
        if os.path.exists(self.dbpath):
            os.remove(self.dbpath)

        self._db_execute("insert ignore into CalDavService (Realm) values (%s)", parser.realm)

        # Now add records to db
        for item in parser.items.values():
            for entry in item.itervalues():
                self._add_to_db(entry)
        self._db_commit()

    def getRealm(self):
        for realm in self._db_execute("select Realm from CalDavService"):
            result = realm[0].decode("utf-8")
        else:
            result = ""
	return result

    def passwordmodeFetch(self):
    	return self.passwordmode

    def cacheRefreshRequired(self, recordType):
    	self._db_reconnect()
    	if (self.cacheMD5 is None):
    	    self.rebuild = 1
    	    return self.rebuild

	calcmd5=self._db_execute("""select md5(concat(sum(crc32(UserSession.Account)), sum(crc32(UserSession.CalendarUserStatus)), sum(crc32(UserSession.Password)))) from UserSession""")[0][0]
	
	# do we need to rebuild?
	# self.rebuild = 0
	if (self.cacheMD5 != calcmd5):
	    self.rebuild = 1
	    self.cacheMD5 = calcmd5
	return self.rebuild

    def listRecords(self, recordType):
	# RIGHT, this is wicked slow, so lets add caching !
	# Brett Embery 2009

	accounts = set()
	
	if self.rebuild is 1:
	    self.rebuild = 0
	    # grab the accounts and calc md5 to see if anything has changed
	    if self.cache is not None:
		del self.cache
	    self.cache = set()

	    for shortName, password in self._db_execute("""select UserSession.Account, UserSession.Password from UserSession where (CalendarUserStatus != '1' OR CalendarUserStatus IS NULL)"""):
        	members = set()
        	groups = set()
 		calendarUserAddresses = set()

		# one large select from sql
		#for (member_record, member_short_name, name, address,)  in self._db_execute("""select CalGroups.MemberRecordType, CalGroups.MemberShortName, CalAddresses.Address from CalGroups, CalAddresses where CalGroups.ShortName=%s and CalGroups.ShortName=CalAddresses.ShortName""", shortName):
                #    members.add(tuple((member_record, member_short_name,)))
                #    groups.add(name)
                #    calendarUserAddresses.add(address)
		
		accounttuple = shortName.partition("@")
		if accounttuple[2] or accounttuple[1] is not "@" :
		    guid = shortName.replace("@", "_")
		    name = shortName + " Calendar";

		    # cache the result!
		    self.cache.add((shortName, guid, password, name,frozenset(members), frozenset(groups), frozenset(calendarUserAddresses),))

	for shortName, guid, password, name, members, groups, calendarUserAddresses in self.cache:
	    yield shortName, guid, password, name, members, groups, calendarUserAddresses

    def getRecord(self, recordType, shortName):	
	self._db_reconnect()
    	
	for shortName, password in self._db_execute("""select UserSession.Account, UserSession.Password from UserSession where UserSession.Account = %s and (CalendarUserStatus != '1' OR CalendarUserStatus IS NULL)""", shortName):
    	    break
    	else:
    	    return None

        # See if we have members
        # See if we are a member of any groups
        # Get calendar user addresses
        members = set()
        groups = set()
	calendarUserAddresses = set()
	if self.groupware_enabled is not 'off':
	    for (member_record, member_short_name, name, address,)  in self._db_execute("""select CalGroups.MemberRecordType, CalGroups.MemberShortName, CalAddresses.Address from CalGroups, CalAddresses where CalGroups.ShortName=%s and CalGroups.ShortName=CalAddresses.ShortName""", shortName):
		members.add(tuple((member_record, member_short_name,)))
		groups.add(name)
		calendarUserAddresses.add(address)

	guid = shortName.replace("@", "_")
	name = shortName + " Calendar";
        
        return shortName, guid, password, name, members, groups, calendarUserAddresses
            
    def members(self, shortName):
        members = set()
        for member in self._db_execute("""select MemberRecordType, MemberShortName from CalGroups where ShortName = %s""", shortName):
            members.add(tuple(member))

        return members

    def groups(self, shortName):
        groups = set()
        for (name,) in self._db_execute("""select ShortName from CalGroups where MemberShortName = %s""", shortName):
            groups.add(name)

        return groups

    def calendarUserAddresses(self, shortName):
        calendarUserAddresses = set()
        for (address,) in self._db_execute("""select Address from CalAddresses where ShortName = %s""", shortName):
            calendarUserAddresses.add(address)

        return calendarUserAddresses

    def _add_to_db(self, record):
        # Do regular account entry
        recordType = record.recordType
        shortName = record.shortNames[0]
        guid = record.guid
        password = record.password
        name = record.fullName

        self._db_execute("""insert ignore into UserSession (Account, Password) values (%s, %s)""", shortName, password)
        
        # Check for members
        for memberRecordType, memberShortName in record.members:
            self._db_execute("""insert ignore into CalGroups (ShortName, MemberRecordType, MemberShortName) values (%s, %s, %s)""", shortName, memberRecordType, memberShortName)
                
        # CUAddress
        for cuaddr in record.calendarUserAddresses:
            self._db_execute("""insert ignore into CalAddresses (Address, ShortName) values (%s, %s)""", cuaddr, shortName)
       
    def _delete_from_db(self, shortName):
        """
        Deletes the specified entry from all dbs.
        @param name: the name of the resource to delete.
        @param shortName: the short name of the resource to delete.
        """
        self._db_execute("delete from CalGroups where ShortName = %s", shortName)
        self._db_execute("delete from CalGroups where MemberShortName = %s", shortName)
        self._db_execute("delete from CalAddresses where ShortName = %s", shortName)
    
    def _db_version(self):
        """
        @return: the schema version assigned to this index.
        """
        return SQLDirectoryManager.dbFormatVersion
        
    def _db_type(self):
        """
        @return: the collection type assigned to this index.
        """
        return SQLDirectoryManager.dbType
        
    def _db_init_data_tables(self, q):
        """
        Initialise the underlying database tables.
        @param q:           a database cursor to use.
        """
	# we must populate at least one thing in here or the md5 calculation will break.
        self._db_execute("""insert ignore into CalGroups (ShortName, MemberRecordType, MemberShortName) values ("placeholder", "for", "mysqlbug")""")

class SQLDirectoryService(DirectoryService):
    """
    XML based implementation of L{IDirectoryService}.
    """
    baseGUID = "8256E464-35E0-4DBB-A99C-F0E30C231675"
    realmName = None
    cache = None

    def __repr__(self):
        return "<%s %r: %r>" % (self.__class__.__name__, self.realmName, self.manager.dbpath)

    def __init__(self, dbParentPath, xmlFile=None):
        super(SQLDirectoryService, self).__init__()

        if type(dbParentPath) is str:
            dbParentPath = FilePath(dbParentPath)
            
	try:
	        self.manager = SQLDirectoryManager(dbParentPath.path)
	except:
	        self.manager = SQLDirectoryManager(dbParentPath.dbParentPath)		
        if xmlFile:
            self.manager.loadFromXML(xmlFile)
	self.cache = None
        self.realmName = self.manager.getRealm()

    def cacheRefreshRequired(self, recordType):
	return self.manager.cacheRefreshRequired(recordType)

    def passwordmodeFetch(self):
	return self.manager.passwordmodeFetch()

    def recordTypes(self):
	recordTypes = (
            DirectoryService.recordType_users,
            DirectoryService.recordType_groups,
            DirectoryService.recordType_locations,
            DirectoryService.recordType_resources,
        )
        return recordTypes

    def listRecords(self, recordType):

      	# if self.cacheRefreshRequired() is 1 or self.cache is None:
	if self.manager.rebuild is 1 or self.cache is None:
	    if self.cache is not None:
		del self.cache
	    self.cache = set()

	    for result in self.manager.listRecords(recordType):
		self.cache.add(SQLDirectoryRecord(service = self, recordType = recordType, shortName = result[0], guid = result[1], password = result[2], name = result[3], members = result[4], groups = result[5], calendarUserAddresses = result[6]))

	for result in self.cache:
            yield result

    def recordWithShortName(self, recordType, shortName):

        result = self.manager.getRecord(recordType, shortName)
        if result:
            return SQLDirectoryRecord(
                service               = self,
                recordType            = recordType,
                shortName             = result[0],
                guid                  = result[1],
                password              = result[2],
                name                  = result[3],
                members               = result[4],
                groups                = result[5],
                calendarUserAddresses = result[6],
            )

        return None

class SQLDirectoryRecord(DirectoryRecord):
    """
    XML based implementation implementation of L{IDirectoryRecord}.
    """
    def __init__(self, service, recordType, shortName, guid, password, name, members, groups, calendarUserAddresses):

        super(SQLDirectoryRecord, self).__init__(
            service               = service,
            recordType            = recordType,
            guid                  = guid,
            shortNames            = (shortName,),
            fullName              = name,
            calendarUserAddresses = calendarUserAddresses,
        )
	self.password = password
        self._members = members
        self._groups  = groups
        self.members_cache = None
        self.groups_cache = None
	self.passwordmode = self.service.passwordmodeFetch()

    def members(self):
	
	if self.members_cache is None:
	    if self.members_cache is not None:
		del self.members_cache
	    self.members_cache = set()

	    for recordType, shortName in self._members:
        	 self.members_cache.add(self.service.recordWithShortName(recordType, shortName))
        	 
        for result in self.members_cache:
            yield result

    def groups(self):
	
	if self.groups_cache is None:
	    if self.groups_cache is not None:
		del self.groups_cache
	    self.groups_cache = set()

	    for shortName in self._groups:
        	 self.groups_cache.add(self.service.recordWithShortName(DirectoryService.recordType_groups, shortName))

        for result in self.groups_cache:
            yield result

    def verifyCredentials(self, credentials):
        if isinstance(credentials, UsernamePassword):
	    valid = False
	    if self.passwordmode == "PLAIN":
		    valid = credentials.password == self.password
	    if self.passwordmode == "MD5":
		    valid = self.password == hashlib.md5(credentials.password).hexdigest()
	    if self.passwordmode == 'MD5-CRYPT':
		    valid = md5crypt.test(credentials.password, self.password)
	    return valid

	#self.service.cacheRefreshRequired();
        result = super(SQLDirectoryRecord, self).verifyCredentials(credentials)

        return result

if __name__ == '__main__':
    mgr = SQLDirectoryManager("./")
    mgr.loadFromXML("test/accounts.xml")
