Fixing issue #3, expanding unit tests.

master
Michael P. Soulier 2011-07-24 17:37:16 -04:00
parent 40977c6f74
commit 04aaa2ef9f
5 changed files with 109 additions and 38 deletions

View File

@ -4,6 +4,7 @@ import unittest
import logging
import tftpy
import os
import time
log = tftpy.log
@ -140,6 +141,72 @@ class TestTftpyState(unittest.TestCase):
def setUp(self):
tftpy.setLogLevel(logging.DEBUG)
def clientServerUploadOptions(self, options):
"""Fire up a client and a server and do an upload."""
root = '/tmp'
home = os.path.dirname(os.path.abspath(__file__))
filename = '100KBFILE'
input_path = os.path.join(home, filename)
server = tftpy.TftpServer(root)
client = tftpy.TftpClient('localhost',
20001,
options)
# Fork a server and run the client in this process.
child_pid = os.fork()
if child_pid:
# parent - let the server start
try:
time.sleep(1)
client.upload(filename,
input_path)
finally:
os.kill(child_pid, 15)
os.waitpid(child_pid, 0)
else:
server.listen('localhost', 20001)
def clientServerDownloadOptions(self, options):
"""Fire up a client and a server and do a download."""
root = os.path.dirname(os.path.abspath(__file__))
server = tftpy.TftpServer(root)
client = tftpy.TftpClient('localhost',
20001,
options)
# Fork a server and run the client in this process.
child_pid = os.fork()
if child_pid:
# parent - let the server start
try:
time.sleep(1)
client.download('100KBFILE',
'/tmp/out')
finally:
os.kill(child_pid, 15)
os.waitpid(child_pid, 0)
else:
server.listen('localhost', 20001)
def testClientServerNoOptions(self):
self.clientServerDownloadOptions({})
def testClientServerBlksize(self):
for blksize in [512, 1024, 2048, 4096]:
self.clientServerDownloadOptions({'blksize': blksize})
def testClientServerUploadNoOptions(self):
self.clientServerUploadOptions({})
def testClientServerUploadOptions(self):
for blksize in [512, 1024, 2048, 4096]:
self.clientServerUploadOptions({'blksize': blksize})
def testClientServerNoOptionsDelay(self):
tftpy.TftpStates.DELAY_BLOCK = 10
self.clientServerDownloadOptions({})
tftpy.TftpStates.DELAY_BLOCK = 0
def testServerNoOptions(self):
"""Test the server states."""
raddress = '127.0.0.2'

View File

@ -74,10 +74,10 @@ class TftpClient(TftpSession):
setting, which is the amount of time that the client will wait for a
DAT packet to be ACKd by the server.
The input option is the full path to the file to upload, which can
optionally be '-' to read from stdin.
Note: If output is a hyphen then stdout is used."""
# Open the input file.
# FIXME: As of the state machine, this is now broken. Need to
# implement with new state machine.
self.context = TftpContextClientUpload(self.host,
self.iport,
filename,

View File

@ -50,15 +50,15 @@ class TftpMetrics(object):
for key in self.dups:
self.dupcount += self.dups[key]
def add_dup(self, blocknumber):
"""This method adds a dup for a block number to the metrics."""
log.debug("Recording a dup for block %d" % blocknumber)
if self.dups.has_key(blocknumber):
self.dups[blocknumber] += 1
def add_dup(self, pkt):
"""This method adds a dup for a packet to the metrics."""
log.debug("Recording a dup of %s" % pkt)
s = str(pkt)
if self.dups.has_key(s):
self.dups[s] += 1
else:
self.dups[blocknumber] = 1
tftpassert(self.dups[blocknumber] < MAX_DUPS,
"Max duplicates for block %d reached" % blocknumber)
self.dups[s] = 1
tftpassert(self.dups[s] < MAX_DUPS, "Max duplicates reached")
###############################################################################
# Context classes

View File

@ -11,6 +11,9 @@ MAX_DUPS = 20
TIMEOUT_RETRIES = 5
DEF_TFTP_PORT = 69
# A hook for deliberately introducing delay in testing.
DELAY_BLOCK = 0
# Initialize the logger.
logging.basicConfig()
# The logger used by this library. Feel free to clobber it with your own, if you like, as

View File

@ -124,30 +124,29 @@ class TftpState(object):
return sendoack
def sendDAT(self, resend=False):
def sendDAT(self):
"""This method sends the next DAT packet based on the data in the
context. It returns a boolean indicating whether the transfer is
finished."""
finished = False
blocknumber = self.context.next_block
# Test hook
if DELAY_BLOCK and DELAY_BLOCK == blocknumber:
import time
log.debug("Deliberately delaying 10 seconds...")
time.sleep(10)
tftpassert( blocknumber > 0, "There is no block zero!" )
dat = None
if resend:
log.warn("Resending block number %d" % blocknumber)
dat = self.context.last_pkt
self.context.metrics.resent_bytes += len(dat.data)
self.context.metrics.add_dup(dat)
else:
blksize = self.context.getBlocksize()
buffer = self.context.fileobj.read(blksize)
log.debug("Read %d bytes into buffer" % len(buffer))
if len(buffer) < blksize:
log.info("Reached EOF on file %s"
% self.context.file_to_transfer)
finished = True
dat = TftpPacketDAT()
dat.data = buffer
dat.blocknumber = blocknumber
blksize = self.context.getBlocksize()
buffer = self.context.fileobj.read(blksize)
log.debug("Read %d bytes into buffer" % len(buffer))
if len(buffer) < blksize:
log.info("Reached EOF on file %s"
% self.context.file_to_transfer)
finished = True
dat = TftpPacketDAT()
dat.data = buffer
dat.blocknumber = blocknumber
self.context.metrics.bytes += len(dat.data)
log.debug("Sending DAT packet %d" % dat.blocknumber)
self.context.sock.sendto(dat.encode().buffer,
@ -170,7 +169,7 @@ class TftpState(object):
self.context.sock.sendto(ackpkt.encode().buffer,
(self.context.host,
self.context.tidport))
self.last_pkt = ackpkt
self.context.last_pkt = ackpkt
def sendError(self, errorcode):
"""This method uses the socket passed, and uses the errorcode to
@ -181,7 +180,7 @@ class TftpState(object):
self.context.sock.sendto(errpkt.encode().buffer,
(self.context.host,
self.context.tidport))
self.last_pkt = errpkt
self.context.last_pkt = errpkt
def sendOACK(self):
"""This method sends an OACK packet with the options from the current
@ -192,18 +191,18 @@ class TftpState(object):
self.context.sock.sendto(pkt.encode().buffer,
(self.context.host,
self.context.tidport))
self.last_pkt = pkt
self.context.last_pkt = pkt
def resendLast(self):
"Resend the last sent packet due to a timeout."
log.warn("Resending packet %s on sessions %s"
% (self.last_pkt, self))
self.context.metrics.resent_bytes += len(self.last_pkt.data)
self.context.metrics.add_dup(self.last_pkt)
self.context.sock.sendto(self.last_pkt.encode().buffer,
% (self.context.last_pkt, self))
self.context.metrics.resent_bytes += len(self.context.last_pkt.buffer)
self.context.metrics.add_dup(self.context.last_pkt)
self.context.sock.sendto(self.context.last_pkt.encode().buffer,
(self.context.host, self.context.tidport))
if self.context.packethook:
self.context.packethook(self.last_pkt)
self.context.packethook(self.context.last_pkt)
def handleDat(self, pkt):
"""This method handles a DAT packet during a client download, or a
@ -232,7 +231,7 @@ class TftpState(object):
self.sendError(TftpErrors.IllegalTftpOp)
raise TftpException, "There is no block zero!"
log.warn("Dropping duplicate block %d" % pkt.blocknumber)
self.context.metrics.add_dup(pkt.blocknumber)
self.context.metrics.add_dup(pkt)
log.debug("ACKing block %d again, just in case" % pkt.blocknumber)
self.sendACK(pkt.blocknumber)
@ -369,7 +368,9 @@ class TftpStateExpectACK(TftpState):
self.context.pending_complete = self.sendDAT()
elif pkt.blocknumber < self.context.next_block:
self.context.metrics.add_dup(pkt.blocknumber)
log.debug("Received duplicate ACK for block %d"
% pkt.blocknumber)
self.context.metrics.add_dup(pkt)
else:
log.warn("Oooh, time warp. Received ACK to packet we "