Adding retries on timeouts, still have to exhaustively test.

Should close issue #21 on github.
master
Michael P. Soulier 2011-07-23 19:40:53 -04:00
parent 6fd9391ad8
commit 1e74abf010
3 changed files with 74 additions and 25 deletions

View File

@ -149,9 +149,17 @@ class TftpServer(TftpSession):
for key in self.sessions:
try:
self.sessions[key].checkTimeout(now)
except TftpException, err:
except TftpTimeout, err:
log.error(str(err))
deletion_list.append(key)
self.sessions[key].retry_count += 1
if self.sessions[key].retry_count >= TIMEOUT_RETRIES:
log.debug("hit max retries on %s, giving up"
% self.sessions[key])
deletion_list.append(key)
else:
log.debug("resending on session %s"
% self.sessions[key])
self.sessions[key].state.resendLast()
log.debug("Iterating deletion list.")
for key in deletion_list:

View File

@ -49,3 +49,8 @@ class TftpException(Exception):
"""This class is the parent class of all exceptions regarding the handling
of the TFTP protocol."""
pass
class TftpTimeout(TftpException):
"""This class represents a timeout error waiting for a response from the
other end."""
pass

View File

@ -86,14 +86,16 @@ class TftpContext(object):
self.tidport = None
# Metrics
self.metrics = TftpMetrics()
# Flag when the transfer is pending completion.
# Fluag when the transfer is pending completion.
self.pending_complete = False
# Time when this context last received any traffic.
# FIXME: does this belong in metrics?
self.last_update = 0
# The last DAT packet we sent, if applicable, to make resending easy.
self.last_dat_pkt = None
# The last packet we sent, if applicable, to make resending easy.
self.last_pkt = None
self.dyn_file_func = dyn_file_func
# Count the number of retry attempts.
self.retry_count = 0
def __del__(self):
"""Simple destructor to try to call housekeeping in the end method if
@ -104,8 +106,9 @@ class TftpContext(object):
def checkTimeout(self, now):
"""Compare current time with last_update time, and raise an exception
if we're over SOCK_TIMEOUT time."""
log.debug("checking for timeout on session %s" % self)
if now - self.last_update > SOCK_TIMEOUT:
raise TftpException, "Timeout waiting for traffic"
raise TftpTimeout, "Timeout waiting for traffic"
def start(self):
raise NotImplementedError, "Abstract method"
@ -145,19 +148,11 @@ class TftpContext(object):
def cycle(self):
"""Here we wait for a response from the server after sending it
something, and dispatch appropriate action to that response."""
# FIXME: This won't work very well in a server context with multiple
# sessions running.
for i in range(TIMEOUT_RETRIES):
log.debug("In cycle, receive attempt %d" % i)
try:
(buffer, (raddress, rport)) = self.sock.recvfrom(MAX_BLKSIZE)
except socket.timeout, err:
log.warn("Timeout waiting for traffic, retrying...")
continue
break
else:
self.sock.close()
raise TftpException, "Hit max timeouts, giving up."
try:
(buffer, (raddress, rport)) = self.sock.recvfrom(MAX_BLKSIZE)
except socket.timeout, err:
log.warn("Timeout waiting for traffic, retrying...")
raise TftpTimeout, "Timed-out waiting for traffic"
# Ok, we've received a packet. Log it.
log.debug("Received %d bytes from %s:%s"
@ -188,6 +183,9 @@ class TftpContext(object):
# And handle it, possibly changing state.
self.state = self.state.handle(recvpkt, raddress, rport)
# If we didn't throw any exceptions here, reset the retry_count to
# zero.
self.retry_count = 0
class TftpContextServer(TftpContext):
"""The context for the server."""
@ -279,12 +277,25 @@ class TftpContextClientUpload(TftpContext):
pkt.options = self.options
self.sock.sendto(pkt.encode().buffer, (self.host, self.port))
self.next_block = 1
self.last_pkt = pkt
# FIXME: should we centralize sendto operations so we can refactor all
# saving of the packet to the last_pkt field?
self.state = TftpStateSentWRQ(self)
while self.state:
log.debug("State is %s" % self.state)
self.cycle()
try:
log.debug("State is %s" % self.state)
self.cycle()
except TftpTimeout, err:
log.error(str(err))
self.retry_count += 1
if self.retry_count >= TIMEOUT_RETRIES:
log.debug("hit max retries, giving up")
raise
else:
log.warn("resending last packet")
self.state.resendLast()
def end(self):
"""Finish up the context."""
@ -343,12 +354,23 @@ class TftpContextClientDownload(TftpContext):
pkt.options = self.options
self.sock.sendto(pkt.encode().buffer, (self.host, self.port))
self.next_block = 1
self.last_pkt = pkt
self.state = TftpStateSentRRQ(self)
while self.state:
log.debug("State is %s" % self.state)
self.cycle()
try:
log.debug("State is %s" % self.state)
self.cycle()
except TftpTimeout, err:
log.error(str(err))
self.retry_count += 1
if self.retry_count >= TIMEOUT_RETRIES:
log.debug("hit max retries, giving up")
raise
else:
log.warn("resending last packet")
self.state.resendLast()
def end(self):
"""Finish up the context."""
@ -479,7 +501,7 @@ class TftpState(object):
dat = None
if resend:
log.warn("Resending block number %d" % blocknumber)
dat = self.context.last_dat_pkt
dat = self.context.last_pkt
self.context.metrics.resent_bytes += len(dat.data)
self.context.metrics.add_dup(dat)
else:
@ -499,7 +521,7 @@ class TftpState(object):
(self.context.host, self.context.tidport))
if self.context.packethook:
self.context.packethook(dat)
self.context.last_dat_pkt = dat
self.context.last_pkt = dat
return finished
def sendACK(self, blocknumber=None):
@ -515,6 +537,7 @@ class TftpState(object):
self.context.sock.sendto(ackpkt.encode().buffer,
(self.context.host,
self.context.tidport))
self.last_pkt = ackpkt
def sendError(self, errorcode):
"""This method uses the socket passed, and uses the errorcode to
@ -525,6 +548,7 @@ class TftpState(object):
self.context.sock.sendto(errpkt.encode().buffer,
(self.context.host,
self.context.tidport))
self.last_pkt = errpkt
def sendOACK(self):
"""This method sends an OACK packet with the options from the current
@ -535,6 +559,18 @@ class TftpState(object):
self.context.sock.sendto(pkt.encode().buffer,
(self.context.host,
self.context.tidport))
self.last_pkt = pkt
def resendLast(self):
"Resend the last sent packet due to a timeout."
log.warn("Resending packet %s on sessions %s"
% (self.last_pkt, self))
self.context.metrics.resent_bytes += len(self.last_pkt.data)
self.context.metrics.add_dup(self.last_pkt)
self.context.sock.sendto(self.last_pkt.encode().buffer,
(self.context.host, self.context.tidport))
if self.context.packethook:
self.context.packethook(self.last_pkt)
def handleDat(self, pkt):
"""This method handles a DAT packet during a client download, or a