From 1e74abf010088abd4bab27de74778e41393911dd Mon Sep 17 00:00:00 2001 From: "Michael P. Soulier" Date: Sat, 23 Jul 2011 19:40:53 -0400 Subject: [PATCH] Adding retries on timeouts, still have to exhaustively test. Should close issue #21 on github. --- tftpy/TftpServer.py | 12 +++++-- tftpy/TftpShared.py | 5 +++ tftpy/TftpStates.py | 82 ++++++++++++++++++++++++++++++++------------- 3 files changed, 74 insertions(+), 25 deletions(-) diff --git a/tftpy/TftpServer.py b/tftpy/TftpServer.py index 9e64d83..46c662b 100644 --- a/tftpy/TftpServer.py +++ b/tftpy/TftpServer.py @@ -149,9 +149,17 @@ class TftpServer(TftpSession): for key in self.sessions: try: self.sessions[key].checkTimeout(now) - except TftpException, err: + except TftpTimeout, err: log.error(str(err)) - deletion_list.append(key) + self.sessions[key].retry_count += 1 + if self.sessions[key].retry_count >= TIMEOUT_RETRIES: + log.debug("hit max retries on %s, giving up" + % self.sessions[key]) + deletion_list.append(key) + else: + log.debug("resending on session %s" + % self.sessions[key]) + self.sessions[key].state.resendLast() log.debug("Iterating deletion list.") for key in deletion_list: diff --git a/tftpy/TftpShared.py b/tftpy/TftpShared.py index 69ade90..1039ed2 100644 --- a/tftpy/TftpShared.py +++ b/tftpy/TftpShared.py @@ -49,3 +49,8 @@ class TftpException(Exception): """This class is the parent class of all exceptions regarding the handling of the TFTP protocol.""" pass + +class TftpTimeout(TftpException): + """This class represents a timeout error waiting for a response from the + other end.""" + pass diff --git a/tftpy/TftpStates.py b/tftpy/TftpStates.py index 6c77499..0992d6c 100644 --- a/tftpy/TftpStates.py +++ b/tftpy/TftpStates.py @@ -86,14 +86,16 @@ class TftpContext(object): self.tidport = None # Metrics self.metrics = TftpMetrics() - # Flag when the transfer is pending completion. + # Fluag when the transfer is pending completion. self.pending_complete = False # Time when this context last received any traffic. # FIXME: does this belong in metrics? self.last_update = 0 - # The last DAT packet we sent, if applicable, to make resending easy. - self.last_dat_pkt = None + # The last packet we sent, if applicable, to make resending easy. + self.last_pkt = None self.dyn_file_func = dyn_file_func + # Count the number of retry attempts. + self.retry_count = 0 def __del__(self): """Simple destructor to try to call housekeeping in the end method if @@ -104,8 +106,9 @@ class TftpContext(object): def checkTimeout(self, now): """Compare current time with last_update time, and raise an exception if we're over SOCK_TIMEOUT time.""" + log.debug("checking for timeout on session %s" % self) if now - self.last_update > SOCK_TIMEOUT: - raise TftpException, "Timeout waiting for traffic" + raise TftpTimeout, "Timeout waiting for traffic" def start(self): raise NotImplementedError, "Abstract method" @@ -145,19 +148,11 @@ class TftpContext(object): def cycle(self): """Here we wait for a response from the server after sending it something, and dispatch appropriate action to that response.""" - # FIXME: This won't work very well in a server context with multiple - # sessions running. - for i in range(TIMEOUT_RETRIES): - log.debug("In cycle, receive attempt %d" % i) - try: - (buffer, (raddress, rport)) = self.sock.recvfrom(MAX_BLKSIZE) - except socket.timeout, err: - log.warn("Timeout waiting for traffic, retrying...") - continue - break - else: - self.sock.close() - raise TftpException, "Hit max timeouts, giving up." + try: + (buffer, (raddress, rport)) = self.sock.recvfrom(MAX_BLKSIZE) + except socket.timeout, err: + log.warn("Timeout waiting for traffic, retrying...") + raise TftpTimeout, "Timed-out waiting for traffic" # Ok, we've received a packet. Log it. log.debug("Received %d bytes from %s:%s" @@ -188,6 +183,9 @@ class TftpContext(object): # And handle it, possibly changing state. self.state = self.state.handle(recvpkt, raddress, rport) + # If we didn't throw any exceptions here, reset the retry_count to + # zero. + self.retry_count = 0 class TftpContextServer(TftpContext): """The context for the server.""" @@ -279,12 +277,25 @@ class TftpContextClientUpload(TftpContext): pkt.options = self.options self.sock.sendto(pkt.encode().buffer, (self.host, self.port)) self.next_block = 1 + self.last_pkt = pkt + # FIXME: should we centralize sendto operations so we can refactor all + # saving of the packet to the last_pkt field? self.state = TftpStateSentWRQ(self) while self.state: - log.debug("State is %s" % self.state) - self.cycle() + try: + log.debug("State is %s" % self.state) + self.cycle() + except TftpTimeout, err: + log.error(str(err)) + self.retry_count += 1 + if self.retry_count >= TIMEOUT_RETRIES: + log.debug("hit max retries, giving up") + raise + else: + log.warn("resending last packet") + self.state.resendLast() def end(self): """Finish up the context.""" @@ -343,12 +354,23 @@ class TftpContextClientDownload(TftpContext): pkt.options = self.options self.sock.sendto(pkt.encode().buffer, (self.host, self.port)) self.next_block = 1 + self.last_pkt = pkt self.state = TftpStateSentRRQ(self) while self.state: - log.debug("State is %s" % self.state) - self.cycle() + try: + log.debug("State is %s" % self.state) + self.cycle() + except TftpTimeout, err: + log.error(str(err)) + self.retry_count += 1 + if self.retry_count >= TIMEOUT_RETRIES: + log.debug("hit max retries, giving up") + raise + else: + log.warn("resending last packet") + self.state.resendLast() def end(self): """Finish up the context.""" @@ -479,7 +501,7 @@ class TftpState(object): dat = None if resend: log.warn("Resending block number %d" % blocknumber) - dat = self.context.last_dat_pkt + dat = self.context.last_pkt self.context.metrics.resent_bytes += len(dat.data) self.context.metrics.add_dup(dat) else: @@ -499,7 +521,7 @@ class TftpState(object): (self.context.host, self.context.tidport)) if self.context.packethook: self.context.packethook(dat) - self.context.last_dat_pkt = dat + self.context.last_pkt = dat return finished def sendACK(self, blocknumber=None): @@ -515,6 +537,7 @@ class TftpState(object): self.context.sock.sendto(ackpkt.encode().buffer, (self.context.host, self.context.tidport)) + self.last_pkt = ackpkt def sendError(self, errorcode): """This method uses the socket passed, and uses the errorcode to @@ -525,6 +548,7 @@ class TftpState(object): self.context.sock.sendto(errpkt.encode().buffer, (self.context.host, self.context.tidport)) + self.last_pkt = errpkt def sendOACK(self): """This method sends an OACK packet with the options from the current @@ -535,6 +559,18 @@ class TftpState(object): self.context.sock.sendto(pkt.encode().buffer, (self.context.host, self.context.tidport)) + self.last_pkt = pkt + + def resendLast(self): + "Resend the last sent packet due to a timeout." + log.warn("Resending packet %s on sessions %s" + % (self.last_pkt, self)) + self.context.metrics.resent_bytes += len(self.last_pkt.data) + self.context.metrics.add_dup(self.last_pkt) + self.context.sock.sendto(self.last_pkt.encode().buffer, + (self.context.host, self.context.tidport)) + if self.context.packethook: + self.context.packethook(self.last_pkt) def handleDat(self, pkt): """This method handles a DAT packet during a client download, or a