Did some rework for the state machine in a server context.

Removed the handler framework in favour of a TftpContextServer used
as the session.
master
Michael P. Soulier 2009-08-15 22:36:58 -04:00
parent 03e4e74829
commit 62b22fb562
6 changed files with 564 additions and 636 deletions

View File

@ -52,12 +52,12 @@ class TftpClient(TftpSession):
# output? This should be in the sample client, but not in the download
# call.
if metrics.duration == 0:
logger.info("Duration too short, rate undetermined")
log.info("Duration too short, rate undetermined")
else:
logger.info('')
logger.info("Downloaded %.2f bytes in %.2f seconds" % (metrics.bytes, metrics.duration))
logger.info("Average rate: %.2f kbps" % metrics.kbps)
logger.info("Received %d duplicate packets" % metrics.dupcount)
log.info('')
log.info("Downloaded %.2f bytes in %.2f seconds" % (metrics.bytes, metrics.duration))
log.info("Average rate: %.2f kbps" % metrics.kbps)
log.info("Received %d duplicate packets" % metrics.dupcount)
def upload(self, filename, input, packethook=None, timeout=SOCK_TIMEOUT):
# Open the input file.
@ -80,9 +80,9 @@ class TftpClient(TftpSession):
# output? This should be in the sample client, but not in the download
# call.
if metrics.duration == 0:
logger.info("Duration too short, rate undetermined")
log.info("Duration too short, rate undetermined")
else:
logger.info('')
logger.info("Downloaded %.2f bytes in %.2f seconds" % (metrics.bytes, metrics.duration))
logger.info("Average rate: %.2f kbps" % metrics.kbps)
logger.info("Received %d duplicate packets" % metrics.dupcount)
log.info('')
log.info("Downloaded %.2f bytes in %.2f seconds" % (metrics.bytes, metrics.duration))
log.info("Average rate: %.2f kbps" % metrics.kbps)
log.info("Received %d duplicate packets" % metrics.dupcount)

View File

@ -19,9 +19,9 @@ class TftpPacketFactory(object):
"""This method is used to parse an existing datagram into its
corresponding TftpPacket object. The buffer is the raw bytes off of
the network."""
logger.debug("parsing a %d byte packet" % len(buffer))
log.debug("parsing a %d byte packet" % len(buffer))
(opcode,) = struct.unpack("!H", buffer[:2])
logger.debug("opcode is %d" % opcode)
log.debug("opcode is %d" % opcode)
packet = self.__create(opcode)
packet.buffer = buffer
return packet.decode()

View File

@ -15,7 +15,7 @@ class TftpSession(object):
def senderror(self, sock, errorcode, address, port):
"""This method uses the socket passed, and uses the errorcode, address
and port to compose and send an error packet."""
logger.debug("In senderror, being asked to send error %d to %s:%s"
log.debug("In senderror, being asked to send error %d to %s:%s"
% (errorcode, address, port))
errpkt = TftpPacketERR()
errpkt.errorcode = errorcode
@ -27,23 +27,23 @@ class TftpPacketWithOptions(object):
goal is just to share code here, and not cause diamond inheritance."""
def __init__(self):
self.options = []
self.options = {}
def setoptions(self, options):
logger.debug("in TftpPacketWithOptions.setoptions")
logger.debug("options: " + str(options))
log.debug("in TftpPacketWithOptions.setoptions")
log.debug("options: " + str(options))
myoptions = {}
for key in options:
newkey = str(key)
myoptions[newkey] = str(options[key])
logger.debug("populated myoptions with %s = %s"
log.debug("populated myoptions with %s = %s"
% (newkey, myoptions[newkey]))
logger.debug("setting options hash to: " + str(myoptions))
log.debug("setting options hash to: " + str(myoptions))
self._options = myoptions
def getoptions(self):
logger.debug("in TftpPacketWithOptions.getoptions")
log.debug("in TftpPacketWithOptions.getoptions")
return self._options
# Set up getter and setter on options to ensure that they are the proper
@ -59,19 +59,19 @@ class TftpPacketWithOptions(object):
format = "!"
options = {}
logger.debug("decode_options: buffer is: " + repr(buffer))
logger.debug("size of buffer is %d bytes" % len(buffer))
log.debug("decode_options: buffer is: " + repr(buffer))
log.debug("size of buffer is %d bytes" % len(buffer))
if len(buffer) == 0:
logger.debug("size of buffer is zero, returning empty hash")
log.debug("size of buffer is zero, returning empty hash")
return {}
# Count the nulls in the buffer. Each one terminates a string.
logger.debug("about to iterate options buffer counting nulls")
log.debug("about to iterate options buffer counting nulls")
length = 0
for c in buffer:
#logger.debug("iterating this byte: " + repr(c))
#log.debug("iterating this byte: " + repr(c))
if ord(c) == 0:
logger.debug("found a null at length %d" % length)
log.debug("found a null at length %d" % length)
if length > 0:
format += "%dsx" % length
length = -1
@ -79,14 +79,14 @@ class TftpPacketWithOptions(object):
raise TftpException, "Invalid options in buffer"
length += 1
logger.debug("about to unpack, format is: %s" % format)
log.debug("about to unpack, format is: %s" % format)
mystruct = struct.unpack(format, buffer)
tftpassert(len(mystruct) % 2 == 0,
"packet with odd number of option/value pairs")
for i in range(0, len(mystruct), 2):
logger.debug("setting option %s to %s" % (mystruct[i], mystruct[i+1]))
log.debug("setting option %s to %s" % (mystruct[i], mystruct[i+1]))
options[mystruct[i]] = mystruct[i+1]
return options
@ -134,10 +134,10 @@ class TftpPacketInitial(TftpPacket, TftpPacketWithOptions):
ptype = None
if self.opcode == 1: ptype = "RRQ"
else: ptype = "WRQ"
logger.debug("Encoding %s packet, filename = %s, mode = %s"
log.debug("Encoding %s packet, filename = %s, mode = %s"
% (ptype, self.filename, self.mode))
for key in self.options:
logger.debug(" Option %s = %s" % (key, self.options[key]))
log.debug(" Option %s = %s" % (key, self.options[key]))
format = "!H"
format += "%dsx" % len(self.filename)
@ -148,7 +148,7 @@ class TftpPacketInitial(TftpPacket, TftpPacketWithOptions):
# Add options.
options_list = []
if self.options.keys() > 0:
logger.debug("there are options to encode")
log.debug("there are options to encode")
for key in self.options:
# Populate the option name
format += "%dsx" % len(key)
@ -157,9 +157,9 @@ class TftpPacketInitial(TftpPacket, TftpPacketWithOptions):
format += "%dsx" % len(str(self.options[key]))
options_list.append(str(self.options[key]))
logger.debug("format is %s" % format)
logger.debug("options_list is %s" % options_list)
logger.debug("size of struct is %d" % struct.calcsize(format))
log.debug("format is %s" % format)
log.debug("options_list is %s" % options_list)
log.debug("size of struct is %d" % struct.calcsize(format))
self.buffer = struct.pack(format,
self.opcode,
@ -167,7 +167,7 @@ class TftpPacketInitial(TftpPacket, TftpPacketWithOptions):
self.mode,
*options_list)
logger.debug("buffer is " + repr(self.buffer))
log.debug("buffer is " + repr(self.buffer))
return self
def decode(self):
@ -177,13 +177,13 @@ class TftpPacketInitial(TftpPacket, TftpPacketWithOptions):
nulls = 0
format = ""
nulls = length = tlength = 0
logger.debug("in decode: about to iterate buffer counting nulls")
log.debug("in decode: about to iterate buffer counting nulls")
subbuf = self.buffer[2:]
for c in subbuf:
logger.debug("iterating this byte: " + repr(c))
log.debug("iterating this byte: " + repr(c))
if ord(c) == 0:
nulls += 1
logger.debug("found a null at length %d, now have %d"
log.debug("found a null at length %d, now have %d"
% (length, nulls))
format += "%dsx" % length
length = -1
@ -193,17 +193,17 @@ class TftpPacketInitial(TftpPacket, TftpPacketWithOptions):
length += 1
tlength += 1
logger.debug("hopefully found end of mode at length %d" % tlength)
log.debug("hopefully found end of mode at length %d" % tlength)
# length should now be the end of the mode.
tftpassert(nulls == 2, "malformed packet")
shortbuf = subbuf[:tlength+1]
logger.debug("about to unpack buffer with format: %s" % format)
logger.debug("unpacking buffer: " + repr(shortbuf))
log.debug("about to unpack buffer with format: %s" % format)
log.debug("unpacking buffer: " + repr(shortbuf))
mystruct = struct.unpack(format, shortbuf)
tftpassert(len(mystruct) == 2, "malformed packet")
logger.debug("setting filename to %s" % mystruct[0])
logger.debug("setting mode to %s" % mystruct[1])
log.debug("setting filename to %s" % mystruct[0])
log.debug("setting mode to %s" % mystruct[1])
self.filename = mystruct[0]
self.mode = mystruct[1]
@ -269,7 +269,7 @@ DATA | 03 | Block # | Data |
"""Encode the DAT packet. This method populates self.buffer, and
returns self for easy method chaining."""
if len(self.data) == 0:
logger.debug("Encoding an empty DAT packet")
log.debug("Encoding an empty DAT packet")
format = "!HH%ds" % len(self.data)
self.buffer = struct.pack(format,
self.opcode,
@ -283,12 +283,12 @@ DATA | 03 | Block # | Data |
# We know the first 2 bytes are the opcode. The second two are the
# block number.
(self.blocknumber,) = struct.unpack("!H", self.buffer[2:4])
logger.debug("decoding DAT packet, block number %d" % self.blocknumber)
logger.debug("should be %d bytes in the packet total"
log.debug("decoding DAT packet, block number %d" % self.blocknumber)
log.debug("should be %d bytes in the packet total"
% len(self.buffer))
# Everything else is data.
self.data = self.buffer[4:]
logger.debug("found %d bytes of data"
log.debug("found %d bytes of data"
% len(self.data))
return self
@ -308,14 +308,14 @@ ACK | 04 | Block # |
return 'ACK packet: block %d' % self.blocknumber
def encode(self):
logger.debug("encoding ACK: opcode = %d, block = %d"
log.debug("encoding ACK: opcode = %d, block = %d"
% (self.opcode, self.blocknumber))
self.buffer = struct.pack("!HH", self.opcode, self.blocknumber)
return self
def decode(self):
self.opcode, self.blocknumber = struct.unpack("!HH", self.buffer)
logger.debug("decoded ACK packet: opcode = %d, block = %d"
log.debug("decoded ACK packet: opcode = %d, block = %d"
% (self.opcode, self.blocknumber))
return self
@ -365,7 +365,7 @@ ERROR | 05 | ErrorCode | ErrMsg | 0 |
"""Encode the DAT packet based on instance variables, populating
self.buffer, returning self."""
format = "!HH%dsx" % len(self.errmsgs[self.errorcode])
logger.debug("encoding ERR packet with format %s" % format)
log.debug("encoding ERR packet with format %s" % format)
self.buffer = struct.pack(format,
self.opcode,
self.errorcode,
@ -375,13 +375,13 @@ ERROR | 05 | ErrorCode | ErrMsg | 0 |
def decode(self):
"Decode self.buffer, populating instance variables and return self."
tftpassert(len(self.buffer) > 4, "malformed ERR packet, too short")
logger.debug("Decoding ERR packet, length %s bytes" %
log.debug("Decoding ERR packet, length %s bytes" %
len(self.buffer))
format = "!HH%dsx" % (len(self.buffer) - 5)
logger.debug("Decoding ERR packet with format: %s" % format)
log.debug("Decoding ERR packet with format: %s" % format)
self.opcode, self.errorcode, self.errmsg = struct.unpack(format,
self.buffer)
logger.error("ERR packet - errorcode: %d, message: %s"
log.error("ERR packet - errorcode: %d, message: %s"
% (self.errorcode, self.errmsg))
return self
@ -402,10 +402,10 @@ class TftpPacketOACK(TftpPacket, TftpPacketWithOptions):
def encode(self):
format = "!H" # opcode
options_list = []
logger.debug("in TftpPacketOACK.encode")
log.debug("in TftpPacketOACK.encode")
for key in self.options:
logger.debug("looping on option key %s" % key)
logger.debug("value is %s" % self.options[key])
log.debug("looping on option key %s" % key)
log.debug("value is %s" % self.options[key])
format += "%dsx" % len(key)
format += "%dsx" % len(self.options[key])
options_list.append(key)
@ -429,7 +429,7 @@ class TftpPacketOACK(TftpPacket, TftpPacketWithOptions):
# We can accept anything between the min and max values.
size = self.options[name]
if size >= MIN_BLKSIZE and size <= MAX_BLKSIZE:
logger.debug("negotiated blksize of %d bytes" % size)
log.debug("negotiated blksize of %d bytes" % size)
options[blksize] = size
else:
raise TftpException, "Unsupported option: %s" % name

View File

@ -1,4 +1,5 @@
import socket, os, re, time, random
import select
from TftpShared import *
from TftpPacketTypes import *
from TftpPacketFactory import *
@ -15,26 +16,27 @@ class TftpServer(TftpSession):
self.listenip = None
self.listenport = None
self.sock = None
# FIXME: What about multiple roots?
self.root = os.path.abspath(tftproot)
self.dynfunc = dyn_file_func
self.dyn_file_func = dyn_file_func
# A dict of handlers, where each session is keyed by a string like
# ip:tid for the remote end.
self.handlers = {}
if os.path.exists(self.root):
logger.debug("tftproot %s does exist" % self.root)
log.debug("tftproot %s does exist" % self.root)
if not os.path.isdir(self.root):
raise TftpException, "The tftproot must be a directory."
else:
logger.debug("tftproot %s is a directory" % self.root)
log.debug("tftproot %s is a directory" % self.root)
if os.access(self.root, os.R_OK):
logger.debug("tftproot %s is readable" % self.root)
log.debug("tftproot %s is readable" % self.root)
else:
raise TftpException, "The tftproot must be readable"
if os.access(self.root, os.W_OK):
logger.debug("tftproot %s is writable" % self.root)
log.debug("tftproot %s is writable" % self.root)
else:
logger.warning("The tftproot %s is not writable" % self.root)
log.warning("The tftproot %s is not writable" % self.root)
else:
raise TftpException, "The tftproot does not exist."
@ -45,14 +47,12 @@ class TftpServer(TftpSession):
"""Start a server listening on the supplied interface and port. This
defaults to INADDR_ANY (all interfaces) and UDP port 69. You can also
supply a different socket timeout value, if desired."""
import select
tftp_factory = TftpPacketFactory()
# Don't use new 2.5 ternary operator yet
# listenip = listenip if listenip else '0.0.0.0'
if not listenip: listenip = '0.0.0.0'
logger.info("Server requested on ip %s, port %s"
log.info("Server requested on ip %s, port %s"
% (listenip, listenport))
try:
# FIXME - sockets should be non-blocking?
@ -62,388 +62,82 @@ class TftpServer(TftpSession):
# Reraise it for now.
raise
logger.info("Starting receive loop...")
log.info("Starting receive loop...")
while True:
# Build the inputlist array of sockets to select() on.
inputlist = []
inputlist.append(self.sock)
for key in self.handlers:
inputlist.append(self.handlers[key].sock)
for key in self.sessions:
inputlist.append(self.sessions[key].sock)
# Block until some socket has input on it.
logger.debug("Performing select on this inputlist: %s" % inputlist)
log.debug("Performing select on this inputlist: %s" % inputlist)
readyinput, readyoutput, readyspecial = select.select(inputlist,
[],
[],
SOCK_TIMEOUT)
#(buffer, (raddress, rport)) = self.sock.recvfrom(MAX_BLKSIZE)
#recvpkt = tftp_factory.parse(buffer)
#key = "%s:%s" % (raddress, rport)
deletion_list = []
# Handle the available data, if any. Maybe we timed-out.
for readysock in readyinput:
# Is the traffic on the main server socket? ie. new session?
if readysock == self.sock:
logger.debug("Data ready on our main socket")
log.debug("Data ready on our main socket")
buffer, (raddress, rport) = self.sock.recvfrom(MAX_BLKSIZE)
logger.debug("Read %d bytes" % len(buffer))
log.debug("Read %d bytes" % len(buffer))
recvpkt = tftp_factory.parse(buffer)
# FIXME: Is this the best way to do a session key? What
# about symmetric udp?
key = "%s:%s" % (raddress, rport)
if isinstance(recvpkt, TftpPacketRRQ):
logger.debug("RRQ packet from %s:%s" % (raddress, rport))
if not self.handlers.has_key(key):
try:
logger.debug("New download request, session key = %s"
% key)
self.handlers[key] = TftpServerHandler(key,
'rrq',
self.root,
listenip,
tftp_factory,
self.dynfunc)
self.handlers[key].handle((recvpkt, raddress, rport))
except TftpException, err:
logger.error("Fatal exception thrown from handler: %s"
% str(err))
logger.debug("Deleting handler: %s" % key)
deletion_list.append(key)
else:
logger.warn("Received RRQ for existing session!")
self.senderror(self.sock,
TftpErrors.IllegalTftpOp,
raddress,
rport)
continue
elif isinstance(recvpkt, TftpPacketWRQ):
logger.error("Write requests not implemented at this time.")
self.senderror(self.sock,
TftpErrors.IllegalTftpOp,
raddress,
rport)
continue
if not self.sessions.has_key(key):
log.debug("Creating new server context for "
"session key = %s" % key)
self.sessions[key] = TftpContextServer(raddress,
rport,
timeout,
self.root,
self.dyn_file_func)
self.sessions[key].start(buffer)
else:
# FIXME - this will have to change if we do symmetric UDP
logger.error("Should only receive RRQ or WRQ packets "
"on main listen port. Received %s" % recvpkt)
self.senderror(self.sock,
TftpErrors.IllegalTftpOp,
raddress,
rport)
continue
log.warn("received traffic on main socket for "
"existing session??")
else:
for key in self.handlers:
if readysock == self.handlers[key].sock:
# FIXME - violating DRY principle with above code
# Must find the owner of this traffic.
for key in self.session:
if readysock == self.session[key].sock:
try:
self.handlers[key].handle()
self.session[key].cycle()
if self.session[key].state == None:
log.info("Successful transfer.")
deletion_list.append(key)
break
except TftpException, err:
deletion_list.append(key)
if self.handlers[key].state.state == 'fin':
logger.info("Successful transfer.")
break
else:
logger.error("Fatal exception thrown from handler: %s"
% str(err))
log.error("Fatal exception thrown from "
"handler: %s" % str(err))
else:
logger.error("Can't find the owner for this packet. Discarding.")
log.error("Can't find the owner for this packet. "
"Discarding.")
logger.debug("Looping on all handlers to check for timeouts")
log.debug("Looping on all handlers to check for timeouts")
now = time.time()
for key in self.handlers:
for key in self.sessions:
try:
self.handlers[key].check_timeout(now)
self.sessions[key].checkTimeout(now)
except TftpException, err:
logger.error("Fatal exception thrown from handler: %s"
log.error("Fatal exception thrown from handler: %s"
% str(err))
deletion_list.append(key)
logger.debug("Iterating deletion list.")
log.debug("Iterating deletion list.")
for key in deletion_list:
if self.handlers.has_key(key):
logger.debug("Deleting handler %s" % key)
del self.handlers[key]
if self.sessions.has_key(key):
log.debug("Deleting handler %s" % key)
del self.sessions[key]
deletion_list = []
class TftpServerHandler(TftpSession):
"""This class implements a handler for a given server session, handling
the work for one download."""
def __init__(self, key, state, root, listenip, factory, dyn_file_func):
TftpSession.__init__(self)
logger.info("Starting new handler. Key %s." % key)
self.key = key
self.host, self.port = self.key.split(':')
self.port = int(self.port)
self.listenip = listenip
# Note, correct state here is important as it tells the handler whether it's
# handling a download or an upload.
self.state = state
self.root = root
self.mode = None
self.filename = None
self.sock = False
self.options = { 'blksize': DEF_BLKSIZE }
self.blocknumber = 0
self.buffer = None
self.fileobj = None
self.timesent = 0
self.timeouts = 0
self.tftp_factory = factory
self.dynfunc = dyn_file_func
count = 0
while not self.sock:
self.sock = self.gensock(listenip)
count += 1
if count > 10:
raise TftpException, "Failed to bind this handler to any port"
def check_timeout(self, now):
"""This method checks to see if we've timed-out waiting for traffic
from the client."""
if self.timesent:
if now - self.timesent > SOCK_TIMEOUT:
self.timeout()
def timeout(self):
"""This method handles a timeout condition."""
logger.debug("Handling timeout for handler %s" % self.key)
self.timeouts += 1
if self.timeouts > TIMEOUT_RETRIES:
raise TftpException, "Hit max retries, giving up."
if self.state.state == 'dat' or self.state.state == 'fin':
logger.debug("Timing out on DAT. Need to resend.")
self.send_dat(resend=True)
elif self.state.state == 'oack':
logger.debug("Timing out on OACK. Need to resend.")
self.send_oack()
else:
tftpassert(False,
"Timing out in unsupported state %s" %
self.state.state)
def gensock(self, listenip):
"""This method generates a new UDP socket, whose listening port must
be randomly generated, and not conflict with any already in use. For
now, let the OS do this."""
random.seed()
port = random.randrange(1025, 65536)
# FIXME - sockets should be non-blocking?
sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
logger.debug("Trying a handler socket on port %d" % port)
try:
sock.bind((listenip, port))
return sock
except socket.error, err:
if err[0] == 98:
logger.warn("Handler %s, port %d was already taken" % (self.key, port))
return False
else:
raise
def handle(self, pkttuple=None):
"""This method informs a handler instance that it has data waiting on
its socket that it must read and process."""
recvpkt = raddress = rport = None
if pkttuple:
logger.debug("Handed pkt %s for handler %s" % (recvpkt, self.key))
recvpkt, raddress, rport = pkttuple
else:
logger.debug("Data ready for handler %s" % self.key)
buffer, (raddress, rport) = self.sock.recvfrom(MAX_BLKSIZE)
logger.debug("Read %d bytes" % len(buffer))
recvpkt = self.tftp_factory.parse(buffer)
# FIXME - refactor into another method, this is too big
if isinstance(recvpkt, TftpPacketRRQ):
logger.debug("Handler %s received RRQ packet" % self.key)
logger.debug("Requested file is %s, mode is %s" % (recvpkt.filename,
recvpkt.mode))
# FIXME - only octet mode is supported at this time.
if recvpkt.mode != 'octet':
self.senderror(self.sock,
TftpErrors.IllegalTftpOp,
raddress,
rport)
raise TftpException, "Unsupported mode: %s" % recvpkt.mode
# test host/port of client end
if self.host != raddress or self.port != rport:
self.senderror(self.sock,
TftpErrors.UnknownTID,
raddress,
rport)
logger.error("Expected traffic from %s:%s but received it "
"from %s:%s instead."
% (self.host, self.port, raddress, rport))
self.errors += 1
return
if self.state.state == 'rrq':
logger.debug("Received RRQ. Composing response.")
self.filename = self.root + os.sep + recvpkt.filename
logger.debug("The path to the desired file is %s" %
self.filename)
self.filename = os.path.abspath(self.filename)
logger.debug("The absolute path is %s" % self.filename)
# Security check. Make sure it's prefixed by the tftproot.
if self.filename.find(self.root) == 0:
logger.debug("The path appears to be safe: %s" %
self.filename)
else:
logger.error("Insecure path: %s" % self.filename)
self.errors += 1
self.senderror(self.sock,
TftpErrors.AccessViolation,
raddress,
rport)
raise TftpException, "Insecure path: %s" % self.filename
# Does the file exist?
if(os.path.exists(self.filename) or not self.dynfunc is None):
logger.debug("File %s exists." % self.filename)
# Check options. Currently we only support the blksize
# option.
if recvpkt.options.has_key('blksize'):
logger.debug("RRQ includes a blksize option")
blksize = int(recvpkt.options['blksize'])
# Delete the option now that it's handled.
del recvpkt.options['blksize']
if blksize >= MIN_BLKSIZE and blksize <= MAX_BLKSIZE:
logger.info("Client requested blksize = %d"
% blksize)
self.options['blksize'] = blksize
else:
logger.warning("Client %s requested invalid "
"blocksize %d, responding with default"
% (self.key, blksize))
self.options['blksize'] = DEF_BLKSIZE
if recvpkt.options.has_key('tsize'):
logger.info('RRQ includes tsize option')
self.options['tsize'] = os.stat(self.filename).st_size
# Delete the option now that it's handled.
del recvpkt.options['tsize']
if len(recvpkt.options.keys()) > 0:
logger.warning("Client %s requested unsupported options: %s"
% (self.key, recvpkt.options))
if self.options:
logger.info("Options requested, sending OACK")
self.send_oack()
else:
logger.debug("Client %s requested no options."
% self.key)
self.start_download()
else:
logger.error("Requested file %s does not exist." %
self.filename)
self.senderror(self.sock,
TftpErrors.FileNotFound,
raddress,
rport)
raise TftpException, "Requested file not found: %s" % self.filename
else:
# We're receiving an RRQ when we're not expecting one.
logger.error("Received an RRQ in handler %s "
"but we're in state %s" % (self.key, self.state))
self.errors += 1
# Next packet type
elif isinstance(recvpkt, TftpPacketACK):
logger.debug("Received an ACK from the client.")
if recvpkt.blocknumber == 0 and self.state.state == 'oack':
logger.debug("Received ACK with 0 blocknumber, starting download")
self.start_download()
else:
if self.state.state == 'dat' or self.state.state == 'fin':
if self.blocknumber == recvpkt.blocknumber:
logger.debug("Received ACK for block %d"
% recvpkt.blocknumber)
if self.state.state == 'fin':
raise TftpException, "Successful transfer."
else:
self.send_dat()
elif recvpkt.blocknumber < self.blocknumber:
# Don't resend a DAT due to an old ACK. Fixes the
# sorceror's apprentice problem.
logger.warn("Received old ACK for block number %d"
% recvpkt.blocknumber)
else:
logger.warn("Received ACK for block number "
"%d, apparently from the future"
% recvpkt.blocknumber)
else:
logger.error("Received ACK with block number %d "
"while in state %s"
% (recvpkt.blocknumber,
self.state.state))
elif isinstance(recvpkt, TftpPacketERR):
logger.error("Received error packet from client: %s" % recvpkt)
self.state.state = 'err'
raise TftpException, "Received error from client"
# Handle other packet types.
else:
logger.error("Received packet %s while handling a download"
% recvpkt)
self.senderror(self.sock,
TftpErrors.IllegalTftpOp,
self.host,
self.port)
raise TftpException, "Invalid packet received during download"
def start_download(self):
"""This method opens self.filename, stores the resulting file object
in self.fileobj, and calls send_dat()."""
self.state.state = 'dat'
if os.path.exists(self.filename):
self.fileobj = open(self.filename, "rb")
else:
self.fileobj = self.dynfunc(self.filename)
self.send_dat()
def send_dat(self, resend=False):
"""This method reads sends a DAT packet based on what is in self.buffer."""
if not resend:
blksize = int(self.options['blksize'])
self.buffer = self.fileobj.read(blksize)
logger.debug("Read %d bytes into buffer" % len(self.buffer))
if len(self.buffer) < blksize:
logger.info("Reached EOF on file %s" % self.filename)
self.state.state = 'fin'
self.blocknumber += 1
if self.blocknumber > 65535:
logger.debug("Blocknumber rolled over to zero")
self.blocknumber = 0
else:
logger.warn("Resending block number %d" % self.blocknumber)
dat = TftpPacketDAT()
dat.data = self.buffer
dat.blocknumber = self.blocknumber
logger.debug("Sending DAT packet %d" % self.blocknumber)
self.sock.sendto(dat.encode().buffer, (self.host, self.port))
self.timesent = time.time()
# FIXME - should these be factored-out into the session class?
def send_oack(self):
"""This method sends an OACK packet based on current params."""
logger.debug("Composing and sending OACK packet")
oack = TftpPacketOACK()
oack.options = self.options
self.sock.sendto(oack.encode().buffer,
(self.host, self.port))
self.timesent = time.time()
self.state.state = 'oack'

View File

@ -17,7 +17,7 @@ DEF_TFTP_PORT = 69
logging.basicConfig()
# The logger used by this library. Feel free to clobber it with your own, if you like, as
# long as it conforms to Python's logging.
logger = logging.getLogger('tftpy')
log = logging.getLogger('tftpy')
def tftpassert(condition, msg):
"""This function is a simple utility that will check the condition
@ -31,8 +31,8 @@ def setLogLevel(level):
"""This function is a utility function for setting the internal log level.
The log level defaults to logging.NOTSET, so unwanted output to stdout is
not created."""
global logger
logger.setLevel(level)
global log
log.setLevel(level)
class TftpErrors(object):
"""This class is a convenience for defining the common tftp error codes,

View File

@ -22,17 +22,28 @@ class TftpMetrics(object):
# Rates
self.bps = 0
self.kbps = 0
# Generic errors
self.errors = 0
def compute(self):
# Compute transfer time
self.duration = self.end_time - self.start_time
logger.debug("TftpMetrics.compute: duration is %s" % self.duration)
log.debug("TftpMetrics.compute: duration is %s" % self.duration)
self.bps = (self.bytes * 8.0) / self.duration
self.kbps = self.bps / 1024.0
logger.debug("TftpMetrics.compute: kbps is %s" % self.kbps)
dupcount = 0
log.debug("TftpMetrics.compute: kbps is %s" % self.kbps)
for key in self.dups:
dupcount += self.dups[key]
self.dupcount += self.dups[key]
def add_dup(self, blocknumber):
"""This method adds a dup for a block number to the metrics."""
log.debug("Recording a dup for block %d" % blocknumber)
if self.dups.has_key(blocknumber):
self.dups[pkt.blocknumber] += 1
else:
self.dups[pkt.blocknumber] = 1
tftpassert(self.dups[pkt.blocknumber] < MAX_DUPS,
"Max duplicates for block %d reached" % blocknumber)
###############################################################################
# Context classes
@ -40,16 +51,32 @@ class TftpMetrics(object):
class TftpContext(object):
"""The base class of the contexts."""
def __init__(self, host, port):
def __init__(self, host, port, timeout):
"""Constructor for the base context, setting shared instance
variables."""
self.file_to_transfer = None
self.fileobj = None
self.options = None
self.packethook = None
self.sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
self.sock.settimeout(timeout)
self.state = None
self.next_block = 0
self.factory = TftpPacketFactory()
# Note, setting the host will also set self.address, as it's a property.
self.host = host
self.port = port
# The port associated with the TID
self.tidport = None
# Metrics
self.metrics = TftpMetrics()
# Flag when the transfer is pending completion.
self.pending_complete = False
def checkTimeout(self, now):
# FIXME
pass
def start(self):
return NotImplementedError, "Abstract method"
@ -69,37 +96,9 @@ class TftpContext(object):
host = property(gethost, sethost)
def sendAck(self, blocknumber):
"""This method sends an ack packet to the block number specified."""
logger.info("sending ack to block %d" % blocknumber)
ackpkt = TftpPacketACK()
ackpkt.blocknumber = blocknumber
self.sock.sendto(ackpkt.encode().buffer, (self.host, self.tidport))
def sendError(self, errorcode):
"""This method uses the socket passed, and uses the errorcode to
compose and send an error packet."""
logger.debug("In sendError, being asked to send error %d" % errorcode)
errpkt = TftpPacketERR()
errpkt.errorcode = errorcode
self.sock.sendto(errpkt.encode().buffer, (self.host, self.tidport))
class TftpContextClient(TftpContext):
"""This class represents shared functionality by both the download and
upload client contexts."""
def __init__(self, host, port, filename, options, packethook, timeout):
TftpContext.__init__(self, host, port)
self.file_to_transfer = filename
self.options = options
self.packethook = packethook
self.sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
self.sock.settimeout(timeout)
self.state = None
self.next_block = 0
def setNextBlock(self, block):
if block > 2 ** 16:
logger.debug("block number rollover to 0 again")
log.debug("block number rollover to 0 again")
block = 0
self.__eblock = block
@ -111,19 +110,21 @@ class TftpContextClient(TftpContext):
def cycle(self):
"""Here we wait for a response from the server after sending it
something, and dispatch appropriate action to that response."""
# FIXME: This won't work very well in a server context with multiple
# sessions running.
for i in range(TIMEOUT_RETRIES):
logger.debug("in cycle, receive attempt %d" % i)
log.debug("in cycle, receive attempt %d" % i)
try:
(buffer, (raddress, rport)) = self.sock.recvfrom(MAX_BLKSIZE)
except socket.timeout, err:
logger.warn("Timeout waiting for traffic, retrying...")
log.warn("Timeout waiting for traffic, retrying...")
continue
break
else:
raise TftpException, "Hit max timeouts, giving up."
# Ok, we've received a packet. Log it.
logger.debug("Received %d bytes from %s:%s"
log.debug("Received %d bytes from %s:%s"
% (len(buffer), raddress, rport))
# Decode it.
@ -131,11 +132,11 @@ class TftpContextClient(TftpContext):
# Check for known "connection".
if raddress != self.address:
logger.warn("Received traffic from %s, expected host %s. Discarding"
log.warn("Received traffic from %s, expected host %s. Discarding"
% (raddress, self.host))
if self.tidport and self.tidport != rport:
logger.warn("Received traffic from %s:%s but we're "
log.warn("Received traffic from %s:%s but we're "
"connected to %s:%s. Discarding."
% (raddress, rport,
self.host, self.tidport))
@ -150,29 +151,66 @@ class TftpContextClient(TftpContext):
# And handle it, possibly changing state.
self.state = self.state.handle(recvpkt, raddress, rport)
class TftpContextClientUpload(TftpContextClient):
class TftpContextServer(TftpContext):
"""The context for the server."""
def __init__(self, host, port, timeout, root, dyn_file_func):
TftpContext.__init__(self,
host,
port,
timeout)
# At this point we have no idea if this is a download or an upload. We
# need to let the start state determine that.
self.state = TftpStateServerStart()
self.root = root
self.dyn_file_func = dyn_file_func
def start(self, buffer):
"""Start the state cycle. Note that the server context receives an
initial packet in its start method."""
log.debug("TftpContextServer.start() - pkt = %s" % pkt)
self.metrics.start_time = time.time()
log.debug("set metrics.start_time to %s" % self.metrics.start_time)
pkt = self.factory.parse(buffer)
log.debug("TftpContextServer.start() - factory returned a %s" % pkt)
# Call handle once with the initial packet. This should put us into
# the download or the upload state.
self.state = self.state.handle(pkt,
self.host,
self.port)
try:
while self.state:
log.debug("state is %s" % self.state)
self.cycle()
finally:
self.fileobj.close()
class TftpContextClientUpload(TftpContext):
"""The upload context for the client during an upload."""
def __init__(self, host, port, filename, input, options, packethook, timeout):
TftpContextClient.__init__(self,
host,
port,
filename,
options,
packethook,
timeout)
TftpContext.__init__(self,
host,
port,
timeout)
self.file_to_transfer = filename
self.options = options
self.packethook = packethook
self.fileobj = open(input, "wb")
logger.debug("TftpContextClientUpload.__init__()")
logger.debug("file_to_transfer = %s, options = %s" %
log.debug("TftpContextClientUpload.__init__()")
log.debug("file_to_transfer = %s, options = %s" %
(self.file_to_transfer, self.options))
def start(self):
logger.info("sending tftp upload request to %s" % self.host)
logger.info(" filename -> %s" % self.file_to_transfer)
logger.info(" options -> %s" % self.options)
log.info("sending tftp upload request to %s" % self.host)
log.info(" filename -> %s" % self.file_to_transfer)
log.info(" options -> %s" % self.options)
self.metrics.start_time = time.time()
logger.debug("set metrics.start_time to %s" % self.metrics.start_time)
log.debug("set metrics.start_time to %s" % self.metrics.start_time)
# FIXME: put this in a sendWRQ method?
pkt = TftpPacketWRQ()
@ -186,7 +224,7 @@ class TftpContextClientUpload(TftpContextClient):
try:
while self.state:
logger.debug("state is %s" % self.state)
log.debug("state is %s" % self.state)
self.cycle()
finally:
self.fileobj.close()
@ -194,32 +232,32 @@ class TftpContextClientUpload(TftpContextClient):
def end(self):
pass
class TftpContextClientDownload(TftpContextClient):
class TftpContextClientDownload(TftpContext):
"""The download context for the client during a download."""
def __init__(self, host, port, filename, output, options, packethook, timeout):
TftpContextClient.__init__(self,
host,
port,
filename,
options,
packethook,
timeout)
TftpContext.__init__(self,
host,
port,
filename,
options,
packethook,
timeout)
# FIXME - need to support alternate return formats than files?
# File-like objects would be ideal, ala duck-typing.
self.fileobj = open(output, "wb")
logger.debug("TftpContextClientDownload.__init__()")
logger.debug("file_to_transfer = %s, options = %s" %
log.debug("TftpContextClientDownload.__init__()")
log.debug("file_to_transfer = %s, options = %s" %
(self.file_to_transfer, self.options))
def start(self):
"""Initiate the download."""
logger.info("sending tftp download request to %s" % self.host)
logger.info(" filename -> %s" % self.file_to_transfer)
logger.info(" options -> %s" % self.options)
log.info("sending tftp download request to %s" % self.host)
log.info(" filename -> %s" % self.file_to_transfer)
log.info(" options -> %s" % self.options)
self.metrics.start_time = time.time()
logger.debug("set metrics.start_time to %s" % self.metrics.start_time)
log.debug("set metrics.start_time to %s" % self.metrics.start_time)
# FIXME: put this in a sendRRQ method?
pkt = TftpPacketRRQ()
@ -233,7 +271,7 @@ class TftpContextClientDownload(TftpContextClient):
try:
while self.state:
logger.debug("state is %s" % self.state)
log.debug("state is %s" % self.state)
self.cycle()
finally:
self.fileobj.close()
@ -241,7 +279,7 @@ class TftpContextClientDownload(TftpContextClient):
def end(self):
"""Finish up the context."""
self.metrics.end_time = time.time()
logger.debug("set metrics.end_time to %s" % self.metrics.end_time)
log.debug("set metrics.end_time to %s" % self.metrics.end_time)
self.metrics.compute()
@ -268,235 +306,431 @@ class TftpState(object):
options."""
if pkt.options.keys() > 0:
if pkt.match_options(self.context.options):
logger.info("Successful negotiation of options")
log.info("Successful negotiation of options")
# Set options to OACK options
self.context.options = pkt.options
for key in self.context.options:
logger.info(" %s = %s" % (key, self.context.options[key]))
log.info(" %s = %s" % (key, self.context.options[key]))
else:
logger.error("failed to negotiate options")
log.error("failed to negotiate options")
raise TftpException, "Failed to negotiate options"
else:
raise TftpException, "No options found in OACK"
class TftpStateUpload(TftpState):
"""A class holding common code for upload states."""
def sendDat(self, resend=False):
def returnSupportedOptions(self, options):
"""This method takes a requested options list from a client, and
returns the ones that are supported."""
# We support the options blksize and tsize right now.
# FIXME - put this somewhere else?
accepted_options = {}
for option in options:
if option == 'blksize':
# Make sure it's valid.
if int(options[option]) > MAX_BLKSIZE:
accepted_options[option] = MAX_BLKSIZE
elif option == 'tsize':
log.debug("tsize option is set")
accepted_options['tsize'] = 1
else:
log.info("Dropping unsupported option '%s'" % option)
return accepted_options
def serverInitial(self, pkt, raddress, rport):
"""This method performs initial setup for a server context transfer,
put here to refactor code out of the TftpStateServerRecvRRQ and
TftpStateServerRecvWRQ classes, since their initial setup is
identical. The method returns a boolean, sendoack, to indicate whether
it is required to send an OACK to the client."""
options = pkt.options
sendoack = False
if not options:
log.debug("setting default options, blksize")
# FIXME: put default options elsewhere
self.context.options = { 'blksize': DEF_BLKSIZE }
else:
log.debug("options requested: %s" % options)
self.context.options = self.returnSupportedOptions(options)
sendoack = True
# FIXME - only octet mode is supported at this time.
if pkt.mode != 'octet':
self.sendError(TftpErrors.IllegalTftpOp)
raise TftpException, \
"Only octet transfers are supported at this time."
# test host/port of client end
if self.context.host != raddress or self.context.port != rport:
self.sendError(TftpErrors.UnknownTID)
log.error("Expected traffic from %s:%s but received it "
"from %s:%s instead."
% (self.context.host,
self.context.port,
raddress,
rport))
# FIXME: increment an error count?
# Return same state, we're still waiting for valid traffic.
return self
log.debug("requested filename is %s" % pkt.filename)
# There are no os.sep's allowed in the filename.
# FIXME: Should we allow subdirectories?
if pkt.filename.find(os.sep) >= 0:
self.sendError(TftpErrors.IllegalTftpOp)
raise TftpException, "%s found in filename, not permitted" % os.sep
self.context.file_to_transfer = pkt.filename
return sendoack
def sendDAT(self, resend=False):
"""This method sends the next DAT packet based on the data in the
context. It returns a boolean indicating whether the transfer is
finished."""
finished = False
blocknumber = self.context.next_block
if not resend:
blksize = int(self.context.options['blksize'])
buffer = self.context.fileobj.read(blksize)
logger.debug("Read %d bytes into buffer" % len(buffer))
log.debug("Read %d bytes into buffer" % len(buffer))
if len(buffer) < blksize:
logger.info("Reached EOF on file %s" % self.context.input)
log.info("Reached EOF on file %s" % self.context.input)
finished = True
self.context.next_block += 1
self.bytes += len(buffer)
else:
logger.warn("Resending block number %d" % blocknumber)
log.warn("Resending block number %d" % blocknumber)
dat = TftpPacketDAT()
dat.data = buffer
dat.blocknumber = blocknumber
logger.debug("Sending DAT packet %d" % blocknumber)
log.debug("Sending DAT packet %d" % blocknumber)
self.context.sock.sendto(dat.encode().buffer,
(self.context.host, self.context.port))
if self.context.packethook:
self.context.packethook(dat)
return finished
class TftpStateDownload(TftpState):
"""A class holding common code for download states."""
def sendACK(self, blocknumber=None):
"""This method sends an ack packet to the block number specified. If
none is specified, it defaults to the next_block property in the
parent context."""
if not blocknumber:
blocknumber = self.context.next_block
log.info("sending ack to block %d" % blocknumber)
ackpkt = TftpPacketACK()
ackpkt.blocknumber = blocknumber
self.context.sock.sendto(ackpkt.encode().buffer,
(self.context.host,
self.context.tidport))
def sendError(self, errorcode):
"""This method uses the socket passed, and uses the errorcode to
compose and send an error packet."""
log.debug("In sendError, being asked to send error %d" % errorcode)
errpkt = TftpPacketERR()
errpkt.errorcode = errorcode
self.context.sock.sendto(errpkt.encode().buffer,
(self.context.host,
self.context.tidport))
def sendOACK(self):
"""This method sends an OACK packet with the options from the current
context."""
log.debug("In sendOACK with options %s" % options)
pkt = TftpPacketOACK()
pkt.options = self.options
self.context.sock.sendto(pkt.encode().buffer,
(self.context.host,
self.context.tidport))
def handleDat(self, pkt):
"""This method handles a DAT packet during a download."""
logger.info("handling DAT packet - block %d" % pkt.blocknumber)
logger.debug("expecting block %s" % self.context.next_block)
"""This method handles a DAT packet during a client download, or a
server upload."""
log.info("handling DAT packet - block %d" % pkt.blocknumber)
log.debug("expecting block %s" % self.context.next_block)
if pkt.blocknumber == self.context.next_block:
logger.debug("good, received block %d in sequence"
log.debug("good, received block %d in sequence"
% pkt.blocknumber)
self.context.sendAck(pkt.blocknumber)
self.sendACK()
self.context.next_block += 1
logger.debug("writing %d bytes to output file"
log.debug("writing %d bytes to output file"
% len(pkt.data))
self.context.fileobj.write(pkt.data)
self.context.metrics.bytes += len(pkt.data)
# Check for end-of-file, any less than full data packet.
if len(pkt.data) < int(self.context.options['blksize']):
logger.info("end of file detected")
log.info("end of file detected")
return None
elif pkt.blocknumber < self.context.next_block:
logger.warn("dropping duplicate block %d" % pkt.blocknumber)
if self.context.metrics.dups.has_key(pkt.blocknumber):
self.context.metrics.dups[pkt.blocknumber] += 1
else:
self.context.metrics.dups[pkt.blocknumber] = 1
tftpassert(self.context.metrics.dups[pkt.blocknumber] < MAX_DUPS,
"Max duplicates for block %d reached" % pkt.blocknumber)
# FIXME: double-check sorceror's apprentice problem!
logger.debug("ACKing block %d again, just in case" % pkt.blocknumber)
self.context.sendAck(pkt.blocknumber)
log.warn("dropping duplicate block %d" % pkt.blocknumber)
self.context.metrics.add_dup(pkt.blocknumber)
log.debug("ACKing block %d again, just in case" % pkt.blocknumber)
self.sendACK(pkt.blocknumber)
else:
# FIXME: should we be more tolerant and just discard instead?
msg = "Whoa! Received future block %d but expected %d" \
% (pkt.blocknumber, self.context.next_block)
logger.error(msg)
log.error(msg)
raise TftpException, msg
# Default is to ack
return TftpStateSentACK(self.context)
return TftpStateExpectDAT(self.context)
class TftpStateSentWRQ(TftpStateUpload):
"""Just sent an WRQ packet for an upload."""
class TftpStateServerRecvRRQ(TftpState):
"""This class represents the state of the TFTP server when it has just
received an RRQ packet."""
def handle(self, pkt, raddress, rport):
"Handle an initial RRQ packet as a server."
log.debug("In TftpStateServerRecvRRQ.handle")
sendoack = self.serverInitial(pkt, raddress, rport)
path = self.context.root + os.sep + self.context.file_to_transfer
log.info("Opening file %s for reading" % path)
if os.path.exists(path):
# Note: Open in binary mode for win32 portability, since win32
# blows.
self.context.fileobj = open(path, "rb")
elif self.dyn_file_func:
log.debug("No such file %s but using dyn_file_func" % path)
self.context.fileobj = \
self.dyn_file_func(self.context.file_to_transfer)
else:
send.sendError(TftpErrors.FileNotFound)
raise TftpException, "File not found: %s" % path
# Options negotiation.
if sendoack:
self.sendOACK()
return TftpStateServerOACK(self.context)
else:
log.debug("No requested options, starting send...")
self.context.pending_complete = self.sendDAT()
return TftpStateExpectACK(self.context)
# Note, we don't have to check any other states in this method, that's
# up to the caller.
class TftpStateServerRecvWRQ(TftpState):
"""This class represents the state of the TFTP server when it has just
received a WRQ packet."""
def handle(self, pkt, raddress, rport):
"Handle an initial WRQ packet as a server."
log.debug("In TftpStateServerRecvWRQ.handle")
sendoack = self.serverInitial(pkt, raddress, rport)
path = self.context.root + os.sep + self.context.file_to_transfer
log.info("Opening file %s for writing" % path)
if os.path.exists(path):
# FIXME: correct behavior?
log.warn("File %s exists already, overwriting...")
# FIXME: I think we should upload to a temp file and not overwrite the
# existing file until the file is successfully uploaded.
self.context.fileobj = open(path, "wb")
# Options negotiation.
if sendoack:
log.debug("Sending OACK to client")
self.sendOACK()
else:
log.debug("No requested options, starting send...")
self.sendACK()
# We may have sent an OACK, but we're expecting a DAT as the response
# to either the OACK or an ACK, so lets unconditionally use the
# TftpStateExpectDAT state.
return TftpStateExpectDAT(self.context)
# Note, we don't have to check any other states in this method, that's
# up to the caller.
class TftpStateServerStart(TftpState):
"""The start state for the server."""
def handle(self, pkt, raddress, rport):
"""Handle a packet we just received."""
if not self.context.tidport:
self.context.tidport = rport
logger.debug("Set remote port for session to %s" % rport)
# If we're going to successfully transfer the file, then we should see
# either an OACK for accepted options, or an ACK to ignore options.
if isinstance(pkt, TftpPacketOACK):
logger.info("received OACK from server")
try:
self.handleOACK(pkt)
except TftpException, err:
logger.error("failed to negotiate options")
self.context.sendError(TftpErrors.FailedNegotiation)
raise
else:
logger.debug("sending first DAT packet")
fin = self.context.sendDat()
if fin:
logger.info("Add done")
return None
else:
logger.debug("Changing state to TftpStateSentDAT")
return TftpStateSentDAT(self.context)
elif isinstance(pkt, TftpPacketACK):
logger.info("received ACK from server")
logger.debug("apparently the server ignored our options")
# The block number should be zero.
if pkt.blocknumber == 0:
logger.debug("ack blocknumber is zero as expected")
logger.debug("sending first DAT packet")
fin = self.context.sendDat()
if fin:
logger.info("Add done")
return None
else:
logger.debug("Changing state to TftpStateSentDAT")
return TftpStateSentDAT(self.context)
else:
logger.warn("discarding ACK to block %s" % pkt.blocknumber)
logger.debug("still waiting for valid response from server")
return self
elif isinstance(pkt, TftpPacketERR):
self.context.sendError(TftpErrors.IllegalTftpOp)
raise TftpException, "Received ERR from server: " + str(pkt)
elif isinstance(pkt, TftpPacketRRQ):
self.context.sendError(TftpErrors.IllegalTftpOp)
raise TftpException, "Received RRQ from server while in upload"
elif isinstance(pkt, TftpPacketDAT):
self.context.sendError(TftpErrors.IllegalTftpOp)
raise TftpException, "Received DAT from server while in upload"
log.debug("In TftpStateServerStart.handle")
if isinstance(pkt, TftpPacketRRQ):
log.debug("handling an RRQ packet")
return TftpStateServerRecvRRQ(self.context).handle(pkt,
raddress,
rport)
elif isinstance(pkt, TftpPacketWRQ):
log.debug("handling a WRQ packet")
return TftpStateServerRecvWRQ(self.context).handle(pkt,
raddress,
rport)
else:
self.context.sendError(TftpErrors.IllegalTftpOp)
raise TftpException, "Received unknown packet type from server: " + str(pkt)
self.sendError(TftpErrors.IllegalTftpOp)
raise TftpException, \
"Invalid packet to begin up/download: %s" % pkt
# By default, no state change.
return self
class TftpStateSentDAT(TftpStateUpload):
class TftpStateExpectACK(TftpState):
"""This class represents the state of the transfer when a DAT was just
sent, and we are waiting for an ACK from the server. This class is the
same one used by the client during the upload, and the server during the
download."""
class TftpStateSentRRQ(TftpStateDownload):
"""Just sent an RRQ packet."""
def handle(self, pkt, raddress, rport):
"""Handle the packet in response to an RRQ to the server."""
if not self.context.tidport:
self.context.tidport = rport
logger.debug("Set remote port for session to %s" % rport)
"Handle a packet, hopefully an ACK since we just sent a DAT."
if isinstance(pkt, TftpPacketACK):
log.info("Received ACK for packet %d" % pkt.blocknumber)
# Is this an ack to the one we just sent?
if self.context.next_block == pkt.blocknumber:
if self.context.pending_complete:
log.info("Received ACK to final DAT, we're done.")
return None
else:
log.debug("Good ACK, sending next DAT")
self.context.pending_complete = self.sendDAT()
elif pkt.blocknumber < self.context.next_block:
self.context.metrics.add_dup(pkt.blocknumber)
# Now check the packet type and dispatch it properly.
if isinstance(pkt, TftpPacketOACK):
logger.info("received OACK from server")
try:
self.handleOACK(pkt)
except TftpException, err:
logger.error("failed to negotiate options: %s" % str(err))
self.context.sendError(TftpErrors.FailedNegotiation)
raise
else:
logger.debug("sending ACK to OACK")
self.context.sendAck(blocknumber=0)
logger.debug("Changing state to TftpStateSentACK")
return TftpStateSentACK(self.context)
elif isinstance(pkt, TftpPacketDAT):
# If there are any options set, then the server didn't honour any
# of them.
logger.info("received DAT from server")
if self.context.options:
logger.info("server ignored options, falling back to defaults")
self.context.options = { 'blksize': DEF_BLKSIZE }
return self.handleDat(pkt)
# Every other packet type is a problem.
elif isinstance(recvpkt, TftpPacketACK):
# Umm, we ACK, the server doesn't.
self.context.sendError(TftpErrors.IllegalTftpOp)
raise TftpException, "Received ACK from server while in download"
elif isinstance(recvpkt, TftpPacketWRQ):
self.context.sendError(TftpErrors.IllegalTftpOp)
raise TftpException, "Received WRQ from server while in download"
elif isinstance(recvpkt, TftpPacketERR):
self.context.sendError(TftpErrors.IllegalTftpOp)
raise TftpException, "Received ERR from server: " + str(recvpkt)
log.warn("Oooh, time warp. Received ACK to packet we "
"didn't send yet. Discarding.")
self.context.metrics.errors += 1
return self
elif isinstance(pkt, TftpPacketERR):
log.error("Received ERR packet from peer: %s" % str(pkt))
raise TftpException, \
"Received ERR packet from peer: %s" % str(pkt)
else:
self.context.sendError(TftpErrors.IllegalTftpOp)
raise TftpException, "Received unknown packet type from server: " + str(recvpkt)
log.warn("Discarding unsupported packet: %s" % str(pkt))
return self
# By default, no state change.
return self
class TftpStateSentACK(TftpStateDownload):
class TftpStateExpectDAT(TftpState):
"""Just sent an ACK packet. Waiting for DAT."""
def handle(self, pkt, raddress, rport):
"""Handle the packet in response to an ACK, which should be a DAT."""
if isinstance(pkt, TftpPacketDAT):
return self.handleDat(pkt)
# Every other packet type is a problem.
elif isinstance(recvpkt, TftpPacketACK):
# Umm, we ACK, you don't.
self.sendError(TftpErrors.IllegalTftpOp)
raise TftpException, "Received ACK from peer when expecting DAT"
elif isinstance(recvpkt, TftpPacketWRQ):
self.sendError(TftpErrors.IllegalTftpOp)
raise TftpException, "Received WRQ from peer when expecting DAT"
elif isinstance(recvpkt, TftpPacketERR):
self.sendError(TftpErrors.IllegalTftpOp)
raise TftpException, "Received ERR from peer: " + str(recvpkt)
else:
self.sendError(TftpErrors.IllegalTftpOp)
raise TftpException, "Received unknown packet type from peer: " + str(recvpkt)
class TftpStateSentWRQ(TftpState):
"""Just sent an WRQ packet for an upload."""
def handle(self, pkt, raddress, rport):
"""Handle a packet we just received."""
if not self.context.tidport:
self.context.tidport = rport
log.debug("Set remote port for session to %s" % rport)
# If we're going to successfully transfer the file, then we should see
# either an OACK for accepted options, or an ACK to ignore options.
if isinstance(pkt, TftpPacketOACK):
log.info("received OACK from server")
try:
self.handleOACK(pkt)
except TftpException, err:
log.error("failed to negotiate options")
self.sendError(TftpErrors.FailedNegotiation)
raise
else:
log.debug("sending first DAT packet")
self.context.pending_complete = self.sendDAT()
log.debug("Changing state to TftpStateExpectACK")
return TftpStateExpectACK(self.context)
elif isinstance(pkt, TftpPacketACK):
log.info("received ACK from server")
log.debug("apparently the server ignored our options")
# The block number should be zero.
if pkt.blocknumber == 0:
log.debug("ack blocknumber is zero as expected")
log.debug("sending first DAT packet")
self.pending_complete = self.context.sendDAT()
log.debug("Changing state to TftpStateExpectACK")
return TftpStateExpectACK(self.context)
else:
log.warn("discarding ACK to block %s" % pkt.blocknumber)
log.debug("still waiting for valid response from server")
return self
elif isinstance(pkt, TftpPacketERR):
self.sendError(TftpErrors.IllegalTftpOp)
raise TftpException, "Received ERR from server: " + str(pkt)
elif isinstance(pkt, TftpPacketRRQ):
self.sendError(TftpErrors.IllegalTftpOp)
raise TftpException, "Received RRQ from server while in upload"
elif isinstance(pkt, TftpPacketDAT):
self.sendError(TftpErrors.IllegalTftpOp)
raise TftpException, "Received DAT from server while in upload"
else:
self.sendError(TftpErrors.IllegalTftpOp)
raise TftpException, "Received unknown packet type from server: " + str(pkt)
# By default, no state change.
return self
class TftpStateSentRRQ(TftpState):
"""Just sent an RRQ packet."""
def handle(self, pkt, raddress, rport):
"""Handle the packet in response to an RRQ to the server."""
if not self.context.tidport:
self.context.tidport = rport
log.debug("Set remote port for session to %s" % rport)
# Now check the packet type and dispatch it properly.
if isinstance(pkt, TftpPacketOACK):
log.info("received OACK from server")
try:
self.handleOACK(pkt)
except TftpException, err:
log.error("failed to negotiate options: %s" % str(err))
self.sendError(TftpErrors.FailedNegotiation)
raise
else:
log.debug("sending ACK to OACK")
self.sendACK(blocknumber=0)
log.debug("Changing state to TftpStateExpectDAT")
return TftpStateExpectDAT(self.context)
elif isinstance(pkt, TftpPacketDAT):
# If there are any options set, then the server didn't honour any
# of them.
log.info("received DAT from server")
if self.context.options:
log.info("server ignored options, falling back to defaults")
self.context.options = { 'blksize': DEF_BLKSIZE }
return self.handleDat(pkt)
# Every other packet type is a problem.
elif isinstance(recvpkt, TftpPacketACK):
# Umm, we ACK, the server doesn't.
self.context.sendError(TftpErrors.IllegalTftpOp)
self.sendError(TftpErrors.IllegalTftpOp)
raise TftpException, "Received ACK from server while in download"
elif isinstance(recvpkt, TftpPacketWRQ):
self.context.sendError(TftpErrors.IllegalTftpOp)
self.sendError(TftpErrors.IllegalTftpOp)
raise TftpException, "Received WRQ from server while in download"
elif isinstance(recvpkt, TftpPacketERR):
self.context.sendError(TftpErrors.IllegalTftpOp)
self.sendError(TftpErrors.IllegalTftpOp)
raise TftpException, "Received ERR from server: " + str(recvpkt)
else:
self.context.sendError(TftpErrors.IllegalTftpOp)
self.sendError(TftpErrors.IllegalTftpOp)
raise TftpException, "Received unknown packet type from server: " + str(recvpkt)
# By default, no state change.
return self