diff --git a/lib/tftpy.py b/lib/tftpy.py index ca9b98a..1c9fcf1 100755 --- a/lib/tftpy.py +++ b/lib/tftpy.py @@ -11,13 +11,13 @@ verlist = sys.version_info if not verlist[0] >= 2 or not verlist[1] >= 4: raise AssertionError, "Requires at least Python 2.4" -# Change this as desired. FIXME - make this a command-line arg LOG_LEVEL = logging.NOTSET MIN_BLKSIZE = 8 DEF_BLKSIZE = 512 MAX_BLKSIZE = 65536 SOCK_TIMEOUT = 5 MAX_DUPS = 20 +TIMEOUT_RETRIES = 5 # Initialize the logger. logging.basicConfig( @@ -521,9 +521,15 @@ class TftpClient(TftpSession): host = property(gethost, sethost) - def download(self, filename, output, packethook=None): + def download(self, filename, output, packethook=None, timeout=SOCK_TIMEOUT): """This method initiates a tftp download from the configured remote - host, requesting the filename passed.""" + host, requesting the filename passed. It saves the file to a local + file specified in the output parameter. If a packethook is provided, + it must be a function that takes a single parameter, which will be a + copy of each DAT packet received in the form of a TftpPacketDAT + object. The timeout parameter may be used to override the default + SOCK_TIMEOUT setting, which is the amount of time that the client will + wait for a receive packet to arrive.""" # Open the output file. # FIXME - need to support alternate return formats than files? outputfile = open(output, "wb") @@ -535,7 +541,7 @@ class TftpClient(TftpSession): tftp_factory = TftpPacketFactory() sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) - sock.settimeout(SOCK_TIMEOUT) + sock.settimeout(timeout) logger.info("Sending tftp download request to %s" % self.host) logger.info(" filename -> %s" % filename) @@ -546,8 +552,18 @@ class TftpClient(TftpSession): sock.sendto(pkt.encode().buffer, (self.host, self.port)) self.state.state = 'rrq' + timeouts = 0 while True: - (buffer, (raddress, rport)) = sock.recvfrom(MAX_BLKSIZE) + try: + (buffer, (raddress, rport)) = sock.recvfrom(MAX_BLKSIZE) + except socket.timeout, err: + timeouts += 1 + if timeouts >= TIMEOUT_RETRIES: + raise TftpException, "Hit max timeouts, giving up." + else: + logger.warn("Timeout waiting for traffic, retrying...") + continue + recvpkt = tftp_factory.parse(buffer) logger.debug("Received %d bytes from %s:%s"