view src/memory/memory.py @ 931:3b30e9f83d88

misc: sat stop would not kill all sat instances anymore
author souliane <souliane@mailoo.org>
date Sun, 23 Mar 2014 22:44:49 +0100
parents cbf4122baae7
children 5b2d2f1f05d0
line wrap: on
line source

#!/usr/bin/python
# -*- coding: utf-8 -*-

# SAT: a jabber client
# Copyright (C) 2009, 2010, 2011, 2012, 2013, 2014 Jérôme Poisson (goffi@goffi.org)

# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.

# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU Affero General Public License for more details.

# You should have received a copy of the GNU Affero General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.

from sat.core.i18n import _

import os.path
import csv
from xdg import BaseDirectory
from ConfigParser import SafeConfigParser, NoOptionError, NoSectionError
from uuid import uuid4
from logging import debug, info, warning, error
from twisted.internet import defer, reactor
from twisted.words.protocols.jabber import jid
from sat.core import exceptions
from sat.core.constants import Const as C
from sat.memory.sqlite import SqliteStorage
from sat.memory.persistent import PersistentDict
from sat.memory.params import Params


class Sessions(object):
    DEFAULT_TIMEOUT = 600

    def __init__(self, timeout = None):
        """
        @param timeout: nb of seconds before session destruction
        """
        self._sessions = dict()
        self.timeout = timeout or Sessions.DEFAULT_TIMEOUT

    def newSession(self, session_data=None, profile=None):
        """ Create a new session
        @param session_data: mutable data to use, default to a dict
        @param profile: if set, the session is owned by the profile,
                        and profileGet must be used instead of __getitem__
        @return: session_id, session_data
        """
        session_id = str(uuid4())
        timer = reactor.callLater(self.timeout, self._purgeSession, session_id)
        if session_data is None:
            session_data = {}
        self._sessions[session_id] = (timer, session_data) if profile is None else (timer, session_data, profile)
        return session_id, session_data

    def _purgeSession(self, session_id):
        del self._sessions[session_id]
        debug("Session [%s] purged" % session_id)

    def __len__(self):
        return len(self._sessions)

    def __contains__(self, session_id):
        return session_id in self._sessions

    def profileGet(self, session_id, profile):
        try:
            timer, session_data, profile_set = self._sessions[session_id]
        except ValueError:
            raise exceptions.InternalError("You need to use __getitem__ when profile is not set")
        if profile_set != profile:
            raise exceptions.InternalError("current profile differ from set profile !")
        timer.reset(self.timeout)
        return session_data

    def __getitem__(self, session_id):
        try:
            timer, session_data = self._sessions[session_id]
        except ValueError:
            raise exceptions.InternalError("You need to use profileGet instead of __getitem__ when profile is set")
        timer.reset(self.timeout)
        return session_data

    def __setitem__(self, key, value):
        raise NotImplementedError("You need do use newSession to create a session")

    def __delitem__(self, session_id):
        """ Cancel the timer, then actually delete the session data """
        try:
            timer = self._sessions[session_id][0]
            timer.cancel()
            self._purgeSession(session_id)
        except KeyError:
            debug ("Session [%s] doesn't exists, timeout expired ?" % session_id)

    def keys(self):
        return self._sessions.keys()

    def iterkeys(self):
        return self._sessions.iterkeys()


class Memory(object):
    """This class manage all persistent informations"""

    def __init__(self, host):
        info(_("Memory manager init"))
        self.initialized = defer.Deferred()
        self.host = host
        self.entitiesCache = {}  # XXX: keep presence/last resource/other data in cache
                                 #     /!\ an entity is not necessarily in roster
        self.subscriptions = {}
        self.server_features = {}  # used to store discovery's informations
        self.server_identities = {}
        self.config = self.parseMainConf()
        self.__fixLocalDir()
        database_file = os.path.expanduser(os.path.join(self.getConfig('', 'local_dir'), C.SAVEFILE_DATABASE))
        self.storage = SqliteStorage(database_file, host.__version__)
        PersistentDict.storage = self.storage
        self.params = Params(host, self.storage)
        info(_("Loading default params template"))
        self.params.load_default_params()
        d = self.storage.initialized.addCallback(lambda ignore: self.load())
        self.memory_data = PersistentDict("memory")
        d.addCallback(lambda ignore: self.memory_data.load())
        d.chainDeferred(self.initialized)

    def parseMainConf(self):
        """look for main .ini configuration file, and parse it"""
        config = SafeConfigParser(defaults=C.DEFAULT_CONFIG)
        try:
            config.read(C.CONFIG_FILES)
        except:
            error(_("Can't read main config !"))
        return config

    # XXX: tmp update code, will be removed in the future
    # When you remove this, please also remove sat.core.constants.Const.DEFAULT_LOCAL_DIR
    # and add the default value for 'local_dir' in sat.core.constants.Const.DEFAULT_CONFIG
    def __fixLocalDir(self):
        """Retro-compatibility with the previous local_dir default value."""
        if self.getConfig('', 'local_dir'):
            return  # nothing to do
        old_default = '~/.sat'
        if os.path.isfile(os.path.expanduser(old_default) + '/' + C.SAVEFILE_DATABASE):
            warning(_("A database has been found in the default local_dir for previous versions (< 0.5)"))
            config = SafeConfigParser()
            target_file = None
            for file_ in C.CONFIG_FILES[::-1]:
                # we will eventually update the existing file with the highest priority, if it's a user personal file...
                if os.path.isfile(file_):
                    if file_.startswith(os.path.expanduser('~')):
                        config.read([file_])
                        target_file = file_
                    break
            if not target_file:
                # ... otherwise we create a new config file for that user
                target_file = BaseDirectory.save_config_path('sat') + '/sat.conf'
            config.set('', 'local_dir', old_default)
            with open(target_file, 'wb') as configfile:
                config.write(configfile)
            warning(_("Auto-update: local_dir set to %(path)s in the file %(config_file)s") % {'path': old_default, 'config_file': file_})
        else:  # use the new default local_dir
            self.config.set('', 'local_dir', C.DEFAULT_LOCAL_DIR)

    def getConfig(self, section, name):
        """Get the main configuration option
        @param section: section of the config file (None or '' for DEFAULT)
        @param name: name of the option
        """
        if not section:
            section = 'DEFAULT'
        try:
            value = self.config.get(section, name)
        except (NoOptionError, NoSectionError):
            value = ''

        if name.endswith('_path') or name.endswith('_dir'):
            value = os.path.expanduser(value)
        # thx to Brian (http://stackoverflow.com/questions/186857/splitting-a-semicolon-separated-string-to-a-dictionary-in-python/186873#186873)
        elif name.endswith('_list'):
            value = csv.reader([value], delimiter=',', quotechar='"').next()
        elif name.endswith('_dict'):
            value = dict(csv.reader([item], delimiter=':', quotechar='"').next()
                         for item in csv.reader([value], delimiter=',', quotechar='"').next())
        return value

    def load_xml(self, filename):
        """Load parameters template from xml file"""
        if filename is None:
            return False
        filename = os.path.expanduser(filename)
        if os.path.exists(filename):
            try:
                self.params.load_xml(filename)
                debug(_("Parameters loaded from file: %s") % filename)
                return True
            except Exception as e:
                error(_("Can't load parameters from file: %s") % e)
        return False

    def load(self):
        """Load parameters and all memory things from db"""
        #parameters data
        return self.params.loadGenParams()

    def loadIndividualParams(self, profile):
        """Load individual parameters for a profile
        @param profile: %(doc_profile)s"""
        return self.params.loadIndParams(profile)

    def startProfileSession(self, profile):
        """"Iniatialise session for a profile
        @param profile: %(doc_profile)s"""
        info(_("[%s] Profile session started" % profile))
        self.entitiesCache[profile] = {}

    def purgeProfileSession(self, profile):
        """Delete cache of data of profile
        @param profile: %(doc_profile)s"""
        info(_("[%s] Profile session purge" % profile))
        self.params.purgeProfile(profile)
        try:
            del self.entitiesCache[profile]
        except KeyError:
            error(_("Trying to purge roster status cache for a profile not in memory: [%s]") % profile)

    def save_xml(self, filename=None):
        """Save parameters template to xml file"""
        if filename is None:
            return False
        #TODO: need to encrypt files (at least passwords !) and set permissions
        filename = os.path.expanduser(filename)
        try:
            self.params.save_xml(filename)
            debug(_("Parameters saved to file: %s") % filename)
            return True
        except Exception as e:
            error(_("Can't save parameters to file: %s") % e)
        return False

    def getProfilesList(self):
        return self.storage.getProfilesList()

    def getProfileName(self, profile_key, return_profile_keys = False):
        """Return name of profile from keyword
        @param profile_key: can be the profile name or a keywork (like @DEFAULT@)
        @return: profile name or None if it doesn't exist"""
        return self.params.getProfileName(profile_key, return_profile_keys)

    def asyncCreateProfile(self, name):
        """Create a new profile
        @param name: Profile name
        """
        return self.params.asyncCreateProfile(name)

    def asyncDeleteProfile(self, name, force=False):
        """Delete an existing profile
        @param name: Name of the profile
        @param force: force the deletion even if the profile is connected.
        To be used for direct calls only (not through the bridge).
        @return: a Deferred instance
        """
        return self.params.asyncDeleteProfile(name, force)

    def addToHistory(self, from_jid, to_jid, message, type_='chat', extra=None, timestamp=None, profile=C.PROF_KEY_NONE):
        assert profile != C.PROF_KEY_NONE
        if extra is None:
            extra = {}
        return self.storage.addToHistory(from_jid, to_jid, message, type_, extra, timestamp, profile)

    def getHistory(self, from_jid, to_jid, limit=0, between=True, profile=C.PROF_KEY_NONE):
        assert profile != C.PROF_KEY_NONE
        return self.storage.getHistory(jid.JID(from_jid), jid.JID(to_jid), limit, between, profile)

    def addServerFeature(self, feature, jid_, profile):
        """Add a feature discovered from server
        @param feature: string of the feature
        @param jid_: the jid of the target server
        @param profile: which profile asked this server?"""
        if profile not in self.server_features:
            self.server_features[profile] = {}
        features = self.server_features[profile].setdefault(jid_, [])
        features.append(feature)

    def addServerIdentity(self, category, type_, entity, jid_, profile):
        """Add an identity discovered from server
        @param feature: string of the feature
        @param jid_: the jid of the target server
        @param profile: which profile asked this server?"""
        if not profile in self.server_identities:
            self.server_identities[profile] = {}
        identities = self.server_identities[profile].setdefault(jid_, {})
        if (category, type_) not in identities:
            identities[(category, type_)] = set()
        identities[(category, type_)].add(entity)

    def getServerServiceEntities(self, category, type_, jid_=None, profile=None):
        """Return all available entities of a server for the service (category, type_)
        @param category: identity's category
        @param type_: identitiy's type
        @param jid_: the jid of the target server (None for profile's server)
        @param profile: which profile is asking this server?
        @return: a set of entities or None if no cached data were found
        """
        if jid_ is None:
            jid_ = self.host.getClientHostJid(profile)
        if profile in self.server_identities and jid_ in self.server_identities[profile]:
            return self.server_identities[profile][jid_].get((category, type_), set())
        else:
            return None

    def getServerServiceEntity(self, category, type_, jid_=None, profile=None):
        """Helper method to get first available entity of a server for the service (category, type_)
        @param category: identity's category
        @param type_: identitiy's type
        @param jid_: the jid of the target server (None for profile's server)
        @param profile: which profile is asking this server?
        @return: the first found entity or None if no cached data were found
        """
        entities = self.getServerServiceEntities(category, type_, jid_, profile)
        if entities is None:
            warning(_("Entities (%(category)s/%(type)s) of %(server)s not available, maybe they haven't been asked yet?")
                    % {"category": category, "type": type_, "server": jid_})
            return None
        else:
            return list(entities)[0] if entities else None

    def getAllServerIdentities(self, jid_, profile):
        """Helper method to get all identities of a server
        @param jid_: the jid of the target server (None for profile's server)
        @param profile: which profile is asking this server?
        @return: a set of entities or None if no cached data were found
        """
        if jid_ is None:
            jid_ = self.host.getClientHostJid(profile)
        if jid_ not in self.server_identities[profile]:
            return None
        entities = set()
        for set_ in self.server_identities[profile][jid_].values():
            entities.update(set_)
        return entities

    def hasServerFeature(self, feature, jid_=None, profile_key=C.PROF_KEY_NONE):
        """Tell if the specified server has the required feature
        @param feature: requested feature
        @param jid_: the jid of the target server (None for profile's server)
        @param profile_key: %(doc_profile_key)s
        """
        profile = self.getProfileName(profile_key)
        if not profile:
            error(_('Trying find server feature for a non-existant profile'))
            return None
        assert profile in self.server_features
        if jid_ is None:
            jid_ = self.host.getClientHostJid(profile)
        if jid_ in self.server_features[profile]:
            return feature in self.server_features[profile][jid_]
        else:
            warning(_("Features of %s not available, maybe they haven't been asked yet?") % jid_)
            return None

    def _getLastResource(self, jid_s, profile_key):
        jid_ = jid.JID(jid_s)
        return self.getLastResource(jid_, profile_key)


    def getLastResource(self, jid_, profile_key):
        """Return the last resource used by a jid_
        @param jid_: bare jid
        @param profile_key: %(doc_profile_key)s"""
        profile = self.getProfileName(profile_key)
        if not profile or not self.host.isConnected(profile):
            error(_('Asking jid_s for a non-existant or not connected profile'))
            return ""
        entity = jid_.userhost()
        if not entity in self.entitiesCache[profile]:
            info(_("Entity not in cache"))
            return ""
        try:
            return self.entitiesCache[profile][entity]["last_resource"]
        except KeyError:
            return ""

    def getPresenceStatus(self, profile_key):
        profile = self.getProfileName(profile_key)
        if not profile:
            error(_('Asking contacts for a non-existant profile'))
            return {}
        entities_presence = {}
        for entity in self.entitiesCache[profile]:
            if "presence" in self.entitiesCache[profile][entity]:
                entities_presence[entity] = self.entitiesCache[profile][entity]["presence"]

        debug("Memory getPresenceStatus (%s)", entities_presence)
        return entities_presence

    def setPresenceStatus(self, entity_jid, show, priority, statuses, profile_key):
        """Change the presence status of an entity"""
        profile = self.getProfileName(profile_key)
        if not profile:
            error(_('Trying to add presence status to a non-existant profile'))
            return
        entity_data = self.entitiesCache[profile].setdefault(entity_jid.userhost(), {})
        resource = jid.parse(entity_jid.full())[2] or ''
        if resource:
            entity_data["last_resource"] = resource
        if not "last_resource" in entity_data:
            entity_data["last_resource"] = ''

        entity_data.setdefault("presence", {})[resource] = (show, priority, statuses)

    def updateEntityData(self, entity_jid, key, value, profile_key):
        """Set a misc data for an entity
        @param entity_jid: JID of the entity, or '@ALL@' to update all entities)
        @param key: key to set (eg: "type")
        @param value: value for this key (eg: "chatroom"), or C.PROF_KEY_NONE to delete
        @param profile_key: %(doc_profile_key)s
        """
        profile = self.getProfileName(profile_key)
        if not profile:
            raise exceptions.ProfileUnknownError(_('Trying to get entity data for a non-existant profile'))
        if not profile in self.entitiesCache:
            raise exceptions.ProfileNotInCacheError
        if entity_jid == "@ALL@":
            entities_map = self.entitiesCache[profile]
        else:
            entity = entity_jid.userhost()
            self.entitiesCache[profile].setdefault(entity, {})
            entities_map = {entity: self.entitiesCache[profile][entity]}
        for entity in entities_map:
            entity_map = entities_map[entity]
            if value == C.PROF_KEY_NONE and key in entity_map:
                del entity_map[key]
            else:
                entity_map[key] = value
            if isinstance(value, basestring):
                self.host.bridge.entityDataUpdated(entity, key, value, profile)

    def getEntityData(self, entity_jid, keys_list, profile_key):
        """Get a list of cached values for entity
        @param entity_jid: JID of the entity
        @param keys_list: list of keys to get, empty list for everything
        @param profile_key: %(doc_profile_key)s
        @return: dict withs values for each key in keys_list.
                 if there is no value of a given key, resulting dict will
                 have nothing with that key nether
        @raise: exceptions.UnknownEntityError if entity is not in cache
                exceptions.ProfileNotInCacheError if profile is not in cache
        """
        profile = self.getProfileName(profile_key)
        if not profile:
            raise exceptions.ProfileUnknownError(_('Trying to get entity data for a non-existant profile'))
        if not profile in self.entitiesCache:
            raise exceptions.ProfileNotInCacheError
        if not entity_jid.userhost() in self.entitiesCache[profile]:
            raise exceptions.UnknownEntityError(entity_jid.userhost())
        entity_data = self.entitiesCache[profile][entity_jid.userhost()]
        if not keys_list:
            return entity_data
        ret = {}
        for key in keys_list:
            if key in entity_data:
                ret[key] = entity_data[key]
        return ret

    def delEntityCache(self, entity_jid, profile_key):
        """Remove cached data for entity
        @param entity_jid: JID of the entity
        """
        profile = self.getProfileName(profile_key)
        try:
            del self.entitiesCache[profile][entity_jid.userhost()]
        except KeyError:
            pass

    def addWaitingSub(self, type_, entity_jid, profile_key):
        """Called when a subcription request is received"""
        profile = self.getProfileName(profile_key)
        assert profile
        if profile not in self.subscriptions:
            self.subscriptions[profile] = {}
        self.subscriptions[profile][entity_jid] = type_

    def delWaitingSub(self, entity_jid, profile_key):
        """Called when a subcription request is finished"""
        profile = self.getProfileName(profile_key)
        assert profile
        if profile in self.subscriptions and entity_jid in self.subscriptions[profile]:
            del self.subscriptions[profile][entity_jid]

    def getWaitingSub(self, profile_key):
        """Called to get a list of currently waiting subscription requests"""
        profile = self.getProfileName(profile_key)
        if not profile:
            error(_('Asking waiting subscriptions for a non-existant profile'))
            return {}
        if profile not in self.subscriptions:
            return {}

        return self.subscriptions[profile]

    def getStringParamA(self, name, category, attr="value", profile_key=C.PROF_KEY_NONE):
        return self.params.getStringParamA(name, category, attr, profile_key)

    def getParamA(self, name, category, attr="value", profile_key=C.PROF_KEY_NONE):
        return self.params.getParamA(name, category, attr, profile_key)

    def asyncGetParamA(self, name, category, attr="value", security_limit=C.NO_SECURITY_LIMIT, profile_key=C.PROF_KEY_NONE):
        return self.params.asyncGetParamA(name, category, attr, security_limit, profile_key)

    def asyncGetStringParamA(self, name, category, attr="value", security_limit=C.NO_SECURITY_LIMIT, profile_key=C.PROF_KEY_NONE):
        return self.params.asyncGetStringParamA(name, category, attr, security_limit, profile_key)

    def getParamsUI(self, security_limit=C.NO_SECURITY_LIMIT, app='', profile_key=C.PROF_KEY_NONE):
        return self.params.getParamsUI(security_limit, app, profile_key)

    def getParams(self, security_limit=C.NO_SECURITY_LIMIT, app='', profile_key=C.PROF_KEY_NONE):
        return self.params.getParams(security_limit, app, profile_key)

    def getParamsForCategory(self, category, security_limit=C.NO_SECURITY_LIMIT, app='', profile_key=C.PROF_KEY_NONE):
        return self.params.getParamsForCategory(category, security_limit, app, profile_key)

    def getParamsCategories(self):
        return self.params.getParamsCategories()

    def setParam(self, name, value, category, security_limit=C.NO_SECURITY_LIMIT, profile_key=C.PROF_KEY_NONE):
        return self.params.setParam(name, value, category, security_limit, profile_key)

    def updateParams(self, xml):
        return self.params.updateParams(xml)

    def paramsRegisterApp(self, xml, security_limit=C.NO_SECURITY_LIMIT, app=''):
        return self.params.paramsRegisterApp(xml, security_limit, app)

    def setDefault(self, name, category, callback, errback=None):
        return self.params.setDefault(name, category, callback, errback)