Fixing issue #3, expanding unit tests.
This commit is contained in:
parent
40977c6f74
commit
04aaa2ef9f
5 changed files with 109 additions and 38 deletions
67
t/test.py
67
t/test.py
|
@ -4,6 +4,7 @@ import unittest
|
|||
import logging
|
||||
import tftpy
|
||||
import os
|
||||
import time
|
||||
|
||||
log = tftpy.log
|
||||
|
||||
|
@ -140,6 +141,72 @@ class TestTftpyState(unittest.TestCase):
|
|||
def setUp(self):
|
||||
tftpy.setLogLevel(logging.DEBUG)
|
||||
|
||||
def clientServerUploadOptions(self, options):
|
||||
"""Fire up a client and a server and do an upload."""
|
||||
root = '/tmp'
|
||||
home = os.path.dirname(os.path.abspath(__file__))
|
||||
filename = '100KBFILE'
|
||||
input_path = os.path.join(home, filename)
|
||||
server = tftpy.TftpServer(root)
|
||||
client = tftpy.TftpClient('localhost',
|
||||
20001,
|
||||
options)
|
||||
# Fork a server and run the client in this process.
|
||||
child_pid = os.fork()
|
||||
if child_pid:
|
||||
# parent - let the server start
|
||||
try:
|
||||
time.sleep(1)
|
||||
client.upload(filename,
|
||||
input_path)
|
||||
finally:
|
||||
os.kill(child_pid, 15)
|
||||
os.waitpid(child_pid, 0)
|
||||
|
||||
else:
|
||||
server.listen('localhost', 20001)
|
||||
|
||||
def clientServerDownloadOptions(self, options):
|
||||
"""Fire up a client and a server and do a download."""
|
||||
root = os.path.dirname(os.path.abspath(__file__))
|
||||
server = tftpy.TftpServer(root)
|
||||
client = tftpy.TftpClient('localhost',
|
||||
20001,
|
||||
options)
|
||||
# Fork a server and run the client in this process.
|
||||
child_pid = os.fork()
|
||||
if child_pid:
|
||||
# parent - let the server start
|
||||
try:
|
||||
time.sleep(1)
|
||||
client.download('100KBFILE',
|
||||
'/tmp/out')
|
||||
finally:
|
||||
os.kill(child_pid, 15)
|
||||
os.waitpid(child_pid, 0)
|
||||
|
||||
else:
|
||||
server.listen('localhost', 20001)
|
||||
|
||||
def testClientServerNoOptions(self):
|
||||
self.clientServerDownloadOptions({})
|
||||
|
||||
def testClientServerBlksize(self):
|
||||
for blksize in [512, 1024, 2048, 4096]:
|
||||
self.clientServerDownloadOptions({'blksize': blksize})
|
||||
|
||||
def testClientServerUploadNoOptions(self):
|
||||
self.clientServerUploadOptions({})
|
||||
|
||||
def testClientServerUploadOptions(self):
|
||||
for blksize in [512, 1024, 2048, 4096]:
|
||||
self.clientServerUploadOptions({'blksize': blksize})
|
||||
|
||||
def testClientServerNoOptionsDelay(self):
|
||||
tftpy.TftpStates.DELAY_BLOCK = 10
|
||||
self.clientServerDownloadOptions({})
|
||||
tftpy.TftpStates.DELAY_BLOCK = 0
|
||||
|
||||
def testServerNoOptions(self):
|
||||
"""Test the server states."""
|
||||
raddress = '127.0.0.2'
|
||||
|
|
|
@ -74,10 +74,10 @@ class TftpClient(TftpSession):
|
|||
setting, which is the amount of time that the client will wait for a
|
||||
DAT packet to be ACKd by the server.
|
||||
|
||||
The input option is the full path to the file to upload, which can
|
||||
optionally be '-' to read from stdin.
|
||||
|
||||
Note: If output is a hyphen then stdout is used."""
|
||||
# Open the input file.
|
||||
# FIXME: As of the state machine, this is now broken. Need to
|
||||
# implement with new state machine.
|
||||
self.context = TftpContextClientUpload(self.host,
|
||||
self.iport,
|
||||
filename,
|
||||
|
|
|
@ -50,15 +50,15 @@ class TftpMetrics(object):
|
|||
for key in self.dups:
|
||||
self.dupcount += self.dups[key]
|
||||
|
||||
def add_dup(self, blocknumber):
|
||||
"""This method adds a dup for a block number to the metrics."""
|
||||
log.debug("Recording a dup for block %d" % blocknumber)
|
||||
if self.dups.has_key(blocknumber):
|
||||
self.dups[blocknumber] += 1
|
||||
def add_dup(self, pkt):
|
||||
"""This method adds a dup for a packet to the metrics."""
|
||||
log.debug("Recording a dup of %s" % pkt)
|
||||
s = str(pkt)
|
||||
if self.dups.has_key(s):
|
||||
self.dups[s] += 1
|
||||
else:
|
||||
self.dups[blocknumber] = 1
|
||||
tftpassert(self.dups[blocknumber] < MAX_DUPS,
|
||||
"Max duplicates for block %d reached" % blocknumber)
|
||||
self.dups[s] = 1
|
||||
tftpassert(self.dups[s] < MAX_DUPS, "Max duplicates reached")
|
||||
|
||||
###############################################################################
|
||||
# Context classes
|
||||
|
|
|
@ -11,6 +11,9 @@ MAX_DUPS = 20
|
|||
TIMEOUT_RETRIES = 5
|
||||
DEF_TFTP_PORT = 69
|
||||
|
||||
# A hook for deliberately introducing delay in testing.
|
||||
DELAY_BLOCK = 0
|
||||
|
||||
# Initialize the logger.
|
||||
logging.basicConfig()
|
||||
# The logger used by this library. Feel free to clobber it with your own, if you like, as
|
||||
|
|
|
@ -124,30 +124,29 @@ class TftpState(object):
|
|||
|
||||
return sendoack
|
||||
|
||||
def sendDAT(self, resend=False):
|
||||
def sendDAT(self):
|
||||
"""This method sends the next DAT packet based on the data in the
|
||||
context. It returns a boolean indicating whether the transfer is
|
||||
finished."""
|
||||
finished = False
|
||||
blocknumber = self.context.next_block
|
||||
# Test hook
|
||||
if DELAY_BLOCK and DELAY_BLOCK == blocknumber:
|
||||
import time
|
||||
log.debug("Deliberately delaying 10 seconds...")
|
||||
time.sleep(10)
|
||||
tftpassert( blocknumber > 0, "There is no block zero!" )
|
||||
dat = None
|
||||
if resend:
|
||||
log.warn("Resending block number %d" % blocknumber)
|
||||
dat = self.context.last_pkt
|
||||
self.context.metrics.resent_bytes += len(dat.data)
|
||||
self.context.metrics.add_dup(dat)
|
||||
else:
|
||||
blksize = self.context.getBlocksize()
|
||||
buffer = self.context.fileobj.read(blksize)
|
||||
log.debug("Read %d bytes into buffer" % len(buffer))
|
||||
if len(buffer) < blksize:
|
||||
log.info("Reached EOF on file %s"
|
||||
% self.context.file_to_transfer)
|
||||
finished = True
|
||||
dat = TftpPacketDAT()
|
||||
dat.data = buffer
|
||||
dat.blocknumber = blocknumber
|
||||
blksize = self.context.getBlocksize()
|
||||
buffer = self.context.fileobj.read(blksize)
|
||||
log.debug("Read %d bytes into buffer" % len(buffer))
|
||||
if len(buffer) < blksize:
|
||||
log.info("Reached EOF on file %s"
|
||||
% self.context.file_to_transfer)
|
||||
finished = True
|
||||
dat = TftpPacketDAT()
|
||||
dat.data = buffer
|
||||
dat.blocknumber = blocknumber
|
||||
self.context.metrics.bytes += len(dat.data)
|
||||
log.debug("Sending DAT packet %d" % dat.blocknumber)
|
||||
self.context.sock.sendto(dat.encode().buffer,
|
||||
|
@ -170,7 +169,7 @@ class TftpState(object):
|
|||
self.context.sock.sendto(ackpkt.encode().buffer,
|
||||
(self.context.host,
|
||||
self.context.tidport))
|
||||
self.last_pkt = ackpkt
|
||||
self.context.last_pkt = ackpkt
|
||||
|
||||
def sendError(self, errorcode):
|
||||
"""This method uses the socket passed, and uses the errorcode to
|
||||
|
@ -181,7 +180,7 @@ class TftpState(object):
|
|||
self.context.sock.sendto(errpkt.encode().buffer,
|
||||
(self.context.host,
|
||||
self.context.tidport))
|
||||
self.last_pkt = errpkt
|
||||
self.context.last_pkt = errpkt
|
||||
|
||||
def sendOACK(self):
|
||||
"""This method sends an OACK packet with the options from the current
|
||||
|
@ -192,18 +191,18 @@ class TftpState(object):
|
|||
self.context.sock.sendto(pkt.encode().buffer,
|
||||
(self.context.host,
|
||||
self.context.tidport))
|
||||
self.last_pkt = pkt
|
||||
self.context.last_pkt = pkt
|
||||
|
||||
def resendLast(self):
|
||||
"Resend the last sent packet due to a timeout."
|
||||
log.warn("Resending packet %s on sessions %s"
|
||||
% (self.last_pkt, self))
|
||||
self.context.metrics.resent_bytes += len(self.last_pkt.data)
|
||||
self.context.metrics.add_dup(self.last_pkt)
|
||||
self.context.sock.sendto(self.last_pkt.encode().buffer,
|
||||
% (self.context.last_pkt, self))
|
||||
self.context.metrics.resent_bytes += len(self.context.last_pkt.buffer)
|
||||
self.context.metrics.add_dup(self.context.last_pkt)
|
||||
self.context.sock.sendto(self.context.last_pkt.encode().buffer,
|
||||
(self.context.host, self.context.tidport))
|
||||
if self.context.packethook:
|
||||
self.context.packethook(self.last_pkt)
|
||||
self.context.packethook(self.context.last_pkt)
|
||||
|
||||
def handleDat(self, pkt):
|
||||
"""This method handles a DAT packet during a client download, or a
|
||||
|
@ -232,7 +231,7 @@ class TftpState(object):
|
|||
self.sendError(TftpErrors.IllegalTftpOp)
|
||||
raise TftpException, "There is no block zero!"
|
||||
log.warn("Dropping duplicate block %d" % pkt.blocknumber)
|
||||
self.context.metrics.add_dup(pkt.blocknumber)
|
||||
self.context.metrics.add_dup(pkt)
|
||||
log.debug("ACKing block %d again, just in case" % pkt.blocknumber)
|
||||
self.sendACK(pkt.blocknumber)
|
||||
|
||||
|
@ -369,7 +368,9 @@ class TftpStateExpectACK(TftpState):
|
|||
self.context.pending_complete = self.sendDAT()
|
||||
|
||||
elif pkt.blocknumber < self.context.next_block:
|
||||
self.context.metrics.add_dup(pkt.blocknumber)
|
||||
log.debug("Received duplicate ACK for block %d"
|
||||
% pkt.blocknumber)
|
||||
self.context.metrics.add_dup(pkt)
|
||||
|
||||
else:
|
||||
log.warn("Oooh, time warp. Received ACK to packet we "
|
||||
|
|
Reference in a new issue