Did some rework for the state machine in a server context.

Removed the handler framework in favour of a TftpContextServer used
as the session.
master
Michael P. Soulier 2009-08-15 22:36:58 -04:00
parent 03e4e74829
commit 62b22fb562
6 changed files with 564 additions and 636 deletions

View File

@ -52,12 +52,12 @@ class TftpClient(TftpSession):
# output? This should be in the sample client, but not in the download
# call.
if metrics.duration == 0:
logger.info("Duration too short, rate undetermined")
log.info("Duration too short, rate undetermined")
else:
logger.info('')
logger.info("Downloaded %.2f bytes in %.2f seconds" % (metrics.bytes, metrics.duration))
logger.info("Average rate: %.2f kbps" % metrics.kbps)
logger.info("Received %d duplicate packets" % metrics.dupcount)
log.info('')
log.info("Downloaded %.2f bytes in %.2f seconds" % (metrics.bytes, metrics.duration))
log.info("Average rate: %.2f kbps" % metrics.kbps)
log.info("Received %d duplicate packets" % metrics.dupcount)
def upload(self, filename, input, packethook=None, timeout=SOCK_TIMEOUT):
# Open the input file.
@ -80,9 +80,9 @@ class TftpClient(TftpSession):
# output? This should be in the sample client, but not in the download
# call.
if metrics.duration == 0:
logger.info("Duration too short, rate undetermined")
log.info("Duration too short, rate undetermined")
else:
logger.info('')
logger.info("Downloaded %.2f bytes in %.2f seconds" % (metrics.bytes, metrics.duration))
logger.info("Average rate: %.2f kbps" % metrics.kbps)
logger.info("Received %d duplicate packets" % metrics.dupcount)
log.info('')
log.info("Downloaded %.2f bytes in %.2f seconds" % (metrics.bytes, metrics.duration))
log.info("Average rate: %.2f kbps" % metrics.kbps)
log.info("Received %d duplicate packets" % metrics.dupcount)

View File

@ -19,9 +19,9 @@ class TftpPacketFactory(object):
"""This method is used to parse an existing datagram into its
corresponding TftpPacket object. The buffer is the raw bytes off of
the network."""
logger.debug("parsing a %d byte packet" % len(buffer))
log.debug("parsing a %d byte packet" % len(buffer))
(opcode,) = struct.unpack("!H", buffer[:2])
logger.debug("opcode is %d" % opcode)
log.debug("opcode is %d" % opcode)
packet = self.__create(opcode)
packet.buffer = buffer
return packet.decode()

View File

@ -15,7 +15,7 @@ class TftpSession(object):
def senderror(self, sock, errorcode, address, port):
"""This method uses the socket passed, and uses the errorcode, address
and port to compose and send an error packet."""
logger.debug("In senderror, being asked to send error %d to %s:%s"
log.debug("In senderror, being asked to send error %d to %s:%s"
% (errorcode, address, port))
errpkt = TftpPacketERR()
errpkt.errorcode = errorcode
@ -27,23 +27,23 @@ class TftpPacketWithOptions(object):
goal is just to share code here, and not cause diamond inheritance."""
def __init__(self):
self.options = []
self.options = {}
def setoptions(self, options):
logger.debug("in TftpPacketWithOptions.setoptions")
logger.debug("options: " + str(options))
log.debug("in TftpPacketWithOptions.setoptions")
log.debug("options: " + str(options))
myoptions = {}
for key in options:
newkey = str(key)
myoptions[newkey] = str(options[key])
logger.debug("populated myoptions with %s = %s"
log.debug("populated myoptions with %s = %s"
% (newkey, myoptions[newkey]))
logger.debug("setting options hash to: " + str(myoptions))
log.debug("setting options hash to: " + str(myoptions))
self._options = myoptions
def getoptions(self):
logger.debug("in TftpPacketWithOptions.getoptions")
log.debug("in TftpPacketWithOptions.getoptions")
return self._options
# Set up getter and setter on options to ensure that they are the proper
@ -59,19 +59,19 @@ class TftpPacketWithOptions(object):
format = "!"
options = {}
logger.debug("decode_options: buffer is: " + repr(buffer))
logger.debug("size of buffer is %d bytes" % len(buffer))
log.debug("decode_options: buffer is: " + repr(buffer))
log.debug("size of buffer is %d bytes" % len(buffer))
if len(buffer) == 0:
logger.debug("size of buffer is zero, returning empty hash")
log.debug("size of buffer is zero, returning empty hash")
return {}
# Count the nulls in the buffer. Each one terminates a string.
logger.debug("about to iterate options buffer counting nulls")
log.debug("about to iterate options buffer counting nulls")
length = 0
for c in buffer:
#logger.debug("iterating this byte: " + repr(c))
#log.debug("iterating this byte: " + repr(c))
if ord(c) == 0:
logger.debug("found a null at length %d" % length)
log.debug("found a null at length %d" % length)
if length > 0:
format += "%dsx" % length
length = -1
@ -79,14 +79,14 @@ class TftpPacketWithOptions(object):
raise TftpException, "Invalid options in buffer"
length += 1
logger.debug("about to unpack, format is: %s" % format)
log.debug("about to unpack, format is: %s" % format)
mystruct = struct.unpack(format, buffer)
tftpassert(len(mystruct) % 2 == 0,
"packet with odd number of option/value pairs")
for i in range(0, len(mystruct), 2):
logger.debug("setting option %s to %s" % (mystruct[i], mystruct[i+1]))
log.debug("setting option %s to %s" % (mystruct[i], mystruct[i+1]))
options[mystruct[i]] = mystruct[i+1]
return options
@ -134,10 +134,10 @@ class TftpPacketInitial(TftpPacket, TftpPacketWithOptions):
ptype = None
if self.opcode == 1: ptype = "RRQ"
else: ptype = "WRQ"
logger.debug("Encoding %s packet, filename = %s, mode = %s"
log.debug("Encoding %s packet, filename = %s, mode = %s"
% (ptype, self.filename, self.mode))
for key in self.options:
logger.debug(" Option %s = %s" % (key, self.options[key]))
log.debug(" Option %s = %s" % (key, self.options[key]))
format = "!H"
format += "%dsx" % len(self.filename)
@ -148,7 +148,7 @@ class TftpPacketInitial(TftpPacket, TftpPacketWithOptions):
# Add options.
options_list = []
if self.options.keys() > 0:
logger.debug("there are options to encode")
log.debug("there are options to encode")
for key in self.options:
# Populate the option name
format += "%dsx" % len(key)
@ -157,9 +157,9 @@ class TftpPacketInitial(TftpPacket, TftpPacketWithOptions):
format += "%dsx" % len(str(self.options[key]))
options_list.append(str(self.options[key]))
logger.debug("format is %s" % format)
logger.debug("options_list is %s" % options_list)
logger.debug("size of struct is %d" % struct.calcsize(format))
log.debug("format is %s" % format)
log.debug("options_list is %s" % options_list)
log.debug("size of struct is %d" % struct.calcsize(format))
self.buffer = struct.pack(format,
self.opcode,
@ -167,7 +167,7 @@ class TftpPacketInitial(TftpPacket, TftpPacketWithOptions):
self.mode,
*options_list)
logger.debug("buffer is " + repr(self.buffer))
log.debug("buffer is " + repr(self.buffer))
return self
def decode(self):
@ -177,13 +177,13 @@ class TftpPacketInitial(TftpPacket, TftpPacketWithOptions):
nulls = 0
format = ""
nulls = length = tlength = 0
logger.debug("in decode: about to iterate buffer counting nulls")
log.debug("in decode: about to iterate buffer counting nulls")
subbuf = self.buffer[2:]
for c in subbuf:
logger.debug("iterating this byte: " + repr(c))
log.debug("iterating this byte: " + repr(c))
if ord(c) == 0:
nulls += 1
logger.debug("found a null at length %d, now have %d"
log.debug("found a null at length %d, now have %d"
% (length, nulls))
format += "%dsx" % length
length = -1
@ -193,17 +193,17 @@ class TftpPacketInitial(TftpPacket, TftpPacketWithOptions):
length += 1
tlength += 1
logger.debug("hopefully found end of mode at length %d" % tlength)
log.debug("hopefully found end of mode at length %d" % tlength)
# length should now be the end of the mode.
tftpassert(nulls == 2, "malformed packet")
shortbuf = subbuf[:tlength+1]
logger.debug("about to unpack buffer with format: %s" % format)
logger.debug("unpacking buffer: " + repr(shortbuf))
log.debug("about to unpack buffer with format: %s" % format)
log.debug("unpacking buffer: " + repr(shortbuf))
mystruct = struct.unpack(format, shortbuf)
tftpassert(len(mystruct) == 2, "malformed packet")
logger.debug("setting filename to %s" % mystruct[0])
logger.debug("setting mode to %s" % mystruct[1])
log.debug("setting filename to %s" % mystruct[0])
log.debug("setting mode to %s" % mystruct[1])
self.filename = mystruct[0]
self.mode = mystruct[1]
@ -269,7 +269,7 @@ DATA | 03 | Block # | Data |
"""Encode the DAT packet. This method populates self.buffer, and
returns self for easy method chaining."""
if len(self.data) == 0:
logger.debug("Encoding an empty DAT packet")
log.debug("Encoding an empty DAT packet")
format = "!HH%ds" % len(self.data)
self.buffer = struct.pack(format,
self.opcode,
@ -283,12 +283,12 @@ DATA | 03 | Block # | Data |
# We know the first 2 bytes are the opcode. The second two are the
# block number.
(self.blocknumber,) = struct.unpack("!H", self.buffer[2:4])
logger.debug("decoding DAT packet, block number %d" % self.blocknumber)
logger.debug("should be %d bytes in the packet total"
log.debug("decoding DAT packet, block number %d" % self.blocknumber)
log.debug("should be %d bytes in the packet total"
% len(self.buffer))
# Everything else is data.
self.data = self.buffer[4:]
logger.debug("found %d bytes of data"
log.debug("found %d bytes of data"
% len(self.data))
return self
@ -308,14 +308,14 @@ ACK | 04 | Block # |
return 'ACK packet: block %d' % self.blocknumber
def encode(self):
logger.debug("encoding ACK: opcode = %d, block = %d"
log.debug("encoding ACK: opcode = %d, block = %d"
% (self.opcode, self.blocknumber))
self.buffer = struct.pack("!HH", self.opcode, self.blocknumber)
return self
def decode(self):
self.opcode, self.blocknumber = struct.unpack("!HH", self.buffer)
logger.debug("decoded ACK packet: opcode = %d, block = %d"
log.debug("decoded ACK packet: opcode = %d, block = %d"
% (self.opcode, self.blocknumber))
return self
@ -365,7 +365,7 @@ ERROR | 05 | ErrorCode | ErrMsg | 0 |
"""Encode the DAT packet based on instance variables, populating
self.buffer, returning self."""
format = "!HH%dsx" % len(self.errmsgs[self.errorcode])
logger.debug("encoding ERR packet with format %s" % format)
log.debug("encoding ERR packet with format %s" % format)
self.buffer = struct.pack(format,
self.opcode,
self.errorcode,
@ -375,13 +375,13 @@ ERROR | 05 | ErrorCode | ErrMsg | 0 |
def decode(self):
"Decode self.buffer, populating instance variables and return self."
tftpassert(len(self.buffer) > 4, "malformed ERR packet, too short")
logger.debug("Decoding ERR packet, length %s bytes" %
log.debug("Decoding ERR packet, length %s bytes" %
len(self.buffer))
format = "!HH%dsx" % (len(self.buffer) - 5)
logger.debug("Decoding ERR packet with format: %s" % format)
log.debug("Decoding ERR packet with format: %s" % format)
self.opcode, self.errorcode, self.errmsg = struct.unpack(format,
self.buffer)
logger.error("ERR packet - errorcode: %d, message: %s"
log.error("ERR packet - errorcode: %d, message: %s"
% (self.errorcode, self.errmsg))
return self
@ -402,10 +402,10 @@ class TftpPacketOACK(TftpPacket, TftpPacketWithOptions):
def encode(self):
format = "!H" # opcode
options_list = []
logger.debug("in TftpPacketOACK.encode")
log.debug("in TftpPacketOACK.encode")
for key in self.options:
logger.debug("looping on option key %s" % key)
logger.debug("value is %s" % self.options[key])
log.debug("looping on option key %s" % key)
log.debug("value is %s" % self.options[key])
format += "%dsx" % len(key)
format += "%dsx" % len(self.options[key])
options_list.append(key)
@ -429,7 +429,7 @@ class TftpPacketOACK(TftpPacket, TftpPacketWithOptions):
# We can accept anything between the min and max values.
size = self.options[name]
if size >= MIN_BLKSIZE and size <= MAX_BLKSIZE:
logger.debug("negotiated blksize of %d bytes" % size)
log.debug("negotiated blksize of %d bytes" % size)
options[blksize] = size
else:
raise TftpException, "Unsupported option: %s" % name

View File

@ -1,4 +1,5 @@
import socket, os, re, time, random
import select
from TftpShared import *
from TftpPacketTypes import *
from TftpPacketFactory import *
@ -15,26 +16,27 @@ class TftpServer(TftpSession):
self.listenip = None
self.listenport = None
self.sock = None
# FIXME: What about multiple roots?
self.root = os.path.abspath(tftproot)
self.dynfunc = dyn_file_func
self.dyn_file_func = dyn_file_func
# A dict of handlers, where each session is keyed by a string like
# ip:tid for the remote end.
self.handlers = {}
if os.path.exists(self.root):
logger.debug("tftproot %s does exist" % self.root)
log.debug("tftproot %s does exist" % self.root)
if not os.path.isdir(self.root):
raise TftpException, "The tftproot must be a directory."
else:
logger.debug("tftproot %s is a directory" % self.root)
log.debug("tftproot %s is a directory" % self.root)
if os.access(self.root, os.R_OK):
logger.debug("tftproot %s is readable" % self.root)
log.debug("tftproot %s is readable" % self.root)
else:
raise TftpException, "The tftproot must be readable"
if os.access(self.root, os.W_OK):
logger.debug("tftproot %s is writable" % self.root)
log.debug("tftproot %s is writable" % self.root)
else:
logger.warning("The tftproot %s is not writable" % self.root)
log.warning("The tftproot %s is not writable" % self.root)
else:
raise TftpException, "The tftproot does not exist."
@ -45,14 +47,12 @@ class TftpServer(TftpSession):
"""Start a server listening on the supplied interface and port. This
defaults to INADDR_ANY (all interfaces) and UDP port 69. You can also
supply a different socket timeout value, if desired."""
import select
tftp_factory = TftpPacketFactory()
# Don't use new 2.5 ternary operator yet
# listenip = listenip if listenip else '0.0.0.0'
if not listenip: listenip = '0.0.0.0'
logger.info("Server requested on ip %s, port %s"
log.info("Server requested on ip %s, port %s"
% (listenip, listenport))
try:
# FIXME - sockets should be non-blocking?
@ -62,388 +62,82 @@ class TftpServer(TftpSession):
# Reraise it for now.
raise
logger.info("Starting receive loop...")
log.info("Starting receive loop...")
while True:
# Build the inputlist array of sockets to select() on.
inputlist = []
inputlist.append(self.sock)
for key in self.handlers:
inputlist.append(self.handlers[key].sock)
for key in self.sessions:
inputlist.append(self.sessions[key].sock)
# Block until some socket has input on it.
logger.debug("Performing select on this inputlist: %s" % inputlist)
log.debug("Performing select on this inputlist: %s" % inputlist)
readyinput, readyoutput, readyspecial = select.select(inputlist,
[],
[],
SOCK_TIMEOUT)
#(buffer, (raddress, rport)) = self.sock.recvfrom(MAX_BLKSIZE)
#recvpkt = tftp_factory.parse(buffer)
#key = "%s:%s" % (raddress, rport)
deletion_list = []
# Handle the available data, if any. Maybe we timed-out.
for readysock in readyinput:
# Is the traffic on the main server socket? ie. new session?
if readysock == self.sock:
logger.debug("Data ready on our main socket")
log.debug("Data ready on our main socket")
buffer, (raddress, rport) = self.sock.recvfrom(MAX_BLKSIZE)
logger.debug("Read %d bytes" % len(buffer))
log.debug("Read %d bytes" % len(buffer))
recvpkt = tftp_factory.parse(buffer)
# FIXME: Is this the best way to do a session key? What
# about symmetric udp?
key = "%s:%s" % (raddress, rport)
if isinstance(recvpkt, TftpPacketRRQ):
logger.debug("RRQ packet from %s:%s" % (raddress, rport))
if not self.handlers.has_key(key):
try:
logger.debug("New download request, session key = %s"
% key)
self.handlers[key] = TftpServerHandler(key,
'rrq',
self.root,
listenip,
tftp_factory,
self.dynfunc)
self.handlers[key].handle((recvpkt, raddress, rport))
except TftpException, err:
logger.error("Fatal exception thrown from handler: %s"
% str(err))
logger.debug("Deleting handler: %s" % key)
deletion_list.append(key)
else:
logger.warn("Received RRQ for existing session!")
self.senderror(self.sock,
TftpErrors.IllegalTftpOp,
raddress,
rport)
continue
elif isinstance(recvpkt, TftpPacketWRQ):
logger.error("Write requests not implemented at this time.")
self.senderror(self.sock,
TftpErrors.IllegalTftpOp,
raddress,
rport)
continue
if not self.sessions.has_key(key):
log.debug("Creating new server context for "
"session key = %s" % key)
self.sessions[key] = TftpContextServer(raddress,
rport,
timeout,
self.root,
self.dyn_file_func)
self.sessions[key].start(buffer)
else:
# FIXME - this will have to change if we do symmetric UDP
logger.error("Should only receive RRQ or WRQ packets "
"on main listen port. Received %s" % recvpkt)
self.senderror(self.sock,
TftpErrors.IllegalTftpOp,
raddress,
rport)
continue
log.warn("received traffic on main socket for "
"existing session??")
else:
for key in self.handlers:
if readysock == self.handlers[key].sock:
# FIXME - violating DRY principle with above code
# Must find the owner of this traffic.
for key in self.session:
if readysock == self.session[key].sock:
try:
self.handlers[key].handle()
self.session[key].cycle()
if self.session[key].state == None:
log.info("Successful transfer.")
deletion_list.append(key)
break
except TftpException, err:
deletion_list.append(key)
if self.handlers[key].state.state == 'fin':
logger.info("Successful transfer.")
break
else:
logger.error("Fatal exception thrown from handler: %s"
% str(err))
log.error("Fatal exception thrown from "
"handler: %s" % str(err))
else:
logger.error("Can't find the owner for this packet. Discarding.")
log.error("Can't find the owner for this packet. "
"Discarding.")
logger.debug("Looping on all handlers to check for timeouts")
log.debug("Looping on all handlers to check for timeouts")
now = time.time()
for key in self.handlers:
for key in self.sessions:
try:
self.handlers[key].check_timeout(now)
self.sessions[key].checkTimeout(now)
except TftpException, err:
logger.error("Fatal exception thrown from handler: %s"
log.error("Fatal exception thrown from handler: %s"
% str(err))
deletion_list.append(key)
logger.debug("Iterating deletion list.")
log.debug("Iterating deletion list.")
for key in deletion_list:
if self.handlers.has_key(key):
logger.debug("Deleting handler %s" % key)
del self.handlers[key]
if self.sessions.has_key(key):
log.debug("Deleting handler %s" % key)
del self.sessions[key]
deletion_list = []
class TftpServerHandler(TftpSession):
"""This class implements a handler for a given server session, handling
the work for one download."""
def __init__(self, key, state, root, listenip, factory, dyn_file_func):
TftpSession.__init__(self)
logger.info("Starting new handler. Key %s." % key)
self.key = key
self.host, self.port = self.key.split(':')
self.port = int(self.port)
self.listenip = listenip
# Note, correct state here is important as it tells the handler whether it's
# handling a download or an upload.
self.state = state
self.root = root
self.mode = None
self.filename = None
self.sock = False
self.options = { 'blksize': DEF_BLKSIZE }
self.blocknumber = 0
self.buffer = None
self.fileobj = None
self.timesent = 0
self.timeouts = 0
self.tftp_factory = factory
self.dynfunc = dyn_file_func
count = 0
while not self.sock:
self.sock = self.gensock(listenip)
count += 1
if count > 10:
raise TftpException, "Failed to bind this handler to any port"
def check_timeout(self, now):
"""This method checks to see if we've timed-out waiting for traffic
from the client."""
if self.timesent:
if now - self.timesent > SOCK_TIMEOUT:
self.timeout()
def timeout(self):
"""This method handles a timeout condition."""
logger.debug("Handling timeout for handler %s" % self.key)
self.timeouts += 1
if self.timeouts > TIMEOUT_RETRIES:
raise TftpException, "Hit max retries, giving up."
if self.state.state == 'dat' or self.state.state == 'fin':
logger.debug("Timing out on DAT. Need to resend.")
self.send_dat(resend=True)
elif self.state.state == 'oack':
logger.debug("Timing out on OACK. Need to resend.")
self.send_oack()
else:
tftpassert(False,
"Timing out in unsupported state %s" %
self.state.state)
def gensock(self, listenip):
"""This method generates a new UDP socket, whose listening port must
be randomly generated, and not conflict with any already in use. For
now, let the OS do this."""
random.seed()
port = random.randrange(1025, 65536)
# FIXME - sockets should be non-blocking?
sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
logger.debug("Trying a handler socket on port %d" % port)
try:
sock.bind((listenip, port))
return sock
except socket.error, err:
if err[0] == 98:
logger.warn("Handler %s, port %d was already taken" % (self.key, port))
return False
else:
raise
def handle(self, pkttuple=None):
"""This method informs a handler instance that it has data waiting on
its socket that it must read and process."""
recvpkt = raddress = rport = None
if pkttuple:
logger.debug("Handed pkt %s for handler %s" % (recvpkt, self.key))
recvpkt, raddress, rport = pkttuple
else:
logger.debug("Data ready for handler %s" % self.key)
buffer, (raddress, rport) = self.sock.recvfrom(MAX_BLKSIZE)
logger.debug("Read %d bytes" % len(buffer))
recvpkt = self.tftp_factory.parse(buffer)
# FIXME - refactor into another method, this is too big
if isinstance(recvpkt, TftpPacketRRQ):
logger.debug("Handler %s received RRQ packet" % self.key)
logger.debug("Requested file is %s, mode is %s" % (recvpkt.filename,
recvpkt.mode))
# FIXME - only octet mode is supported at this time.
if recvpkt.mode != 'octet':
self.senderror(self.sock,
TftpErrors.IllegalTftpOp,
raddress,
rport)
raise TftpException, "Unsupported mode: %s" % recvpkt.mode
# test host/port of client end
if self.host != raddress or self.port != rport:
self.senderror(self.sock,
TftpErrors.UnknownTID,
raddress,
rport)
logger.error("Expected traffic from %s:%s but received it "
"from %s:%s instead."
% (self.host, self.port, raddress, rport))
self.errors += 1
return
if self.state.state == 'rrq':
logger.debug("Received RRQ. Composing response.")
self.filename = self.root + os.sep + recvpkt.filename
logger.debug("The path to the desired file is %s" %
self.filename)
self.filename = os.path.abspath(self.filename)
logger.debug("The absolute path is %s" % self.filename)
# Security check. Make sure it's prefixed by the tftproot.
if self.filename.find(self.root) == 0:
logger.debug("The path appears to be safe: %s" %
self.filename)
else:
logger.error("Insecure path: %s" % self.filename)
self.errors += 1
self.senderror(self.sock,
TftpErrors.AccessViolation,
raddress,
rport)
raise TftpException, "Insecure path: %s" % self.filename
# Does the file exist?
if(os.path.exists(self.filename) or not self.dynfunc is None):
logger.debug("File %s exists." % self.filename)
# Check options. Currently we only support the blksize
# option.
if recvpkt.options.has_key('blksize'):
logger.debug("RRQ includes a blksize option")
blksize = int(recvpkt.options['blksize'])
# Delete the option now that it's handled.
del recvpkt.options['blksize']
if blksize >= MIN_BLKSIZE and blksize <= MAX_BLKSIZE:
logger.info("Client requested blksize = %d"
% blksize)
self.options['blksize'] = blksize
else:
logger.warning("Client %s requested invalid "
"blocksize %d, responding with default"
% (self.key, blksize))
self.options['blksize'] = DEF_BLKSIZE
if recvpkt.options.has_key('tsize'):
logger.info('RRQ includes tsize option')
self.options['tsize'] = os.stat(self.filename).st_size
# Delete the option now that it's handled.
del recvpkt.options['tsize']
if len(recvpkt.options.keys()) > 0:
logger.warning("Client %s requested unsupported options: %s"
% (self.key, recvpkt.options))
if self.options:
logger.info("Options requested, sending OACK")
self.send_oack()
else:
logger.debug("Client %s requested no options."
% self.key)
self.start_download()
else:
logger.error("Requested file %s does not exist." %
self.filename)
self.senderror(self.sock,
TftpErrors.FileNotFound,
raddress,
rport)
raise TftpException, "Requested file not found: %s" % self.filename
else:
# We're receiving an RRQ when we're not expecting one.
logger.error("Received an RRQ in handler %s "
"but we're in state %s" % (self.key, self.state))
self.errors += 1
# Next packet type
elif isinstance(recvpkt, TftpPacketACK):
logger.debug("Received an ACK from the client.")
if recvpkt.blocknumber == 0 and self.state.state == 'oack':
logger.debug("Received ACK with 0 blocknumber, starting download")
self.start_download()
else:
if self.state.state == 'dat' or self.state.state == 'fin':
if self.blocknumber == recvpkt.blocknumber:
logger.debug("Received ACK for block %d"
% recvpkt.blocknumber)
if self.state.state == 'fin':
raise TftpException, "Successful transfer."
else:
self.send_dat()
elif recvpkt.blocknumber < self.blocknumber:
# Don't resend a DAT due to an old ACK. Fixes the
# sorceror's apprentice problem.
logger.warn("Received old ACK for block number %d"
% recvpkt.blocknumber)
else:
logger.warn("Received ACK for block number "
"%d, apparently from the future"
% recvpkt.blocknumber)
else:
logger.error("Received ACK with block number %d "
"while in state %s"
% (recvpkt.blocknumber,
self.state.state))
elif isinstance(recvpkt, TftpPacketERR):
logger.error("Received error packet from client: %s" % recvpkt)
self.state.state = 'err'
raise TftpException, "Received error from client"
# Handle other packet types.
else:
logger.error("Received packet %s while handling a download"
% recvpkt)
self.senderror(self.sock,
TftpErrors.IllegalTftpOp,
self.host,
self.port)
raise TftpException, "Invalid packet received during download"
def start_download(self):
"""This method opens self.filename, stores the resulting file object
in self.fileobj, and calls send_dat()."""
self.state.state = 'dat'
if os.path.exists(self.filename):
self.fileobj = open(self.filename, "rb")
else:
self.fileobj = self.dynfunc(self.filename)
self.send_dat()
def send_dat(self, resend=False):
"""This method reads sends a DAT packet based on what is in self.buffer."""
if not resend:
blksize = int(self.options['blksize'])
self.buffer = self.fileobj.read(blksize)
logger.debug("Read %d bytes into buffer" % len(self.buffer))
if len(self.buffer) < blksize:
logger.info("Reached EOF on file %s" % self.filename)
self.state.state = 'fin'
self.blocknumber += 1
if self.blocknumber > 65535:
logger.debug("Blocknumber rolled over to zero")
self.blocknumber = 0
else:
logger.warn("Resending block number %d" % self.blocknumber)
dat = TftpPacketDAT()
dat.data = self.buffer
dat.blocknumber = self.blocknumber
logger.debug("Sending DAT packet %d" % self.blocknumber)
self.sock.sendto(dat.encode().buffer, (self.host, self.port))
self.timesent = time.time()
# FIXME - should these be factored-out into the session class?
def send_oack(self):
"""This method sends an OACK packet based on current params."""
logger.debug("Composing and sending OACK packet")
oack = TftpPacketOACK()
oack.options = self.options
self.sock.sendto(oack.encode().buffer,
(self.host, self.port))
self.timesent = time.time()
self.state.state = 'oack'

View File

@ -17,7 +17,7 @@ DEF_TFTP_PORT = 69
logging.basicConfig()
# The logger used by this library. Feel free to clobber it with your own, if you like, as
# long as it conforms to Python's logging.
logger = logging.getLogger('tftpy')
log = logging.getLogger('tftpy')
def tftpassert(condition, msg):
"""This function is a simple utility that will check the condition
@ -31,8 +31,8 @@ def setLogLevel(level):
"""This function is a utility function for setting the internal log level.
The log level defaults to logging.NOTSET, so unwanted output to stdout is
not created."""
global logger
logger.setLevel(level)
global log
log.setLevel(level)
class TftpErrors(object):
"""This class is a convenience for defining the common tftp error codes,

View File

@ -22,17 +22,28 @@ class TftpMetrics(object):
# Rates
self.bps = 0
self.kbps = 0
# Generic errors
self.errors = 0
def compute(self):
# Compute transfer time
self.duration = self.end_time - self.start_time
logger.debug("TftpMetrics.compute: duration is %s" % self.duration)
log.debug("TftpMetrics.compute: duration is %s" % self.duration)
self.bps = (self.bytes * 8.0) / self.duration
self.kbps = self.bps / 1024.0
logger.debug("TftpMetrics.compute: kbps is %s" % self.kbps)
dupcount = 0
log.debug("TftpMetrics.compute: kbps is %s" % self.kbps)
for key in self.dups:
dupcount += self.dups[key]
self.dupcount += self.dups[key]
def add_dup(self, blocknumber):
"""This method adds a dup for a block number to the metrics."""
log.debug("Recording a dup for block %d" % blocknumber)
if self.dups.has_key(blocknumber):
self.dups[pkt.blocknumber] += 1
else:
self.dups[pkt.blocknumber] = 1
tftpassert(self.dups[pkt.blocknumber] < MAX_DUPS,
"Max duplicates for block %d reached" % blocknumber)
###############################################################################
# Context classes
@ -40,16 +51,32 @@ class TftpMetrics(object):
class TftpContext(object):
"""The base class of the contexts."""
def __init__(self, host, port):
def __init__(self, host, port, timeout):
"""Constructor for the base context, setting shared instance
variables."""
self.file_to_transfer = None
self.fileobj = None
self.options = None
self.packethook = None
self.sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
self.sock.settimeout(timeout)
self.state = None
self.next_block = 0
self.factory = TftpPacketFactory()
# Note, setting the host will also set self.address, as it's a property.
self.host = host
self.port = port
# The port associated with the TID
self.tidport = None
# Metrics
self.metrics = TftpMetrics()
# Flag when the transfer is pending completion.
self.pending_complete = False
def checkTimeout(self, now):
# FIXME
pass
def start(self):
return NotImplementedError, "Abstract method"
@ -69,37 +96,9 @@ class TftpContext(object):
host = property(gethost, sethost)
def sendAck(self, blocknumber):
"""This method sends an ack packet to the block number specified."""
logger.info("sending ack to block %d" % blocknumber)
ackpkt = TftpPacketACK()
ackpkt.blocknumber = blocknumber
self.sock.sendto(ackpkt.encode().buffer, (self.host, self.tidport))
def sendError(self, errorcode):
"""This method uses the socket passed, and uses the errorcode to
compose and send an error packet."""
logger.debug("In sendError, being asked to send error %d" % errorcode)
errpkt = TftpPacketERR()
errpkt.errorcode = errorcode
self.sock.sendto(errpkt.encode().buffer, (self.host, self.tidport))
class TftpContextClient(TftpContext):
"""This class represents shared functionality by both the download and
upload client contexts."""
def __init__(self, host, port, filename, options, packethook, timeout):
TftpContext.__init__(self, host, port)
self.file_to_transfer = filename
self.options = options
self.packethook = packethook
self.sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
self.sock.settimeout(timeout)
self.state = None
self.next_block = 0
def setNextBlock(self, block):
if block > 2 ** 16:
logger.debug("block number rollover to 0 again")
log.debug("block number rollover to 0 again")
block = 0
self.__eblock = block
@ -111,19 +110,21 @@ class TftpContextClient(TftpContext):
def cycle(self):
"""Here we wait for a response from the server after sending it
something, and dispatch appropriate action to that response."""
# FIXME: This won't work very well in a server context with multiple
# sessions running.
for i in range(TIMEOUT_RETRIES):
logger.debug("in cycle, receive attempt %d" % i)
log.debug("in cycle, receive attempt %d" % i)
try:
(buffer, (raddress, rport)) = self.sock.recvfrom(MAX_BLKSIZE)
except socket.timeout, err:
logger.warn("Timeout waiting for traffic, retrying...")
log.warn("Timeout waiting for traffic, retrying...")
continue
break
else:
raise TftpException, "Hit max timeouts, giving up."
# Ok, we've received a packet. Log it.
logger.debug("Received %d bytes from %s:%s"
log.debug("Received %d bytes from %s:%s"
% (len(buffer), raddress, rport))
# Decode it.
@ -131,11 +132,11 @@ class TftpContextClient(TftpContext):
# Check for known "connection".
if raddress != self.address:
logger.warn("Received traffic from %s, expected host %s. Discarding"
log.warn("Received traffic from %s, expected host %s. Discarding"
% (raddress, self.host))
if self.tidport and self.tidport != rport:
logger.warn("Received traffic from %s:%s but we're "
log.warn("Received traffic from %s:%s but we're "
"connected to %s:%s. Discarding."
% (raddress, rport,
self.host, self.tidport))
@ -150,29 +151,66 @@ class TftpContextClient(TftpContext):
# And handle it, possibly changing state.
self.state = self.state.handle(recvpkt, raddress, rport)
class TftpContextClientUpload(TftpContextClient):
class TftpContextServer(TftpContext):
"""The context for the server."""
def __init__(self, host, port, timeout, root, dyn_file_func):
TftpContext.__init__(self,
host,
port,
timeout)
# At this point we have no idea if this is a download or an upload. We
# need to let the start state determine that.
self.state = TftpStateServerStart()
self.root = root
self.dyn_file_func = dyn_file_func
def start(self, buffer):
"""Start the state cycle. Note that the server context receives an
initial packet in its start method."""
log.debug("TftpContextServer.start() - pkt = %s" % pkt)
self.metrics.start_time = time.time()
log.debug("set metrics.start_time to %s" % self.metrics.start_time)
pkt = self.factory.parse(buffer)
log.debug("TftpContextServer.start() - factory returned a %s" % pkt)
# Call handle once with the initial packet. This should put us into
# the download or the upload state.
self.state = self.state.handle(pkt,
self.host,
self.port)
try:
while self.state:
log.debug("state is %s" % self.state)
self.cycle()
finally:
self.fileobj.close()
class TftpContextClientUpload(TftpContext):
"""The upload context for the client during an upload."""
def __init__(self, host, port, filename, input, options, packethook, timeout):
TftpContextClient.__init__(self,
host,
port,
filename,
options,
packethook,
timeout)
TftpContext.__init__(self,
host,
port,
timeout)
self.file_to_transfer = filename
self.options = options
self.packethook = packethook
self.fileobj = open(input, "wb")
logger.debug("TftpContextClientUpload.__init__()")
logger.debug("file_to_transfer = %s, options = %s" %
log.debug("TftpContextClientUpload.__init__()")
log.debug("file_to_transfer = %s, options = %s" %
(self.file_to_transfer, self.options))
def start(self):
logger.info("sending tftp upload request to %s" % self.host)
logger.info(" filename -> %s" % self.file_to_transfer)
logger.info(" options -> %s" % self.options)
log.info("sending tftp upload request to %s" % self.host)
log.info(" filename -> %s" % self.file_to_transfer)
log.info(" options -> %s" % self.options)
self.metrics.start_time = time.time()
logger.debug("set metrics.start_time to %s" % self.metrics.start_time)
log.debug("set metrics.start_time to %s" % self.metrics.start_time)
# FIXME: put this in a sendWRQ method?
pkt = TftpPacketWRQ()
@ -186,7 +224,7 @@ class TftpContextClientUpload(TftpContextClient):
try:
while self.state:
logger.debug("state is %s" % self.state)
log.debug("state is %s" % self.state)
self.cycle()
finally:
self.fileobj.close()
@ -194,32 +232,32 @@ class TftpContextClientUpload(TftpContextClient):
def end(self):
pass
class TftpContextClientDownload(TftpContextClient):
class TftpContextClientDownload(TftpContext):
"""The download context for the client during a download."""
def __init__(self, host, port, filename, output, options, packethook, timeout):
TftpContextClient.__init__(self,
host,
port,
filename,
options,
packethook,
timeout)
TftpContext.__init__(self,
host,
port,
filename,
options,
packethook,
timeout)
# FIXME - need to support alternate return formats than files?
# File-like objects would be ideal, ala duck-typing.
self.fileobj = open(output, "wb")
logger.debug("TftpContextClientDownload.__init__()")
logger.debug("file_to_transfer = %s, options = %s" %
log.debug("TftpContextClientDownload.__init__()")
log.debug("file_to_transfer = %s, options = %s" %
(self.file_to_transfer, self.options))
def start(self):
"""Initiate the download."""
logger.info("sending tftp download request to %s" % self.host)
logger.info(" filename -> %s" % self.file_to_transfer)
logger.info(" options -> %s" % self.options)
log.info("sending tftp download request to %s" % self.host)
log.info(" filename -> %s" % self.file_to_transfer)
log.info(" options -> %s" % self.options)
self.metrics.start_time = time.time()
logger.debug("set metrics.start_time to %s" % self.metrics.start_time)
log.debug("set metrics.start_time to %s" % self.metrics.start_time)
# FIXME: put this in a sendRRQ method?
pkt = TftpPacketRRQ()
@ -233,7 +271,7 @@ class TftpContextClientDownload(TftpContextClient):
try:
while self.state:
logger.debug("state is %s" % self.state)
log.debug("state is %s" % self.state)
self.cycle()
finally:
self.fileobj.close()
@ -241,7 +279,7 @@ class TftpContextClientDownload(TftpContextClient):
def end(self):
"""Finish up the context."""
self.metrics.end_time = time.time()
logger.debug("set metrics.end_time to %s" % self.metrics.end_time)
log.debug("set metrics.end_time to %s" % self.metrics.end_time)
self.metrics.compute()
@ -268,235 +306,431 @@ class TftpState(object):
options."""
if pkt.options.keys() > 0:
if pkt.match_options(self.context.options):
logger.info("Successful negotiation of options")
log.info("Successful negotiation of options")
# Set options to OACK options
self.context.options = pkt.options
for key in self.context.options:
logger.info(" %s = %s" % (key, self.context.options[key]))
log.info(" %s = %s" % (key, self.context.options[key]))
else:
logger.error("failed to negotiate options")
log.error("failed to negotiate options")
raise TftpException, "Failed to negotiate options"
else:
raise TftpException, "No options found in OACK"
class TftpStateUpload(TftpState):
"""A class holding common code for upload states."""
def sendDat(self, resend=False):
def returnSupportedOptions(self, options):
"""This method takes a requested options list from a client, and
returns the ones that are supported."""
# We support the options blksize and tsize right now.
# FIXME - put this somewhere else?
accepted_options = {}
for option in options:
if option == 'blksize':
# Make sure it's valid.
if int(options[option]) > MAX_BLKSIZE:
accepted_options[option] = MAX_BLKSIZE
elif option == 'tsize':
log.debug("tsize option is set")
accepted_options['tsize'] = 1
else:
log.info("Dropping unsupported option '%s'" % option)
return accepted_options
def serverInitial(self, pkt, raddress, rport):
"""This method performs initial setup for a server context transfer,
put here to refactor code out of the TftpStateServerRecvRRQ and
TftpStateServerRecvWRQ classes, since their initial setup is
identical. The method returns a boolean, sendoack, to indicate whether
it is required to send an OACK to the client."""
options = pkt.options
sendoack = False
if not options:
log.debug("setting default options, blksize")
# FIXME: put default options elsewhere
self.context.options = { 'blksize': DEF_BLKSIZE }
else:
log.debug("options requested: %s" % options)
self.context.options = self.returnSupportedOptions(options)
sendoack = True
# FIXME - only octet mode is supported at this time.
if pkt.mode != 'octet':
self.sendError(TftpErrors.IllegalTftpOp)
raise TftpException, \
"Only octet transfers are supported at this time."
# test host/port of client end
if self.context.host != raddress or self.context.port != rport:
self.sendError(TftpErrors.UnknownTID)
log.error("Expected traffic from %s:%s but received it "
"from %s:%s instead."
% (self.context.host,
self.context.port,
raddress,
rport))
# FIXME: increment an error count?
# Return same state, we're still waiting for valid traffic.
return self
log.debug("requested filename is %s" % pkt.filename)
# There are no os.sep's allowed in the filename.
# FIXME: Should we allow subdirectories?
if pkt.filename.find(os.sep) >= 0:
self.sendError(TftpErrors.IllegalTftpOp)
raise TftpException, "%s found in filename, not permitted" % os.sep
self.context.file_to_transfer = pkt.filename
return sendoack
def sendDAT(self, resend=False):
"""This method sends the next DAT packet based on the data in the
context. It returns a boolean indicating whether the transfer is
finished."""
finished = False
blocknumber = self.context.next_block
if not resend:
blksize = int(self.context.options['blksize'])
buffer = self.context.fileobj.read(blksize)
logger.debug("Read %d bytes into buffer" % len(buffer))
log.debug("Read %d bytes into buffer" % len(buffer))
if len(buffer) < blksize:
logger.info("Reached EOF on file %s" % self.context.input)
log.info("Reached EOF on file %s" % self.context.input)
finished = True
self.context.next_block += 1
self.bytes += len(buffer)
else:
logger.warn("Resending block number %d" % blocknumber)
log.warn("Resending block number %d" % blocknumber)
dat = TftpPacketDAT()
dat.data = buffer
dat.blocknumber = blocknumber
logger.debug("Sending DAT packet %d" % blocknumber)
log.debug("Sending DAT packet %d" % blocknumber)
self.context.sock.sendto(dat.encode().buffer,
(self.context.host, self.context.port))
if self.context.packethook:
self.context.packethook(dat)
return finished
class TftpStateDownload(TftpState):
"""A class holding common code for download states."""
def sendACK(self, blocknumber=None):
"""This method sends an ack packet to the block number specified. If
none is specified, it defaults to the next_block property in the
parent context."""
if not blocknumber:
blocknumber = self.context.next_block
log.info("sending ack to block %d" % blocknumber)
ackpkt = TftpPacketACK()
ackpkt.blocknumber = blocknumber
self.context.sock.sendto(ackpkt.encode().buffer,
(self.context.host,
self.context.tidport))
def sendError(self, errorcode):
"""This method uses the socket passed, and uses the errorcode to
compose and send an error packet."""
log.debug("In sendError, being asked to send error %d" % errorcode)
errpkt = TftpPacketERR()
errpkt.errorcode = errorcode
self.context.sock.sendto(errpkt.encode().buffer,
(self.context.host,
self.context.tidport))
def sendOACK(self):
"""This method sends an OACK packet with the options from the current
context."""
log.debug("In sendOACK with options %s" % options)
pkt = TftpPacketOACK()
pkt.options = self.options
self.context.sock.sendto(pkt.encode().buffer,
(self.context.host,
self.context.tidport))
def handleDat(self, pkt):
"""This method handles a DAT packet during a download."""
logger.info("handling DAT packet - block %d" % pkt.blocknumber)
logger.debug("expecting block %s" % self.context.next_block)
"""This method handles a DAT packet during a client download, or a
server upload."""
log.info("handling DAT packet - block %d" % pkt.blocknumber)
log.debug("expecting block %s" % self.context.next_block)
if pkt.blocknumber == self.context.next_block:
logger.debug("good, received block %d in sequence"
log.debug("good, received block %d in sequence"
% pkt.blocknumber)
self.context.sendAck(pkt.blocknumber)
self.sendACK()
self.context.next_block += 1
logger.debug("writing %d bytes to output file"
log.debug("writing %d bytes to output file"
% len(pkt.data))
self.context.fileobj.write(pkt.data)
self.context.metrics.bytes += len(pkt.data)
# Check for end-of-file, any less than full data packet.
if len(pkt.data) < int(self.context.options['blksize']):
logger.info("end of file detected")
log.info("end of file detected")
return None
elif pkt.blocknumber < self.context.next_block:
logger.warn("dropping duplicate block %d" % pkt.blocknumber)
if self.context.metrics.dups.has_key(pkt.blocknumber):
self.context.metrics.dups[pkt.blocknumber] += 1
else:
self.context.metrics.dups[pkt.blocknumber] = 1
tftpassert(self.context.metrics.dups[pkt.blocknumber] < MAX_DUPS,
"Max duplicates for block %d reached" % pkt.blocknumber)
# FIXME: double-check sorceror's apprentice problem!
logger.debug("ACKing block %d again, just in case" % pkt.blocknumber)
self.context.sendAck(pkt.blocknumber)
log.warn("dropping duplicate block %d" % pkt.bl