Added OACK packet, and factored-out client code.

git-svn-id: https://tftpy.svn.sourceforge.net/svnroot/tftpy/trunk@5 63283fd4-ec1e-0410-9879-cb7f675518da
This commit is contained in:
msoulier 2006-10-04 02:43:05 +00:00
parent 430f4f2a63
commit 88c387b1ec
2 changed files with 213 additions and 158 deletions

View file

@ -0,0 +1,72 @@
#!/usr/bin/env python
import sys
from optparse import OptionParser
from tftpy import *
def main():
usage=""
parser = OptionParser(usage=usage)
parser.add_option('-t',
'--test',
action='store_true',
dest='test',
help='run test case(s)',
default=False)
parser.add_option('-H',
'--host',
action='store',
dest='host',
help='remote host or ip address')
parser.add_option('-p',
'--port',
action='store',
dest='port',
help='remote port to use (default: 69)',
default=69)
parser.add_option('-f',
'--filename',
action='store',
dest='filename',
help='filename to fetch')
parser.add_option('-b',
'--blocksize',
action='store',
dest='blocksize',
help='udp packet size to use (default: 512)',
default=512)
parser.add_option('-o',
'--output',
action='store',
dest='output',
help='output file (default: out)',
default='out')
options, args = parser.parse_args()
if options.test:
options.host = "216.191.234.113"
options.port = 20001
options.filename = 'ipp510main.bin'
options.output = 'ipp510main.bin'
if not options.host or not options.filename:
parser.print_help()
sys.exit(1)
class Progress(object):
def __init__(self, out):
self.progress = 0
self.out = out
def progresshook(self, pkt):
self.progress += len(pkt.data)
self.out("Downloaded %d bytes" % self.progress)
progresshook = Progress(logger.info).progresshook
tclient = TftpClient(options.host,
options.port,
options.blocksize)
tclient.download(options.filename,
options.output,
progresshook)
if __name__ == '__main__':
main()

View file

@ -1,13 +1,10 @@
#!/usr/bin/python
"""This library implements the tftp protocol, based on rfc 1350.
http://www.faqs.org/rfcs/rfc1350.html
At the moment it implements only a client class, but will include a server,
with support for variable block sizes.
"""
import struct, random, socket, sys, logging, time
from optparse import OptionParser
import struct, socket, logging, time, sys
# Make sure that this is at least Python 2.4
verlist = sys.version_info
@ -27,16 +24,25 @@ logging.basicConfig(
level=LOG_LEVEL,
format='%(asctime)s %(name)-12s %(levelname)-8s %(message)s',
datefmt='%m-%d %H:%M')
logger = logging.getLogger('tftplib')
logger = logging.getLogger('tftpy')
class TftpException(Exception):
"""This class is the parent class of all exceptions regarding the handling
of the TFTP protocol."""
pass
def tftpassert(condition, msg):
"""This function is a simple utility that will check the condition
passed for a false state. If it finds one, it throws a TftpException
with the message passed. This just makes the code throughout cleaner
by refactoring."""
if not condition:
raise TftpException, msg
class TftpPacket(object):
"""This class is the parent class of all tftp packet classes. It is an
abstract class, providing an interface, and should not be instantiated
directly."""
def __init__(self):
self.opcode = 0
self.buffer = None
@ -58,70 +64,8 @@ class TftpPacket(object):
This is an abstract method."""
raise NotImplementedError, "Abstract method"
class TftpPacketInitial(TftpPacket):
"""This class is a common parent class for the RRQ and WRQ packets, as they share
quite a bit of code."""
def __init__(self):
TftpPacket.__init__(self)
self.filename = None
self.mode = None
self.options = {}
def encode(self):
"""Encode the packet's buffer from the instance variables."""
tftpassert(self.filename, "filename required in initial packet")
tftpassert(self.mode, "mode required in initial packet")
format = "!H"
length = len(self.filename)
format += "%ds" % length
format += "B"
if self.mode == "octet":
format += "5s"
else:
raise AssertionError, "Unsupported mode: %s" % mode
format += "B"
logger.debug("format is %s" % format)
logger.debug("size of struct is %d" % struct.calcsize(format))
self.buffer = struct.pack(format, self.opcode, self.filename, 0, self.mode, 0)
return self
def decode(self):
tftpassert(self.buffer, "Can't decode, buffer is empty")
# FIXME - this shares a lot of code with decode_with_options
nulls = 0
# 2 byte opcode, followed by filename and mode strings, optionally followed
# by options.
format = ""
nulls = length = tlength = 0
logger.debug("about to iterate buffer counting nulls")
for c in self.buffer:
if ord(c) == 0:
nulls += 1
logger.debug("found a null at length %d, now have %d" % (length, nulls))
length = 0
format += "%dsx" % length
# At 2 nulls, we want to mark that position for decoding.
if nulls == 2:
break
length += 1
tlength += 1
logger.debug("hopefully found end of mode at length %d" % tlength)
# length should now be the end of the mode.
tftpassert(nulls == 2, "malformed packet")
shortbuf = self.buffer[2:tlength]
mystruct = struct.unpack(format, shortbuf)
for key in mystruct:
logger.debug("option name is %s, value is %s" % (key, mystruct[key]))
tftpassert(len(mystruct) == 2, "malformed packet")
self.options = self.decode_with_options(self.buffer[tlength:])
return self
def decode_with_options(self, buffer):
def decode_options(self, buffer):
"""This method decodes the section of the buffer that contains an
unknown number of options. It returns a dictionary of option names and
values."""
@ -145,15 +89,87 @@ class TftpPacketInitial(TftpPacket):
# Unpack the buffer.
mystruct = struct.unpack(format, buffer)
for key in mystruct:
self.debug("option name is %s, value is %s" % (key, mystruct[key]))
self.debug("option name is %s, value is %s"
% (key, mystruct[key]))
tftpassert(len(mystruct) % 2 == 0, "packet with odd number of option/value pairs")
tftpassert(len(mystruct) % 2 == 0,
"packet with odd number of option/value pairs")
for i in range(0, len(mystruct), 2):
options[mystruct[i]] = mystruct[i+1]
return options
class TftpPacketInitial(TftpPacket):
"""This class is a common parent class for the RRQ and WRQ packets, as
they share quite a bit of code."""
def __init__(self):
TftpPacket.__init__(self)
self.filename = None
self.mode = None
self.options = {}
def encode(self):
"""Encode the packet's buffer from the instance variables."""
tftpassert(self.filename, "filename required in initial packet")
tftpassert(self.mode, "mode required in initial packet")
format = "!H"
format += "%dsx" % len(self.filename)
if self.mode == "octet":
format += "5sx"
else:
raise AssertionError, "Unsupported mode: %s" % mode
# Add options.
options_list = []
if self.options.keys() > 0:
for key in self.options:
format += "%dsx" % len(key)
format += "%dsx" % len(self.options[key])
options_list.append(key)
options_list.append(self.options[key])
#format += "B"
logger.debug("format is %s" % format)
logger.debug("size of struct is %d" % struct.calcsize(format))
self.buffer = struct.pack(format, self.opcode, self.filename, self.mode, *options_list)
return self
def decode(self):
tftpassert(self.buffer, "Can't decode, buffer is empty")
# FIXME - this shares a lot of code with decode_options
nulls = 0
# 2 byte opcode, followed by filename and mode strings, optionally
# followed by options.
format = ""
nulls = length = tlength = 0
logger.debug("about to iterate buffer counting nulls")
for c in self.buffer:
if ord(c) == 0:
nulls += 1
logger.debug("found a null at length %d, now have %d"
% (length, nulls))
length = 0
format += "%dsx" % length
# At 2 nulls, we want to mark that position for decoding.
if nulls == 2:
break
length += 1
tlength += 1
logger.debug("hopefully found end of mode at length %d" % tlength)
# length should now be the end of the mode.
tftpassert(nulls == 2, "malformed packet")
shortbuf = self.buffer[2:tlength]
mystruct = struct.unpack(format, shortbuf)
for key in mystruct:
logger.debug("option name is %s, value is %s"
% (key, mystruct[key]))
tftpassert(len(mystruct) == 2, "malformed packet")
self.options = self.decode_options(self.buffer[tlength:])
return self
class TftpPacketRRQ(TftpPacketInitial):
"""
2 bytes string 1 byte string 1 byte
@ -190,11 +206,14 @@ DATA | 03 | Block # | Data |
self.data = None
def encode(self):
"""Encode the DAT packet. This method populates self.buffer, and returns
self for easy method chaining."""
"""Encode the DAT packet. This method populates self.buffer, and
returns self for easy method chaining."""
tftpassert(len(self.data) > 0, "no point encoding empty data packet")
format = "!HH%ds" % len(self.data)
self.buffer = struct.pack(format, self.opcode, self.blocknumber, self.data)
self.buffer = struct.pack(format,
self.opcode,
self.blocknumber,
self.data)
return self
def decode(self):
@ -204,11 +223,12 @@ DATA | 03 | Block # | Data |
# block number.
(self.blocknumber,) = struct.unpack("!H", self.buffer[2:4])
logger.info("decoding DAT packet, block number %d" % self.blocknumber)
logger.debug("should be %d bytes in the packet total" % len(self.buffer))
logger.debug("should be %d bytes in the packet total"
% len(self.buffer))
# Everything else is data.
self.data = self.buffer[4:]
logger.debug("found %d bytes of data" \
% len(self.data))
logger.debug("found %d bytes of data"
% len(self.data))
return self
class TftpPacketACK(TftpPacket):
@ -224,15 +244,15 @@ ACK | 04 | Block # |
self.blocknumber = 0
def encode(self):
logger.debug("encoding ACK: opcode = %d, block = %d" \
% (self.opcode, self.blocknumber))
logger.debug("encoding ACK: opcode = %d, block = %d"
% (self.opcode, self.blocknumber))
self.buffer = struct.pack("!HH", self.opcode, self.blocknumber)
return self
def decode(self):
self.opcode, self.blocknumber = struct.unpack("!HH", self.buffer)
logger.debug("decoded ACK packet: opcode = %d, block = %d" \
% (self.opcode, self.blocknumber))
logger.debug("decoded ACK packet: opcode = %d, block = %d"
% (self.opcode, self.blocknumber))
return self
class TftpPacketERR(TftpPacket):
@ -270,8 +290,8 @@ ERROR | 05 | ErrorCode | ErrMsg | 0 |
}
def encode(self):
"""Encode the DAT packet based on instance variables, populating self.buffer,
returning self."""
"""Encode the DAT packet based on instance variables, populating
self.buffer, returning self."""
format = "!HH%dsx" % len(self.errmsgs[self.errorcode])
self.debug("encoding ERR packet with format %s" % format)
self.buffer = struct.pack(format,
@ -284,9 +304,36 @@ ERROR | 05 | ErrorCode | ErrMsg | 0 |
"Decode self.buffer, populating instance variables and return self."
tftpassert(len(self.buffer) >= 5, "malformed ERR packet")
format = "!HH%dsx" % len(self.buffer)-5
self.opcode, self.errorcode, self.errmsg = struct.unpack(format, self.buffer)
logger.error("ERR packet - errorcode: %d, message: %s" \
% (errorcode, self.errmsg))
self.opcode, self.errorcode, self.errmsg = struct.unpack(format,
self.buffer)
logger.error("ERR packet - errorcode: %d, message: %s"
% (errorcode, self.errmsg))
return self
class TftpPacketOACK(TftpPacket):
"""
# +-------+---~~---+---+---~~---+---+---~~---+---+---~~---+---+
# | opc | opt1 | 0 | value1 | 0 | optN | 0 | valueN | 0 |
# +-------+---~~---+---+---~~---+---+---~~---+---+---~~---+---+
"""
def __init__(self):
TftpPacket.__init__(self)
self.opcode = 6
self.options = {}
def encode(self):
format = "!H" # opcode
options_list = []
for key in self.options:
format += "%dsx" % len(key)
format += "%dsx" % len(self.options[key])
options_list.append(key)
options_list.append(self.options[key])
self.buffer = struct.pack(format, self.opcode, *options_list)
return self
def decode(self):
self.options = self.decode_options(self.buffer[2:])
return self
class TftpPacketFactory(object):
@ -301,7 +348,8 @@ class TftpPacketFactory(object):
}
def create(self, opcode):
tftpassert(self.classes.has_key(opcode), "Unsupported opcode: %d" % opcode)
tftpassert(self.classes.has_key(opcode),
"Unsupported opcode: %d" % opcode)
packet = self.classes[opcode]()
logger.debug("packet is %s" % packet)
return packet
@ -367,7 +415,8 @@ class TftpClient(Tftp):
logger.debug("recvpkt.blocknumber = %d" % recvpkt.blocknumber)
logger.debug("curblock = %d" % curblock)
if recvpkt.blocknumber == curblock+1:
logger.debug("good, received block %d in sequence" % recvpkt.blocknumber)
logger.debug("good, received block %d in sequence"
% recvpkt.blocknumber)
curblock += 1
# ACK the packet, and save the data.
logger.info("sending ACK to block %d" % curblock)
@ -376,7 +425,8 @@ class TftpClient(Tftp):
ackpkt.blocknumber = curblock
sock.sendto(ackpkt.encode().buffer, (self.host, self.port))
logger.debug("writing %d bytes to output file" % len(recvpkt.data))
logger.debug("writing %d bytes to output file"
% len(recvpkt.data))
outputfile.write(recvpkt.data)
bytes += len(recvpkt.data)
# If there is a packethook defined, call it.
@ -393,8 +443,8 @@ class TftpClient(Tftp):
dups[curblock] += 1
else:
dups[curblock] = 1
if dups[curblock] >= MAX_DUPS:
raise TftpException, "Max duplicates for block %d reached" % curblock
tftpassert(dups[curblock] < MAX_DUPS,
"Max duplicates for block %d reached" % curblock)
logger.debug("ACKing block %d again, just in case" % curblock)
ackpkt = TftpPacketACK()
ackpkt.blocknumber = curblock
@ -428,71 +478,4 @@ class TftpClient(Tftp):
dupcount = 0
for key in dups:
dupcount += dups[key]
logger.info("Received %d duplicate packets" % dupcount)
def main():
usage=""
parser = OptionParser(usage=usage)
parser.add_option('-t',
'--test',
action='store_true',
dest='test',
help='run test case(s)',
default=False)
parser.add_option('-H',
'--host',
action='store',
dest='host',
help='remote host or ip address')
parser.add_option('-p',
'--port',
action='store',
dest='port',
help='remote port to use (default: 69)',
default=69)
parser.add_option('-f',
'--filename',
action='store',
dest='filename',
help='filename to fetch')
parser.add_option('-b',
'--blocksize',
action='store',
dest='blocksize',
help='udp packet size to use (default: 512)',
default=512)
parser.add_option('-o',
'--output',
action='store',
dest='output',
help='output file (default: out)',
default='out')
options, args = parser.parse_args()
if options.test:
options.host = "216.191.234.113"
options.port = 20001
options.filename = 'ipp510main.bin'
options.output = 'ipp510main.bin'
if not options.host or not options.filename:
parser.print_help()
sys.exit(1)
class Progress(object):
def __init__(self, out):
self.progress = 0
self.out = out
def progresshook(self, pkt):
self.progress += len(pkt.data)
self.out("Downloaded %d bytes" % self.progress)
progresshook = Progress(logger.info).progresshook
tclient = TftpClient(options.host,
options.port,
options.blocksize)
tclient.download(options.filename,
options.output,
progresshook)
if __name__ == '__main__':
main()
logger.info("Received %d duplicate packets" % dupcount)