diff src/plugins/plugin_xep_0065.py @ 394:8f3551ceee17

plugin XEP-0065: refactored and misc stuff fixed. Still not finished plugins XEP-0096: XEP-0065 (Socks5 stream method) managed
author Goffi <goffi@goffi.org>
date Mon, 03 Oct 2011 18:05:15 +0200
parents 7c79d4a8c9e6
children cb0285372818
line wrap: on
line diff
--- a/src/plugins/plugin_xep_0065.py	Sun Oct 02 00:29:04 2011 +0200
+++ b/src/plugins/plugin_xep_0065.py	Mon Oct 03 18:05:15 2011 +0200
@@ -55,8 +55,10 @@
 THE SOFTWARE.
 """
 
-from logging import debug, info, error
+from logging import debug, info, warning, error
 from twisted.internet import protocol, reactor
+from twisted.internet import error as jab_error
+from twisted.words.protocols.jabber import client, jid
 from twisted.protocols.basic import FileSender
 from twisted.words.xish import domish
 from twisted.web.client import getPage
@@ -76,6 +78,7 @@
 IQ_SET = '/iq[@type="set"]'
 NS_BS = 'http://jabber.org/protocol/bytestreams'
 BS_REQUEST = IQ_SET + '/query[@xmlns="' + NS_BS + '"]'
+TIMEOUT = 60 #timeout for workflow
 
 
 
@@ -127,6 +130,13 @@
 REPLY_ADDR_NOT_SUPPORTED = 0x08
 
 
+def calculateHash(from_jid, to_jid, sid):
+    """Calculate SHA1 Hash according to XEP-0065
+    @param from_jid: jid of the requester
+    @param to_jid: jid of the target
+    @param sid: session id
+    @return: hash (string)"""
+    return hashlib.sha1((sid + from_jid.full() + to_jid.full()).encode('utf-8')).hexdigest()
 
 
 
@@ -141,9 +151,6 @@
         self.peersock = None
         self.addressType = 0
         self.requestType = 0
-        self.activeConns = {}
-        self.pendingConns = {}
-        self.transfered = 0 #nb of bytes already copied
 
     def _startNegotiation(self):
         debug("_startNegotiation")
@@ -213,27 +220,6 @@
         self.transport.write(result)
         self.transport.loseConnection()
     
-    def addConnection(self, address, connection):
-        info(_("Adding connection: %(address)s, %(connection)s") % {'address':address, 'connection':connection})
-        olist = self.pendingConns.get(address, [])
-        if len(olist) <= 1:
-            olist.append(connection)
-            self.pendingConns[address] = olist
-            return True
-        else:
-            return False
-
-    def removePendingConnection(self, address, connection):
-        olist = self.pendingConns[address]
-        if len(olist) == 1:
-            del self.pendingConns[address]
-        else:
-            olist.remove(connection)
-            self.pendingConns[address] = olist
-
-    def removeActiveConnection(self, address):
-        del self.activeConns[address]
-
     def _parseRequest(self):
         debug("_parseRequest")
         try:
@@ -279,7 +265,7 @@
     def _makeRequest(self):
         debug("_makeRequest")
         self.state = STATE_TARGET_REQUEST
-        sha1 = hashlib.sha1(self.sid + self.initiator_jid + self.target_jid).hexdigest()
+        sha1 = calculateHash(self.data["from"], self.data["to"], self.sid)
         request = struct.pack('!5B%dsH' % len(sha1), SOCKS5_VER, CMD_CONNECT, 0, ADDR_DOMAINNAME, len(sha1), sha1, 0)
         self.transport.write(request)
 
@@ -310,67 +296,43 @@
                 self.loseConnection()
                 return
 
-            debug(_("Saving file in %s."), self.data["dest_path"])
-            self.dest_file = open(self.data["dest_path"], 'w')
             self.state = STATE_TARGET_READY
-            self.activateCB(self.target_jid, self.initiator_jid, self.sid, self.IQ_id, self.xmlstream)
-
+            self.factory.activateCb(self.sid, self.factory.iq_id)
 
         except struct.error, why:
             return None
 
     def connectionMade(self):
-        debug("connectionMade (mode = %s)" % self.mode)
-        self.host.registerProgressCB(self.transfert_id, self.getProgress)
-
-        if self.mode == "target":
+        debug("connectionMade (mode = %s)" % "requester" if isinstance(self.factory, Socks5ServerFactory) else "target")
+        
+        if isinstance(self.factory, Socks5ClientFactory):
+            self.sid = self.factory.sid
+            self.data = self.factory.data
             self.state = STATE_TARGET_INITIAL
             self._startNegotiation()
 
     def connectRequested(self, addr, port):
         debug("connectRequested")
-        # Check for special connect to the namespace -- this signifies that the client
-        # is just checking to ensure it can connect to the streamhost
-        if addr == "http://jabber.org/protocol/bytestreams":
-            self.connectCompleted(addr, 0)
-            self.transport.loseConnection()
+        
+        # Check that this session if expected
+        if not self.factory.hash_sid_map.has_key(addr):
+            #no: we refuse it
+            self.sendErrorReply(socks5.REPLY_CONN_REFUSED)
             return
-            
-        # Save addr, for cleanup
-        self.addr = addr
-        
-        # Check to see if the requested address is already
-        # activated -- send an error if so
-        if addr in self.activeConns:
-            self.sendErrorReply(socks5.REPLY_CONN_NOT_ALLOWED)
-            return
+        self.sid = self.factory.hash_sid_map[addr]
+        self.factory.current_stream[self.sid]["start_transfer_cb"] = self.startTransfer
+        self.connectCompleted(addr, 0)
+        self.transport.stopReading()
 
-        # Add this address to the pending connections
-        if self.addConnection(addr, self):
-            self.connectCompleted(addr, 0)
-            self.transport.stopReading()
-        else:
-            self.sendErrorReply(socks5.REPLY_CONN_REFUSED)
-
-    def getProgress(self, data):
-        """Fill data with position of current transfert"""
-        try:
-            data["position"] = str(self.dest_file.tell())
-            data["size"] = self.filesize
-        except (ValueError, AttributeError):
-            pass
-
+    def startTransfer(self, file_obj):
+        """Callback called when the result iq is received"""
+        d = self.beginFileTransfer(file_obj, self.transport)
+        d.addCallback(self.fileTransfered)
+    
     def fileTransfered(self, d):
         info(_("File transfer completed, closing connection"))
         self.transport.loseConnection()
-        try:
-            self.dest_file.close()
-        except:
-            pass
-
-    def updateTransfered(self, data):
-        self.transfered+=len(data)
-        return data
+        self.factory.finishedCb(self.sid, True)
 
     def connectCompleted(self, remotehost, remoteport):
         debug("connectCompleted")
@@ -381,9 +343,6 @@
                                  ADDR_DOMAINNAME, len(remotehost), remotehost, remoteport)
         self.transport.write(result)
         self.state = STATE_READY
-        self.dest_file=open(self.filepath)
-        d=self.beginFileTransfer(self.dest_file, self.transport, self.updateTransfered)
-        d.addCallback(self.fileTransfered)
     
     def bindRequested(self, addr, port):
         pass
@@ -394,8 +353,7 @@
 
     def dataReceived(self, buf):
         if self.state == STATE_TARGET_READY:
-            self.dest_file.write(buf)
-            self.transfered+=len(buf)
+            self.data["file_obj"].write(buf)
             return
 
         self.buf = self.buf + buf
@@ -422,21 +380,21 @@
 
     def connectionLost(self, reason):
         debug("connectionLost")
-        self.host.removeProgressCB(self.transfert_id)
-        if self.state == STATE_CONNECT_PENDING:
-            self.removePendingConnection(self.addr, self)
-        else:
+        if self.state != STATE_CONNECT_PENDING:
             self.transport.unregisterProducer()
             if self.peersock != None:
                 self.peersock.peersock = None
                 self.peersock.transport.unregisterProducer()
                 self.peersock = None
-                self.removeActiveConnection(self.addr)
+
 
 class Socks5ServerFactory(protocol.ServerFactory):
     protocol = SOCKSv5
-    protocol.mode = "initiator"  #FIXME: Q&D way, fix it 
 
+    def __init__(self, current_stream, hash_sid_map, finishedCb):
+        self.current_stream = current_stream
+        self.hash_sid_map = hash_sid_map
+        self.finishedCb = finishedCb
 
     def startedConnecting(self, connector):
         debug (_("Socks 5 server connection started"))
@@ -446,21 +404,30 @@
 
 class Socks5ClientFactory(protocol.ClientFactory):
     protocol = SOCKSv5
-    protocol.mode = "target"  #FIXME: Q&D way, fix it 
+
+    def __init__(self, current_stream, sid, iq_id, activateCb, finishedCb):
+        self.data = current_stream[sid]
+        self.sid = sid
+        self.iq_id = iq_id
+        self.activateCb = activateCb
+        self.finishedCb = finishedCb
 
     def startedConnecting(self, connector):
         debug (_("Socks 5 client connection started"))
 
     def clientConnectionLost(self, connector, reason):
-        debug (_("Socks 5 client connection lost (reason: %s)"), reason)
+        debug (_("Socks 5 client connection lost"))
+        self.finishedCb(self.sid, reason.type == jab_error.ConnectionDone) #TODO: really check if the state is actually successful
 
 
 class XEP_0065():
     
+    NAMESPACE = NS_BS
+
     params = """
     <params>
     <general>
-    <category name="File Transfert">
+    <category name="File Transfer">
         <param name="IP" value='0.0.0.0' default_cb='yes' type="string" />
         <param name="Port" value="28915" type="string" />
     </category>
@@ -470,17 +437,20 @@
 
     def __init__(self, host):
         info(_("Plugin XEP_0065 initialization"))
+        
+        #session data
+        self.current_stream = {} #key: stream_id, value: data(dict)
+        self.hash_sid_map = {}  #key: hash of the transfer session, value: session id
+        
         self.host = host
         debug(_("registering"))
-        self.server_factory = Socks5ServerFactory()
-        self.server_factory.protocol.host = self.host #needed for progress CB
-        self.client_factory = Socks5ClientFactory()
+        self.server_factory = Socks5ServerFactory(self.current_stream, self.hash_sid_map, self._killId)
 
         #parameters
         host.memory.importParams(XEP_0065.params)
-        host.memory.setDefault("IP", "File Transfert", self.getExternalIP)
+        host.memory.setDefault("IP", "File Transfer", self.getExternalIP)
+        port = int(self.host.memory.getParamA("Port", "File Transfer"))
         
-        port = int(self.host.memory.getParamA("Port", "File Transfert"))
         info(_("Launching Socks5 Stream server on port %d"), port)
         reactor.listenTCP(port, self.server_factory)
     
@@ -491,53 +461,217 @@
         """Return IP visible from outside, by asking to a website"""
         return getPage("http://www.goffi.org/sat_tools/get_ip.php")
 
+    def getProgress(self, sid, data):
+        """Fill data with position of current transfer"""
+        try:
+            file_obj = self.current_stream[sid]["file_obj"]
+            data["position"] = str(file_obj.tell())
+            data["size"] = str(self.current_stream[sid]["size"])
+        except:
+            pass
+    
+    def _timeOut(self, sid):
+        """Delecte current_stream id, called after timeout
+        @param id: id of self.current_stream"""
+        info(_("Socks5 Bytestream: TimeOut reached for id %s") % sid);
+        self._killId(sid, False, "TIMEOUT")
+    
+    def _killId(self, sid, success=False, failure_reason="UNKNOWN"):
+        """Delete an current_stream id, clean up associated observers
+        @param sid: id of self.current_stream"""
+        if not self.current_stream.has_key(sid):
+            warning(_("kill id called on a non existant id"))
+            return
+        if self.current_stream[sid].has_key("observer_cb"):
+            xmlstream = self.current_stream[sid]["xmlstream"]
+            xmlstream.removeObserver(self.current_stream[sid]["event_data"], self.current_stream[sid]["observer_cb"])
+        if self.current_stream[sid]['timer'].active():
+            self.current_stream[sid]['timer'].cancel()
+        if self.current_stream[sid].has_key("size"):
+            self.host.removeProgressCB(sid)
+       
+        file_obj = self.current_stream[sid]['file_obj']
+        success_cb = self.current_stream[sid]['success_cb']
+        failure_cb = self.current_stream[sid]['failure_cb']
+        
+        del self.current_stream[sid]
+        if self.hash_sid_map.has_key(sid):
+            del self.hash_sid_map[sid]
+
+        if success:
+            success_cb(sid, file_obj, NS_BS)
+        else:
+            failure_cb(sid, file_obj, NS_BS, failure_reason)
+
     def setData(self, data, id):
         self.data = data
-        self.transfert_id = id
+        self.transfer_id = id
         
     def sendFile(self, id, filepath, size):
-        #lauching socks5 initiator
-        debug(_("Launching socks5 initiator"))
-        self.server_factory.protocol.mode = "initiator"
+        #lauching socks5 requester
+        debug(_("Launching socks5 requester"))
+        self.server_factory.protocol.mode = "requester"
         self.server_factory.protocol.filepath = filepath
         self.server_factory.protocol.filesize = size
-        self.server_factory.protocol.transfert_id = id
+        self.server_factory.protocol.transfer_id = id
+
+
+    def startStream(self, file_obj, to_jid, sid, length, successCb, failureCb, size = None, profile='@NONE@'):
+        """Launch the stream workflow
+        @param file_obj: file_obj to send
+        @param to_jid: JID of the recipient
+        @param sid: Stream session id
+        @param length: number of byte to send, or None to send until the end
+        @param successCb: method to call when stream successfuly finished
+        @param failureCb: method to call when something goes wrong
+        @param profile: %(doc_profile)s"""
+        if length != None:
+            error(_('stream length not managed yet'))
+            return;
+        profile_jid, xmlstream = self.host.getJidNStream(profile)
+        data = self.current_stream[sid] = {}
+        data["timer"] = reactor.callLater(TIMEOUT, self._timeOut, sid)
+        data["file_obj"] = file_obj
+        data["to"] = to_jid
+        data["success_cb"] = successCb
+        data["failure_cb"] = failureCb
+        data["xmlstream"] = xmlstream
+        data["hash"] = calculateHash(profile_jid, to_jid, sid)
+        self.hash_sid_map[data["hash"]] = sid
+        if size:
+            data["size"] = size
+            self.host.registerProgressCB(sid, self.getProgress)
+        iq_elt = client.IQ(xmlstream,'set')
+        iq_elt["from"] = profile_jid.full()
+        iq_elt["to"] = to_jid.full()
+        query_elt = iq_elt.addElement('query', NS_BS)
+        query_elt['mode'] = 'tcp'
+        query_elt['sid'] = sid
+        streamhost = query_elt.addElement('streamhost')
+        streamhost['host'] = "127.0.0.1" #self.host.memory.getParamA("IP", "File Transfer")
+        streamhost['port'] = self.host.memory.getParamA("Port", "File Transfer")
+        streamhost['jid'] = profile_jid.full()
+        iq_elt.addCallback(self.iqResult, sid)
+        iq_elt.send()
 
-    def getFile(self, iq, profile_key='@DEFAULT@'):
+    def iqResult(self, sid, iq_elt):
+        """Called when the result of open iq is received"""
+        if iq_elt["type"] == "error":
+            warning(_("Transfer failed"))
+            return
+        
+        try: 
+            data = self.current_stream[sid]
+            callback = data["start_transfer_cb"]
+            file_obj = data["file_obj"]
+            timer = data["timer"]
+        except KeyError:
+            error(_("Internal error, can't do transfer"))
+            return
+        
+        if timer.active():
+            timer.cancel()
+
+        callback(file_obj)
+
+
+    def prepareToReceive(self, from_jid, sid, file_obj, size, success_cb, failure_cb):
+        """Called when a bytestream is imminent
+        @param from_jid: jid of the sender
+        @param sid: Stream id
+        @param file_obj: File object where data will be written
+        @param size: full size of the data, or None if unknown
+        @param success_cb: method to call when successfuly finished
+        @param failure_cb: method to call when something goes wrong"""
+        data = self.current_stream[sid] = {}
+        data["from"] = from_jid
+        data["file_obj"] = file_obj
+        data["seq"] = -1
+        if size:
+            data["size"] = size
+            self.host.registerProgressCB(sid, self.getProgress)
+        data["timer"] = reactor.callLater(TIMEOUT, self._timeOut, sid)
+        data["success_cb"] = success_cb
+        data["failure_cb"] = failure_cb
+    
+    
+    def streamQuery(self, iq_elt, profile):
         """Get file using byte stream"""
-        client = self.host.getClient(profile_key)
-        assert(client)
-        iq.handled = True
-        SI_elem = iq.firstChildElement()
-        IQ_id = iq['id']
-        for element in SI_elem.elements():
-            if element.name == "streamhost":
-                info (_("Stream proposed: host=[%(host)s] port=[%(port)s]") % {'host':element['host'], 'port':element['port']})
-                factory = self.client_factory
-                self.server_factory.protocol.mode = "target"
-                factory.protocol.host = self.host #needed for progress CB
-                factory.protocol.xmlstream = client.xmlstream
-                factory.protocol.data = self.data
-                factory.protocol.transfert_id = self.transfert_id
-                factory.protocol.filesize = self.data["size"]
-                factory.protocol.sid = SI_elem['sid']
-                factory.protocol.initiator_jid = element['jid']
-                factory.protocol.target_jid = client.jid.full()
-                factory.protocol.IQ_id = IQ_id
-                factory.protocol.activateCB = self.activateStream
-                reactor.connectTCP(element['host'], int(element['port']), factory)
+        debug(_("BS stream query"))
+        profile_jid, xmlstream = self.host.getJidNStream(profile)
+        iq_elt.handled = True
+        query_elt = iq_elt.firstChildElement()
+        sid = query_elt.getAttribute("sid")
+        streamhost_elts = filter(lambda elt: elt.name == 'streamhost', query_elt.elements())
+        
+        if not sid in self.current_stream:
+            warning(_("Ignoring unexpected BS transfer: %s" % sid))
+            self.sendNotAcceptableError(iq_elt['id'], iq_elt['from'], xmlstream)
+            return
+
+        self.current_stream[sid]["to"] = jid.JID(iq_elt["to"])
+        self.current_stream[sid]["xmlstream"] = xmlstream
+
+        if not streamhost_elts:
+            warning(_("No streamhost found in stream query %s" % sid))
+            self.sendBadRequestError(iq_elt['id'], iq_elt['from'], xmlstream)
+            return
+
+        streamhost_elt = streamhost_elts[0] #TODO: manage several streamhost elements case
+        sh_host = streamhost_elt.getAttribute("host")
+        sh_port = streamhost_elt.getAttribute("port")
+        sh_jid = streamhost_elt.getAttribute("jid")
+        if not sh_host or not sh_port or not sh_jid:
+            warning(_("incomplete streamhost element"))
+            self.sendBadRequestError(iq_elt['id'], iq_elt['from'], xmlstream)
+            return
+
+        self.current_stream[sid]["streamhost"] = (sh_host, sh_port, sh_jid)
+
+        info (_("Stream proposed: host=[%(host)s] port=[%(port)s]") % {'host':sh_host, 'port':sh_port})
+        factory = Socks5ClientFactory(self.current_stream, sid, iq_elt["id"], self.activateStream, self._killId)
+        reactor.connectTCP(sh_host, int(sh_port), factory)
                 
-    def activateStream(self, from_jid, to_jid, sid, IQ_id, xmlstream):
+    def activateStream(self, sid, iq_id):
         debug(_("activating stream"))
         result = domish.Element(('', 'iq'))
+        data = self.current_stream[sid]
         result['type'] = 'result'
-        result['id'] = IQ_id
-        result['from'] = from_jid
-        result['to'] = to_jid
-        query = result.addElement('query', 'http://jabber.org/protocol/bytestreams')
+        result['id'] = iq_id
+        result['from'] = data["to"].full()
+        result['to'] = data["from"].full()
+        query = result.addElement('query', NS_BS)
         query['sid'] = sid
         streamhost = query.addElement('streamhost-used')
-        streamhost['jid'] = to_jid  #FIXME: use real streamhost
+        streamhost['jid'] = data["streamhost"][2]
+        data["xmlstream"].send(result)
+
+    def sendNotAcceptableError(self, iq_id, to_jid, xmlstream):
+        """Not acceptable error used when the stream is not expected or something is going wrong
+        @param iq_id: IQ id
+        @param to_jid: addressee
+        @param xmlstream: XML stream to use to send the error"""
+        result = domish.Element(('', 'iq'))
+        result['type'] = 'result'
+        result['id'] = iq_id
+        result['to'] = to_jid 
+        error_el = result.addElement('error')
+        error_el['type'] = 'modify'
+        error_el.addElement(('urn:ietf:params:xml:ns:xmpp-stanzas','not-acceptable'))
+        xmlstream.send(result)
+
+    def sendBadRequestError(self, iq_id, to_jid, xmlstream):
+        """Not acceptable error used when the stream is not expected or something is going wrong
+        @param iq_id: IQ id
+        @param to_jid: addressee
+        @param xmlstream: XML stream to use to send the error"""
+        result = domish.Element(('', 'iq'))
+        result['type'] = 'result'
+        result['id'] = iq_id
+        result['to'] = to_jid 
+        error_el = result.addElement('error')
+        error_el['type'] = 'cancel'
+        error_el.addElement(('urn:ietf:params:xml:ns:xmpp-stanzas','bad-request'))
         xmlstream.send(result)
 
 class XEP_0065_handler(XMPPHandler):
@@ -548,7 +682,7 @@
         self.host = plugin_parent.host
     
     def connectionInitialized(self):
-        self.xmlstream.addObserver(BS_REQUEST, self.plugin_parent.getFile)
+        self.xmlstream.addObserver(BS_REQUEST, self.plugin_parent.streamQuery, profile = self.parent.profile)
 
 
     def getDiscoInfo(self, requestor, target, nodeIdentifier=''):