Fixing issue #3, expanding unit tests.
parent
40977c6f74
commit
04aaa2ef9f
67
t/test.py
67
t/test.py
|
@ -4,6 +4,7 @@ import unittest
|
||||||
import logging
|
import logging
|
||||||
import tftpy
|
import tftpy
|
||||||
import os
|
import os
|
||||||
|
import time
|
||||||
|
|
||||||
log = tftpy.log
|
log = tftpy.log
|
||||||
|
|
||||||
|
@ -140,6 +141,72 @@ class TestTftpyState(unittest.TestCase):
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
tftpy.setLogLevel(logging.DEBUG)
|
tftpy.setLogLevel(logging.DEBUG)
|
||||||
|
|
||||||
|
def clientServerUploadOptions(self, options):
|
||||||
|
"""Fire up a client and a server and do an upload."""
|
||||||
|
root = '/tmp'
|
||||||
|
home = os.path.dirname(os.path.abspath(__file__))
|
||||||
|
filename = '100KBFILE'
|
||||||
|
input_path = os.path.join(home, filename)
|
||||||
|
server = tftpy.TftpServer(root)
|
||||||
|
client = tftpy.TftpClient('localhost',
|
||||||
|
20001,
|
||||||
|
options)
|
||||||
|
# Fork a server and run the client in this process.
|
||||||
|
child_pid = os.fork()
|
||||||
|
if child_pid:
|
||||||
|
# parent - let the server start
|
||||||
|
try:
|
||||||
|
time.sleep(1)
|
||||||
|
client.upload(filename,
|
||||||
|
input_path)
|
||||||
|
finally:
|
||||||
|
os.kill(child_pid, 15)
|
||||||
|
os.waitpid(child_pid, 0)
|
||||||
|
|
||||||
|
else:
|
||||||
|
server.listen('localhost', 20001)
|
||||||
|
|
||||||
|
def clientServerDownloadOptions(self, options):
|
||||||
|
"""Fire up a client and a server and do a download."""
|
||||||
|
root = os.path.dirname(os.path.abspath(__file__))
|
||||||
|
server = tftpy.TftpServer(root)
|
||||||
|
client = tftpy.TftpClient('localhost',
|
||||||
|
20001,
|
||||||
|
options)
|
||||||
|
# Fork a server and run the client in this process.
|
||||||
|
child_pid = os.fork()
|
||||||
|
if child_pid:
|
||||||
|
# parent - let the server start
|
||||||
|
try:
|
||||||
|
time.sleep(1)
|
||||||
|
client.download('100KBFILE',
|
||||||
|
'/tmp/out')
|
||||||
|
finally:
|
||||||
|
os.kill(child_pid, 15)
|
||||||
|
os.waitpid(child_pid, 0)
|
||||||
|
|
||||||
|
else:
|
||||||
|
server.listen('localhost', 20001)
|
||||||
|
|
||||||
|
def testClientServerNoOptions(self):
|
||||||
|
self.clientServerDownloadOptions({})
|
||||||
|
|
||||||
|
def testClientServerBlksize(self):
|
||||||
|
for blksize in [512, 1024, 2048, 4096]:
|
||||||
|
self.clientServerDownloadOptions({'blksize': blksize})
|
||||||
|
|
||||||
|
def testClientServerUploadNoOptions(self):
|
||||||
|
self.clientServerUploadOptions({})
|
||||||
|
|
||||||
|
def testClientServerUploadOptions(self):
|
||||||
|
for blksize in [512, 1024, 2048, 4096]:
|
||||||
|
self.clientServerUploadOptions({'blksize': blksize})
|
||||||
|
|
||||||
|
def testClientServerNoOptionsDelay(self):
|
||||||
|
tftpy.TftpStates.DELAY_BLOCK = 10
|
||||||
|
self.clientServerDownloadOptions({})
|
||||||
|
tftpy.TftpStates.DELAY_BLOCK = 0
|
||||||
|
|
||||||
def testServerNoOptions(self):
|
def testServerNoOptions(self):
|
||||||
"""Test the server states."""
|
"""Test the server states."""
|
||||||
raddress = '127.0.0.2'
|
raddress = '127.0.0.2'
|
||||||
|
|
|
@ -74,10 +74,10 @@ class TftpClient(TftpSession):
|
||||||
setting, which is the amount of time that the client will wait for a
|
setting, which is the amount of time that the client will wait for a
|
||||||
DAT packet to be ACKd by the server.
|
DAT packet to be ACKd by the server.
|
||||||
|
|
||||||
|
The input option is the full path to the file to upload, which can
|
||||||
|
optionally be '-' to read from stdin.
|
||||||
|
|
||||||
Note: If output is a hyphen then stdout is used."""
|
Note: If output is a hyphen then stdout is used."""
|
||||||
# Open the input file.
|
|
||||||
# FIXME: As of the state machine, this is now broken. Need to
|
|
||||||
# implement with new state machine.
|
|
||||||
self.context = TftpContextClientUpload(self.host,
|
self.context = TftpContextClientUpload(self.host,
|
||||||
self.iport,
|
self.iport,
|
||||||
filename,
|
filename,
|
||||||
|
|
|
@ -50,15 +50,15 @@ class TftpMetrics(object):
|
||||||
for key in self.dups:
|
for key in self.dups:
|
||||||
self.dupcount += self.dups[key]
|
self.dupcount += self.dups[key]
|
||||||
|
|
||||||
def add_dup(self, blocknumber):
|
def add_dup(self, pkt):
|
||||||
"""This method adds a dup for a block number to the metrics."""
|
"""This method adds a dup for a packet to the metrics."""
|
||||||
log.debug("Recording a dup for block %d" % blocknumber)
|
log.debug("Recording a dup of %s" % pkt)
|
||||||
if self.dups.has_key(blocknumber):
|
s = str(pkt)
|
||||||
self.dups[blocknumber] += 1
|
if self.dups.has_key(s):
|
||||||
|
self.dups[s] += 1
|
||||||
else:
|
else:
|
||||||
self.dups[blocknumber] = 1
|
self.dups[s] = 1
|
||||||
tftpassert(self.dups[blocknumber] < MAX_DUPS,
|
tftpassert(self.dups[s] < MAX_DUPS, "Max duplicates reached")
|
||||||
"Max duplicates for block %d reached" % blocknumber)
|
|
||||||
|
|
||||||
###############################################################################
|
###############################################################################
|
||||||
# Context classes
|
# Context classes
|
||||||
|
|
|
@ -11,6 +11,9 @@ MAX_DUPS = 20
|
||||||
TIMEOUT_RETRIES = 5
|
TIMEOUT_RETRIES = 5
|
||||||
DEF_TFTP_PORT = 69
|
DEF_TFTP_PORT = 69
|
||||||
|
|
||||||
|
# A hook for deliberately introducing delay in testing.
|
||||||
|
DELAY_BLOCK = 0
|
||||||
|
|
||||||
# Initialize the logger.
|
# Initialize the logger.
|
||||||
logging.basicConfig()
|
logging.basicConfig()
|
||||||
# The logger used by this library. Feel free to clobber it with your own, if you like, as
|
# The logger used by this library. Feel free to clobber it with your own, if you like, as
|
||||||
|
|
|
@ -124,30 +124,29 @@ class TftpState(object):
|
||||||
|
|
||||||
return sendoack
|
return sendoack
|
||||||
|
|
||||||
def sendDAT(self, resend=False):
|
def sendDAT(self):
|
||||||
"""This method sends the next DAT packet based on the data in the
|
"""This method sends the next DAT packet based on the data in the
|
||||||
context. It returns a boolean indicating whether the transfer is
|
context. It returns a boolean indicating whether the transfer is
|
||||||
finished."""
|
finished."""
|
||||||
finished = False
|
finished = False
|
||||||
blocknumber = self.context.next_block
|
blocknumber = self.context.next_block
|
||||||
|
# Test hook
|
||||||
|
if DELAY_BLOCK and DELAY_BLOCK == blocknumber:
|
||||||
|
import time
|
||||||
|
log.debug("Deliberately delaying 10 seconds...")
|
||||||
|
time.sleep(10)
|
||||||
tftpassert( blocknumber > 0, "There is no block zero!" )
|
tftpassert( blocknumber > 0, "There is no block zero!" )
|
||||||
dat = None
|
dat = None
|
||||||
if resend:
|
blksize = self.context.getBlocksize()
|
||||||
log.warn("Resending block number %d" % blocknumber)
|
buffer = self.context.fileobj.read(blksize)
|
||||||
dat = self.context.last_pkt
|
log.debug("Read %d bytes into buffer" % len(buffer))
|
||||||
self.context.metrics.resent_bytes += len(dat.data)
|
if len(buffer) < blksize:
|
||||||
self.context.metrics.add_dup(dat)
|
log.info("Reached EOF on file %s"
|
||||||
else:
|
% self.context.file_to_transfer)
|
||||||
blksize = self.context.getBlocksize()
|
finished = True
|
||||||
buffer = self.context.fileobj.read(blksize)
|
dat = TftpPacketDAT()
|
||||||
log.debug("Read %d bytes into buffer" % len(buffer))
|
dat.data = buffer
|
||||||
if len(buffer) < blksize:
|
dat.blocknumber = blocknumber
|
||||||
log.info("Reached EOF on file %s"
|
|
||||||
% self.context.file_to_transfer)
|
|
||||||
finished = True
|
|
||||||
dat = TftpPacketDAT()
|
|
||||||
dat.data = buffer
|
|
||||||
dat.blocknumber = blocknumber
|
|
||||||
self.context.metrics.bytes += len(dat.data)
|
self.context.metrics.bytes += len(dat.data)
|
||||||
log.debug("Sending DAT packet %d" % dat.blocknumber)
|
log.debug("Sending DAT packet %d" % dat.blocknumber)
|
||||||
self.context.sock.sendto(dat.encode().buffer,
|
self.context.sock.sendto(dat.encode().buffer,
|
||||||
|
@ -170,7 +169,7 @@ class TftpState(object):
|
||||||
self.context.sock.sendto(ackpkt.encode().buffer,
|
self.context.sock.sendto(ackpkt.encode().buffer,
|
||||||
(self.context.host,
|
(self.context.host,
|
||||||
self.context.tidport))
|
self.context.tidport))
|
||||||
self.last_pkt = ackpkt
|
self.context.last_pkt = ackpkt
|
||||||
|
|
||||||
def sendError(self, errorcode):
|
def sendError(self, errorcode):
|
||||||
"""This method uses the socket passed, and uses the errorcode to
|
"""This method uses the socket passed, and uses the errorcode to
|
||||||
|
@ -181,7 +180,7 @@ class TftpState(object):
|
||||||
self.context.sock.sendto(errpkt.encode().buffer,
|
self.context.sock.sendto(errpkt.encode().buffer,
|
||||||
(self.context.host,
|
(self.context.host,
|
||||||
self.context.tidport))
|
self.context.tidport))
|
||||||
self.last_pkt = errpkt
|
self.context.last_pkt = errpkt
|
||||||
|
|
||||||
def sendOACK(self):
|
def sendOACK(self):
|
||||||
"""This method sends an OACK packet with the options from the current
|
"""This method sends an OACK packet with the options from the current
|
||||||
|
@ -192,18 +191,18 @@ class TftpState(object):
|
||||||
self.context.sock.sendto(pkt.encode().buffer,
|
self.context.sock.sendto(pkt.encode().buffer,
|
||||||
(self.context.host,
|
(self.context.host,
|
||||||
self.context.tidport))
|
self.context.tidport))
|
||||||
self.last_pkt = pkt
|
self.context.last_pkt = pkt
|
||||||
|
|
||||||
def resendLast(self):
|
def resendLast(self):
|
||||||
"Resend the last sent packet due to a timeout."
|
"Resend the last sent packet due to a timeout."
|
||||||
log.warn("Resending packet %s on sessions %s"
|
log.warn("Resending packet %s on sessions %s"
|
||||||
% (self.last_pkt, self))
|
% (self.context.last_pkt, self))
|
||||||
self.context.metrics.resent_bytes += len(self.last_pkt.data)
|
self.context.metrics.resent_bytes += len(self.context.last_pkt.buffer)
|
||||||
self.context.metrics.add_dup(self.last_pkt)
|
self.context.metrics.add_dup(self.context.last_pkt)
|
||||||
self.context.sock.sendto(self.last_pkt.encode().buffer,
|
self.context.sock.sendto(self.context.last_pkt.encode().buffer,
|
||||||
(self.context.host, self.context.tidport))
|
(self.context.host, self.context.tidport))
|
||||||
if self.context.packethook:
|
if self.context.packethook:
|
||||||
self.context.packethook(self.last_pkt)
|
self.context.packethook(self.context.last_pkt)
|
||||||
|
|
||||||
def handleDat(self, pkt):
|
def handleDat(self, pkt):
|
||||||
"""This method handles a DAT packet during a client download, or a
|
"""This method handles a DAT packet during a client download, or a
|
||||||
|
@ -232,7 +231,7 @@ class TftpState(object):
|
||||||
self.sendError(TftpErrors.IllegalTftpOp)
|
self.sendError(TftpErrors.IllegalTftpOp)
|
||||||
raise TftpException, "There is no block zero!"
|
raise TftpException, "There is no block zero!"
|
||||||
log.warn("Dropping duplicate block %d" % pkt.blocknumber)
|
log.warn("Dropping duplicate block %d" % pkt.blocknumber)
|
||||||
self.context.metrics.add_dup(pkt.blocknumber)
|
self.context.metrics.add_dup(pkt)
|
||||||
log.debug("ACKing block %d again, just in case" % pkt.blocknumber)
|
log.debug("ACKing block %d again, just in case" % pkt.blocknumber)
|
||||||
self.sendACK(pkt.blocknumber)
|
self.sendACK(pkt.blocknumber)
|
||||||
|
|
||||||
|
@ -369,7 +368,9 @@ class TftpStateExpectACK(TftpState):
|
||||||
self.context.pending_complete = self.sendDAT()
|
self.context.pending_complete = self.sendDAT()
|
||||||
|
|
||||||
elif pkt.blocknumber < self.context.next_block:
|
elif pkt.blocknumber < self.context.next_block:
|
||||||
self.context.metrics.add_dup(pkt.blocknumber)
|
log.debug("Received duplicate ACK for block %d"
|
||||||
|
% pkt.blocknumber)
|
||||||
|
self.context.metrics.add_dup(pkt)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
log.warn("Oooh, time warp. Received ACK to packet we "
|
log.warn("Oooh, time warp. Received ACK to packet we "
|
||||||
|
|
Reference in New Issue