Fixed broken decode, and adjusted the client options.

git-svn-id: https://tftpy.svn.sourceforge.net/svnroot/tftpy/trunk@20 63283fd4-ec1e-0410-9879-cb7f675518da
master
msoulier 2006-10-09 02:44:27 +00:00
parent 6db1b2cfda
commit 6ebd6fcbc8
3 changed files with 42 additions and 34 deletions

View File

@ -1,6 +1,6 @@
#!/usr/bin/env python
import sys, logging
import sys, logging, os
from optparse import OptionParser
import tftpy
@ -34,13 +34,15 @@ def main():
'--output',
action='store',
dest='output',
help='output file (default: out)',
default='out')
help='output file (default: same as requested filename)')
options, args = parser.parse_args()
if not options.host or not options.filename:
parser.print_help()
sys.exit(1)
if not options.output:
options.output = os.path.basename(options.filename)
class Progress(object):
def __init__(self, out):
self.progress = 0
@ -49,7 +51,7 @@ def main():
self.progress += len(pkt.data)
self.out("Downloaded %d bytes" % self.progress)
tftpy.setLogLevel(logging.DEBUG)
tftpy.setLogLevel(logging.INFO)
progresshook = Progress(tftpy.logger.info).progresshook

View File

@ -12,13 +12,22 @@ if not verlist[0] >= 2 or not verlist[1] >= 4:
raise AssertionError, "Requires at least Python 2.4"
# Change this as desired. FIXME - make this a command-line arg
LOG_LEVEL = logging.DEBUG
LOG_LEVEL = logging.NOTSET
MIN_BLKSIZE = 8
DEF_BLKSIZE = 512
MAX_BLKSIZE = 65536
SOCK_TIMEOUT = 5
MAX_DUPS = 20
# Initialize the logger.
logging.basicConfig(
level=LOG_LEVEL,
format='%(asctime)s %(name)-12s %(levelname)-8s %(message)s',
datefmt='%m-%d %H:%M:%S')
# The logger used by this library. Feel free to clobber it with your own, if you like, as
# long as it conforms to Python's logging.
logger = logging.getLogger('tftpy')
def tftpassert(condition, msg):
"""This function is a simple utility that will check the condition
passed for a false state. If it finds one, it throws a TftpException
@ -27,23 +36,12 @@ def tftpassert(condition, msg):
if not condition:
raise TftpException, msg
def setLogLevel(level=LOG_LEVEL):
def setLogLevel(level):
"""This function is a utility function for setting the internal log level.
The log level defaults to logging.NOTSET, so unwanted output to stdout is not
created."""
global logger
# Initialize the logger.
logging.basicConfig(
level=level,
format='%(asctime)s %(name)-12s %(levelname)-8s %(message)s',
datefmt='%m-%d %H:%M:%S')
logger = logging.getLogger('tftpy')
# The logger used by this library. Feel free to clobber it with your own, if you like, as
# long as it conforms to Python's logging.
logger = None
# Set up the default logger.
setLogLevel()
logger.setLevel(level)
class TftpException(Exception):
"""This class is the parent class of all exceptions regarding the handling
@ -84,23 +82,26 @@ class TftpPacket(object):
format = "!"
options = {}
logger.debug("buffer is: " + buffer)
logger.debug("buffer is: " + buffer.__repr__())
logger.debug("size of buffer is %d bytes" % len(buffer))
# Count the nulls in the buffer. Each one terminates a string.
logger.debug("about to iterate options buffer counting nulls")
length = 0
for c in buffer:
# When we iterate, skip the first 2 bytes where the opcode is.
subbuf = buffer[2:]
for c in subbuf:
#logger.debug("iterating this byte: " + c.__repr__())
if ord(c) == 0:
logger.debug("found a null at length %d" % length)
if length > 0:
format += "%dsx" % length
length = -1
else:
raise TftpException, "Invalid options buffer"
raise TftpException, "Invalid options in buffer"
length += 1
logger.debug("about to unpack, format is: %s" % format)
mystruct = struct.unpack(format, buffer)
mystruct = struct.unpack(format, subbuf)
tftpassert(len(mystruct) % 2 == 0,
"packet with odd number of option/value pairs")
@ -139,12 +140,12 @@ class TftpPacketInitial(TftpPacket):
format += "%dsx" % len(str(self.options[key]))
options_list.append(key)
options_list.append(str(self.options[key]))
#format += "B"
logger.debug("format is %s" % format)
logger.debug("size of struct is %d" % struct.calcsize(format))
self.buffer = struct.pack(format, self.opcode, self.filename, self.mode, *options_list)
logger.debug("buffer is " + self.buffer)
logger.debug("buffer is " + self.buffer.__repr__())
return self
def decode(self):
@ -152,35 +153,38 @@ class TftpPacketInitial(TftpPacket):
# FIXME - this shares a lot of code with decode_options
nulls = 0
# 2 byte opcode, followed by filename and mode strings, optionally
# followed by options.
format = ""
nulls = length = tlength = 0
logger.debug("about to iterate buffer counting nulls")
for c in self.buffer:
subbuf = self.buffer[2:]
for c in subbuf:
#logger.debug("iterating this byte: " + c.__repr__())
if ord(c) == 0:
nulls += 1
logger.debug("found a null at length %d, now have %d"
% (length, nulls))
length = 0
format += "%dsx" % length
length = -1
# At 2 nulls, we want to mark that position for decoding.
if nulls == 2:
break
length += 1
tlength += 1
logger.debug("hopefully found end of mode at length %d" % tlength)
# length should now be the end of the mode.
tftpassert(nulls == 2, "malformed packet")
shortbuf = self.buffer[2:tlength]
shortbuf = subbuf[:tlength+1]
logger.debug("about to unpack buffer with format: %s" % format)
logger.debug("unpacking buffer: " + shortbuf.__repr__())
mystruct = struct.unpack(format, shortbuf)
for key in mystruct:
logger.debug("option name is %s, value is %s"
% (key, mystruct[key]))
for i in range(0, len(mystruct), 2):
logger.debug("setting option %s to %s" % (mystruct[i], mystruct[i+1]))
tftpassert(len(mystruct) == 2, "malformed packet")
self.options = self.decode_options(self.buffer[tlength:])
#self.options = self.decode_options(self.buffer[tlength:])
self.options = self.decode_options(subbuf[tlength+1:])
return self
class TftpPacketRRQ(TftpPacketInitial):
@ -600,6 +604,7 @@ class TftpClient(TftpSession):
end_time = time.time()
duration = end_time - start_time
outputfile.close()
logger.info('')
logger.info("Downloaded %d bytes in %d seconds" % (bytes, duration))
bps = (bytes * 8.0) / duration
kbps = bps / 1024.0

View File

@ -1,12 +1,13 @@
"""Unit tests for tftpy."""
import unittest
import logging
import tftpy
class TestTftpy(unittest.TestCase):
def setUp(self):
pass
tftpy.setLogLevel(logging.INFO)
def testTftpPacketRRQ(self):
options = {}