Did some rework for the state machine in a server context.

Removed the handler framework in favour of a TftpContextServer used
as the session.
This commit is contained in:
Michael P. Soulier 2009-08-15 22:36:58 -04:00
parent 03e4e74829
commit 62b22fb562
6 changed files with 564 additions and 636 deletions

View file

@ -52,12 +52,12 @@ class TftpClient(TftpSession):
# output? This should be in the sample client, but not in the download # output? This should be in the sample client, but not in the download
# call. # call.
if metrics.duration == 0: if metrics.duration == 0:
logger.info("Duration too short, rate undetermined") log.info("Duration too short, rate undetermined")
else: else:
logger.info('') log.info('')
logger.info("Downloaded %.2f bytes in %.2f seconds" % (metrics.bytes, metrics.duration)) log.info("Downloaded %.2f bytes in %.2f seconds" % (metrics.bytes, metrics.duration))
logger.info("Average rate: %.2f kbps" % metrics.kbps) log.info("Average rate: %.2f kbps" % metrics.kbps)
logger.info("Received %d duplicate packets" % metrics.dupcount) log.info("Received %d duplicate packets" % metrics.dupcount)
def upload(self, filename, input, packethook=None, timeout=SOCK_TIMEOUT): def upload(self, filename, input, packethook=None, timeout=SOCK_TIMEOUT):
# Open the input file. # Open the input file.
@ -80,9 +80,9 @@ class TftpClient(TftpSession):
# output? This should be in the sample client, but not in the download # output? This should be in the sample client, but not in the download
# call. # call.
if metrics.duration == 0: if metrics.duration == 0:
logger.info("Duration too short, rate undetermined") log.info("Duration too short, rate undetermined")
else: else:
logger.info('') log.info('')
logger.info("Downloaded %.2f bytes in %.2f seconds" % (metrics.bytes, metrics.duration)) log.info("Downloaded %.2f bytes in %.2f seconds" % (metrics.bytes, metrics.duration))
logger.info("Average rate: %.2f kbps" % metrics.kbps) log.info("Average rate: %.2f kbps" % metrics.kbps)
logger.info("Received %d duplicate packets" % metrics.dupcount) log.info("Received %d duplicate packets" % metrics.dupcount)

View file

@ -19,9 +19,9 @@ class TftpPacketFactory(object):
"""This method is used to parse an existing datagram into its """This method is used to parse an existing datagram into its
corresponding TftpPacket object. The buffer is the raw bytes off of corresponding TftpPacket object. The buffer is the raw bytes off of
the network.""" the network."""
logger.debug("parsing a %d byte packet" % len(buffer)) log.debug("parsing a %d byte packet" % len(buffer))
(opcode,) = struct.unpack("!H", buffer[:2]) (opcode,) = struct.unpack("!H", buffer[:2])
logger.debug("opcode is %d" % opcode) log.debug("opcode is %d" % opcode)
packet = self.__create(opcode) packet = self.__create(opcode)
packet.buffer = buffer packet.buffer = buffer
return packet.decode() return packet.decode()

View file

@ -15,7 +15,7 @@ class TftpSession(object):
def senderror(self, sock, errorcode, address, port): def senderror(self, sock, errorcode, address, port):
"""This method uses the socket passed, and uses the errorcode, address """This method uses the socket passed, and uses the errorcode, address
and port to compose and send an error packet.""" and port to compose and send an error packet."""
logger.debug("In senderror, being asked to send error %d to %s:%s" log.debug("In senderror, being asked to send error %d to %s:%s"
% (errorcode, address, port)) % (errorcode, address, port))
errpkt = TftpPacketERR() errpkt = TftpPacketERR()
errpkt.errorcode = errorcode errpkt.errorcode = errorcode
@ -27,23 +27,23 @@ class TftpPacketWithOptions(object):
goal is just to share code here, and not cause diamond inheritance.""" goal is just to share code here, and not cause diamond inheritance."""
def __init__(self): def __init__(self):
self.options = [] self.options = {}
def setoptions(self, options): def setoptions(self, options):
logger.debug("in TftpPacketWithOptions.setoptions") log.debug("in TftpPacketWithOptions.setoptions")
logger.debug("options: " + str(options)) log.debug("options: " + str(options))
myoptions = {} myoptions = {}
for key in options: for key in options:
newkey = str(key) newkey = str(key)
myoptions[newkey] = str(options[key]) myoptions[newkey] = str(options[key])
logger.debug("populated myoptions with %s = %s" log.debug("populated myoptions with %s = %s"
% (newkey, myoptions[newkey])) % (newkey, myoptions[newkey]))
logger.debug("setting options hash to: " + str(myoptions)) log.debug("setting options hash to: " + str(myoptions))
self._options = myoptions self._options = myoptions
def getoptions(self): def getoptions(self):
logger.debug("in TftpPacketWithOptions.getoptions") log.debug("in TftpPacketWithOptions.getoptions")
return self._options return self._options
# Set up getter and setter on options to ensure that they are the proper # Set up getter and setter on options to ensure that they are the proper
@ -59,19 +59,19 @@ class TftpPacketWithOptions(object):
format = "!" format = "!"
options = {} options = {}
logger.debug("decode_options: buffer is: " + repr(buffer)) log.debug("decode_options: buffer is: " + repr(buffer))
logger.debug("size of buffer is %d bytes" % len(buffer)) log.debug("size of buffer is %d bytes" % len(buffer))
if len(buffer) == 0: if len(buffer) == 0:
logger.debug("size of buffer is zero, returning empty hash") log.debug("size of buffer is zero, returning empty hash")
return {} return {}
# Count the nulls in the buffer. Each one terminates a string. # Count the nulls in the buffer. Each one terminates a string.
logger.debug("about to iterate options buffer counting nulls") log.debug("about to iterate options buffer counting nulls")
length = 0 length = 0
for c in buffer: for c in buffer:
#logger.debug("iterating this byte: " + repr(c)) #log.debug("iterating this byte: " + repr(c))
if ord(c) == 0: if ord(c) == 0:
logger.debug("found a null at length %d" % length) log.debug("found a null at length %d" % length)
if length > 0: if length > 0:
format += "%dsx" % length format += "%dsx" % length
length = -1 length = -1
@ -79,14 +79,14 @@ class TftpPacketWithOptions(object):
raise TftpException, "Invalid options in buffer" raise TftpException, "Invalid options in buffer"
length += 1 length += 1
logger.debug("about to unpack, format is: %s" % format) log.debug("about to unpack, format is: %s" % format)
mystruct = struct.unpack(format, buffer) mystruct = struct.unpack(format, buffer)
tftpassert(len(mystruct) % 2 == 0, tftpassert(len(mystruct) % 2 == 0,
"packet with odd number of option/value pairs") "packet with odd number of option/value pairs")
for i in range(0, len(mystruct), 2): for i in range(0, len(mystruct), 2):
logger.debug("setting option %s to %s" % (mystruct[i], mystruct[i+1])) log.debug("setting option %s to %s" % (mystruct[i], mystruct[i+1]))
options[mystruct[i]] = mystruct[i+1] options[mystruct[i]] = mystruct[i+1]
return options return options
@ -134,10 +134,10 @@ class TftpPacketInitial(TftpPacket, TftpPacketWithOptions):
ptype = None ptype = None
if self.opcode == 1: ptype = "RRQ" if self.opcode == 1: ptype = "RRQ"
else: ptype = "WRQ" else: ptype = "WRQ"
logger.debug("Encoding %s packet, filename = %s, mode = %s" log.debug("Encoding %s packet, filename = %s, mode = %s"
% (ptype, self.filename, self.mode)) % (ptype, self.filename, self.mode))
for key in self.options: for key in self.options:
logger.debug(" Option %s = %s" % (key, self.options[key])) log.debug(" Option %s = %s" % (key, self.options[key]))
format = "!H" format = "!H"
format += "%dsx" % len(self.filename) format += "%dsx" % len(self.filename)
@ -148,7 +148,7 @@ class TftpPacketInitial(TftpPacket, TftpPacketWithOptions):
# Add options. # Add options.
options_list = [] options_list = []
if self.options.keys() > 0: if self.options.keys() > 0:
logger.debug("there are options to encode") log.debug("there are options to encode")
for key in self.options: for key in self.options:
# Populate the option name # Populate the option name
format += "%dsx" % len(key) format += "%dsx" % len(key)
@ -157,9 +157,9 @@ class TftpPacketInitial(TftpPacket, TftpPacketWithOptions):
format += "%dsx" % len(str(self.options[key])) format += "%dsx" % len(str(self.options[key]))
options_list.append(str(self.options[key])) options_list.append(str(self.options[key]))
logger.debug("format is %s" % format) log.debug("format is %s" % format)
logger.debug("options_list is %s" % options_list) log.debug("options_list is %s" % options_list)
logger.debug("size of struct is %d" % struct.calcsize(format)) log.debug("size of struct is %d" % struct.calcsize(format))
self.buffer = struct.pack(format, self.buffer = struct.pack(format,
self.opcode, self.opcode,
@ -167,7 +167,7 @@ class TftpPacketInitial(TftpPacket, TftpPacketWithOptions):
self.mode, self.mode,
*options_list) *options_list)
logger.debug("buffer is " + repr(self.buffer)) log.debug("buffer is " + repr(self.buffer))
return self return self
def decode(self): def decode(self):
@ -177,13 +177,13 @@ class TftpPacketInitial(TftpPacket, TftpPacketWithOptions):
nulls = 0 nulls = 0
format = "" format = ""
nulls = length = tlength = 0 nulls = length = tlength = 0
logger.debug("in decode: about to iterate buffer counting nulls") log.debug("in decode: about to iterate buffer counting nulls")
subbuf = self.buffer[2:] subbuf = self.buffer[2:]
for c in subbuf: for c in subbuf:
logger.debug("iterating this byte: " + repr(c)) log.debug("iterating this byte: " + repr(c))
if ord(c) == 0: if ord(c) == 0:
nulls += 1 nulls += 1
logger.debug("found a null at length %d, now have %d" log.debug("found a null at length %d, now have %d"
% (length, nulls)) % (length, nulls))
format += "%dsx" % length format += "%dsx" % length
length = -1 length = -1
@ -193,17 +193,17 @@ class TftpPacketInitial(TftpPacket, TftpPacketWithOptions):
length += 1 length += 1
tlength += 1 tlength += 1
logger.debug("hopefully found end of mode at length %d" % tlength) log.debug("hopefully found end of mode at length %d" % tlength)
# length should now be the end of the mode. # length should now be the end of the mode.
tftpassert(nulls == 2, "malformed packet") tftpassert(nulls == 2, "malformed packet")
shortbuf = subbuf[:tlength+1] shortbuf = subbuf[:tlength+1]
logger.debug("about to unpack buffer with format: %s" % format) log.debug("about to unpack buffer with format: %s" % format)
logger.debug("unpacking buffer: " + repr(shortbuf)) log.debug("unpacking buffer: " + repr(shortbuf))
mystruct = struct.unpack(format, shortbuf) mystruct = struct.unpack(format, shortbuf)
tftpassert(len(mystruct) == 2, "malformed packet") tftpassert(len(mystruct) == 2, "malformed packet")
logger.debug("setting filename to %s" % mystruct[0]) log.debug("setting filename to %s" % mystruct[0])
logger.debug("setting mode to %s" % mystruct[1]) log.debug("setting mode to %s" % mystruct[1])
self.filename = mystruct[0] self.filename = mystruct[0]
self.mode = mystruct[1] self.mode = mystruct[1]
@ -269,7 +269,7 @@ DATA | 03 | Block # | Data |
"""Encode the DAT packet. This method populates self.buffer, and """Encode the DAT packet. This method populates self.buffer, and
returns self for easy method chaining.""" returns self for easy method chaining."""
if len(self.data) == 0: if len(self.data) == 0:
logger.debug("Encoding an empty DAT packet") log.debug("Encoding an empty DAT packet")
format = "!HH%ds" % len(self.data) format = "!HH%ds" % len(self.data)
self.buffer = struct.pack(format, self.buffer = struct.pack(format,
self.opcode, self.opcode,
@ -283,12 +283,12 @@ DATA | 03 | Block # | Data |
# We know the first 2 bytes are the opcode. The second two are the # We know the first 2 bytes are the opcode. The second two are the
# block number. # block number.
(self.blocknumber,) = struct.unpack("!H", self.buffer[2:4]) (self.blocknumber,) = struct.unpack("!H", self.buffer[2:4])
logger.debug("decoding DAT packet, block number %d" % self.blocknumber) log.debug("decoding DAT packet, block number %d" % self.blocknumber)
logger.debug("should be %d bytes in the packet total" log.debug("should be %d bytes in the packet total"
% len(self.buffer)) % len(self.buffer))
# Everything else is data. # Everything else is data.
self.data = self.buffer[4:] self.data = self.buffer[4:]
logger.debug("found %d bytes of data" log.debug("found %d bytes of data"
% len(self.data)) % len(self.data))
return self return self
@ -308,14 +308,14 @@ ACK | 04 | Block # |
return 'ACK packet: block %d' % self.blocknumber return 'ACK packet: block %d' % self.blocknumber
def encode(self): def encode(self):
logger.debug("encoding ACK: opcode = %d, block = %d" log.debug("encoding ACK: opcode = %d, block = %d"
% (self.opcode, self.blocknumber)) % (self.opcode, self.blocknumber))
self.buffer = struct.pack("!HH", self.opcode, self.blocknumber) self.buffer = struct.pack("!HH", self.opcode, self.blocknumber)
return self return self
def decode(self): def decode(self):
self.opcode, self.blocknumber = struct.unpack("!HH", self.buffer) self.opcode, self.blocknumber = struct.unpack("!HH", self.buffer)
logger.debug("decoded ACK packet: opcode = %d, block = %d" log.debug("decoded ACK packet: opcode = %d, block = %d"
% (self.opcode, self.blocknumber)) % (self.opcode, self.blocknumber))
return self return self
@ -365,7 +365,7 @@ ERROR | 05 | ErrorCode | ErrMsg | 0 |
"""Encode the DAT packet based on instance variables, populating """Encode the DAT packet based on instance variables, populating
self.buffer, returning self.""" self.buffer, returning self."""
format = "!HH%dsx" % len(self.errmsgs[self.errorcode]) format = "!HH%dsx" % len(self.errmsgs[self.errorcode])
logger.debug("encoding ERR packet with format %s" % format) log.debug("encoding ERR packet with format %s" % format)
self.buffer = struct.pack(format, self.buffer = struct.pack(format,
self.opcode, self.opcode,
self.errorcode, self.errorcode,
@ -375,13 +375,13 @@ ERROR | 05 | ErrorCode | ErrMsg | 0 |
def decode(self): def decode(self):
"Decode self.buffer, populating instance variables and return self." "Decode self.buffer, populating instance variables and return self."
tftpassert(len(self.buffer) > 4, "malformed ERR packet, too short") tftpassert(len(self.buffer) > 4, "malformed ERR packet, too short")
logger.debug("Decoding ERR packet, length %s bytes" % log.debug("Decoding ERR packet, length %s bytes" %
len(self.buffer)) len(self.buffer))
format = "!HH%dsx" % (len(self.buffer) - 5) format = "!HH%dsx" % (len(self.buffer) - 5)
logger.debug("Decoding ERR packet with format: %s" % format) log.debug("Decoding ERR packet with format: %s" % format)
self.opcode, self.errorcode, self.errmsg = struct.unpack(format, self.opcode, self.errorcode, self.errmsg = struct.unpack(format,
self.buffer) self.buffer)
logger.error("ERR packet - errorcode: %d, message: %s" log.error("ERR packet - errorcode: %d, message: %s"
% (self.errorcode, self.errmsg)) % (self.errorcode, self.errmsg))
return self return self
@ -402,10 +402,10 @@ class TftpPacketOACK(TftpPacket, TftpPacketWithOptions):
def encode(self): def encode(self):
format = "!H" # opcode format = "!H" # opcode
options_list = [] options_list = []
logger.debug("in TftpPacketOACK.encode") log.debug("in TftpPacketOACK.encode")
for key in self.options: for key in self.options:
logger.debug("looping on option key %s" % key) log.debug("looping on option key %s" % key)
logger.debug("value is %s" % self.options[key]) log.debug("value is %s" % self.options[key])
format += "%dsx" % len(key) format += "%dsx" % len(key)
format += "%dsx" % len(self.options[key]) format += "%dsx" % len(self.options[key])
options_list.append(key) options_list.append(key)
@ -429,7 +429,7 @@ class TftpPacketOACK(TftpPacket, TftpPacketWithOptions):
# We can accept anything between the min and max values. # We can accept anything between the min and max values.
size = self.options[name] size = self.options[name]
if size >= MIN_BLKSIZE and size <= MAX_BLKSIZE: if size >= MIN_BLKSIZE and size <= MAX_BLKSIZE:
logger.debug("negotiated blksize of %d bytes" % size) log.debug("negotiated blksize of %d bytes" % size)
options[blksize] = size options[blksize] = size
else: else:
raise TftpException, "Unsupported option: %s" % name raise TftpException, "Unsupported option: %s" % name

View file

@ -1,4 +1,5 @@
import socket, os, re, time, random import socket, os, re, time, random
import select
from TftpShared import * from TftpShared import *
from TftpPacketTypes import * from TftpPacketTypes import *
from TftpPacketFactory import * from TftpPacketFactory import *
@ -15,26 +16,27 @@ class TftpServer(TftpSession):
self.listenip = None self.listenip = None
self.listenport = None self.listenport = None
self.sock = None self.sock = None
# FIXME: What about multiple roots?
self.root = os.path.abspath(tftproot) self.root = os.path.abspath(tftproot)
self.dynfunc = dyn_file_func self.dyn_file_func = dyn_file_func
# A dict of handlers, where each session is keyed by a string like # A dict of handlers, where each session is keyed by a string like
# ip:tid for the remote end. # ip:tid for the remote end.
self.handlers = {} self.handlers = {}
if os.path.exists(self.root): if os.path.exists(self.root):
logger.debug("tftproot %s does exist" % self.root) log.debug("tftproot %s does exist" % self.root)
if not os.path.isdir(self.root): if not os.path.isdir(self.root):
raise TftpException, "The tftproot must be a directory." raise TftpException, "The tftproot must be a directory."
else: else:
logger.debug("tftproot %s is a directory" % self.root) log.debug("tftproot %s is a directory" % self.root)
if os.access(self.root, os.R_OK): if os.access(self.root, os.R_OK):
logger.debug("tftproot %s is readable" % self.root) log.debug("tftproot %s is readable" % self.root)
else: else:
raise TftpException, "The tftproot must be readable" raise TftpException, "The tftproot must be readable"
if os.access(self.root, os.W_OK): if os.access(self.root, os.W_OK):
logger.debug("tftproot %s is writable" % self.root) log.debug("tftproot %s is writable" % self.root)
else: else:
logger.warning("The tftproot %s is not writable" % self.root) log.warning("The tftproot %s is not writable" % self.root)
else: else:
raise TftpException, "The tftproot does not exist." raise TftpException, "The tftproot does not exist."
@ -45,14 +47,12 @@ class TftpServer(TftpSession):
"""Start a server listening on the supplied interface and port. This """Start a server listening on the supplied interface and port. This
defaults to INADDR_ANY (all interfaces) and UDP port 69. You can also defaults to INADDR_ANY (all interfaces) and UDP port 69. You can also
supply a different socket timeout value, if desired.""" supply a different socket timeout value, if desired."""
import select
tftp_factory = TftpPacketFactory() tftp_factory = TftpPacketFactory()
# Don't use new 2.5 ternary operator yet # Don't use new 2.5 ternary operator yet
# listenip = listenip if listenip else '0.0.0.0' # listenip = listenip if listenip else '0.0.0.0'
if not listenip: listenip = '0.0.0.0' if not listenip: listenip = '0.0.0.0'
logger.info("Server requested on ip %s, port %s" log.info("Server requested on ip %s, port %s"
% (listenip, listenport)) % (listenip, listenport))
try: try:
# FIXME - sockets should be non-blocking? # FIXME - sockets should be non-blocking?
@ -62,388 +62,82 @@ class TftpServer(TftpSession):
# Reraise it for now. # Reraise it for now.
raise raise
logger.info("Starting receive loop...") log.info("Starting receive loop...")
while True: while True:
# Build the inputlist array of sockets to select() on. # Build the inputlist array of sockets to select() on.
inputlist = [] inputlist = []
inputlist.append(self.sock) inputlist.append(self.sock)
for key in self.handlers: for key in self.sessions:
inputlist.append(self.handlers[key].sock) inputlist.append(self.sessions[key].sock)
# Block until some socket has input on it. # Block until some socket has input on it.
logger.debug("Performing select on this inputlist: %s" % inputlist) log.debug("Performing select on this inputlist: %s" % inputlist)
readyinput, readyoutput, readyspecial = select.select(inputlist, readyinput, readyoutput, readyspecial = select.select(inputlist,
[], [],
[], [],
SOCK_TIMEOUT) SOCK_TIMEOUT)
#(buffer, (raddress, rport)) = self.sock.recvfrom(MAX_BLKSIZE)
#recvpkt = tftp_factory.parse(buffer)
#key = "%s:%s" % (raddress, rport)
deletion_list = [] deletion_list = []
# Handle the available data, if any. Maybe we timed-out.
for readysock in readyinput: for readysock in readyinput:
# Is the traffic on the main server socket? ie. new session?
if readysock == self.sock: if readysock == self.sock:
logger.debug("Data ready on our main socket") log.debug("Data ready on our main socket")
buffer, (raddress, rport) = self.sock.recvfrom(MAX_BLKSIZE) buffer, (raddress, rport) = self.sock.recvfrom(MAX_BLKSIZE)
logger.debug("Read %d bytes" % len(buffer))
log.debug("Read %d bytes" % len(buffer))
recvpkt = tftp_factory.parse(buffer) recvpkt = tftp_factory.parse(buffer)
# FIXME: Is this the best way to do a session key? What
# about symmetric udp?
key = "%s:%s" % (raddress, rport) key = "%s:%s" % (raddress, rport)
if isinstance(recvpkt, TftpPacketRRQ): if not self.sessions.has_key(key):
logger.debug("RRQ packet from %s:%s" % (raddress, rport)) log.debug("Creating new server context for "
if not self.handlers.has_key(key): "session key = %s" % key)
try: self.sessions[key] = TftpContextServer(raddress,
logger.debug("New download request, session key = %s" rport,
% key) timeout,
self.handlers[key] = TftpServerHandler(key, self.root,
'rrq', self.dyn_file_func)
self.root, self.sessions[key].start(buffer)
listenip,
tftp_factory,
self.dynfunc)
self.handlers[key].handle((recvpkt, raddress, rport))
except TftpException, err:
logger.error("Fatal exception thrown from handler: %s"
% str(err))
logger.debug("Deleting handler: %s" % key)
deletion_list.append(key)
else:
logger.warn("Received RRQ for existing session!")
self.senderror(self.sock,
TftpErrors.IllegalTftpOp,
raddress,
rport)
continue
elif isinstance(recvpkt, TftpPacketWRQ):
logger.error("Write requests not implemented at this time.")
self.senderror(self.sock,
TftpErrors.IllegalTftpOp,
raddress,
rport)
continue
else: else:
# FIXME - this will have to change if we do symmetric UDP log.warn("received traffic on main socket for "
logger.error("Should only receive RRQ or WRQ packets " "existing session??")
"on main listen port. Received %s" % recvpkt)
self.senderror(self.sock,
TftpErrors.IllegalTftpOp,
raddress,
rport)
continue
else: else:
for key in self.handlers: # Must find the owner of this traffic.
if readysock == self.handlers[key].sock: for key in self.session:
# FIXME - violating DRY principle with above code if readysock == self.session[key].sock:
try: try:
self.handlers[key].handle() self.session[key].cycle()
if self.session[key].state == None:
log.info("Successful transfer.")
deletion_list.append(key)
break break
except TftpException, err: except TftpException, err:
deletion_list.append(key) deletion_list.append(key)
if self.handlers[key].state.state == 'fin': log.error("Fatal exception thrown from "
logger.info("Successful transfer.") "handler: %s" % str(err))
break
else:
logger.error("Fatal exception thrown from handler: %s"
% str(err))
else: else:
logger.error("Can't find the owner for this packet. Discarding.") log.error("Can't find the owner for this packet. "
"Discarding.")
logger.debug("Looping on all handlers to check for timeouts") log.debug("Looping on all handlers to check for timeouts")
now = time.time() now = time.time()
for key in self.handlers: for key in self.sessions:
try: try:
self.handlers[key].check_timeout(now) self.sessions[key].checkTimeout(now)
except TftpException, err: except TftpException, err:
logger.error("Fatal exception thrown from handler: %s" log.error("Fatal exception thrown from handler: %s"
% str(err)) % str(err))
deletion_list.append(key) deletion_list.append(key)
logger.debug("Iterating deletion list.") log.debug("Iterating deletion list.")
for key in deletion_list: for key in deletion_list:
if self.handlers.has_key(key): if self.sessions.has_key(key):
logger.debug("Deleting handler %s" % key) log.debug("Deleting handler %s" % key)
del self.handlers[key] del self.sessions[key]
deletion_list = [] deletion_list = []
class TftpServerHandler(TftpSession):
"""This class implements a handler for a given server session, handling
the work for one download."""
def __init__(self, key, state, root, listenip, factory, dyn_file_func):
TftpSession.__init__(self)
logger.info("Starting new handler. Key %s." % key)
self.key = key
self.host, self.port = self.key.split(':')
self.port = int(self.port)
self.listenip = listenip
# Note, correct state here is important as it tells the handler whether it's
# handling a download or an upload.
self.state = state
self.root = root
self.mode = None
self.filename = None
self.sock = False
self.options = { 'blksize': DEF_BLKSIZE }
self.blocknumber = 0
self.buffer = None
self.fileobj = None
self.timesent = 0
self.timeouts = 0
self.tftp_factory = factory
self.dynfunc = dyn_file_func
count = 0
while not self.sock:
self.sock = self.gensock(listenip)
count += 1
if count > 10:
raise TftpException, "Failed to bind this handler to any port"
def check_timeout(self, now):
"""This method checks to see if we've timed-out waiting for traffic
from the client."""
if self.timesent:
if now - self.timesent > SOCK_TIMEOUT:
self.timeout()
def timeout(self):
"""This method handles a timeout condition."""
logger.debug("Handling timeout for handler %s" % self.key)
self.timeouts += 1
if self.timeouts > TIMEOUT_RETRIES:
raise TftpException, "Hit max retries, giving up."
if self.state.state == 'dat' or self.state.state == 'fin':
logger.debug("Timing out on DAT. Need to resend.")
self.send_dat(resend=True)
elif self.state.state == 'oack':
logger.debug("Timing out on OACK. Need to resend.")
self.send_oack()
else:
tftpassert(False,
"Timing out in unsupported state %s" %
self.state.state)
def gensock(self, listenip):
"""This method generates a new UDP socket, whose listening port must
be randomly generated, and not conflict with any already in use. For
now, let the OS do this."""
random.seed()
port = random.randrange(1025, 65536)
# FIXME - sockets should be non-blocking?
sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
logger.debug("Trying a handler socket on port %d" % port)
try:
sock.bind((listenip, port))
return sock
except socket.error, err:
if err[0] == 98:
logger.warn("Handler %s, port %d was already taken" % (self.key, port))
return False
else:
raise
def handle(self, pkttuple=None):
"""This method informs a handler instance that it has data waiting on
its socket that it must read and process."""
recvpkt = raddress = rport = None
if pkttuple:
logger.debug("Handed pkt %s for handler %s" % (recvpkt, self.key))
recvpkt, raddress, rport = pkttuple
else:
logger.debug("Data ready for handler %s" % self.key)
buffer, (raddress, rport) = self.sock.recvfrom(MAX_BLKSIZE)
logger.debug("Read %d bytes" % len(buffer))
recvpkt = self.tftp_factory.parse(buffer)
# FIXME - refactor into another method, this is too big
if isinstance(recvpkt, TftpPacketRRQ):
logger.debug("Handler %s received RRQ packet" % self.key)
logger.debug("Requested file is %s, mode is %s" % (recvpkt.filename,
recvpkt.mode))
# FIXME - only octet mode is supported at this time.
if recvpkt.mode != 'octet':
self.senderror(self.sock,
TftpErrors.IllegalTftpOp,
raddress,
rport)
raise TftpException, "Unsupported mode: %s" % recvpkt.mode
# test host/port of client end
if self.host != raddress or self.port != rport:
self.senderror(self.sock,
TftpErrors.UnknownTID,
raddress,
rport)
logger.error("Expected traffic from %s:%s but received it "
"from %s:%s instead."
% (self.host, self.port, raddress, rport))
self.errors += 1
return
if self.state.state == 'rrq':
logger.debug("Received RRQ. Composing response.")
self.filename = self.root + os.sep + recvpkt.filename
logger.debug("The path to the desired file is %s" %
self.filename)
self.filename = os.path.abspath(self.filename)
logger.debug("The absolute path is %s" % self.filename)
# Security check. Make sure it's prefixed by the tftproot.
if self.filename.find(self.root) == 0:
logger.debug("The path appears to be safe: %s" %
self.filename)
else:
logger.error("Insecure path: %s" % self.filename)
self.errors += 1
self.senderror(self.sock,
TftpErrors.AccessViolation,
raddress,
rport)
raise TftpException, "Insecure path: %s" % self.filename
# Does the file exist?
if(os.path.exists(self.filename) or not self.dynfunc is None):
logger.debug("File %s exists." % self.filename)
# Check options. Currently we only support the blksize
# option.
if recvpkt.options.has_key('blksize'):
logger.debug("RRQ includes a blksize option")
blksize = int(recvpkt.options['blksize'])
# Delete the option now that it's handled.
del recvpkt.options['blksize']
if blksize >= MIN_BLKSIZE and blksize <= MAX_BLKSIZE:
logger.info("Client requested blksize = %d"
% blksize)
self.options['blksize'] = blksize
else:
logger.warning("Client %s requested invalid "
"blocksize %d, responding with default"
% (self.key, blksize))
self.options['blksize'] = DEF_BLKSIZE
if recvpkt.options.has_key('tsize'):
logger.info('RRQ includes tsize option')
self.options['tsize'] = os.stat(self.filename).st_size
# Delete the option now that it's handled.
del recvpkt.options['tsize']
if len(recvpkt.options.keys()) > 0:
logger.warning("Client %s requested unsupported options: %s"
% (self.key, recvpkt.options))
if self.options:
logger.info("Options requested, sending OACK")
self.send_oack()
else:
logger.debug("Client %s requested no options."
% self.key)
self.start_download()
else:
logger.error("Requested file %s does not exist." %
self.filename)
self.senderror(self.sock,
TftpErrors.FileNotFound,
raddress,
rport)
raise TftpException, "Requested file not found: %s" % self.filename
else:
# We're receiving an RRQ when we're not expecting one.
logger.error("Received an RRQ in handler %s "
"but we're in state %s" % (self.key, self.state))
self.errors += 1
# Next packet type
elif isinstance(recvpkt, TftpPacketACK):
logger.debug("Received an ACK from the client.")
if recvpkt.blocknumber == 0 and self.state.state == 'oack':
logger.debug("Received ACK with 0 blocknumber, starting download")
self.start_download()
else:
if self.state.state == 'dat' or self.state.state == 'fin':
if self.blocknumber == recvpkt.blocknumber:
logger.debug("Received ACK for block %d"
% recvpkt.blocknumber)
if self.state.state == 'fin':
raise TftpException, "Successful transfer."
else:
self.send_dat()
elif recvpkt.blocknumber < self.blocknumber:
# Don't resend a DAT due to an old ACK. Fixes the
# sorceror's apprentice problem.
logger.warn("Received old ACK for block number %d"
% recvpkt.blocknumber)
else:
logger.warn("Received ACK for block number "
"%d, apparently from the future"
% recvpkt.blocknumber)
else:
logger.error("Received ACK with block number %d "
"while in state %s"
% (recvpkt.blocknumber,
self.state.state))
elif isinstance(recvpkt, TftpPacketERR):
logger.error("Received error packet from client: %s" % recvpkt)
self.state.state = 'err'
raise TftpException, "Received error from client"
# Handle other packet types.
else:
logger.error("Received packet %s while handling a download"
% recvpkt)
self.senderror(self.sock,
TftpErrors.IllegalTftpOp,
self.host,
self.port)
raise TftpException, "Invalid packet received during download"
def start_download(self):
"""This method opens self.filename, stores the resulting file object
in self.fileobj, and calls send_dat()."""
self.state.state = 'dat'
if os.path.exists(self.filename):
self.fileobj = open(self.filename, "rb")
else:
self.fileobj = self.dynfunc(self.filename)
self.send_dat()
def send_dat(self, resend=False):
"""This method reads sends a DAT packet based on what is in self.buffer."""
if not resend:
blksize = int(self.options['blksize'])
self.buffer = self.fileobj.read(blksize)
logger.debug("Read %d bytes into buffer" % len(self.buffer))
if len(self.buffer) < blksize:
logger.info("Reached EOF on file %s" % self.filename)
self.state.state = 'fin'
self.blocknumber += 1
if self.blocknumber > 65535:
logger.debug("Blocknumber rolled over to zero")
self.blocknumber = 0
else:
logger.warn("Resending block number %d" % self.blocknumber)
dat = TftpPacketDAT()
dat.data = self.buffer
dat.blocknumber = self.blocknumber
logger.debug("Sending DAT packet %d" % self.blocknumber)
self.sock.sendto(dat.encode().buffer, (self.host, self.port))
self.timesent = time.time()
# FIXME - should these be factored-out into the session class?
def send_oack(self):
"""This method sends an OACK packet based on current params."""
logger.debug("Composing and sending OACK packet")
oack = TftpPacketOACK()
oack.options = self.options
self.sock.sendto(oack.encode().buffer,
(self.host, self.port))
self.timesent = time.time()
self.state.state = 'oack'

View file

@ -17,7 +17,7 @@ DEF_TFTP_PORT = 69
logging.basicConfig() logging.basicConfig()
# The logger used by this library. Feel free to clobber it with your own, if you like, as # The logger used by this library. Feel free to clobber it with your own, if you like, as
# long as it conforms to Python's logging. # long as it conforms to Python's logging.
logger = logging.getLogger('tftpy') log = logging.getLogger('tftpy')
def tftpassert(condition, msg): def tftpassert(condition, msg):
"""This function is a simple utility that will check the condition """This function is a simple utility that will check the condition
@ -31,8 +31,8 @@ def setLogLevel(level):
"""This function is a utility function for setting the internal log level. """This function is a utility function for setting the internal log level.
The log level defaults to logging.NOTSET, so unwanted output to stdout is The log level defaults to logging.NOTSET, so unwanted output to stdout is
not created.""" not created."""
global logger global log
logger.setLevel(level) log.setLevel(level)
class TftpErrors(object): class TftpErrors(object):
"""This class is a convenience for defining the common tftp error codes, """This class is a convenience for defining the common tftp error codes,

View file

@ -22,17 +22,28 @@ class TftpMetrics(object):
# Rates # Rates
self.bps = 0 self.bps = 0
self.kbps = 0 self.kbps = 0
# Generic errors
self.errors = 0
def compute(self): def compute(self):
# Compute transfer time # Compute transfer time
self.duration = self.end_time - self.start_time self.duration = self.end_time - self.start_time
logger.debug("TftpMetrics.compute: duration is %s" % self.duration) log.debug("TftpMetrics.compute: duration is %s" % self.duration)
self.bps = (self.bytes * 8.0) / self.duration self.bps = (self.bytes * 8.0) / self.duration
self.kbps = self.bps / 1024.0 self.kbps = self.bps / 1024.0
logger.debug("TftpMetrics.compute: kbps is %s" % self.kbps) log.debug("TftpMetrics.compute: kbps is %s" % self.kbps)
dupcount = 0
for key in self.dups: for key in self.dups:
dupcount += self.dups[key] self.dupcount += self.dups[key]
def add_dup(self, blocknumber):
"""This method adds a dup for a block number to the metrics."""
log.debug("Recording a dup for block %d" % blocknumber)
if self.dups.has_key(blocknumber):
self.dups[pkt.blocknumber] += 1
else:
self.dups[pkt.blocknumber] = 1
tftpassert(self.dups[pkt.blocknumber] < MAX_DUPS,
"Max duplicates for block %d reached" % blocknumber)
############################################################################### ###############################################################################
# Context classes # Context classes
@ -40,16 +51,32 @@ class TftpMetrics(object):
class TftpContext(object): class TftpContext(object):
"""The base class of the contexts.""" """The base class of the contexts."""
def __init__(self, host, port):
def __init__(self, host, port, timeout):
"""Constructor for the base context, setting shared instance """Constructor for the base context, setting shared instance
variables.""" variables."""
self.file_to_transfer = None
self.fileobj = None
self.options = None
self.packethook = None
self.sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
self.sock.settimeout(timeout)
self.state = None
self.next_block = 0
self.factory = TftpPacketFactory() self.factory = TftpPacketFactory()
# Note, setting the host will also set self.address, as it's a property.
self.host = host self.host = host
self.port = port self.port = port
# The port associated with the TID # The port associated with the TID
self.tidport = None self.tidport = None
# Metrics # Metrics
self.metrics = TftpMetrics() self.metrics = TftpMetrics()
# Flag when the transfer is pending completion.
self.pending_complete = False
def checkTimeout(self, now):
# FIXME
pass
def start(self): def start(self):
return NotImplementedError, "Abstract method" return NotImplementedError, "Abstract method"
@ -69,37 +96,9 @@ class TftpContext(object):
host = property(gethost, sethost) host = property(gethost, sethost)
def sendAck(self, blocknumber):
"""This method sends an ack packet to the block number specified."""
logger.info("sending ack to block %d" % blocknumber)
ackpkt = TftpPacketACK()
ackpkt.blocknumber = blocknumber
self.sock.sendto(ackpkt.encode().buffer, (self.host, self.tidport))
def sendError(self, errorcode):
"""This method uses the socket passed, and uses the errorcode to
compose and send an error packet."""
logger.debug("In sendError, being asked to send error %d" % errorcode)
errpkt = TftpPacketERR()
errpkt.errorcode = errorcode
self.sock.sendto(errpkt.encode().buffer, (self.host, self.tidport))
class TftpContextClient(TftpContext):
"""This class represents shared functionality by both the download and
upload client contexts."""
def __init__(self, host, port, filename, options, packethook, timeout):
TftpContext.__init__(self, host, port)
self.file_to_transfer = filename
self.options = options
self.packethook = packethook
self.sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
self.sock.settimeout(timeout)
self.state = None
self.next_block = 0
def setNextBlock(self, block): def setNextBlock(self, block):
if block > 2 ** 16: if block > 2 ** 16:
logger.debug("block number rollover to 0 again") log.debug("block number rollover to 0 again")
block = 0 block = 0
self.__eblock = block self.__eblock = block
@ -111,19 +110,21 @@ class TftpContextClient(TftpContext):
def cycle(self): def cycle(self):
"""Here we wait for a response from the server after sending it """Here we wait for a response from the server after sending it
something, and dispatch appropriate action to that response.""" something, and dispatch appropriate action to that response."""
# FIXME: This won't work very well in a server context with multiple
# sessions running.
for i in range(TIMEOUT_RETRIES): for i in range(TIMEOUT_RETRIES):
logger.debug("in cycle, receive attempt %d" % i) log.debug("in cycle, receive attempt %d" % i)
try: try:
(buffer, (raddress, rport)) = self.sock.recvfrom(MAX_BLKSIZE) (buffer, (raddress, rport)) = self.sock.recvfrom(MAX_BLKSIZE)
except socket.timeout, err: except socket.timeout, err:
logger.warn("Timeout waiting for traffic, retrying...") log.warn("Timeout waiting for traffic, retrying...")
continue continue
break break
else: else:
raise TftpException, "Hit max timeouts, giving up." raise TftpException, "Hit max timeouts, giving up."
# Ok, we've received a packet. Log it. # Ok, we've received a packet. Log it.
logger.debug("Received %d bytes from %s:%s" log.debug("Received %d bytes from %s:%s"
% (len(buffer), raddress, rport)) % (len(buffer), raddress, rport))
# Decode it. # Decode it.
@ -131,11 +132,11 @@ class TftpContextClient(TftpContext):
# Check for known "connection". # Check for known "connection".
if raddress != self.address: if raddress != self.address:
logger.warn("Received traffic from %s, expected host %s. Discarding" log.warn("Received traffic from %s, expected host %s. Discarding"
% (raddress, self.host)) % (raddress, self.host))
if self.tidport and self.tidport != rport: if self.tidport and self.tidport != rport:
logger.warn("Received traffic from %s:%s but we're " log.warn("Received traffic from %s:%s but we're "
"connected to %s:%s. Discarding." "connected to %s:%s. Discarding."
% (raddress, rport, % (raddress, rport,
self.host, self.tidport)) self.host, self.tidport))
@ -150,29 +151,66 @@ class TftpContextClient(TftpContext):
# And handle it, possibly changing state. # And handle it, possibly changing state.
self.state = self.state.handle(recvpkt, raddress, rport) self.state = self.state.handle(recvpkt, raddress, rport)
class TftpContextClientUpload(TftpContextClient): class TftpContextServer(TftpContext):
"""The context for the server."""
def __init__(self, host, port, timeout, root, dyn_file_func):
TftpContext.__init__(self,
host,
port,
timeout)
# At this point we have no idea if this is a download or an upload. We
# need to let the start state determine that.
self.state = TftpStateServerStart()
self.root = root
self.dyn_file_func = dyn_file_func
def start(self, buffer):
"""Start the state cycle. Note that the server context receives an
initial packet in its start method."""
log.debug("TftpContextServer.start() - pkt = %s" % pkt)
self.metrics.start_time = time.time()
log.debug("set metrics.start_time to %s" % self.metrics.start_time)
pkt = self.factory.parse(buffer)
log.debug("TftpContextServer.start() - factory returned a %s" % pkt)
# Call handle once with the initial packet. This should put us into
# the download or the upload state.
self.state = self.state.handle(pkt,
self.host,
self.port)
try:
while self.state:
log.debug("state is %s" % self.state)
self.cycle()
finally:
self.fileobj.close()
class TftpContextClientUpload(TftpContext):
"""The upload context for the client during an upload.""" """The upload context for the client during an upload."""
def __init__(self, host, port, filename, input, options, packethook, timeout): def __init__(self, host, port, filename, input, options, packethook, timeout):
TftpContextClient.__init__(self, TftpContext.__init__(self,
host, host,
port, port,
filename, timeout)
options, self.file_to_transfer = filename
packethook, self.options = options
timeout) self.packethook = packethook
self.fileobj = open(input, "wb") self.fileobj = open(input, "wb")
logger.debug("TftpContextClientUpload.__init__()") log.debug("TftpContextClientUpload.__init__()")
logger.debug("file_to_transfer = %s, options = %s" % log.debug("file_to_transfer = %s, options = %s" %
(self.file_to_transfer, self.options)) (self.file_to_transfer, self.options))
def start(self): def start(self):
logger.info("sending tftp upload request to %s" % self.host) log.info("sending tftp upload request to %s" % self.host)
logger.info(" filename -> %s" % self.file_to_transfer) log.info(" filename -> %s" % self.file_to_transfer)
logger.info(" options -> %s" % self.options) log.info(" options -> %s" % self.options)
self.metrics.start_time = time.time() self.metrics.start_time = time.time()
logger.debug("set metrics.start_time to %s" % self.metrics.start_time) log.debug("set metrics.start_time to %s" % self.metrics.start_time)
# FIXME: put this in a sendWRQ method? # FIXME: put this in a sendWRQ method?
pkt = TftpPacketWRQ() pkt = TftpPacketWRQ()
@ -186,7 +224,7 @@ class TftpContextClientUpload(TftpContextClient):
try: try:
while self.state: while self.state:
logger.debug("state is %s" % self.state) log.debug("state is %s" % self.state)
self.cycle() self.cycle()
finally: finally:
self.fileobj.close() self.fileobj.close()
@ -194,32 +232,32 @@ class TftpContextClientUpload(TftpContextClient):
def end(self): def end(self):
pass pass
class TftpContextClientDownload(TftpContextClient): class TftpContextClientDownload(TftpContext):
"""The download context for the client during a download.""" """The download context for the client during a download."""
def __init__(self, host, port, filename, output, options, packethook, timeout): def __init__(self, host, port, filename, output, options, packethook, timeout):
TftpContextClient.__init__(self, TftpContext.__init__(self,
host, host,
port, port,
filename, filename,
options, options,
packethook, packethook,
timeout) timeout)
# FIXME - need to support alternate return formats than files? # FIXME - need to support alternate return formats than files?
# File-like objects would be ideal, ala duck-typing. # File-like objects would be ideal, ala duck-typing.
self.fileobj = open(output, "wb") self.fileobj = open(output, "wb")
logger.debug("TftpContextClientDownload.__init__()") log.debug("TftpContextClientDownload.__init__()")
logger.debug("file_to_transfer = %s, options = %s" % log.debug("file_to_transfer = %s, options = %s" %
(self.file_to_transfer, self.options)) (self.file_to_transfer, self.options))
def start(self): def start(self):
"""Initiate the download.""" """Initiate the download."""
logger.info("sending tftp download request to %s" % self.host) log.info("sending tftp download request to %s" % self.host)
logger.info(" filename -> %s" % self.file_to_transfer) log.info(" filename -> %s" % self.file_to_transfer)
logger.info(" options -> %s" % self.options) log.info(" options -> %s" % self.options)
self.metrics.start_time = time.time() self.metrics.start_time = time.time()
logger.debug("set metrics.start_time to %s" % self.metrics.start_time) log.debug("set metrics.start_time to %s" % self.metrics.start_time)
# FIXME: put this in a sendRRQ method? # FIXME: put this in a sendRRQ method?
pkt = TftpPacketRRQ() pkt = TftpPacketRRQ()
@ -233,7 +271,7 @@ class TftpContextClientDownload(TftpContextClient):
try: try:
while self.state: while self.state:
logger.debug("state is %s" % self.state) log.debug("state is %s" % self.state)
self.cycle() self.cycle()
finally: finally:
self.fileobj.close() self.fileobj.close()
@ -241,7 +279,7 @@ class TftpContextClientDownload(TftpContextClient):
def end(self): def end(self):
"""Finish up the context.""" """Finish up the context."""
self.metrics.end_time = time.time() self.metrics.end_time = time.time()
logger.debug("set metrics.end_time to %s" % self.metrics.end_time) log.debug("set metrics.end_time to %s" % self.metrics.end_time)
self.metrics.compute() self.metrics.compute()
@ -268,235 +306,431 @@ class TftpState(object):
options.""" options."""
if pkt.options.keys() > 0: if pkt.options.keys() > 0:
if pkt.match_options(self.context.options): if pkt.match_options(self.context.options):
logger.info("Successful negotiation of options") log.info("Successful negotiation of options")
# Set options to OACK options # Set options to OACK options
self.context.options = pkt.options self.context.options = pkt.options
for key in self.context.options: for key in self.context.options:
logger.info(" %s = %s" % (key, self.context.options[key])) log.info(" %s = %s" % (key, self.context.options[key]))
else: else:
logger.error("failed to negotiate options") log.error("failed to negotiate options")
raise TftpException, "Failed to negotiate options" raise TftpException, "Failed to negotiate options"
else: else:
raise TftpException, "No options found in OACK" raise TftpException, "No options found in OACK"
class TftpStateUpload(TftpState): def returnSupportedOptions(self, options):
"""A class holding common code for upload states.""" """This method takes a requested options list from a client, and
def sendDat(self, resend=False): returns the ones that are supported."""
# We support the options blksize and tsize right now.
# FIXME - put this somewhere else?
accepted_options = {}
for option in options:
if option == 'blksize':
# Make sure it's valid.
if int(options[option]) > MAX_BLKSIZE:
accepted_options[option] = MAX_BLKSIZE
elif option == 'tsize':
log.debug("tsize option is set")
accepted_options['tsize'] = 1
else:
log.info("Dropping unsupported option '%s'" % option)
return accepted_options
def serverInitial(self, pkt, raddress, rport):
"""This method performs initial setup for a server context transfer,
put here to refactor code out of the TftpStateServerRecvRRQ and
TftpStateServerRecvWRQ classes, since their initial setup is
identical. The method returns a boolean, sendoack, to indicate whether
it is required to send an OACK to the client."""
options = pkt.options
sendoack = False
if not options:
log.debug("setting default options, blksize")
# FIXME: put default options elsewhere
self.context.options = { 'blksize': DEF_BLKSIZE }
else:
log.debug("options requested: %s" % options)
self.context.options = self.returnSupportedOptions(options)
sendoack = True
# FIXME - only octet mode is supported at this time.
if pkt.mode != 'octet':
self.sendError(TftpErrors.IllegalTftpOp)
raise TftpException, \
"Only octet transfers are supported at this time."
# test host/port of client end
if self.context.host != raddress or self.context.port != rport:
self.sendError(TftpErrors.UnknownTID)
log.error("Expected traffic from %s:%s but received it "
"from %s:%s instead."
% (self.context.host,
self.context.port,
raddress,
rport))
# FIXME: increment an error count?
# Return same state, we're still waiting for valid traffic.
return self
log.debug("requested filename is %s" % pkt.filename)
# There are no os.sep's allowed in the filename.
# FIXME: Should we allow subdirectories?
if pkt.filename.find(os.sep) >= 0:
self.sendError(TftpErrors.IllegalTftpOp)
raise TftpException, "%s found in filename, not permitted" % os.sep
self.context.file_to_transfer = pkt.filename
return sendoack
def sendDAT(self, resend=False):
"""This method sends the next DAT packet based on the data in the
context. It returns a boolean indicating whether the transfer is
finished."""
finished = False finished = False
blocknumber = self.context.next_block blocknumber = self.context.next_block
if not resend: if not resend:
blksize = int(self.context.options['blksize']) blksize = int(self.context.options['blksize'])
buffer = self.context.fileobj.read(blksize) buffer = self.context.fileobj.read(blksize)
logger.debug("Read %d bytes into buffer" % len(buffer)) log.debug("Read %d bytes into buffer" % len(buffer))
if len(buffer) < blksize: if len(buffer) < blksize:
logger.info("Reached EOF on file %s" % self.context.input) log.info("Reached EOF on file %s" % self.context.input)
finished = True finished = True
self.context.next_block += 1 self.context.next_block += 1
self.bytes += len(buffer) self.bytes += len(buffer)
else: else:
logger.warn("Resending block number %d" % blocknumber) log.warn("Resending block number %d" % blocknumber)
dat = TftpPacketDAT() dat = TftpPacketDAT()
dat.data = buffer dat.data = buffer
dat.blocknumber = blocknumber dat.blocknumber = blocknumber
logger.debug("Sending DAT packet %d" % blocknumber) log.debug("Sending DAT packet %d" % blocknumber)
self.context.sock.sendto(dat.encode().buffer, self.context.sock.sendto(dat.encode().buffer,
(self.context.host, self.context.port)) (self.context.host, self.context.port))
if self.context.packethook: if self.context.packethook:
self.context.packethook(dat) self.context.packethook(dat)
return finished return finished
class TftpStateDownload(TftpState): def sendACK(self, blocknumber=None):
"""A class holding common code for download states.""" """This method sends an ack packet to the block number specified. If
none is specified, it defaults to the next_block property in the
parent context."""
if not blocknumber:
blocknumber = self.context.next_block
log.info("sending ack to block %d" % blocknumber)
ackpkt = TftpPacketACK()
ackpkt.blocknumber = blocknumber
self.context.sock.sendto(ackpkt.encode().buffer,
(self.context.host,
self.context.tidport))
def sendError(self, errorcode):
"""This method uses the socket passed, and uses the errorcode to
compose and send an error packet."""
log.debug("In sendError, being asked to send error %d" % errorcode)
errpkt = TftpPacketERR()
errpkt.errorcode = errorcode
self.context.sock.sendto(errpkt.encode().buffer,
(self.context.host,
self.context.tidport))
def sendOACK(self):
"""This method sends an OACK packet with the options from the current
context."""
log.debug("In sendOACK with options %s" % options)
pkt = TftpPacketOACK()
pkt.options = self.options
self.context.sock.sendto(pkt.encode().buffer,
(self.context.host,
self.context.tidport))
def handleDat(self, pkt): def handleDat(self, pkt):
"""This method handles a DAT packet during a download.""" """This method handles a DAT packet during a client download, or a
logger.info("handling DAT packet - block %d" % pkt.blocknumber) server upload."""
logger.debug("expecting block %s" % self.context.next_block) log.info("handling DAT packet - block %d" % pkt.blocknumber)
log.debug("expecting block %s" % self.context.next_block)
if pkt.blocknumber == self.context.next_block: if pkt.blocknumber == self.context.next_block:
logger.debug("good, received block %d in sequence" log.debug("good, received block %d in sequence"
% pkt.blocknumber) % pkt.blocknumber)
self.context.sendAck(pkt.blocknumber) self.sendACK()
self.context.next_block += 1 self.context.next_block += 1
logger.debug("writing %d bytes to output file" log.debug("writing %d bytes to output file"
% len(pkt.data)) % len(pkt.data))
self.context.fileobj.write(pkt.data) self.context.fileobj.write(pkt.data)
self.context.metrics.bytes += len(pkt.data) self.context.metrics.bytes += len(pkt.data)
# Check for end-of-file, any less than full data packet. # Check for end-of-file, any less than full data packet.
if len(pkt.data) < int(self.context.options['blksize']): if len(pkt.data) < int(self.context.options['blksize']):
logger.info("end of file detected") log.info("end of file detected")
return None return None
elif pkt.blocknumber < self.context.next_block: elif pkt.blocknumber < self.context.next_block:
logger.warn("dropping duplicate block %d" % pkt.blocknumber) log.warn("dropping duplicate block %d" % pkt.blocknumber)
if self.context.metrics.dups.has_key(pkt.blocknumber): self.context.metrics.add_dup(pkt.blocknumber)
self.context.metrics.dups[pkt.blocknumber] += 1 log.debug("ACKing block %d again, just in case" % pkt.blocknumber)
else: self.sendACK(pkt.blocknumber)
self.context.metrics.dups[pkt.blocknumber] = 1
tftpassert(self.context.metrics.dups[pkt.blocknumber] < MAX_DUPS,
"Max duplicates for block %d reached" % pkt.blocknumber)
# FIXME: double-check sorceror's apprentice problem!
logger.debug("ACKing block %d again, just in case" % pkt.blocknumber)
self.context.sendAck(pkt.blocknumber)
else: else:
# FIXME: should we be more tolerant and just discard instead? # FIXME: should we be more tolerant and just discard instead?
msg = "Whoa! Received future block %d but expected %d" \ msg = "Whoa! Received future block %d but expected %d" \
% (pkt.blocknumber, self.context.next_block) % (pkt.blocknumber, self.context.next_block)
logger.error(msg) log.error(msg)
raise TftpException, msg raise TftpException, msg
# Default is to ack # Default is to ack
return TftpStateSentACK(self.context) return TftpStateExpectDAT(self.context)
class TftpStateSentWRQ(TftpStateUpload): class TftpStateServerRecvRRQ(TftpState):
"""Just sent an WRQ packet for an upload.""" """This class represents the state of the TFTP server when it has just
received an RRQ packet."""
def handle(self, pkt, raddress, rport):
"Handle an initial RRQ packet as a server."
log.debug("In TftpStateServerRecvRRQ.handle")
sendoack = self.serverInitial(pkt, raddress, rport)
path = self.context.root + os.sep + self.context.file_to_transfer
log.info("Opening file %s for reading" % path)
if os.path.exists(path):
# Note: Open in binary mode for win32 portability, since win32
# blows.
self.context.fileobj = open(path, "rb")
elif self.dyn_file_func:
log.debug("No such file %s but using dyn_file_func" % path)
self.context.fileobj = \
self.dyn_file_func(self.context.file_to_transfer)
else:
send.sendError(TftpErrors.FileNotFound)
raise TftpException, "File not found: %s" % path
# Options negotiation.
if sendoack:
self.sendOACK()
return TftpStateServerOACK(self.context)
else:
log.debug("No requested options, starting send...")
self.context.pending_complete = self.sendDAT()
return TftpStateExpectACK(self.context)
# Note, we don't have to check any other states in this method, that's
# up to the caller.
class TftpStateServerRecvWRQ(TftpState):
"""This class represents the state of the TFTP server when it has just
received a WRQ packet."""
def handle(self, pkt, raddress, rport):
"Handle an initial WRQ packet as a server."
log.debug("In TftpStateServerRecvWRQ.handle")
sendoack = self.serverInitial(pkt, raddress, rport)
path = self.context.root + os.sep + self.context.file_to_transfer
log.info("Opening file %s for writing" % path)
if os.path.exists(path):
# FIXME: correct behavior?
log.warn("File %s exists already, overwriting...")
# FIXME: I think we should upload to a temp file and not overwrite the
# existing file until the file is successfully uploaded.
self.context.fileobj = open(path, "wb")
# Options negotiation.
if sendoack:
log.debug("Sending OACK to client")
self.sendOACK()
else:
log.debug("No requested options, starting send...")
self.sendACK()
# We may have sent an OACK, but we're expecting a DAT as the response
# to either the OACK or an ACK, so lets unconditionally use the
# TftpStateExpectDAT state.
return TftpStateExpectDAT(self.context)
# Note, we don't have to check any other states in this method, that's
# up to the caller.
class TftpStateServerStart(TftpState):
"""The start state for the server."""
def handle(self, pkt, raddress, rport): def handle(self, pkt, raddress, rport):
"""Handle a packet we just received.""" """Handle a packet we just received."""
if not self.context.tidport: log.debug("In TftpStateServerStart.handle")
self.context.tidport = rport if isinstance(pkt, TftpPacketRRQ):
logger.debug("Set remote port for session to %s" % rport) log.debug("handling an RRQ packet")
return TftpStateServerRecvRRQ(self.context).handle(pkt,
# If we're going to successfully transfer the file, then we should see raddress,
# either an OACK for accepted options, or an ACK to ignore options. rport)
if isinstance(pkt, TftpPacketOACK): elif isinstance(pkt, TftpPacketWRQ):
logger.info("received OACK from server") log.debug("handling a WRQ packet")
try: return TftpStateServerRecvWRQ(self.context).handle(pkt,
self.handleOACK(pkt) raddress,
except TftpException, err: rport)
logger.error("failed to negotiate options")
self.context.sendError(TftpErrors.FailedNegotiation)
raise
else:
logger.debug("sending first DAT packet")
fin = self.context.sendDat()
if fin:
logger.info("Add done")
return None
else:
logger.debug("Changing state to TftpStateSentDAT")
return TftpStateSentDAT(self.context)
elif isinstance(pkt, TftpPacketACK):
logger.info("received ACK from server")
logger.debug("apparently the server ignored our options")
# The block number should be zero.
if pkt.blocknumber == 0:
logger.debug("ack blocknumber is zero as expected")
logger.debug("sending first DAT packet")
fin = self.context.sendDat()
if fin:
logger.info("Add done")
return None
else:
logger.debug("Changing state to TftpStateSentDAT")
return TftpStateSentDAT(self.context)
else:
logger.warn("discarding ACK to block %s" % pkt.blocknumber)
logger.debug("still waiting for valid response from server")
return self
elif isinstance(pkt, TftpPacketERR):
self.context.sendError(TftpErrors.IllegalTftpOp)
raise TftpException, "Received ERR from server: " + str(pkt)
elif isinstance(pkt, TftpPacketRRQ):
self.context.sendError(TftpErrors.IllegalTftpOp)
raise TftpException, "Received RRQ from server while in upload"
elif isinstance(pkt, TftpPacketDAT):
self.context.sendError(TftpErrors.IllegalTftpOp)
raise TftpException, "Received DAT from server while in upload"
else: else:
self.context.sendError(TftpErrors.IllegalTftpOp) self.sendError(TftpErrors.IllegalTftpOp)
raise TftpException, "Received unknown packet type from server: " + str(pkt) raise TftpException, \
"Invalid packet to begin up/download: %s" % pkt
# By default, no state change. class TftpStateExpectACK(TftpState):
return self
class TftpStateSentDAT(TftpStateUpload):
"""This class represents the state of the transfer when a DAT was just """This class represents the state of the transfer when a DAT was just
sent, and we are waiting for an ACK from the server. This class is the sent, and we are waiting for an ACK from the server. This class is the
same one used by the client during the upload, and the server during the same one used by the client during the upload, and the server during the
download.""" download."""
class TftpStateSentRRQ(TftpStateDownload):
"""Just sent an RRQ packet."""
def handle(self, pkt, raddress, rport): def handle(self, pkt, raddress, rport):
"""Handle the packet in response to an RRQ to the server.""" "Handle a packet, hopefully an ACK since we just sent a DAT."
if not self.context.tidport: if isinstance(pkt, TftpPacketACK):
self.context.tidport = rport log.info("Received ACK for packet %d" % pkt.blocknumber)
logger.debug("Set remote port for session to %s" % rport) # Is this an ack to the one we just sent?
if self.context.next_block == pkt.blocknumber:
if self.context.pending_complete:
log.info("Received ACK to final DAT, we're done.")
return None
else:
log.debug("Good ACK, sending next DAT")
self.context.pending_complete = self.sendDAT()
elif pkt.blocknumber < self.context.next_block:
self.context.metrics.add_dup(pkt.blocknumber)
# Now check the packet type and dispatch it properly.
if isinstance(pkt, TftpPacketOACK):
logger.info("received OACK from server")
try:
self.handleOACK(pkt)
except TftpException, err:
logger.error("failed to negotiate options: %s" % str(err))
self.context.sendError(TftpErrors.FailedNegotiation)
raise
else: else:
logger.debug("sending ACK to OACK") log.warn("Oooh, time warp. Received ACK to packet we "
"didn't send yet. Discarding.")
self.context.sendAck(blocknumber=0) self.context.metrics.errors += 1
return self
logger.debug("Changing state to TftpStateSentACK") elif isinstance(pkt, TftpPacketERR):
return TftpStateSentACK(self.context) log.error("Received ERR packet from peer: %s" % str(pkt))
raise TftpException, \
elif isinstance(pkt, TftpPacketDAT): "Received ERR packet from peer: %s" % str(pkt)
# If there are any options set, then the server didn't honour any
# of them.
logger.info("received DAT from server")
if self.context.options:
logger.info("server ignored options, falling back to defaults")
self.context.options = { 'blksize': DEF_BLKSIZE }
return self.handleDat(pkt)
# Every other packet type is a problem.
elif isinstance(recvpkt, TftpPacketACK):
# Umm, we ACK, the server doesn't.
self.context.sendError(TftpErrors.IllegalTftpOp)
raise TftpException, "Received ACK from server while in download"
elif isinstance(recvpkt, TftpPacketWRQ):
self.context.sendError(TftpErrors.IllegalTftpOp)
raise TftpException, "Received WRQ from server while in download"
elif isinstance(recvpkt, TftpPacketERR):
self.context.sendError(TftpErrors.IllegalTftpOp)
raise TftpException, "Received ERR from server: " + str(recvpkt)
else: else:
self.context.sendError(TftpErrors.IllegalTftpOp) log.warn("Discarding unsupported packet: %s" % str(pkt))
raise TftpException, "Received unknown packet type from server: " + str(recvpkt) return self
# By default, no state change. class TftpStateExpectDAT(TftpState):
return self
class TftpStateSentACK(TftpStateDownload):
"""Just sent an ACK packet. Waiting for DAT.""" """Just sent an ACK packet. Waiting for DAT."""
def handle(self, pkt, raddress, rport): def handle(self, pkt, raddress, rport):
"""Handle the packet in response to an ACK, which should be a DAT.""" """Handle the packet in response to an ACK, which should be a DAT."""
if isinstance(pkt, TftpPacketDAT): if isinstance(pkt, TftpPacketDAT):
return self.handleDat(pkt) return self.handleDat(pkt)
# Every other packet type is a problem.
elif isinstance(recvpkt, TftpPacketACK):
# Umm, we ACK, you don't.
self.sendError(TftpErrors.IllegalTftpOp)
raise TftpException, "Received ACK from peer when expecting DAT"
elif isinstance(recvpkt, TftpPacketWRQ):
self.sendError(TftpErrors.IllegalTftpOp)
raise TftpException, "Received WRQ from peer when expecting DAT"
elif isinstance(recvpkt, TftpPacketERR):
self.sendError(TftpErrors.IllegalTftpOp)
raise TftpException, "Received ERR from peer: " + str(recvpkt)
else:
self.sendError(TftpErrors.IllegalTftpOp)
raise TftpException, "Received unknown packet type from peer: " + str(recvpkt)
class TftpStateSentWRQ(TftpState):
"""Just sent an WRQ packet for an upload."""
def handle(self, pkt, raddress, rport):
"""Handle a packet we just received."""
if not self.context.tidport:
self.context.tidport = rport
log.debug("Set remote port for session to %s" % rport)
# If we're going to successfully transfer the file, then we should see
# either an OACK for accepted options, or an ACK to ignore options.
if isinstance(pkt, TftpPacketOACK):
log.info("received OACK from server")
try:
self.handleOACK(pkt)
except TftpException, err:
log.error("failed to negotiate options")
self.sendError(TftpErrors.FailedNegotiation)
raise
else:
log.debug("sending first DAT packet")
self.context.pending_complete = self.sendDAT()
log.debug("Changing state to TftpStateExpectACK")
return TftpStateExpectACK(self.context)
elif isinstance(pkt, TftpPacketACK):
log.info("received ACK from server")
log.debug("apparently the server ignored our options")
# The block number should be zero.
if pkt.blocknumber == 0:
log.debug("ack blocknumber is zero as expected")
log.debug("sending first DAT packet")
self.pending_complete = self.context.sendDAT()
log.debug("Changing state to TftpStateExpectACK")
return TftpStateExpectACK(self.context)
else:
log.warn("discarding ACK to block %s" % pkt.blocknumber)
log.debug("still waiting for valid response from server")
return self
elif isinstance(pkt, TftpPacketERR):
self.sendError(TftpErrors.IllegalTftpOp)
raise TftpException, "Received ERR from server: " + str(pkt)
elif isinstance(pkt, TftpPacketRRQ):
self.sendError(TftpErrors.IllegalTftpOp)
raise TftpException, "Received RRQ from server while in upload"
elif isinstance(pkt, TftpPacketDAT):
self.sendError(TftpErrors.IllegalTftpOp)
raise TftpException, "Received DAT from server while in upload"
else:
self.sendError(TftpErrors.IllegalTftpOp)
raise TftpException, "Received unknown packet type from server: " + str(pkt)
# By default, no state change.
return self
class TftpStateSentRRQ(TftpState):
"""Just sent an RRQ packet."""
def handle(self, pkt, raddress, rport):
"""Handle the packet in response to an RRQ to the server."""
if not self.context.tidport:
self.context.tidport = rport
log.debug("Set remote port for session to %s" % rport)
# Now check the packet type and dispatch it properly.
if isinstance(pkt, TftpPacketOACK):
log.info("received OACK from server")
try:
self.handleOACK(pkt)
except TftpException, err:
log.error("failed to negotiate options: %s" % str(err))
self.sendError(TftpErrors.FailedNegotiation)
raise
else:
log.debug("sending ACK to OACK")
self.sendACK(blocknumber=0)
log.debug("Changing state to TftpStateExpectDAT")
return TftpStateExpectDAT(self.context)
elif isinstance(pkt, TftpPacketDAT):
# If there are any options set, then the server didn't honour any
# of them.
log.info("received DAT from server")
if self.context.options:
log.info("server ignored options, falling back to defaults")
self.context.options = { 'blksize': DEF_BLKSIZE }
return self.handleDat(pkt)
# Every other packet type is a problem. # Every other packet type is a problem.
elif isinstance(recvpkt, TftpPacketACK): elif isinstance(recvpkt, TftpPacketACK):
# Umm, we ACK, the server doesn't. # Umm, we ACK, the server doesn't.
self.context.sendError(TftpErrors.IllegalTftpOp) self.sendError(TftpErrors.IllegalTftpOp)
raise TftpException, "Received ACK from server while in download" raise TftpException, "Received ACK from server while in download"
elif isinstance(recvpkt, TftpPacketWRQ): elif isinstance(recvpkt, TftpPacketWRQ):
self.context.sendError(TftpErrors.IllegalTftpOp) self.sendError(TftpErrors.IllegalTftpOp)
raise TftpException, "Received WRQ from server while in download" raise TftpException, "Received WRQ from server while in download"
elif isinstance(recvpkt, TftpPacketERR): elif isinstance(recvpkt, TftpPacketERR):
self.context.sendError(TftpErrors.IllegalTftpOp) self.sendError(TftpErrors.IllegalTftpOp)
raise TftpException, "Received ERR from server: " + str(recvpkt) raise TftpException, "Received ERR from server: " + str(recvpkt)
else: else:
self.context.sendError(TftpErrors.IllegalTftpOp) self.sendError(TftpErrors.IllegalTftpOp)
raise TftpException, "Received unknown packet type from server: " + str(recvpkt) raise TftpException, "Received unknown packet type from server: " + str(recvpkt)
# By default, no state change.
return self