diff --git a/t/test.py b/t/test.py index 1052db1..d044ac8 100644 --- a/t/test.py +++ b/t/test.py @@ -4,6 +4,7 @@ import unittest import logging import tftpy import os +import time log = tftpy.log @@ -140,6 +141,72 @@ class TestTftpyState(unittest.TestCase): def setUp(self): tftpy.setLogLevel(logging.DEBUG) + def clientServerUploadOptions(self, options): + """Fire up a client and a server and do an upload.""" + root = '/tmp' + home = os.path.dirname(os.path.abspath(__file__)) + filename = '100KBFILE' + input_path = os.path.join(home, filename) + server = tftpy.TftpServer(root) + client = tftpy.TftpClient('localhost', + 20001, + options) + # Fork a server and run the client in this process. + child_pid = os.fork() + if child_pid: + # parent - let the server start + try: + time.sleep(1) + client.upload(filename, + input_path) + finally: + os.kill(child_pid, 15) + os.waitpid(child_pid, 0) + + else: + server.listen('localhost', 20001) + + def clientServerDownloadOptions(self, options): + """Fire up a client and a server and do a download.""" + root = os.path.dirname(os.path.abspath(__file__)) + server = tftpy.TftpServer(root) + client = tftpy.TftpClient('localhost', + 20001, + options) + # Fork a server and run the client in this process. + child_pid = os.fork() + if child_pid: + # parent - let the server start + try: + time.sleep(1) + client.download('100KBFILE', + '/tmp/out') + finally: + os.kill(child_pid, 15) + os.waitpid(child_pid, 0) + + else: + server.listen('localhost', 20001) + + def testClientServerNoOptions(self): + self.clientServerDownloadOptions({}) + + def testClientServerBlksize(self): + for blksize in [512, 1024, 2048, 4096]: + self.clientServerDownloadOptions({'blksize': blksize}) + + def testClientServerUploadNoOptions(self): + self.clientServerUploadOptions({}) + + def testClientServerUploadOptions(self): + for blksize in [512, 1024, 2048, 4096]: + self.clientServerUploadOptions({'blksize': blksize}) + + def testClientServerNoOptionsDelay(self): + tftpy.TftpStates.DELAY_BLOCK = 10 + self.clientServerDownloadOptions({}) + tftpy.TftpStates.DELAY_BLOCK = 0 + def testServerNoOptions(self): """Test the server states.""" raddress = '127.0.0.2' diff --git a/tftpy/TftpClient.py b/tftpy/TftpClient.py index 16753f1..e9e46e7 100644 --- a/tftpy/TftpClient.py +++ b/tftpy/TftpClient.py @@ -74,10 +74,10 @@ class TftpClient(TftpSession): setting, which is the amount of time that the client will wait for a DAT packet to be ACKd by the server. + The input option is the full path to the file to upload, which can + optionally be '-' to read from stdin. + Note: If output is a hyphen then stdout is used.""" - # Open the input file. - # FIXME: As of the state machine, this is now broken. Need to - # implement with new state machine. self.context = TftpContextClientUpload(self.host, self.iport, filename, diff --git a/tftpy/TftpContexts.py b/tftpy/TftpContexts.py index a76b686..c3a1bd4 100644 --- a/tftpy/TftpContexts.py +++ b/tftpy/TftpContexts.py @@ -50,15 +50,15 @@ class TftpMetrics(object): for key in self.dups: self.dupcount += self.dups[key] - def add_dup(self, blocknumber): - """This method adds a dup for a block number to the metrics.""" - log.debug("Recording a dup for block %d" % blocknumber) - if self.dups.has_key(blocknumber): - self.dups[blocknumber] += 1 + def add_dup(self, pkt): + """This method adds a dup for a packet to the metrics.""" + log.debug("Recording a dup of %s" % pkt) + s = str(pkt) + if self.dups.has_key(s): + self.dups[s] += 1 else: - self.dups[blocknumber] = 1 - tftpassert(self.dups[blocknumber] < MAX_DUPS, - "Max duplicates for block %d reached" % blocknumber) + self.dups[s] = 1 + tftpassert(self.dups[s] < MAX_DUPS, "Max duplicates reached") ############################################################################### # Context classes diff --git a/tftpy/TftpShared.py b/tftpy/TftpShared.py index 1039ed2..d09d8bd 100644 --- a/tftpy/TftpShared.py +++ b/tftpy/TftpShared.py @@ -11,6 +11,9 @@ MAX_DUPS = 20 TIMEOUT_RETRIES = 5 DEF_TFTP_PORT = 69 +# A hook for deliberately introducing delay in testing. +DELAY_BLOCK = 0 + # Initialize the logger. logging.basicConfig() # The logger used by this library. Feel free to clobber it with your own, if you like, as diff --git a/tftpy/TftpStates.py b/tftpy/TftpStates.py index 716220a..c106220 100644 --- a/tftpy/TftpStates.py +++ b/tftpy/TftpStates.py @@ -124,30 +124,29 @@ class TftpState(object): return sendoack - def sendDAT(self, resend=False): + def sendDAT(self): """This method sends the next DAT packet based on the data in the context. It returns a boolean indicating whether the transfer is finished.""" finished = False blocknumber = self.context.next_block + # Test hook + if DELAY_BLOCK and DELAY_BLOCK == blocknumber: + import time + log.debug("Deliberately delaying 10 seconds...") + time.sleep(10) tftpassert( blocknumber > 0, "There is no block zero!" ) dat = None - if resend: - log.warn("Resending block number %d" % blocknumber) - dat = self.context.last_pkt - self.context.metrics.resent_bytes += len(dat.data) - self.context.metrics.add_dup(dat) - else: - blksize = self.context.getBlocksize() - buffer = self.context.fileobj.read(blksize) - log.debug("Read %d bytes into buffer" % len(buffer)) - if len(buffer) < blksize: - log.info("Reached EOF on file %s" - % self.context.file_to_transfer) - finished = True - dat = TftpPacketDAT() - dat.data = buffer - dat.blocknumber = blocknumber + blksize = self.context.getBlocksize() + buffer = self.context.fileobj.read(blksize) + log.debug("Read %d bytes into buffer" % len(buffer)) + if len(buffer) < blksize: + log.info("Reached EOF on file %s" + % self.context.file_to_transfer) + finished = True + dat = TftpPacketDAT() + dat.data = buffer + dat.blocknumber = blocknumber self.context.metrics.bytes += len(dat.data) log.debug("Sending DAT packet %d" % dat.blocknumber) self.context.sock.sendto(dat.encode().buffer, @@ -170,7 +169,7 @@ class TftpState(object): self.context.sock.sendto(ackpkt.encode().buffer, (self.context.host, self.context.tidport)) - self.last_pkt = ackpkt + self.context.last_pkt = ackpkt def sendError(self, errorcode): """This method uses the socket passed, and uses the errorcode to @@ -181,7 +180,7 @@ class TftpState(object): self.context.sock.sendto(errpkt.encode().buffer, (self.context.host, self.context.tidport)) - self.last_pkt = errpkt + self.context.last_pkt = errpkt def sendOACK(self): """This method sends an OACK packet with the options from the current @@ -192,18 +191,18 @@ class TftpState(object): self.context.sock.sendto(pkt.encode().buffer, (self.context.host, self.context.tidport)) - self.last_pkt = pkt + self.context.last_pkt = pkt def resendLast(self): "Resend the last sent packet due to a timeout." log.warn("Resending packet %s on sessions %s" - % (self.last_pkt, self)) - self.context.metrics.resent_bytes += len(self.last_pkt.data) - self.context.metrics.add_dup(self.last_pkt) - self.context.sock.sendto(self.last_pkt.encode().buffer, + % (self.context.last_pkt, self)) + self.context.metrics.resent_bytes += len(self.context.last_pkt.buffer) + self.context.metrics.add_dup(self.context.last_pkt) + self.context.sock.sendto(self.context.last_pkt.encode().buffer, (self.context.host, self.context.tidport)) if self.context.packethook: - self.context.packethook(self.last_pkt) + self.context.packethook(self.context.last_pkt) def handleDat(self, pkt): """This method handles a DAT packet during a client download, or a @@ -232,7 +231,7 @@ class TftpState(object): self.sendError(TftpErrors.IllegalTftpOp) raise TftpException, "There is no block zero!" log.warn("Dropping duplicate block %d" % pkt.blocknumber) - self.context.metrics.add_dup(pkt.blocknumber) + self.context.metrics.add_dup(pkt) log.debug("ACKing block %d again, just in case" % pkt.blocknumber) self.sendACK(pkt.blocknumber) @@ -369,7 +368,9 @@ class TftpStateExpectACK(TftpState): self.context.pending_complete = self.sendDAT() elif pkt.blocknumber < self.context.next_block: - self.context.metrics.add_dup(pkt.blocknumber) + log.debug("Received duplicate ACK for block %d" + % pkt.blocknumber) + self.context.metrics.add_dup(pkt) else: log.warn("Oooh, time warp. Received ACK to packet we "