Adding retries on timeouts, still have to exhaustively test.
Should close issue #21 on github.
This commit is contained in:
parent
6fd9391ad8
commit
1e74abf010
3 changed files with 74 additions and 25 deletions
|
@ -149,9 +149,17 @@ class TftpServer(TftpSession):
|
|||
for key in self.sessions:
|
||||
try:
|
||||
self.sessions[key].checkTimeout(now)
|
||||
except TftpException, err:
|
||||
except TftpTimeout, err:
|
||||
log.error(str(err))
|
||||
deletion_list.append(key)
|
||||
self.sessions[key].retry_count += 1
|
||||
if self.sessions[key].retry_count >= TIMEOUT_RETRIES:
|
||||
log.debug("hit max retries on %s, giving up"
|
||||
% self.sessions[key])
|
||||
deletion_list.append(key)
|
||||
else:
|
||||
log.debug("resending on session %s"
|
||||
% self.sessions[key])
|
||||
self.sessions[key].state.resendLast()
|
||||
|
||||
log.debug("Iterating deletion list.")
|
||||
for key in deletion_list:
|
||||
|
|
|
@ -49,3 +49,8 @@ class TftpException(Exception):
|
|||
"""This class is the parent class of all exceptions regarding the handling
|
||||
of the TFTP protocol."""
|
||||
pass
|
||||
|
||||
class TftpTimeout(TftpException):
|
||||
"""This class represents a timeout error waiting for a response from the
|
||||
other end."""
|
||||
pass
|
||||
|
|
|
@ -86,14 +86,16 @@ class TftpContext(object):
|
|||
self.tidport = None
|
||||
# Metrics
|
||||
self.metrics = TftpMetrics()
|
||||
# Flag when the transfer is pending completion.
|
||||
# Fluag when the transfer is pending completion.
|
||||
self.pending_complete = False
|
||||
# Time when this context last received any traffic.
|
||||
# FIXME: does this belong in metrics?
|
||||
self.last_update = 0
|
||||
# The last DAT packet we sent, if applicable, to make resending easy.
|
||||
self.last_dat_pkt = None
|
||||
# The last packet we sent, if applicable, to make resending easy.
|
||||
self.last_pkt = None
|
||||
self.dyn_file_func = dyn_file_func
|
||||
# Count the number of retry attempts.
|
||||
self.retry_count = 0
|
||||
|
||||
def __del__(self):
|
||||
"""Simple destructor to try to call housekeeping in the end method if
|
||||
|
@ -104,8 +106,9 @@ class TftpContext(object):
|
|||
def checkTimeout(self, now):
|
||||
"""Compare current time with last_update time, and raise an exception
|
||||
if we're over SOCK_TIMEOUT time."""
|
||||
log.debug("checking for timeout on session %s" % self)
|
||||
if now - self.last_update > SOCK_TIMEOUT:
|
||||
raise TftpException, "Timeout waiting for traffic"
|
||||
raise TftpTimeout, "Timeout waiting for traffic"
|
||||
|
||||
def start(self):
|
||||
raise NotImplementedError, "Abstract method"
|
||||
|
@ -145,19 +148,11 @@ class TftpContext(object):
|
|||
def cycle(self):
|
||||
"""Here we wait for a response from the server after sending it
|
||||
something, and dispatch appropriate action to that response."""
|
||||
# FIXME: This won't work very well in a server context with multiple
|
||||
# sessions running.
|
||||
for i in range(TIMEOUT_RETRIES):
|
||||
log.debug("In cycle, receive attempt %d" % i)
|
||||
try:
|
||||
(buffer, (raddress, rport)) = self.sock.recvfrom(MAX_BLKSIZE)
|
||||
except socket.timeout, err:
|
||||
log.warn("Timeout waiting for traffic, retrying...")
|
||||
continue
|
||||
break
|
||||
else:
|
||||
self.sock.close()
|
||||
raise TftpException, "Hit max timeouts, giving up."
|
||||
try:
|
||||
(buffer, (raddress, rport)) = self.sock.recvfrom(MAX_BLKSIZE)
|
||||
except socket.timeout, err:
|
||||
log.warn("Timeout waiting for traffic, retrying...")
|
||||
raise TftpTimeout, "Timed-out waiting for traffic"
|
||||
|
||||
# Ok, we've received a packet. Log it.
|
||||
log.debug("Received %d bytes from %s:%s"
|
||||
|
@ -188,6 +183,9 @@ class TftpContext(object):
|
|||
|
||||
# And handle it, possibly changing state.
|
||||
self.state = self.state.handle(recvpkt, raddress, rport)
|
||||
# If we didn't throw any exceptions here, reset the retry_count to
|
||||
# zero.
|
||||
self.retry_count = 0
|
||||
|
||||
class TftpContextServer(TftpContext):
|
||||
"""The context for the server."""
|
||||
|
@ -279,12 +277,25 @@ class TftpContextClientUpload(TftpContext):
|
|||
pkt.options = self.options
|
||||
self.sock.sendto(pkt.encode().buffer, (self.host, self.port))
|
||||
self.next_block = 1
|
||||
self.last_pkt = pkt
|
||||
# FIXME: should we centralize sendto operations so we can refactor all
|
||||
# saving of the packet to the last_pkt field?
|
||||
|
||||
self.state = TftpStateSentWRQ(self)
|
||||
|
||||
while self.state:
|
||||
log.debug("State is %s" % self.state)
|
||||
self.cycle()
|
||||
try:
|
||||
log.debug("State is %s" % self.state)
|
||||
self.cycle()
|
||||
except TftpTimeout, err:
|
||||
log.error(str(err))
|
||||
self.retry_count += 1
|
||||
if self.retry_count >= TIMEOUT_RETRIES:
|
||||
log.debug("hit max retries, giving up")
|
||||
raise
|
||||
else:
|
||||
log.warn("resending last packet")
|
||||
self.state.resendLast()
|
||||
|
||||
def end(self):
|
||||
"""Finish up the context."""
|
||||
|
@ -343,12 +354,23 @@ class TftpContextClientDownload(TftpContext):
|
|||
pkt.options = self.options
|
||||
self.sock.sendto(pkt.encode().buffer, (self.host, self.port))
|
||||
self.next_block = 1
|
||||
self.last_pkt = pkt
|
||||
|
||||
self.state = TftpStateSentRRQ(self)
|
||||
|
||||
while self.state:
|
||||
log.debug("State is %s" % self.state)
|
||||
self.cycle()
|
||||
try:
|
||||
log.debug("State is %s" % self.state)
|
||||
self.cycle()
|
||||
except TftpTimeout, err:
|
||||
log.error(str(err))
|
||||
self.retry_count += 1
|
||||
if self.retry_count >= TIMEOUT_RETRIES:
|
||||
log.debug("hit max retries, giving up")
|
||||
raise
|
||||
else:
|
||||
log.warn("resending last packet")
|
||||
self.state.resendLast()
|
||||
|
||||
def end(self):
|
||||
"""Finish up the context."""
|
||||
|
@ -479,7 +501,7 @@ class TftpState(object):
|
|||
dat = None
|
||||
if resend:
|
||||
log.warn("Resending block number %d" % blocknumber)
|
||||
dat = self.context.last_dat_pkt
|
||||
dat = self.context.last_pkt
|
||||
self.context.metrics.resent_bytes += len(dat.data)
|
||||
self.context.metrics.add_dup(dat)
|
||||
else:
|
||||
|
@ -499,7 +521,7 @@ class TftpState(object):
|
|||
(self.context.host, self.context.tidport))
|
||||
if self.context.packethook:
|
||||
self.context.packethook(dat)
|
||||
self.context.last_dat_pkt = dat
|
||||
self.context.last_pkt = dat
|
||||
return finished
|
||||
|
||||
def sendACK(self, blocknumber=None):
|
||||
|
@ -515,6 +537,7 @@ class TftpState(object):
|
|||
self.context.sock.sendto(ackpkt.encode().buffer,
|
||||
(self.context.host,
|
||||
self.context.tidport))
|
||||
self.last_pkt = ackpkt
|
||||
|
||||
def sendError(self, errorcode):
|
||||
"""This method uses the socket passed, and uses the errorcode to
|
||||
|
@ -525,6 +548,7 @@ class TftpState(object):
|
|||
self.context.sock.sendto(errpkt.encode().buffer,
|
||||
(self.context.host,
|
||||
self.context.tidport))
|
||||
self.last_pkt = errpkt
|
||||
|
||||
def sendOACK(self):
|
||||
"""This method sends an OACK packet with the options from the current
|
||||
|
@ -535,6 +559,18 @@ class TftpState(object):
|
|||
self.context.sock.sendto(pkt.encode().buffer,
|
||||
(self.context.host,
|
||||
self.context.tidport))
|
||||
self.last_pkt = pkt
|
||||
|
||||
def resendLast(self):
|
||||
"Resend the last sent packet due to a timeout."
|
||||
log.warn("Resending packet %s on sessions %s"
|
||||
% (self.last_pkt, self))
|
||||
self.context.metrics.resent_bytes += len(self.last_pkt.data)
|
||||
self.context.metrics.add_dup(self.last_pkt)
|
||||
self.context.sock.sendto(self.last_pkt.encode().buffer,
|
||||
(self.context.host, self.context.tidport))
|
||||
if self.context.packethook:
|
||||
self.context.packethook(self.last_pkt)
|
||||
|
||||
def handleDat(self, pkt):
|
||||
"""This method handles a DAT packet during a client download, or a
|
||||
|
|
Reference in a new issue