Fixed broken decode, and adjusted the client options.
git-svn-id: https://tftpy.svn.sourceforge.net/svnroot/tftpy/trunk@20 63283fd4-ec1e-0410-9879-cb7f675518da
This commit is contained in:
parent
6db1b2cfda
commit
6ebd6fcbc8
3 changed files with 42 additions and 34 deletions
|
@ -1,6 +1,6 @@
|
|||
#!/usr/bin/env python
|
||||
|
||||
import sys, logging
|
||||
import sys, logging, os
|
||||
from optparse import OptionParser
|
||||
import tftpy
|
||||
|
||||
|
@ -34,13 +34,15 @@ def main():
|
|||
'--output',
|
||||
action='store',
|
||||
dest='output',
|
||||
help='output file (default: out)',
|
||||
default='out')
|
||||
help='output file (default: same as requested filename)')
|
||||
options, args = parser.parse_args()
|
||||
if not options.host or not options.filename:
|
||||
parser.print_help()
|
||||
sys.exit(1)
|
||||
|
||||
if not options.output:
|
||||
options.output = os.path.basename(options.filename)
|
||||
|
||||
class Progress(object):
|
||||
def __init__(self, out):
|
||||
self.progress = 0
|
||||
|
@ -49,7 +51,7 @@ def main():
|
|||
self.progress += len(pkt.data)
|
||||
self.out("Downloaded %d bytes" % self.progress)
|
||||
|
||||
tftpy.setLogLevel(logging.DEBUG)
|
||||
tftpy.setLogLevel(logging.INFO)
|
||||
|
||||
progresshook = Progress(tftpy.logger.info).progresshook
|
||||
|
||||
|
|
63
lib/tftpy.py
63
lib/tftpy.py
|
@ -12,13 +12,22 @@ if not verlist[0] >= 2 or not verlist[1] >= 4:
|
|||
raise AssertionError, "Requires at least Python 2.4"
|
||||
|
||||
# Change this as desired. FIXME - make this a command-line arg
|
||||
LOG_LEVEL = logging.DEBUG
|
||||
LOG_LEVEL = logging.NOTSET
|
||||
MIN_BLKSIZE = 8
|
||||
DEF_BLKSIZE = 512
|
||||
MAX_BLKSIZE = 65536
|
||||
SOCK_TIMEOUT = 5
|
||||
MAX_DUPS = 20
|
||||
|
||||
# Initialize the logger.
|
||||
logging.basicConfig(
|
||||
level=LOG_LEVEL,
|
||||
format='%(asctime)s %(name)-12s %(levelname)-8s %(message)s',
|
||||
datefmt='%m-%d %H:%M:%S')
|
||||
# The logger used by this library. Feel free to clobber it with your own, if you like, as
|
||||
# long as it conforms to Python's logging.
|
||||
logger = logging.getLogger('tftpy')
|
||||
|
||||
def tftpassert(condition, msg):
|
||||
"""This function is a simple utility that will check the condition
|
||||
passed for a false state. If it finds one, it throws a TftpException
|
||||
|
@ -27,23 +36,12 @@ def tftpassert(condition, msg):
|
|||
if not condition:
|
||||
raise TftpException, msg
|
||||
|
||||
def setLogLevel(level=LOG_LEVEL):
|
||||
def setLogLevel(level):
|
||||
"""This function is a utility function for setting the internal log level.
|
||||
The log level defaults to logging.NOTSET, so unwanted output to stdout is not
|
||||
created."""
|
||||
global logger
|
||||
# Initialize the logger.
|
||||
logging.basicConfig(
|
||||
level=level,
|
||||
format='%(asctime)s %(name)-12s %(levelname)-8s %(message)s',
|
||||
datefmt='%m-%d %H:%M:%S')
|
||||
logger = logging.getLogger('tftpy')
|
||||
|
||||
# The logger used by this library. Feel free to clobber it with your own, if you like, as
|
||||
# long as it conforms to Python's logging.
|
||||
logger = None
|
||||
# Set up the default logger.
|
||||
setLogLevel()
|
||||
logger.setLevel(level)
|
||||
|
||||
class TftpException(Exception):
|
||||
"""This class is the parent class of all exceptions regarding the handling
|
||||
|
@ -84,23 +82,26 @@ class TftpPacket(object):
|
|||
format = "!"
|
||||
options = {}
|
||||
|
||||
logger.debug("buffer is: " + buffer)
|
||||
logger.debug("buffer is: " + buffer.__repr__())
|
||||
logger.debug("size of buffer is %d bytes" % len(buffer))
|
||||
# Count the nulls in the buffer. Each one terminates a string.
|
||||
logger.debug("about to iterate options buffer counting nulls")
|
||||
length = 0
|
||||
for c in buffer:
|
||||
# When we iterate, skip the first 2 bytes where the opcode is.
|
||||
subbuf = buffer[2:]
|
||||
for c in subbuf:
|
||||
#logger.debug("iterating this byte: " + c.__repr__())
|
||||
if ord(c) == 0:
|
||||
logger.debug("found a null at length %d" % length)
|
||||
if length > 0:
|
||||
format += "%dsx" % length
|
||||
length = -1
|
||||
else:
|
||||
raise TftpException, "Invalid options buffer"
|
||||
raise TftpException, "Invalid options in buffer"
|
||||
length += 1
|
||||
|
||||
logger.debug("about to unpack, format is: %s" % format)
|
||||
mystruct = struct.unpack(format, buffer)
|
||||
mystruct = struct.unpack(format, subbuf)
|
||||
|
||||
tftpassert(len(mystruct) % 2 == 0,
|
||||
"packet with odd number of option/value pairs")
|
||||
|
@ -139,12 +140,12 @@ class TftpPacketInitial(TftpPacket):
|
|||
format += "%dsx" % len(str(self.options[key]))
|
||||
options_list.append(key)
|
||||
options_list.append(str(self.options[key]))
|
||||
#format += "B"
|
||||
|
||||
logger.debug("format is %s" % format)
|
||||
logger.debug("size of struct is %d" % struct.calcsize(format))
|
||||
|
||||
self.buffer = struct.pack(format, self.opcode, self.filename, self.mode, *options_list)
|
||||
logger.debug("buffer is " + self.buffer)
|
||||
logger.debug("buffer is " + self.buffer.__repr__())
|
||||
return self
|
||||
|
||||
def decode(self):
|
||||
|
@ -152,35 +153,38 @@ class TftpPacketInitial(TftpPacket):
|
|||
|
||||
# FIXME - this shares a lot of code with decode_options
|
||||
nulls = 0
|
||||
# 2 byte opcode, followed by filename and mode strings, optionally
|
||||
# followed by options.
|
||||
format = ""
|
||||
nulls = length = tlength = 0
|
||||
logger.debug("about to iterate buffer counting nulls")
|
||||
for c in self.buffer:
|
||||
subbuf = self.buffer[2:]
|
||||
for c in subbuf:
|
||||
#logger.debug("iterating this byte: " + c.__repr__())
|
||||
if ord(c) == 0:
|
||||
nulls += 1
|
||||
logger.debug("found a null at length %d, now have %d"
|
||||
% (length, nulls))
|
||||
length = 0
|
||||
format += "%dsx" % length
|
||||
length = -1
|
||||
# At 2 nulls, we want to mark that position for decoding.
|
||||
if nulls == 2:
|
||||
break
|
||||
length += 1
|
||||
tlength += 1
|
||||
|
||||
logger.debug("hopefully found end of mode at length %d" % tlength)
|
||||
# length should now be the end of the mode.
|
||||
tftpassert(nulls == 2, "malformed packet")
|
||||
shortbuf = self.buffer[2:tlength]
|
||||
shortbuf = subbuf[:tlength+1]
|
||||
logger.debug("about to unpack buffer with format: %s" % format)
|
||||
logger.debug("unpacking buffer: " + shortbuf.__repr__())
|
||||
mystruct = struct.unpack(format, shortbuf)
|
||||
for key in mystruct:
|
||||
logger.debug("option name is %s, value is %s"
|
||||
% (key, mystruct[key]))
|
||||
|
||||
for i in range(0, len(mystruct), 2):
|
||||
logger.debug("setting option %s to %s" % (mystruct[i], mystruct[i+1]))
|
||||
|
||||
tftpassert(len(mystruct) == 2, "malformed packet")
|
||||
self.options = self.decode_options(self.buffer[tlength:])
|
||||
#self.options = self.decode_options(self.buffer[tlength:])
|
||||
self.options = self.decode_options(subbuf[tlength+1:])
|
||||
return self
|
||||
|
||||
class TftpPacketRRQ(TftpPacketInitial):
|
||||
|
@ -600,6 +604,7 @@ class TftpClient(TftpSession):
|
|||
end_time = time.time()
|
||||
duration = end_time - start_time
|
||||
outputfile.close()
|
||||
logger.info('')
|
||||
logger.info("Downloaded %d bytes in %d seconds" % (bytes, duration))
|
||||
bps = (bytes * 8.0) / duration
|
||||
kbps = bps / 1024.0
|
||||
|
|
|
@ -1,12 +1,13 @@
|
|||
"""Unit tests for tftpy."""
|
||||
|
||||
import unittest
|
||||
import logging
|
||||
import tftpy
|
||||
|
||||
class TestTftpy(unittest.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
pass
|
||||
tftpy.setLogLevel(logging.INFO)
|
||||
|
||||
def testTftpPacketRRQ(self):
|
||||
options = {}
|
||||
|
|
Reference in a new issue