Fixing issue #3, expanding unit tests.

master
Michael P. Soulier 2011-07-24 17:37:16 -04:00
parent 40977c6f74
commit 04aaa2ef9f
5 changed files with 109 additions and 38 deletions

View File

@ -4,6 +4,7 @@ import unittest
import logging import logging
import tftpy import tftpy
import os import os
import time
log = tftpy.log log = tftpy.log
@ -140,6 +141,72 @@ class TestTftpyState(unittest.TestCase):
def setUp(self): def setUp(self):
tftpy.setLogLevel(logging.DEBUG) tftpy.setLogLevel(logging.DEBUG)
def clientServerUploadOptions(self, options):
"""Fire up a client and a server and do an upload."""
root = '/tmp'
home = os.path.dirname(os.path.abspath(__file__))
filename = '100KBFILE'
input_path = os.path.join(home, filename)
server = tftpy.TftpServer(root)
client = tftpy.TftpClient('localhost',
20001,
options)
# Fork a server and run the client in this process.
child_pid = os.fork()
if child_pid:
# parent - let the server start
try:
time.sleep(1)
client.upload(filename,
input_path)
finally:
os.kill(child_pid, 15)
os.waitpid(child_pid, 0)
else:
server.listen('localhost', 20001)
def clientServerDownloadOptions(self, options):
"""Fire up a client and a server and do a download."""
root = os.path.dirname(os.path.abspath(__file__))
server = tftpy.TftpServer(root)
client = tftpy.TftpClient('localhost',
20001,
options)
# Fork a server and run the client in this process.
child_pid = os.fork()
if child_pid:
# parent - let the server start
try:
time.sleep(1)
client.download('100KBFILE',
'/tmp/out')
finally:
os.kill(child_pid, 15)
os.waitpid(child_pid, 0)
else:
server.listen('localhost', 20001)
def testClientServerNoOptions(self):
self.clientServerDownloadOptions({})
def testClientServerBlksize(self):
for blksize in [512, 1024, 2048, 4096]:
self.clientServerDownloadOptions({'blksize': blksize})
def testClientServerUploadNoOptions(self):
self.clientServerUploadOptions({})
def testClientServerUploadOptions(self):
for blksize in [512, 1024, 2048, 4096]:
self.clientServerUploadOptions({'blksize': blksize})
def testClientServerNoOptionsDelay(self):
tftpy.TftpStates.DELAY_BLOCK = 10
self.clientServerDownloadOptions({})
tftpy.TftpStates.DELAY_BLOCK = 0
def testServerNoOptions(self): def testServerNoOptions(self):
"""Test the server states.""" """Test the server states."""
raddress = '127.0.0.2' raddress = '127.0.0.2'

View File

@ -74,10 +74,10 @@ class TftpClient(TftpSession):
setting, which is the amount of time that the client will wait for a setting, which is the amount of time that the client will wait for a
DAT packet to be ACKd by the server. DAT packet to be ACKd by the server.
The input option is the full path to the file to upload, which can
optionally be '-' to read from stdin.
Note: If output is a hyphen then stdout is used.""" Note: If output is a hyphen then stdout is used."""
# Open the input file.
# FIXME: As of the state machine, this is now broken. Need to
# implement with new state machine.
self.context = TftpContextClientUpload(self.host, self.context = TftpContextClientUpload(self.host,
self.iport, self.iport,
filename, filename,

View File

@ -50,15 +50,15 @@ class TftpMetrics(object):
for key in self.dups: for key in self.dups:
self.dupcount += self.dups[key] self.dupcount += self.dups[key]
def add_dup(self, blocknumber): def add_dup(self, pkt):
"""This method adds a dup for a block number to the metrics.""" """This method adds a dup for a packet to the metrics."""
log.debug("Recording a dup for block %d" % blocknumber) log.debug("Recording a dup of %s" % pkt)
if self.dups.has_key(blocknumber): s = str(pkt)
self.dups[blocknumber] += 1 if self.dups.has_key(s):
self.dups[s] += 1
else: else:
self.dups[blocknumber] = 1 self.dups[s] = 1
tftpassert(self.dups[blocknumber] < MAX_DUPS, tftpassert(self.dups[s] < MAX_DUPS, "Max duplicates reached")
"Max duplicates for block %d reached" % blocknumber)
############################################################################### ###############################################################################
# Context classes # Context classes

View File

@ -11,6 +11,9 @@ MAX_DUPS = 20
TIMEOUT_RETRIES = 5 TIMEOUT_RETRIES = 5
DEF_TFTP_PORT = 69 DEF_TFTP_PORT = 69
# A hook for deliberately introducing delay in testing.
DELAY_BLOCK = 0
# Initialize the logger. # Initialize the logger.
logging.basicConfig() logging.basicConfig()
# The logger used by this library. Feel free to clobber it with your own, if you like, as # The logger used by this library. Feel free to clobber it with your own, if you like, as

View File

@ -124,30 +124,29 @@ class TftpState(object):
return sendoack return sendoack
def sendDAT(self, resend=False): def sendDAT(self):
"""This method sends the next DAT packet based on the data in the """This method sends the next DAT packet based on the data in the
context. It returns a boolean indicating whether the transfer is context. It returns a boolean indicating whether the transfer is
finished.""" finished."""
finished = False finished = False
blocknumber = self.context.next_block blocknumber = self.context.next_block
# Test hook
if DELAY_BLOCK and DELAY_BLOCK == blocknumber:
import time
log.debug("Deliberately delaying 10 seconds...")
time.sleep(10)
tftpassert( blocknumber > 0, "There is no block zero!" ) tftpassert( blocknumber > 0, "There is no block zero!" )
dat = None dat = None
if resend: blksize = self.context.getBlocksize()
log.warn("Resending block number %d" % blocknumber) buffer = self.context.fileobj.read(blksize)
dat = self.context.last_pkt log.debug("Read %d bytes into buffer" % len(buffer))
self.context.metrics.resent_bytes += len(dat.data) if len(buffer) < blksize:
self.context.metrics.add_dup(dat) log.info("Reached EOF on file %s"
else: % self.context.file_to_transfer)
blksize = self.context.getBlocksize() finished = True
buffer = self.context.fileobj.read(blksize) dat = TftpPacketDAT()
log.debug("Read %d bytes into buffer" % len(buffer)) dat.data = buffer
if len(buffer) < blksize: dat.blocknumber = blocknumber
log.info("Reached EOF on file %s"
% self.context.file_to_transfer)
finished = True
dat = TftpPacketDAT()
dat.data = buffer
dat.blocknumber = blocknumber
self.context.metrics.bytes += len(dat.data) self.context.metrics.bytes += len(dat.data)
log.debug("Sending DAT packet %d" % dat.blocknumber) log.debug("Sending DAT packet %d" % dat.blocknumber)
self.context.sock.sendto(dat.encode().buffer, self.context.sock.sendto(dat.encode().buffer,
@ -170,7 +169,7 @@ class TftpState(object):
self.context.sock.sendto(ackpkt.encode().buffer, self.context.sock.sendto(ackpkt.encode().buffer,
(self.context.host, (self.context.host,
self.context.tidport)) self.context.tidport))
self.last_pkt = ackpkt self.context.last_pkt = ackpkt
def sendError(self, errorcode): def sendError(self, errorcode):
"""This method uses the socket passed, and uses the errorcode to """This method uses the socket passed, and uses the errorcode to
@ -181,7 +180,7 @@ class TftpState(object):
self.context.sock.sendto(errpkt.encode().buffer, self.context.sock.sendto(errpkt.encode().buffer,
(self.context.host, (self.context.host,
self.context.tidport)) self.context.tidport))
self.last_pkt = errpkt self.context.last_pkt = errpkt
def sendOACK(self): def sendOACK(self):
"""This method sends an OACK packet with the options from the current """This method sends an OACK packet with the options from the current
@ -192,18 +191,18 @@ class TftpState(object):
self.context.sock.sendto(pkt.encode().buffer, self.context.sock.sendto(pkt.encode().buffer,
(self.context.host, (self.context.host,
self.context.tidport)) self.context.tidport))
self.last_pkt = pkt self.context.last_pkt = pkt
def resendLast(self): def resendLast(self):
"Resend the last sent packet due to a timeout." "Resend the last sent packet due to a timeout."
log.warn("Resending packet %s on sessions %s" log.warn("Resending packet %s on sessions %s"
% (self.last_pkt, self)) % (self.context.last_pkt, self))
self.context.metrics.resent_bytes += len(self.last_pkt.data) self.context.metrics.resent_bytes += len(self.context.last_pkt.buffer)
self.context.metrics.add_dup(self.last_pkt) self.context.metrics.add_dup(self.context.last_pkt)
self.context.sock.sendto(self.last_pkt.encode().buffer, self.context.sock.sendto(self.context.last_pkt.encode().buffer,
(self.context.host, self.context.tidport)) (self.context.host, self.context.tidport))
if self.context.packethook: if self.context.packethook:
self.context.packethook(self.last_pkt) self.context.packethook(self.context.last_pkt)
def handleDat(self, pkt): def handleDat(self, pkt):
"""This method handles a DAT packet during a client download, or a """This method handles a DAT packet during a client download, or a
@ -232,7 +231,7 @@ class TftpState(object):
self.sendError(TftpErrors.IllegalTftpOp) self.sendError(TftpErrors.IllegalTftpOp)
raise TftpException, "There is no block zero!" raise TftpException, "There is no block zero!"
log.warn("Dropping duplicate block %d" % pkt.blocknumber) log.warn("Dropping duplicate block %d" % pkt.blocknumber)
self.context.metrics.add_dup(pkt.blocknumber) self.context.metrics.add_dup(pkt)
log.debug("ACKing block %d again, just in case" % pkt.blocknumber) log.debug("ACKing block %d again, just in case" % pkt.blocknumber)
self.sendACK(pkt.blocknumber) self.sendACK(pkt.blocknumber)
@ -369,7 +368,9 @@ class TftpStateExpectACK(TftpState):
self.context.pending_complete = self.sendDAT() self.context.pending_complete = self.sendDAT()
elif pkt.blocknumber < self.context.next_block: elif pkt.blocknumber < self.context.next_block:
self.context.metrics.add_dup(pkt.blocknumber) log.debug("Received duplicate ACK for block %d"
% pkt.blocknumber)
self.context.metrics.add_dup(pkt)
else: else:
log.warn("Oooh, time warp. Received ACK to packet we " log.warn("Oooh, time warp. Received ACK to packet we "