Mercurial > libervia-backend
comparison libervia/backend/core/patches.py @ 4071:4b842c1fb686
refactoring: renamed `sat` package to `libervia.backend`
author | Goffi <goffi@goffi.org> |
---|---|
date | Fri, 02 Jun 2023 11:49:51 +0200 |
parents | sat/core/patches.py@524856bd7b19 |
children | a1e7e82a8921 |
comparison
equal
deleted
inserted
replaced
4070:d10748475025 | 4071:4b842c1fb686 |
---|---|
1 import copy | |
2 from twisted.words.protocols.jabber import xmlstream, sasl, client as tclient, jid | |
3 from wokkel import client | |
4 from libervia.backend.core.constants import Const as C | |
5 from libervia.backend.core.log import getLogger | |
6 | |
7 log = getLogger(__name__) | |
8 | |
9 """This module applies monkey patches to Twisted and Wokkel | |
10 First part handle certificate validation during XMPP connectionand are temporary | |
11 (until merged upstream). | |
12 Second part add a trigger point to send and onElement method of XmlStream | |
13 """ | |
14 | |
15 | |
16 ## certificate validation patches | |
17 | |
18 | |
19 class XMPPClient(client.XMPPClient): | |
20 | |
21 def __init__(self, jid, password, host=None, port=5222, | |
22 tls_required=True, configurationForTLS=None): | |
23 self.jid = jid | |
24 self.domain = jid.host.encode('idna') | |
25 self.host = host | |
26 self.port = port | |
27 | |
28 factory = HybridClientFactory( | |
29 jid, password, tls_required=tls_required, | |
30 configurationForTLS=configurationForTLS) | |
31 | |
32 client.StreamManager.__init__(self, factory) | |
33 | |
34 | |
35 def HybridClientFactory(jid, password, tls_required=True, configurationForTLS=None): | |
36 a = HybridAuthenticator(jid, password, tls_required, configurationForTLS) | |
37 | |
38 return xmlstream.XmlStreamFactory(a) | |
39 | |
40 | |
41 class HybridAuthenticator(client.HybridAuthenticator): | |
42 res_binding = True | |
43 | |
44 def __init__(self, jid, password, tls_required=True, configurationForTLS=None): | |
45 xmlstream.ConnectAuthenticator.__init__(self, jid.host) | |
46 self.jid = jid | |
47 self.password = password | |
48 self.tls_required = tls_required | |
49 self.configurationForTLS = configurationForTLS | |
50 | |
51 def associateWithStream(self, xs): | |
52 xmlstream.ConnectAuthenticator.associateWithStream(self, xs) | |
53 | |
54 tlsInit = xmlstream.TLSInitiatingInitializer( | |
55 xs, required=self.tls_required, configurationForTLS=self.configurationForTLS) | |
56 xs.initializers = [client.client.CheckVersionInitializer(xs), | |
57 tlsInit, | |
58 CheckAuthInitializer(xs, self.res_binding)] | |
59 | |
60 | |
61 # XmlStream triggers | |
62 | |
63 | |
64 class XmlStream(xmlstream.XmlStream): | |
65 """XmlStream which allows to add hooks""" | |
66 | |
67 def __init__(self, authenticator): | |
68 xmlstream.XmlStream.__init__(self, authenticator) | |
69 # hooks at this level should not modify content | |
70 # so it's not needed to handle priority as with triggers | |
71 self._onElementHooks = [] | |
72 self._sendHooks = [] | |
73 | |
74 def add_hook(self, hook_type, callback): | |
75 """Add a send or receive hook""" | |
76 conflict_msg = f"Hook conflict: can't add {hook_type} hook {callback}" | |
77 if hook_type == C.STREAM_HOOK_RECEIVE: | |
78 if callback not in self._onElementHooks: | |
79 self._onElementHooks.append(callback) | |
80 else: | |
81 log.warning(conflict_msg) | |
82 elif hook_type == C.STREAM_HOOK_SEND: | |
83 if callback not in self._sendHooks: | |
84 self._sendHooks.append(callback) | |
85 else: | |
86 log.warning(conflict_msg) | |
87 else: | |
88 raise ValueError(f"Invalid hook type: {hook_type}") | |
89 | |
90 def onElement(self, element): | |
91 for hook in self._onElementHooks: | |
92 hook(element) | |
93 xmlstream.XmlStream.onElement(self, element) | |
94 | |
95 def send(self, obj): | |
96 for hook in self._sendHooks: | |
97 hook(obj) | |
98 xmlstream.XmlStream.send(self, obj) | |
99 | |
100 | |
101 # Binding activation (needed for stream management, XEP-0198) | |
102 | |
103 | |
104 class CheckAuthInitializer(client.CheckAuthInitializer): | |
105 | |
106 def __init__(self, xs, res_binding): | |
107 super(CheckAuthInitializer, self).__init__(xs) | |
108 self.res_binding = res_binding | |
109 | |
110 def initialize(self): | |
111 # XXX: modification of client.CheckAuthInitializer which has optional | |
112 # resource binding, and which doesn't do deprecated | |
113 # SessionInitializer | |
114 if (sasl.NS_XMPP_SASL, 'mechanisms') in self.xmlstream.features: | |
115 inits = [(sasl.SASLInitiatingInitializer, True)] | |
116 if self.res_binding: | |
117 inits.append((tclient.BindInitializer, True)), | |
118 | |
119 for initClass, required in inits: | |
120 init = initClass(self.xmlstream) | |
121 init.required = required | |
122 self.xmlstream.initializers.append(init) | |
123 elif (tclient.NS_IQ_AUTH_FEATURE, 'auth') in self.xmlstream.features: | |
124 self.xmlstream.initializers.append( | |
125 tclient.IQAuthInitializer(self.xmlstream)) | |
126 else: | |
127 raise Exception("No available authentication method found") | |
128 | |
129 | |
130 # jid fix | |
131 | |
132 def internJID(jidstring): | |
133 """ | |
134 Return interned JID. | |
135 | |
136 @rtype: L{JID} | |
137 """ | |
138 # XXX: this interJID return a copy of the cached jid | |
139 # this avoid modification of cached jid as JID is mutable | |
140 # TODO: propose this upstream | |
141 | |
142 if jidstring in jid.__internJIDs: | |
143 return copy.copy(jid.__internJIDs[jidstring]) | |
144 else: | |
145 j = jid.JID(jidstring) | |
146 jid.__internJIDs[jidstring] = j | |
147 return copy.copy(j) | |
148 | |
149 | |
150 def apply(): | |
151 # certificate validation | |
152 client.XMPPClient = XMPPClient | |
153 # XmlStream triggers | |
154 xmlstream.XmlStreamFactory.protocol = XmlStream | |
155 # jid fix | |
156 jid.internJID = internJID |