comparison sat/memory/encryption.py @ 3226:2f406b762788

core (memory/encryption): encryption session are now restored on client connection
author Goffi <goffi@goffi.org>
date Sun, 22 Mar 2020 18:39:12 +0100
parents 0469c53ed5dd
children cc3fea71c365
comparison
equal deleted inserted replaced
3225:843a9279fb5a 3226:2f406b762788
15 # GNU Affero General Public License for more details. 15 # GNU Affero General Public License for more details.
16 16
17 # You should have received a copy of the GNU Affero General Public License 17 # You should have received a copy of the GNU Affero General Public License
18 # along with this program. If not, see <http://www.gnu.org/licenses/>. 18 # along with this program. If not, see <http://www.gnu.org/licenses/>.
19 19
20 import copy
20 from functools import partial 21 from functools import partial
22 from collections import namedtuple
23 from twisted.words.protocols.jabber import jid
24 from twisted.internet import defer
25 from twisted.python import failure
21 from sat.core.i18n import D_, _ 26 from sat.core.i18n import D_, _
22 from sat.core.constants import Const as C 27 from sat.core.constants import Const as C
23 from sat.core import exceptions 28 from sat.core import exceptions
24 from collections import namedtuple
25 from sat.core.log import getLogger 29 from sat.core.log import getLogger
26 from sat.tools.common import data_format 30 from sat.tools.common import data_format
27 from twisted.words.protocols.jabber import jid 31 from sat.tools import utils
28 from twisted.internet import defer 32 from sat.memory import persistent
29 from twisted.python import failure 33
30 import copy
31 log = getLogger(__name__)
32 34
33 log = getLogger(__name__) 35 log = getLogger(__name__)
34 36
35 EncryptionPlugin = namedtuple("EncryptionPlugin", ("instance", 37 EncryptionPlugin = namedtuple("EncryptionPlugin", ("instance",
36 "name", 38 "name",
37 "namespace", 39 "namespace",
38 "priority", 40 "priority",
39 "directed")) 41 "directed"))
40 42
41 43
42 class EncryptionHandler(object): 44 class EncryptionHandler:
43 """Class to handle encryption sessions for a client""" 45 """Class to handle encryption sessions for a client"""
44 plugins = [] # plugin able to encrypt messages 46 plugins = [] # plugin able to encrypt messages
45 47
46 def __init__(self, client): 48 def __init__(self, client):
47 self.client = client 49 self.client = client
48 self._sessions = {} # bare_jid ==> encryption_data 50 self._sessions = {} # bare_jid ==> encryption_data
51 self._stored_session = persistent.PersistentDict(
52 "core:encryption", profile=client.profile)
49 53
50 @property 54 @property
51 def host(self): 55 def host(self):
52 return self.client.host_app 56 return self.client.host_app
57
58 async def loadSessions(self):
59 """Load persistent sessions"""
60 await self._stored_session.load()
61 start_d_list = []
62 for entity_jid_s, namespace in self._stored_session.items():
63 entity = jid.JID(entity_jid_s)
64 start_d_list.append(defer.ensureDeferred(self.start(entity, namespace)))
65
66 if start_d_list:
67 result = await defer.DeferredList(start_d_list)
68 for idx, (success, err) in enumerate(result):
69 if not success:
70 entity_jid_s, namespace = list(self._stored_session.items())[idx]
71 log.warning(_(
72 "Could not restart {namespace!r} encryption with {entity}: {err}"
73 ).format(namespace=namespace, entity=entity_jid_s, err=err))
74 log.info(_("encryption sessions restored"))
53 75
54 @classmethod 76 @classmethod
55 def registerPlugin(cls, plg_instance, name, namespace, priority=0, directed=False): 77 def registerPlugin(cls, plg_instance, name, namespace, priority=0, directed=False):
56 """Register a plugin handling an encryption algorithm 78 """Register a plugin handling an encryption algorithm
57 79
140 if 'directed_devices' in session: 162 if 'directed_devices' in session:
141 bridge_data['directed_devices'] = session['directed_devices'] 163 bridge_data['directed_devices'] = session['directed_devices']
142 164
143 return data_format.serialise(bridge_data) 165 return data_format.serialise(bridge_data)
144 166
145 def _startEncryption(self, plugin, entity): 167 async def _startEncryption(self, plugin, entity):
146 """Start encryption with a plugin 168 """Start encryption with a plugin
147 169
148 This method must be called just before adding a plugin session. 170 This method must be called just before adding a plugin session.
149 StartEncryptionn method of plugin will be called if it exists. 171 StartEncryptionn method of plugin will be called if it exists.
150 """ 172 """
173 if not plugin.directed:
174 await self._stored_session.aset(entity.userhost(), plugin.namespace)
151 try: 175 try:
152 start_encryption = plugin.instance.startEncryption 176 start_encryption = plugin.instance.startEncryption
153 except AttributeError: 177 except AttributeError:
154 log.debug("No startEncryption method found for {plugin}".format( 178 log.debug(f"No startEncryption method found for {plugin.namespace}")
155 plugin = plugin.namespace))
156 return defer.succeed(None)
157 else: 179 else:
158 # we copy entity to avoid having the resource changed by stop_encryption 180 # we copy entity to avoid having the resource changed by stop_encryption
159 return defer.maybeDeferred(start_encryption, self.client, copy.copy(entity)) 181 await utils.asDeferred(start_encryption, self.client, copy.copy(entity))
160 182
161 def _stopEncryption(self, plugin, entity): 183 async def _stopEncryption(self, plugin, entity):
162 """Stop encryption with a plugin 184 """Stop encryption with a plugin
163 185
164 This method must be called just before removing a plugin session. 186 This method must be called just before removing a plugin session.
165 StopEncryptionn method of plugin will be called if it exists. 187 StopEncryptionn method of plugin will be called if it exists.
166 """ 188 """
189 try:
190 await self._stored_session.adel(entity.userhost())
191 except KeyError:
192 pass
167 try: 193 try:
168 stop_encryption = plugin.instance.stopEncryption 194 stop_encryption = plugin.instance.stopEncryption
169 except AttributeError: 195 except AttributeError:
170 log.debug("No stopEncryption method found for {plugin}".format( 196 log.debug(f"No stopEncryption method found for {plugin.namespace}")
171 plugin = plugin.namespace))
172 return defer.succeed(None)
173 else: 197 else:
174 # we copy entity to avoid having the resource changed by stop_encryption 198 # we copy entity to avoid having the resource changed by stop_encryption
175 return defer.maybeDeferred(stop_encryption, self.client, copy.copy(entity)) 199 return utils.asDeferred(stop_encryption, self.client, copy.copy(entity))
176 200
177 @defer.inlineCallbacks 201 async def start(self, entity, namespace=None, replace=False):
178 def start(self, entity, namespace=None, replace=False):
179 """Start an encryption session with an entity 202 """Start an encryption session with an entity
180 203
181 @param entity(jid.JID): entity to start an encryption session with 204 @param entity(jid.JID): entity to start an encryption session with
182 must be bare jid is the algorithm encrypt for all devices 205 must be bare jid is the algorithm encrypt for all devices
183 @param namespace(unicode, None): namespace of the encryption algorithm 206 @param namespace(unicode, None): namespace of the encryption algorithm
207 230
208 if replace: 231 if replace:
209 # there is a conflict, but replacement is requested 232 # there is a conflict, but replacement is requested
210 # so we stop previous encryption to use new one 233 # so we stop previous encryption to use new one
211 del self._sessions[bare_jid] 234 del self._sessions[bare_jid]
212 yield self._stopEncryption(former_plugin, entity) 235 await self._stopEncryption(former_plugin, entity)
213 else: 236 else:
214 msg = (_("Session with {bare_jid} is already encrypted with {name}. " 237 msg = (_("Session with {bare_jid} is already encrypted with {name}. "
215 "Please stop encryption session before changing algorithm.") 238 "Please stop encryption session before changing algorithm.")
216 .format(bare_jid=bare_jid, name=plugin.name)) 239 .format(bare_jid=bare_jid, name=plugin.name))
217 log.warning(msg) 240 log.warning(msg)
231 # indicate that we encrypt only for some devices 254 # indicate that we encrypt only for some devices
232 directed_devices = data['directed_devices'] = [entity.resource] 255 directed_devices = data['directed_devices'] = [entity.resource]
233 elif entity.resource: 256 elif entity.resource:
234 raise ValueError(_("{name} encryption must be used with bare jids.")) 257 raise ValueError(_("{name} encryption must be used with bare jids."))
235 258
236 yield self._startEncryption(plugin, entity) 259 await self._startEncryption(plugin, entity)
237 self._sessions[entity.userhostJID()] = data 260 self._sessions[entity.userhostJID()] = data
238 log.info(_("Encryption session has been set for {entity_jid} with " 261 log.info(_("Encryption session has been set for {entity_jid} with "
239 "{encryption_name}").format( 262 "{encryption_name}").format(
240 entity_jid=entity.full(), encryption_name=plugin.name)) 263 entity_jid=entity.full(), encryption_name=plugin.name))
241 self.host.bridge.messageEncryptionStarted( 264 self.host.bridge.messageEncryptionStarted(
252 nb_devices=len(directed_devices), 275 nb_devices=len(directed_devices),
253 devices_list = ', '.join(directed_devices)) 276 devices_list = ', '.join(directed_devices))
254 277
255 self.client.feedback(bare_jid, msg) 278 self.client.feedback(bare_jid, msg)
256 279
257 @defer.inlineCallbacks 280 async def stop(self, entity, namespace=None):
258 def stop(self, entity, namespace=None):
259 """Stop an encryption session with an entity 281 """Stop an encryption session with an entity
260 282
261 @param entity(jid.JID): entity with who the encryption session must be stopped 283 @param entity(jid.JID): entity with who the encryption session must be stopped
262 must be bare jid if the algorithm encrypt for all devices 284 must be bare jid if the algorithm encrypt for all devices
263 @param namespace(unicode): namespace of the session to stop 285 @param namespace(unicode): namespace of the session to stop
264 when specified, used to check we stop the right encryption session 286 when specified, used to check that we stop the right encryption session
265 """ 287 """
266 session = self.getSession(entity.userhostJID()) 288 session = self.getSession(entity.userhostJID())
267 if not session: 289 if not session:
268 raise failure.Failure( 290 raise failure.Failure(
269 exceptions.NotFound(_("There is no encryption session with this " 291 exceptions.NotFound(_("There is no encryption session with this "
293 if not directed_devices: 315 if not directed_devices:
294 # if we have no more directed device sessions, 316 # if we have no more directed device sessions,
295 # we stop the whole session 317 # we stop the whole session
296 # see comment below for deleting session before stopping encryption 318 # see comment below for deleting session before stopping encryption
297 del self._sessions[entity.userhostJID()] 319 del self._sessions[entity.userhostJID()]
298 yield self._stopEncryption(plugin, entity) 320 await self._stopEncryption(plugin, entity)
299 else: 321 else:
300 # plugin's stopEncryption may call stop again (that's the case with OTR) 322 # plugin's stopEncryption may call stop again (that's the case with OTR)
301 # so we need to remove plugin from session before calling self._stopEncryption 323 # so we need to remove plugin from session before calling self._stopEncryption
302 del self._sessions[entity.userhostJID()] 324 del self._sessions[entity.userhostJID()]
303 yield self._stopEncryption(plugin, entity) 325 await self._stopEncryption(plugin, entity)
304 326
305 log.info(_("encryption session stopped with entity {entity}").format( 327 log.info(_("encryption session stopped with entity {entity}").format(
306 entity=entity.full())) 328 entity=entity.full()))
307 self.host.bridge.messageEncryptionStopped( 329 self.host.bridge.messageEncryptionStopped(
308 entity.full(), 330 entity.full(),
388 410
389 @classmethod 411 @classmethod
390 def _onMenuUnencrypted(cls, data, host, profile): 412 def _onMenuUnencrypted(cls, data, host, profile):
391 client = host.getClient(profile) 413 client = host.getClient(profile)
392 peer_jid = jid.JID(data['jid']).userhostJID() 414 peer_jid = jid.JID(data['jid']).userhostJID()
393 d = client.encryption.stop(peer_jid) 415 d = defer.ensureDeferred(client.encryption.stop(peer_jid))
394 d.addCallback(lambda __: {}) 416 d.addCallback(lambda __: {})
395 return d 417 return d
396 418
397 @classmethod 419 @classmethod
398 def _onMenuName(cls, data, host, plg, profile): 420 def _onMenuName(cls, data, host, plg, profile):
399 client = host.getClient(profile) 421 client = host.getClient(profile)
400 peer_jid = jid.JID(data['jid']) 422 peer_jid = jid.JID(data['jid'])
401 if not plg.directed: 423 if not plg.directed:
402 peer_jid = peer_jid.userhostJID() 424 peer_jid = peer_jid.userhostJID()
403 d = client.encryption.start(peer_jid, plg.namespace, replace=True) 425 d = defer.ensureDeferred(
426 client.encryption.start(peer_jid, plg.namespace, replace=True))
404 d.addCallback(lambda __: {}) 427 d.addCallback(lambda __: {})
405 return d 428 return d
406 429
407 @classmethod 430 @classmethod
408 @defer.inlineCallbacks 431 @defer.inlineCallbacks