Added OACK packet, and factored-out client code.
git-svn-id: https://tftpy.svn.sourceforge.net/svnroot/tftpy/trunk@5 63283fd4-ec1e-0410-9879-cb7f675518da
This commit is contained in:
parent
430f4f2a63
commit
88c387b1ec
2 changed files with 213 additions and 158 deletions
|
@ -0,0 +1,72 @@
|
|||
#!/usr/bin/env python
|
||||
|
||||
import sys
|
||||
from optparse import OptionParser
|
||||
from tftpy import *
|
||||
|
||||
def main():
|
||||
usage=""
|
||||
parser = OptionParser(usage=usage)
|
||||
parser.add_option('-t',
|
||||
'--test',
|
||||
action='store_true',
|
||||
dest='test',
|
||||
help='run test case(s)',
|
||||
default=False)
|
||||
parser.add_option('-H',
|
||||
'--host',
|
||||
action='store',
|
||||
dest='host',
|
||||
help='remote host or ip address')
|
||||
parser.add_option('-p',
|
||||
'--port',
|
||||
action='store',
|
||||
dest='port',
|
||||
help='remote port to use (default: 69)',
|
||||
default=69)
|
||||
parser.add_option('-f',
|
||||
'--filename',
|
||||
action='store',
|
||||
dest='filename',
|
||||
help='filename to fetch')
|
||||
parser.add_option('-b',
|
||||
'--blocksize',
|
||||
action='store',
|
||||
dest='blocksize',
|
||||
help='udp packet size to use (default: 512)',
|
||||
default=512)
|
||||
parser.add_option('-o',
|
||||
'--output',
|
||||
action='store',
|
||||
dest='output',
|
||||
help='output file (default: out)',
|
||||
default='out')
|
||||
options, args = parser.parse_args()
|
||||
if options.test:
|
||||
options.host = "216.191.234.113"
|
||||
options.port = 20001
|
||||
options.filename = 'ipp510main.bin'
|
||||
options.output = 'ipp510main.bin'
|
||||
if not options.host or not options.filename:
|
||||
parser.print_help()
|
||||
sys.exit(1)
|
||||
|
||||
class Progress(object):
|
||||
def __init__(self, out):
|
||||
self.progress = 0
|
||||
self.out = out
|
||||
def progresshook(self, pkt):
|
||||
self.progress += len(pkt.data)
|
||||
self.out("Downloaded %d bytes" % self.progress)
|
||||
|
||||
progresshook = Progress(logger.info).progresshook
|
||||
|
||||
tclient = TftpClient(options.host,
|
||||
options.port,
|
||||
options.blocksize)
|
||||
tclient.download(options.filename,
|
||||
options.output,
|
||||
progresshook)
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
299
lib/tftpy.py
299
lib/tftpy.py
|
@ -1,13 +1,10 @@
|
|||
#!/usr/bin/python
|
||||
|
||||
"""This library implements the tftp protocol, based on rfc 1350.
|
||||
http://www.faqs.org/rfcs/rfc1350.html
|
||||
At the moment it implements only a client class, but will include a server,
|
||||
with support for variable block sizes.
|
||||
"""
|
||||
|
||||
import struct, random, socket, sys, logging, time
|
||||
from optparse import OptionParser
|
||||
import struct, socket, logging, time, sys
|
||||
|
||||
# Make sure that this is at least Python 2.4
|
||||
verlist = sys.version_info
|
||||
|
@ -27,16 +24,25 @@ logging.basicConfig(
|
|||
level=LOG_LEVEL,
|
||||
format='%(asctime)s %(name)-12s %(levelname)-8s %(message)s',
|
||||
datefmt='%m-%d %H:%M')
|
||||
logger = logging.getLogger('tftplib')
|
||||
logger = logging.getLogger('tftpy')
|
||||
|
||||
class TftpException(Exception):
|
||||
"""This class is the parent class of all exceptions regarding the handling
|
||||
of the TFTP protocol."""
|
||||
pass
|
||||
|
||||
def tftpassert(condition, msg):
|
||||
"""This function is a simple utility that will check the condition
|
||||
passed for a false state. If it finds one, it throws a TftpException
|
||||
with the message passed. This just makes the code throughout cleaner
|
||||
by refactoring."""
|
||||
if not condition:
|
||||
raise TftpException, msg
|
||||
|
||||
class TftpPacket(object):
|
||||
"""This class is the parent class of all tftp packet classes. It is an
|
||||
abstract class, providing an interface, and should not be instantiated
|
||||
directly."""
|
||||
def __init__(self):
|
||||
self.opcode = 0
|
||||
self.buffer = None
|
||||
|
@ -58,70 +64,8 @@ class TftpPacket(object):
|
|||
|
||||
This is an abstract method."""
|
||||
raise NotImplementedError, "Abstract method"
|
||||
|
||||
class TftpPacketInitial(TftpPacket):
|
||||
"""This class is a common parent class for the RRQ and WRQ packets, as they share
|
||||
quite a bit of code."""
|
||||
def __init__(self):
|
||||
TftpPacket.__init__(self)
|
||||
self.filename = None
|
||||
self.mode = None
|
||||
self.options = {}
|
||||
|
||||
def encode(self):
|
||||
"""Encode the packet's buffer from the instance variables."""
|
||||
tftpassert(self.filename, "filename required in initial packet")
|
||||
tftpassert(self.mode, "mode required in initial packet")
|
||||
|
||||
format = "!H"
|
||||
length = len(self.filename)
|
||||
format += "%ds" % length
|
||||
format += "B"
|
||||
if self.mode == "octet":
|
||||
format += "5s"
|
||||
else:
|
||||
raise AssertionError, "Unsupported mode: %s" % mode
|
||||
format += "B"
|
||||
logger.debug("format is %s" % format)
|
||||
logger.debug("size of struct is %d" % struct.calcsize(format))
|
||||
|
||||
self.buffer = struct.pack(format, self.opcode, self.filename, 0, self.mode, 0)
|
||||
return self
|
||||
|
||||
def decode(self):
|
||||
tftpassert(self.buffer, "Can't decode, buffer is empty")
|
||||
|
||||
# FIXME - this shares a lot of code with decode_with_options
|
||||
nulls = 0
|
||||
# 2 byte opcode, followed by filename and mode strings, optionally followed
|
||||
# by options.
|
||||
format = ""
|
||||
nulls = length = tlength = 0
|
||||
logger.debug("about to iterate buffer counting nulls")
|
||||
for c in self.buffer:
|
||||
if ord(c) == 0:
|
||||
nulls += 1
|
||||
logger.debug("found a null at length %d, now have %d" % (length, nulls))
|
||||
length = 0
|
||||
format += "%dsx" % length
|
||||
# At 2 nulls, we want to mark that position for decoding.
|
||||
if nulls == 2:
|
||||
break
|
||||
length += 1
|
||||
tlength += 1
|
||||
logger.debug("hopefully found end of mode at length %d" % tlength)
|
||||
# length should now be the end of the mode.
|
||||
tftpassert(nulls == 2, "malformed packet")
|
||||
shortbuf = self.buffer[2:tlength]
|
||||
mystruct = struct.unpack(format, shortbuf)
|
||||
for key in mystruct:
|
||||
logger.debug("option name is %s, value is %s" % (key, mystruct[key]))
|
||||
|
||||
tftpassert(len(mystruct) == 2, "malformed packet")
|
||||
self.options = self.decode_with_options(self.buffer[tlength:])
|
||||
return self
|
||||
|
||||
def decode_with_options(self, buffer):
|
||||
def decode_options(self, buffer):
|
||||
"""This method decodes the section of the buffer that contains an
|
||||
unknown number of options. It returns a dictionary of option names and
|
||||
values."""
|
||||
|
@ -145,15 +89,87 @@ class TftpPacketInitial(TftpPacket):
|
|||
# Unpack the buffer.
|
||||
mystruct = struct.unpack(format, buffer)
|
||||
for key in mystruct:
|
||||
self.debug("option name is %s, value is %s" % (key, mystruct[key]))
|
||||
self.debug("option name is %s, value is %s"
|
||||
% (key, mystruct[key]))
|
||||
|
||||
tftpassert(len(mystruct) % 2 == 0, "packet with odd number of option/value pairs")
|
||||
tftpassert(len(mystruct) % 2 == 0,
|
||||
"packet with odd number of option/value pairs")
|
||||
|
||||
for i in range(0, len(mystruct), 2):
|
||||
options[mystruct[i]] = mystruct[i+1]
|
||||
|
||||
return options
|
||||
|
||||
class TftpPacketInitial(TftpPacket):
|
||||
"""This class is a common parent class for the RRQ and WRQ packets, as
|
||||
they share quite a bit of code."""
|
||||
def __init__(self):
|
||||
TftpPacket.__init__(self)
|
||||
self.filename = None
|
||||
self.mode = None
|
||||
self.options = {}
|
||||
|
||||
def encode(self):
|
||||
"""Encode the packet's buffer from the instance variables."""
|
||||
tftpassert(self.filename, "filename required in initial packet")
|
||||
tftpassert(self.mode, "mode required in initial packet")
|
||||
|
||||
format = "!H"
|
||||
format += "%dsx" % len(self.filename)
|
||||
if self.mode == "octet":
|
||||
format += "5sx"
|
||||
else:
|
||||
raise AssertionError, "Unsupported mode: %s" % mode
|
||||
# Add options.
|
||||
options_list = []
|
||||
if self.options.keys() > 0:
|
||||
for key in self.options:
|
||||
format += "%dsx" % len(key)
|
||||
format += "%dsx" % len(self.options[key])
|
||||
options_list.append(key)
|
||||
options_list.append(self.options[key])
|
||||
#format += "B"
|
||||
logger.debug("format is %s" % format)
|
||||
logger.debug("size of struct is %d" % struct.calcsize(format))
|
||||
|
||||
self.buffer = struct.pack(format, self.opcode, self.filename, self.mode, *options_list)
|
||||
return self
|
||||
|
||||
def decode(self):
|
||||
tftpassert(self.buffer, "Can't decode, buffer is empty")
|
||||
|
||||
# FIXME - this shares a lot of code with decode_options
|
||||
nulls = 0
|
||||
# 2 byte opcode, followed by filename and mode strings, optionally
|
||||
# followed by options.
|
||||
format = ""
|
||||
nulls = length = tlength = 0
|
||||
logger.debug("about to iterate buffer counting nulls")
|
||||
for c in self.buffer:
|
||||
if ord(c) == 0:
|
||||
nulls += 1
|
||||
logger.debug("found a null at length %d, now have %d"
|
||||
% (length, nulls))
|
||||
length = 0
|
||||
format += "%dsx" % length
|
||||
# At 2 nulls, we want to mark that position for decoding.
|
||||
if nulls == 2:
|
||||
break
|
||||
length += 1
|
||||
tlength += 1
|
||||
logger.debug("hopefully found end of mode at length %d" % tlength)
|
||||
# length should now be the end of the mode.
|
||||
tftpassert(nulls == 2, "malformed packet")
|
||||
shortbuf = self.buffer[2:tlength]
|
||||
mystruct = struct.unpack(format, shortbuf)
|
||||
for key in mystruct:
|
||||
logger.debug("option name is %s, value is %s"
|
||||
% (key, mystruct[key]))
|
||||
|
||||
tftpassert(len(mystruct) == 2, "malformed packet")
|
||||
self.options = self.decode_options(self.buffer[tlength:])
|
||||
return self
|
||||
|
||||
class TftpPacketRRQ(TftpPacketInitial):
|
||||
"""
|
||||
2 bytes string 1 byte string 1 byte
|
||||
|
@ -190,11 +206,14 @@ DATA | 03 | Block # | Data |
|
|||
self.data = None
|
||||
|
||||
def encode(self):
|
||||
"""Encode the DAT packet. This method populates self.buffer, and returns
|
||||
self for easy method chaining."""
|
||||
"""Encode the DAT packet. This method populates self.buffer, and
|
||||
returns self for easy method chaining."""
|
||||
tftpassert(len(self.data) > 0, "no point encoding empty data packet")
|
||||
format = "!HH%ds" % len(self.data)
|
||||
self.buffer = struct.pack(format, self.opcode, self.blocknumber, self.data)
|
||||
self.buffer = struct.pack(format,
|
||||
self.opcode,
|
||||
self.blocknumber,
|
||||
self.data)
|
||||
return self
|
||||
|
||||
def decode(self):
|
||||
|
@ -204,11 +223,12 @@ DATA | 03 | Block # | Data |
|
|||
# block number.
|
||||
(self.blocknumber,) = struct.unpack("!H", self.buffer[2:4])
|
||||
logger.info("decoding DAT packet, block number %d" % self.blocknumber)
|
||||
logger.debug("should be %d bytes in the packet total" % len(self.buffer))
|
||||
logger.debug("should be %d bytes in the packet total"
|
||||
% len(self.buffer))
|
||||
# Everything else is data.
|
||||
self.data = self.buffer[4:]
|
||||
logger.debug("found %d bytes of data" \
|
||||
% len(self.data))
|
||||
logger.debug("found %d bytes of data"
|
||||
% len(self.data))
|
||||
return self
|
||||
|
||||
class TftpPacketACK(TftpPacket):
|
||||
|
@ -224,15 +244,15 @@ ACK | 04 | Block # |
|
|||
self.blocknumber = 0
|
||||
|
||||
def encode(self):
|
||||
logger.debug("encoding ACK: opcode = %d, block = %d" \
|
||||
% (self.opcode, self.blocknumber))
|
||||
logger.debug("encoding ACK: opcode = %d, block = %d"
|
||||
% (self.opcode, self.blocknumber))
|
||||
self.buffer = struct.pack("!HH", self.opcode, self.blocknumber)
|
||||
return self
|
||||
|
||||
def decode(self):
|
||||
self.opcode, self.blocknumber = struct.unpack("!HH", self.buffer)
|
||||
logger.debug("decoded ACK packet: opcode = %d, block = %d" \
|
||||
% (self.opcode, self.blocknumber))
|
||||
logger.debug("decoded ACK packet: opcode = %d, block = %d"
|
||||
% (self.opcode, self.blocknumber))
|
||||
return self
|
||||
|
||||
class TftpPacketERR(TftpPacket):
|
||||
|
@ -270,8 +290,8 @@ ERROR | 05 | ErrorCode | ErrMsg | 0 |
|
|||
}
|
||||
|
||||
def encode(self):
|
||||
"""Encode the DAT packet based on instance variables, populating self.buffer,
|
||||
returning self."""
|
||||
"""Encode the DAT packet based on instance variables, populating
|
||||
self.buffer, returning self."""
|
||||
format = "!HH%dsx" % len(self.errmsgs[self.errorcode])
|
||||
self.debug("encoding ERR packet with format %s" % format)
|
||||
self.buffer = struct.pack(format,
|
||||
|
@ -284,9 +304,36 @@ ERROR | 05 | ErrorCode | ErrMsg | 0 |
|
|||
"Decode self.buffer, populating instance variables and return self."
|
||||
tftpassert(len(self.buffer) >= 5, "malformed ERR packet")
|
||||
format = "!HH%dsx" % len(self.buffer)-5
|
||||
self.opcode, self.errorcode, self.errmsg = struct.unpack(format, self.buffer)
|
||||
logger.error("ERR packet - errorcode: %d, message: %s" \
|
||||
% (errorcode, self.errmsg))
|
||||
self.opcode, self.errorcode, self.errmsg = struct.unpack(format,
|
||||
self.buffer)
|
||||
logger.error("ERR packet - errorcode: %d, message: %s"
|
||||
% (errorcode, self.errmsg))
|
||||
return self
|
||||
|
||||
class TftpPacketOACK(TftpPacket):
|
||||
"""
|
||||
# +-------+---~~---+---+---~~---+---+---~~---+---+---~~---+---+
|
||||
# | opc | opt1 | 0 | value1 | 0 | optN | 0 | valueN | 0 |
|
||||
# +-------+---~~---+---+---~~---+---+---~~---+---+---~~---+---+
|
||||
"""
|
||||
def __init__(self):
|
||||
TftpPacket.__init__(self)
|
||||
self.opcode = 6
|
||||
self.options = {}
|
||||
|
||||
def encode(self):
|
||||
format = "!H" # opcode
|
||||
options_list = []
|
||||
for key in self.options:
|
||||
format += "%dsx" % len(key)
|
||||
format += "%dsx" % len(self.options[key])
|
||||
options_list.append(key)
|
||||
options_list.append(self.options[key])
|
||||
self.buffer = struct.pack(format, self.opcode, *options_list)
|
||||
return self
|
||||
|
||||
def decode(self):
|
||||
self.options = self.decode_options(self.buffer[2:])
|
||||
return self
|
||||
|
||||
class TftpPacketFactory(object):
|
||||
|
@ -301,7 +348,8 @@ class TftpPacketFactory(object):
|
|||
}
|
||||
|
||||
def create(self, opcode):
|
||||
tftpassert(self.classes.has_key(opcode), "Unsupported opcode: %d" % opcode)
|
||||
tftpassert(self.classes.has_key(opcode),
|
||||
"Unsupported opcode: %d" % opcode)
|
||||
packet = self.classes[opcode]()
|
||||
logger.debug("packet is %s" % packet)
|
||||
return packet
|
||||
|
@ -367,7 +415,8 @@ class TftpClient(Tftp):
|
|||
logger.debug("recvpkt.blocknumber = %d" % recvpkt.blocknumber)
|
||||
logger.debug("curblock = %d" % curblock)
|
||||
if recvpkt.blocknumber == curblock+1:
|
||||
logger.debug("good, received block %d in sequence" % recvpkt.blocknumber)
|
||||
logger.debug("good, received block %d in sequence"
|
||||
% recvpkt.blocknumber)
|
||||
curblock += 1
|
||||
# ACK the packet, and save the data.
|
||||
logger.info("sending ACK to block %d" % curblock)
|
||||
|
@ -376,7 +425,8 @@ class TftpClient(Tftp):
|
|||
ackpkt.blocknumber = curblock
|
||||
sock.sendto(ackpkt.encode().buffer, (self.host, self.port))
|
||||
|
||||
logger.debug("writing %d bytes to output file" % len(recvpkt.data))
|
||||
logger.debug("writing %d bytes to output file"
|
||||
% len(recvpkt.data))
|
||||
outputfile.write(recvpkt.data)
|
||||
bytes += len(recvpkt.data)
|
||||
# If there is a packethook defined, call it.
|
||||
|
@ -393,8 +443,8 @@ class TftpClient(Tftp):
|
|||
dups[curblock] += 1
|
||||
else:
|
||||
dups[curblock] = 1
|
||||
if dups[curblock] >= MAX_DUPS:
|
||||
raise TftpException, "Max duplicates for block %d reached" % curblock
|
||||
tftpassert(dups[curblock] < MAX_DUPS,
|
||||
"Max duplicates for block %d reached" % curblock)
|
||||
logger.debug("ACKing block %d again, just in case" % curblock)
|
||||
ackpkt = TftpPacketACK()
|
||||
ackpkt.blocknumber = curblock
|
||||
|
@ -428,71 +478,4 @@ class TftpClient(Tftp):
|
|||
dupcount = 0
|
||||
for key in dups:
|
||||
dupcount += dups[key]
|
||||
logger.info("Received %d duplicate packets" % dupcount)
|
||||
|
||||
def main():
|
||||
usage=""
|
||||
parser = OptionParser(usage=usage)
|
||||
parser.add_option('-t',
|
||||
'--test',
|
||||
action='store_true',
|
||||
dest='test',
|
||||
help='run test case(s)',
|
||||
default=False)
|
||||
parser.add_option('-H',
|
||||
'--host',
|
||||
action='store',
|
||||
dest='host',
|
||||
help='remote host or ip address')
|
||||
parser.add_option('-p',
|
||||
'--port',
|
||||
action='store',
|
||||
dest='port',
|
||||
help='remote port to use (default: 69)',
|
||||
default=69)
|
||||
parser.add_option('-f',
|
||||
'--filename',
|
||||
action='store',
|
||||
dest='filename',
|
||||
help='filename to fetch')
|
||||
parser.add_option('-b',
|
||||
'--blocksize',
|
||||
action='store',
|
||||
dest='blocksize',
|
||||
help='udp packet size to use (default: 512)',
|
||||
default=512)
|
||||
parser.add_option('-o',
|
||||
'--output',
|
||||
action='store',
|
||||
dest='output',
|
||||
help='output file (default: out)',
|
||||
default='out')
|
||||
options, args = parser.parse_args()
|
||||
if options.test:
|
||||
options.host = "216.191.234.113"
|
||||
options.port = 20001
|
||||
options.filename = 'ipp510main.bin'
|
||||
options.output = 'ipp510main.bin'
|
||||
if not options.host or not options.filename:
|
||||
parser.print_help()
|
||||
sys.exit(1)
|
||||
|
||||
class Progress(object):
|
||||
def __init__(self, out):
|
||||
self.progress = 0
|
||||
self.out = out
|
||||
def progresshook(self, pkt):
|
||||
self.progress += len(pkt.data)
|
||||
self.out("Downloaded %d bytes" % self.progress)
|
||||
|
||||
progresshook = Progress(logger.info).progresshook
|
||||
|
||||
tclient = TftpClient(options.host,
|
||||
options.port,
|
||||
options.blocksize)
|
||||
tclient.download(options.filename,
|
||||
options.output,
|
||||
progresshook)
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
logger.info("Received %d duplicate packets" % dupcount)
|
Reference in a new issue