Did some rework for the state machine in a server context.
Removed the handler framework in favour of a TftpContextServer used as the session.
This commit is contained in:
parent
03e4e74829
commit
62b22fb562
6 changed files with 564 additions and 636 deletions
|
@ -52,12 +52,12 @@ class TftpClient(TftpSession):
|
|||
# output? This should be in the sample client, but not in the download
|
||||
# call.
|
||||
if metrics.duration == 0:
|
||||
logger.info("Duration too short, rate undetermined")
|
||||
log.info("Duration too short, rate undetermined")
|
||||
else:
|
||||
logger.info('')
|
||||
logger.info("Downloaded %.2f bytes in %.2f seconds" % (metrics.bytes, metrics.duration))
|
||||
logger.info("Average rate: %.2f kbps" % metrics.kbps)
|
||||
logger.info("Received %d duplicate packets" % metrics.dupcount)
|
||||
log.info('')
|
||||
log.info("Downloaded %.2f bytes in %.2f seconds" % (metrics.bytes, metrics.duration))
|
||||
log.info("Average rate: %.2f kbps" % metrics.kbps)
|
||||
log.info("Received %d duplicate packets" % metrics.dupcount)
|
||||
|
||||
def upload(self, filename, input, packethook=None, timeout=SOCK_TIMEOUT):
|
||||
# Open the input file.
|
||||
|
@ -80,9 +80,9 @@ class TftpClient(TftpSession):
|
|||
# output? This should be in the sample client, but not in the download
|
||||
# call.
|
||||
if metrics.duration == 0:
|
||||
logger.info("Duration too short, rate undetermined")
|
||||
log.info("Duration too short, rate undetermined")
|
||||
else:
|
||||
logger.info('')
|
||||
logger.info("Downloaded %.2f bytes in %.2f seconds" % (metrics.bytes, metrics.duration))
|
||||
logger.info("Average rate: %.2f kbps" % metrics.kbps)
|
||||
logger.info("Received %d duplicate packets" % metrics.dupcount)
|
||||
log.info('')
|
||||
log.info("Downloaded %.2f bytes in %.2f seconds" % (metrics.bytes, metrics.duration))
|
||||
log.info("Average rate: %.2f kbps" % metrics.kbps)
|
||||
log.info("Received %d duplicate packets" % metrics.dupcount)
|
||||
|
|
|
@ -19,9 +19,9 @@ class TftpPacketFactory(object):
|
|||
"""This method is used to parse an existing datagram into its
|
||||
corresponding TftpPacket object. The buffer is the raw bytes off of
|
||||
the network."""
|
||||
logger.debug("parsing a %d byte packet" % len(buffer))
|
||||
log.debug("parsing a %d byte packet" % len(buffer))
|
||||
(opcode,) = struct.unpack("!H", buffer[:2])
|
||||
logger.debug("opcode is %d" % opcode)
|
||||
log.debug("opcode is %d" % opcode)
|
||||
packet = self.__create(opcode)
|
||||
packet.buffer = buffer
|
||||
return packet.decode()
|
||||
|
|
|
@ -15,7 +15,7 @@ class TftpSession(object):
|
|||
def senderror(self, sock, errorcode, address, port):
|
||||
"""This method uses the socket passed, and uses the errorcode, address
|
||||
and port to compose and send an error packet."""
|
||||
logger.debug("In senderror, being asked to send error %d to %s:%s"
|
||||
log.debug("In senderror, being asked to send error %d to %s:%s"
|
||||
% (errorcode, address, port))
|
||||
errpkt = TftpPacketERR()
|
||||
errpkt.errorcode = errorcode
|
||||
|
@ -27,23 +27,23 @@ class TftpPacketWithOptions(object):
|
|||
goal is just to share code here, and not cause diamond inheritance."""
|
||||
|
||||
def __init__(self):
|
||||
self.options = []
|
||||
self.options = {}
|
||||
|
||||
def setoptions(self, options):
|
||||
logger.debug("in TftpPacketWithOptions.setoptions")
|
||||
logger.debug("options: " + str(options))
|
||||
log.debug("in TftpPacketWithOptions.setoptions")
|
||||
log.debug("options: " + str(options))
|
||||
myoptions = {}
|
||||
for key in options:
|
||||
newkey = str(key)
|
||||
myoptions[newkey] = str(options[key])
|
||||
logger.debug("populated myoptions with %s = %s"
|
||||
log.debug("populated myoptions with %s = %s"
|
||||
% (newkey, myoptions[newkey]))
|
||||
|
||||
logger.debug("setting options hash to: " + str(myoptions))
|
||||
log.debug("setting options hash to: " + str(myoptions))
|
||||
self._options = myoptions
|
||||
|
||||
def getoptions(self):
|
||||
logger.debug("in TftpPacketWithOptions.getoptions")
|
||||
log.debug("in TftpPacketWithOptions.getoptions")
|
||||
return self._options
|
||||
|
||||
# Set up getter and setter on options to ensure that they are the proper
|
||||
|
@ -59,19 +59,19 @@ class TftpPacketWithOptions(object):
|
|||
format = "!"
|
||||
options = {}
|
||||
|
||||
logger.debug("decode_options: buffer is: " + repr(buffer))
|
||||
logger.debug("size of buffer is %d bytes" % len(buffer))
|
||||
log.debug("decode_options: buffer is: " + repr(buffer))
|
||||
log.debug("size of buffer is %d bytes" % len(buffer))
|
||||
if len(buffer) == 0:
|
||||
logger.debug("size of buffer is zero, returning empty hash")
|
||||
log.debug("size of buffer is zero, returning empty hash")
|
||||
return {}
|
||||
|
||||
# Count the nulls in the buffer. Each one terminates a string.
|
||||
logger.debug("about to iterate options buffer counting nulls")
|
||||
log.debug("about to iterate options buffer counting nulls")
|
||||
length = 0
|
||||
for c in buffer:
|
||||
#logger.debug("iterating this byte: " + repr(c))
|
||||
#log.debug("iterating this byte: " + repr(c))
|
||||
if ord(c) == 0:
|
||||
logger.debug("found a null at length %d" % length)
|
||||
log.debug("found a null at length %d" % length)
|
||||
if length > 0:
|
||||
format += "%dsx" % length
|
||||
length = -1
|
||||
|
@ -79,14 +79,14 @@ class TftpPacketWithOptions(object):
|
|||
raise TftpException, "Invalid options in buffer"
|
||||
length += 1
|
||||
|
||||
logger.debug("about to unpack, format is: %s" % format)
|
||||
log.debug("about to unpack, format is: %s" % format)
|
||||
mystruct = struct.unpack(format, buffer)
|
||||
|
||||
tftpassert(len(mystruct) % 2 == 0,
|
||||
"packet with odd number of option/value pairs")
|
||||
|
||||
for i in range(0, len(mystruct), 2):
|
||||
logger.debug("setting option %s to %s" % (mystruct[i], mystruct[i+1]))
|
||||
log.debug("setting option %s to %s" % (mystruct[i], mystruct[i+1]))
|
||||
options[mystruct[i]] = mystruct[i+1]
|
||||
|
||||
return options
|
||||
|
@ -134,10 +134,10 @@ class TftpPacketInitial(TftpPacket, TftpPacketWithOptions):
|
|||
ptype = None
|
||||
if self.opcode == 1: ptype = "RRQ"
|
||||
else: ptype = "WRQ"
|
||||
logger.debug("Encoding %s packet, filename = %s, mode = %s"
|
||||
log.debug("Encoding %s packet, filename = %s, mode = %s"
|
||||
% (ptype, self.filename, self.mode))
|
||||
for key in self.options:
|
||||
logger.debug(" Option %s = %s" % (key, self.options[key]))
|
||||
log.debug(" Option %s = %s" % (key, self.options[key]))
|
||||
|
||||
format = "!H"
|
||||
format += "%dsx" % len(self.filename)
|
||||
|
@ -148,7 +148,7 @@ class TftpPacketInitial(TftpPacket, TftpPacketWithOptions):
|
|||
# Add options.
|
||||
options_list = []
|
||||
if self.options.keys() > 0:
|
||||
logger.debug("there are options to encode")
|
||||
log.debug("there are options to encode")
|
||||
for key in self.options:
|
||||
# Populate the option name
|
||||
format += "%dsx" % len(key)
|
||||
|
@ -157,9 +157,9 @@ class TftpPacketInitial(TftpPacket, TftpPacketWithOptions):
|
|||
format += "%dsx" % len(str(self.options[key]))
|
||||
options_list.append(str(self.options[key]))
|
||||
|
||||
logger.debug("format is %s" % format)
|
||||
logger.debug("options_list is %s" % options_list)
|
||||
logger.debug("size of struct is %d" % struct.calcsize(format))
|
||||
log.debug("format is %s" % format)
|
||||
log.debug("options_list is %s" % options_list)
|
||||
log.debug("size of struct is %d" % struct.calcsize(format))
|
||||
|
||||
self.buffer = struct.pack(format,
|
||||
self.opcode,
|
||||
|
@ -167,7 +167,7 @@ class TftpPacketInitial(TftpPacket, TftpPacketWithOptions):
|
|||
self.mode,
|
||||
*options_list)
|
||||
|
||||
logger.debug("buffer is " + repr(self.buffer))
|
||||
log.debug("buffer is " + repr(self.buffer))
|
||||
return self
|
||||
|
||||
def decode(self):
|
||||
|
@ -177,13 +177,13 @@ class TftpPacketInitial(TftpPacket, TftpPacketWithOptions):
|
|||
nulls = 0
|
||||
format = ""
|
||||
nulls = length = tlength = 0
|
||||
logger.debug("in decode: about to iterate buffer counting nulls")
|
||||
log.debug("in decode: about to iterate buffer counting nulls")
|
||||
subbuf = self.buffer[2:]
|
||||
for c in subbuf:
|
||||
logger.debug("iterating this byte: " + repr(c))
|
||||
log.debug("iterating this byte: " + repr(c))
|
||||
if ord(c) == 0:
|
||||
nulls += 1
|
||||
logger.debug("found a null at length %d, now have %d"
|
||||
log.debug("found a null at length %d, now have %d"
|
||||
% (length, nulls))
|
||||
format += "%dsx" % length
|
||||
length = -1
|
||||
|
@ -193,17 +193,17 @@ class TftpPacketInitial(TftpPacket, TftpPacketWithOptions):
|
|||
length += 1
|
||||
tlength += 1
|
||||
|
||||
logger.debug("hopefully found end of mode at length %d" % tlength)
|
||||
log.debug("hopefully found end of mode at length %d" % tlength)
|
||||
# length should now be the end of the mode.
|
||||
tftpassert(nulls == 2, "malformed packet")
|
||||
shortbuf = subbuf[:tlength+1]
|
||||
logger.debug("about to unpack buffer with format: %s" % format)
|
||||
logger.debug("unpacking buffer: " + repr(shortbuf))
|
||||
log.debug("about to unpack buffer with format: %s" % format)
|
||||
log.debug("unpacking buffer: " + repr(shortbuf))
|
||||
mystruct = struct.unpack(format, shortbuf)
|
||||
|
||||
tftpassert(len(mystruct) == 2, "malformed packet")
|
||||
logger.debug("setting filename to %s" % mystruct[0])
|
||||
logger.debug("setting mode to %s" % mystruct[1])
|
||||
log.debug("setting filename to %s" % mystruct[0])
|
||||
log.debug("setting mode to %s" % mystruct[1])
|
||||
self.filename = mystruct[0]
|
||||
self.mode = mystruct[1]
|
||||
|
||||
|
@ -269,7 +269,7 @@ DATA | 03 | Block # | Data |
|
|||
"""Encode the DAT packet. This method populates self.buffer, and
|
||||
returns self for easy method chaining."""
|
||||
if len(self.data) == 0:
|
||||
logger.debug("Encoding an empty DAT packet")
|
||||
log.debug("Encoding an empty DAT packet")
|
||||
format = "!HH%ds" % len(self.data)
|
||||
self.buffer = struct.pack(format,
|
||||
self.opcode,
|
||||
|
@ -283,12 +283,12 @@ DATA | 03 | Block # | Data |
|
|||
# We know the first 2 bytes are the opcode. The second two are the
|
||||
# block number.
|
||||
(self.blocknumber,) = struct.unpack("!H", self.buffer[2:4])
|
||||
logger.debug("decoding DAT packet, block number %d" % self.blocknumber)
|
||||
logger.debug("should be %d bytes in the packet total"
|
||||
log.debug("decoding DAT packet, block number %d" % self.blocknumber)
|
||||
log.debug("should be %d bytes in the packet total"
|
||||
% len(self.buffer))
|
||||
# Everything else is data.
|
||||
self.data = self.buffer[4:]
|
||||
logger.debug("found %d bytes of data"
|
||||
log.debug("found %d bytes of data"
|
||||
% len(self.data))
|
||||
return self
|
||||
|
||||
|
@ -308,14 +308,14 @@ ACK | 04 | Block # |
|
|||
return 'ACK packet: block %d' % self.blocknumber
|
||||
|
||||
def encode(self):
|
||||
logger.debug("encoding ACK: opcode = %d, block = %d"
|
||||
log.debug("encoding ACK: opcode = %d, block = %d"
|
||||
% (self.opcode, self.blocknumber))
|
||||
self.buffer = struct.pack("!HH", self.opcode, self.blocknumber)
|
||||
return self
|
||||
|
||||
def decode(self):
|
||||
self.opcode, self.blocknumber = struct.unpack("!HH", self.buffer)
|
||||
logger.debug("decoded ACK packet: opcode = %d, block = %d"
|
||||
log.debug("decoded ACK packet: opcode = %d, block = %d"
|
||||
% (self.opcode, self.blocknumber))
|
||||
return self
|
||||
|
||||
|
@ -365,7 +365,7 @@ ERROR | 05 | ErrorCode | ErrMsg | 0 |
|
|||
"""Encode the DAT packet based on instance variables, populating
|
||||
self.buffer, returning self."""
|
||||
format = "!HH%dsx" % len(self.errmsgs[self.errorcode])
|
||||
logger.debug("encoding ERR packet with format %s" % format)
|
||||
log.debug("encoding ERR packet with format %s" % format)
|
||||
self.buffer = struct.pack(format,
|
||||
self.opcode,
|
||||
self.errorcode,
|
||||
|
@ -375,13 +375,13 @@ ERROR | 05 | ErrorCode | ErrMsg | 0 |
|
|||
def decode(self):
|
||||
"Decode self.buffer, populating instance variables and return self."
|
||||
tftpassert(len(self.buffer) > 4, "malformed ERR packet, too short")
|
||||
logger.debug("Decoding ERR packet, length %s bytes" %
|
||||
log.debug("Decoding ERR packet, length %s bytes" %
|
||||
len(self.buffer))
|
||||
format = "!HH%dsx" % (len(self.buffer) - 5)
|
||||
logger.debug("Decoding ERR packet with format: %s" % format)
|
||||
log.debug("Decoding ERR packet with format: %s" % format)
|
||||
self.opcode, self.errorcode, self.errmsg = struct.unpack(format,
|
||||
self.buffer)
|
||||
logger.error("ERR packet - errorcode: %d, message: %s"
|
||||
log.error("ERR packet - errorcode: %d, message: %s"
|
||||
% (self.errorcode, self.errmsg))
|
||||
return self
|
||||
|
||||
|
@ -402,10 +402,10 @@ class TftpPacketOACK(TftpPacket, TftpPacketWithOptions):
|
|||
def encode(self):
|
||||
format = "!H" # opcode
|
||||
options_list = []
|
||||
logger.debug("in TftpPacketOACK.encode")
|
||||
log.debug("in TftpPacketOACK.encode")
|
||||
for key in self.options:
|
||||
logger.debug("looping on option key %s" % key)
|
||||
logger.debug("value is %s" % self.options[key])
|
||||
log.debug("looping on option key %s" % key)
|
||||
log.debug("value is %s" % self.options[key])
|
||||
format += "%dsx" % len(key)
|
||||
format += "%dsx" % len(self.options[key])
|
||||
options_list.append(key)
|
||||
|
@ -429,7 +429,7 @@ class TftpPacketOACK(TftpPacket, TftpPacketWithOptions):
|
|||
# We can accept anything between the min and max values.
|
||||
size = self.options[name]
|
||||
if size >= MIN_BLKSIZE and size <= MAX_BLKSIZE:
|
||||
logger.debug("negotiated blksize of %d bytes" % size)
|
||||
log.debug("negotiated blksize of %d bytes" % size)
|
||||
options[blksize] = size
|
||||
else:
|
||||
raise TftpException, "Unsupported option: %s" % name
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
import socket, os, re, time, random
|
||||
import select
|
||||
from TftpShared import *
|
||||
from TftpPacketTypes import *
|
||||
from TftpPacketFactory import *
|
||||
|
@ -15,26 +16,27 @@ class TftpServer(TftpSession):
|
|||
self.listenip = None
|
||||
self.listenport = None
|
||||
self.sock = None
|
||||
# FIXME: What about multiple roots?
|
||||
self.root = os.path.abspath(tftproot)
|
||||
self.dynfunc = dyn_file_func
|
||||
self.dyn_file_func = dyn_file_func
|
||||
# A dict of handlers, where each session is keyed by a string like
|
||||
# ip:tid for the remote end.
|
||||
self.handlers = {}
|
||||
|
||||
if os.path.exists(self.root):
|
||||
logger.debug("tftproot %s does exist" % self.root)
|
||||
log.debug("tftproot %s does exist" % self.root)
|
||||
if not os.path.isdir(self.root):
|
||||
raise TftpException, "The tftproot must be a directory."
|
||||
else:
|
||||
logger.debug("tftproot %s is a directory" % self.root)
|
||||
log.debug("tftproot %s is a directory" % self.root)
|
||||
if os.access(self.root, os.R_OK):
|
||||
logger.debug("tftproot %s is readable" % self.root)
|
||||
log.debug("tftproot %s is readable" % self.root)
|
||||
else:
|
||||
raise TftpException, "The tftproot must be readable"
|
||||
if os.access(self.root, os.W_OK):
|
||||
logger.debug("tftproot %s is writable" % self.root)
|
||||
log.debug("tftproot %s is writable" % self.root)
|
||||
else:
|
||||
logger.warning("The tftproot %s is not writable" % self.root)
|
||||
log.warning("The tftproot %s is not writable" % self.root)
|
||||
else:
|
||||
raise TftpException, "The tftproot does not exist."
|
||||
|
||||
|
@ -45,14 +47,12 @@ class TftpServer(TftpSession):
|
|||
"""Start a server listening on the supplied interface and port. This
|
||||
defaults to INADDR_ANY (all interfaces) and UDP port 69. You can also
|
||||
supply a different socket timeout value, if desired."""
|
||||
import select
|
||||
|
||||
tftp_factory = TftpPacketFactory()
|
||||
|
||||
# Don't use new 2.5 ternary operator yet
|
||||
# listenip = listenip if listenip else '0.0.0.0'
|
||||
if not listenip: listenip = '0.0.0.0'
|
||||
logger.info("Server requested on ip %s, port %s"
|
||||
log.info("Server requested on ip %s, port %s"
|
||||
% (listenip, listenport))
|
||||
try:
|
||||
# FIXME - sockets should be non-blocking?
|
||||
|
@ -62,388 +62,82 @@ class TftpServer(TftpSession):
|
|||
# Reraise it for now.
|
||||
raise
|
||||
|
||||
logger.info("Starting receive loop...")
|
||||
log.info("Starting receive loop...")
|
||||
while True:
|
||||
# Build the inputlist array of sockets to select() on.
|
||||
inputlist = []
|
||||
inputlist.append(self.sock)
|
||||
for key in self.handlers:
|
||||
inputlist.append(self.handlers[key].sock)
|
||||
for key in self.sessions:
|
||||
inputlist.append(self.sessions[key].sock)
|
||||
|
||||
# Block until some socket has input on it.
|
||||
logger.debug("Performing select on this inputlist: %s" % inputlist)
|
||||
log.debug("Performing select on this inputlist: %s" % inputlist)
|
||||
readyinput, readyoutput, readyspecial = select.select(inputlist,
|
||||
[],
|
||||
[],
|
||||
SOCK_TIMEOUT)
|
||||
|
||||
#(buffer, (raddress, rport)) = self.sock.recvfrom(MAX_BLKSIZE)
|
||||
#recvpkt = tftp_factory.parse(buffer)
|
||||
#key = "%s:%s" % (raddress, rport)
|
||||
|
||||
deletion_list = []
|
||||
|
||||
# Handle the available data, if any. Maybe we timed-out.
|
||||
for readysock in readyinput:
|
||||
# Is the traffic on the main server socket? ie. new session?
|
||||
if readysock == self.sock:
|
||||
logger.debug("Data ready on our main socket")
|
||||
log.debug("Data ready on our main socket")
|
||||
buffer, (raddress, rport) = self.sock.recvfrom(MAX_BLKSIZE)
|
||||
logger.debug("Read %d bytes" % len(buffer))
|
||||
|
||||
log.debug("Read %d bytes" % len(buffer))
|
||||
|
||||
recvpkt = tftp_factory.parse(buffer)
|
||||
# FIXME: Is this the best way to do a session key? What
|
||||
# about symmetric udp?
|
||||
key = "%s:%s" % (raddress, rport)
|
||||
|
||||
if isinstance(recvpkt, TftpPacketRRQ):
|
||||
logger.debug("RRQ packet from %s:%s" % (raddress, rport))
|
||||
if not self.handlers.has_key(key):
|
||||
try:
|
||||
logger.debug("New download request, session key = %s"
|
||||
% key)
|
||||
self.handlers[key] = TftpServerHandler(key,
|
||||
'rrq',
|
||||
self.root,
|
||||
listenip,
|
||||
tftp_factory,
|
||||
self.dynfunc)
|
||||
self.handlers[key].handle((recvpkt, raddress, rport))
|
||||
except TftpException, err:
|
||||
logger.error("Fatal exception thrown from handler: %s"
|
||||
% str(err))
|
||||
logger.debug("Deleting handler: %s" % key)
|
||||
deletion_list.append(key)
|
||||
|
||||
else:
|
||||
logger.warn("Received RRQ for existing session!")
|
||||
self.senderror(self.sock,
|
||||
TftpErrors.IllegalTftpOp,
|
||||
raddress,
|
||||
rport)
|
||||
continue
|
||||
|
||||
elif isinstance(recvpkt, TftpPacketWRQ):
|
||||
logger.error("Write requests not implemented at this time.")
|
||||
self.senderror(self.sock,
|
||||
TftpErrors.IllegalTftpOp,
|
||||
raddress,
|
||||
rport)
|
||||
continue
|
||||
if not self.sessions.has_key(key):
|
||||
log.debug("Creating new server context for "
|
||||
"session key = %s" % key)
|
||||
self.sessions[key] = TftpContextServer(raddress,
|
||||
rport,
|
||||
timeout,
|
||||
self.root,
|
||||
self.dyn_file_func)
|
||||
self.sessions[key].start(buffer)
|
||||
else:
|
||||
# FIXME - this will have to change if we do symmetric UDP
|
||||
logger.error("Should only receive RRQ or WRQ packets "
|
||||
"on main listen port. Received %s" % recvpkt)
|
||||
self.senderror(self.sock,
|
||||
TftpErrors.IllegalTftpOp,
|
||||
raddress,
|
||||
rport)
|
||||
continue
|
||||
log.warn("received traffic on main socket for "
|
||||
"existing session??")
|
||||
|
||||
else:
|
||||
for key in self.handlers:
|
||||
if readysock == self.handlers[key].sock:
|
||||
# FIXME - violating DRY principle with above code
|
||||
# Must find the owner of this traffic.
|
||||
for key in self.session:
|
||||
if readysock == self.session[key].sock:
|
||||
try:
|
||||
self.handlers[key].handle()
|
||||
self.session[key].cycle()
|
||||
if self.session[key].state == None:
|
||||
log.info("Successful transfer.")
|
||||
deletion_list.append(key)
|
||||
break
|
||||
except TftpException, err:
|
||||
deletion_list.append(key)
|
||||
if self.handlers[key].state.state == 'fin':
|
||||
logger.info("Successful transfer.")
|
||||
break
|
||||
else:
|
||||
logger.error("Fatal exception thrown from handler: %s"
|
||||
% str(err))
|
||||
log.error("Fatal exception thrown from "
|
||||
"handler: %s" % str(err))
|
||||
|
||||
else:
|
||||
logger.error("Can't find the owner for this packet. Discarding.")
|
||||
log.error("Can't find the owner for this packet. "
|
||||
"Discarding.")
|
||||
|
||||
logger.debug("Looping on all handlers to check for timeouts")
|
||||
log.debug("Looping on all handlers to check for timeouts")
|
||||
now = time.time()
|
||||
for key in self.handlers:
|
||||
for key in self.sessions:
|
||||
try:
|
||||
self.handlers[key].check_timeout(now)
|
||||
self.sessions[key].checkTimeout(now)
|
||||
except TftpException, err:
|
||||
logger.error("Fatal exception thrown from handler: %s"
|
||||
log.error("Fatal exception thrown from handler: %s"
|
||||
% str(err))
|
||||
deletion_list.append(key)
|
||||
|
||||
logger.debug("Iterating deletion list.")
|
||||
log.debug("Iterating deletion list.")
|
||||
for key in deletion_list:
|
||||
if self.handlers.has_key(key):
|
||||
logger.debug("Deleting handler %s" % key)
|
||||
del self.handlers[key]
|
||||
if self.sessions.has_key(key):
|
||||
log.debug("Deleting handler %s" % key)
|
||||
del self.sessions[key]
|
||||
deletion_list = []
|
||||
|
||||
class TftpServerHandler(TftpSession):
|
||||
"""This class implements a handler for a given server session, handling
|
||||
the work for one download."""
|
||||
|
||||
def __init__(self, key, state, root, listenip, factory, dyn_file_func):
|
||||
TftpSession.__init__(self)
|
||||
logger.info("Starting new handler. Key %s." % key)
|
||||
self.key = key
|
||||
self.host, self.port = self.key.split(':')
|
||||
self.port = int(self.port)
|
||||
self.listenip = listenip
|
||||
# Note, correct state here is important as it tells the handler whether it's
|
||||
# handling a download or an upload.
|
||||
self.state = state
|
||||
self.root = root
|
||||
self.mode = None
|
||||
self.filename = None
|
||||
self.sock = False
|
||||
self.options = { 'blksize': DEF_BLKSIZE }
|
||||
self.blocknumber = 0
|
||||
self.buffer = None
|
||||
self.fileobj = None
|
||||
self.timesent = 0
|
||||
self.timeouts = 0
|
||||
self.tftp_factory = factory
|
||||
self.dynfunc = dyn_file_func
|
||||
count = 0
|
||||
while not self.sock:
|
||||
self.sock = self.gensock(listenip)
|
||||
count += 1
|
||||
if count > 10:
|
||||
raise TftpException, "Failed to bind this handler to any port"
|
||||
|
||||
def check_timeout(self, now):
|
||||
"""This method checks to see if we've timed-out waiting for traffic
|
||||
from the client."""
|
||||
if self.timesent:
|
||||
if now - self.timesent > SOCK_TIMEOUT:
|
||||
self.timeout()
|
||||
|
||||
def timeout(self):
|
||||
"""This method handles a timeout condition."""
|
||||
logger.debug("Handling timeout for handler %s" % self.key)
|
||||
self.timeouts += 1
|
||||
if self.timeouts > TIMEOUT_RETRIES:
|
||||
raise TftpException, "Hit max retries, giving up."
|
||||
|
||||
if self.state.state == 'dat' or self.state.state == 'fin':
|
||||
logger.debug("Timing out on DAT. Need to resend.")
|
||||
self.send_dat(resend=True)
|
||||
elif self.state.state == 'oack':
|
||||
logger.debug("Timing out on OACK. Need to resend.")
|
||||
self.send_oack()
|
||||
else:
|
||||
tftpassert(False,
|
||||
"Timing out in unsupported state %s" %
|
||||
self.state.state)
|
||||
|
||||
def gensock(self, listenip):
|
||||
"""This method generates a new UDP socket, whose listening port must
|
||||
be randomly generated, and not conflict with any already in use. For
|
||||
now, let the OS do this."""
|
||||
random.seed()
|
||||
port = random.randrange(1025, 65536)
|
||||
# FIXME - sockets should be non-blocking?
|
||||
sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
|
||||
logger.debug("Trying a handler socket on port %d" % port)
|
||||
try:
|
||||
sock.bind((listenip, port))
|
||||
return sock
|
||||
except socket.error, err:
|
||||
if err[0] == 98:
|
||||
logger.warn("Handler %s, port %d was already taken" % (self.key, port))
|
||||
return False
|
||||
else:
|
||||
raise
|
||||
|
||||
def handle(self, pkttuple=None):
|
||||
"""This method informs a handler instance that it has data waiting on
|
||||
its socket that it must read and process."""
|
||||
recvpkt = raddress = rport = None
|
||||
if pkttuple:
|
||||
logger.debug("Handed pkt %s for handler %s" % (recvpkt, self.key))
|
||||
recvpkt, raddress, rport = pkttuple
|
||||
else:
|
||||
logger.debug("Data ready for handler %s" % self.key)
|
||||
buffer, (raddress, rport) = self.sock.recvfrom(MAX_BLKSIZE)
|
||||
logger.debug("Read %d bytes" % len(buffer))
|
||||
recvpkt = self.tftp_factory.parse(buffer)
|
||||
|
||||
# FIXME - refactor into another method, this is too big
|
||||
if isinstance(recvpkt, TftpPacketRRQ):
|
||||
logger.debug("Handler %s received RRQ packet" % self.key)
|
||||
logger.debug("Requested file is %s, mode is %s" % (recvpkt.filename,
|
||||
recvpkt.mode))
|
||||
# FIXME - only octet mode is supported at this time.
|
||||
if recvpkt.mode != 'octet':
|
||||
self.senderror(self.sock,
|
||||
TftpErrors.IllegalTftpOp,
|
||||
raddress,
|
||||
rport)
|
||||
raise TftpException, "Unsupported mode: %s" % recvpkt.mode
|
||||
|
||||
# test host/port of client end
|
||||
if self.host != raddress or self.port != rport:
|
||||
self.senderror(self.sock,
|
||||
TftpErrors.UnknownTID,
|
||||
raddress,
|
||||
rport)
|
||||
logger.error("Expected traffic from %s:%s but received it "
|
||||
"from %s:%s instead."
|
||||
% (self.host, self.port, raddress, rport))
|
||||
self.errors += 1
|
||||
return
|
||||
|
||||
if self.state.state == 'rrq':
|
||||
logger.debug("Received RRQ. Composing response.")
|
||||
self.filename = self.root + os.sep + recvpkt.filename
|
||||
logger.debug("The path to the desired file is %s" %
|
||||
self.filename)
|
||||
self.filename = os.path.abspath(self.filename)
|
||||
logger.debug("The absolute path is %s" % self.filename)
|
||||
# Security check. Make sure it's prefixed by the tftproot.
|
||||
if self.filename.find(self.root) == 0:
|
||||
logger.debug("The path appears to be safe: %s" %
|
||||
self.filename)
|
||||
else:
|
||||
logger.error("Insecure path: %s" % self.filename)
|
||||
self.errors += 1
|
||||
self.senderror(self.sock,
|
||||
TftpErrors.AccessViolation,
|
||||
raddress,
|
||||
rport)
|
||||
raise TftpException, "Insecure path: %s" % self.filename
|
||||
|
||||
# Does the file exist?
|
||||
if(os.path.exists(self.filename) or not self.dynfunc is None):
|
||||
logger.debug("File %s exists." % self.filename)
|
||||
|
||||
# Check options. Currently we only support the blksize
|
||||
# option.
|
||||
if recvpkt.options.has_key('blksize'):
|
||||
logger.debug("RRQ includes a blksize option")
|
||||
blksize = int(recvpkt.options['blksize'])
|
||||
# Delete the option now that it's handled.
|
||||
del recvpkt.options['blksize']
|
||||
if blksize >= MIN_BLKSIZE and blksize <= MAX_BLKSIZE:
|
||||
logger.info("Client requested blksize = %d"
|
||||
% blksize)
|
||||
self.options['blksize'] = blksize
|
||||
else:
|
||||
logger.warning("Client %s requested invalid "
|
||||
"blocksize %d, responding with default"
|
||||
% (self.key, blksize))
|
||||
self.options['blksize'] = DEF_BLKSIZE
|
||||
|
||||
if recvpkt.options.has_key('tsize'):
|
||||
logger.info('RRQ includes tsize option')
|
||||
self.options['tsize'] = os.stat(self.filename).st_size
|
||||
# Delete the option now that it's handled.
|
||||
del recvpkt.options['tsize']
|
||||
|
||||
if len(recvpkt.options.keys()) > 0:
|
||||
logger.warning("Client %s requested unsupported options: %s"
|
||||
% (self.key, recvpkt.options))
|
||||
|
||||
if self.options:
|
||||
logger.info("Options requested, sending OACK")
|
||||
self.send_oack()
|
||||
else:
|
||||
logger.debug("Client %s requested no options."
|
||||
% self.key)
|
||||
self.start_download()
|
||||
|
||||
else:
|
||||
logger.error("Requested file %s does not exist." %
|
||||
self.filename)
|
||||
self.senderror(self.sock,
|
||||
TftpErrors.FileNotFound,
|
||||
raddress,
|
||||
rport)
|
||||
raise TftpException, "Requested file not found: %s" % self.filename
|
||||
|
||||
else:
|
||||
# We're receiving an RRQ when we're not expecting one.
|
||||
logger.error("Received an RRQ in handler %s "
|
||||
"but we're in state %s" % (self.key, self.state))
|
||||
self.errors += 1
|
||||
|
||||
# Next packet type
|
||||
elif isinstance(recvpkt, TftpPacketACK):
|
||||
logger.debug("Received an ACK from the client.")
|
||||
if recvpkt.blocknumber == 0 and self.state.state == 'oack':
|
||||
logger.debug("Received ACK with 0 blocknumber, starting download")
|
||||
self.start_download()
|
||||
else:
|
||||
if self.state.state == 'dat' or self.state.state == 'fin':
|
||||
if self.blocknumber == recvpkt.blocknumber:
|
||||
logger.debug("Received ACK for block %d"
|
||||
% recvpkt.blocknumber)
|
||||
if self.state.state == 'fin':
|
||||
raise TftpException, "Successful transfer."
|
||||
else:
|
||||
self.send_dat()
|
||||
elif recvpkt.blocknumber < self.blocknumber:
|
||||
# Don't resend a DAT due to an old ACK. Fixes the
|
||||
# sorceror's apprentice problem.
|
||||
logger.warn("Received old ACK for block number %d"
|
||||
% recvpkt.blocknumber)
|
||||
else:
|
||||
logger.warn("Received ACK for block number "
|
||||
"%d, apparently from the future"
|
||||
% recvpkt.blocknumber)
|
||||
else:
|
||||
logger.error("Received ACK with block number %d "
|
||||
"while in state %s"
|
||||
% (recvpkt.blocknumber,
|
||||
self.state.state))
|
||||
|
||||
elif isinstance(recvpkt, TftpPacketERR):
|
||||
logger.error("Received error packet from client: %s" % recvpkt)
|
||||
self.state.state = 'err'
|
||||
raise TftpException, "Received error from client"
|
||||
|
||||
# Handle other packet types.
|
||||
else:
|
||||
logger.error("Received packet %s while handling a download"
|
||||
% recvpkt)
|
||||
self.senderror(self.sock,
|
||||
TftpErrors.IllegalTftpOp,
|
||||
self.host,
|
||||
self.port)
|
||||
raise TftpException, "Invalid packet received during download"
|
||||
|
||||
def start_download(self):
|
||||
"""This method opens self.filename, stores the resulting file object
|
||||
in self.fileobj, and calls send_dat()."""
|
||||
self.state.state = 'dat'
|
||||
if os.path.exists(self.filename):
|
||||
self.fileobj = open(self.filename, "rb")
|
||||
else:
|
||||
self.fileobj = self.dynfunc(self.filename)
|
||||
self.send_dat()
|
||||
|
||||
def send_dat(self, resend=False):
|
||||
"""This method reads sends a DAT packet based on what is in self.buffer."""
|
||||
if not resend:
|
||||
blksize = int(self.options['blksize'])
|
||||
self.buffer = self.fileobj.read(blksize)
|
||||
logger.debug("Read %d bytes into buffer" % len(self.buffer))
|
||||
if len(self.buffer) < blksize:
|
||||
logger.info("Reached EOF on file %s" % self.filename)
|
||||
self.state.state = 'fin'
|
||||
self.blocknumber += 1
|
||||
if self.blocknumber > 65535:
|
||||
logger.debug("Blocknumber rolled over to zero")
|
||||
self.blocknumber = 0
|
||||
else:
|
||||
logger.warn("Resending block number %d" % self.blocknumber)
|
||||
dat = TftpPacketDAT()
|
||||
dat.data = self.buffer
|
||||
dat.blocknumber = self.blocknumber
|
||||
logger.debug("Sending DAT packet %d" % self.blocknumber)
|
||||
self.sock.sendto(dat.encode().buffer, (self.host, self.port))
|
||||
self.timesent = time.time()
|
||||
|
||||
# FIXME - should these be factored-out into the session class?
|
||||
def send_oack(self):
|
||||
"""This method sends an OACK packet based on current params."""
|
||||
logger.debug("Composing and sending OACK packet")
|
||||
oack = TftpPacketOACK()
|
||||
oack.options = self.options
|
||||
self.sock.sendto(oack.encode().buffer,
|
||||
(self.host, self.port))
|
||||
self.timesent = time.time()
|
||||
self.state.state = 'oack'
|
||||
|
|
|
@ -17,7 +17,7 @@ DEF_TFTP_PORT = 69
|
|||
logging.basicConfig()
|
||||
# The logger used by this library. Feel free to clobber it with your own, if you like, as
|
||||
# long as it conforms to Python's logging.
|
||||
logger = logging.getLogger('tftpy')
|
||||
log = logging.getLogger('tftpy')
|
||||
|
||||
def tftpassert(condition, msg):
|
||||
"""This function is a simple utility that will check the condition
|
||||
|
@ -31,8 +31,8 @@ def setLogLevel(level):
|
|||
"""This function is a utility function for setting the internal log level.
|
||||
The log level defaults to logging.NOTSET, so unwanted output to stdout is
|
||||
not created."""
|
||||
global logger
|
||||
logger.setLevel(level)
|
||||
global log
|
||||
log.setLevel(level)
|
||||
|
||||
class TftpErrors(object):
|
||||
"""This class is a convenience for defining the common tftp error codes,
|
||||
|
|
|
@ -22,17 +22,28 @@ class TftpMetrics(object):
|
|||
# Rates
|
||||
self.bps = 0
|
||||
self.kbps = 0
|
||||
# Generic errors
|
||||
self.errors = 0
|
||||
|
||||
def compute(self):
|
||||
# Compute transfer time
|
||||
self.duration = self.end_time - self.start_time
|
||||
logger.debug("TftpMetrics.compute: duration is %s" % self.duration)
|
||||
log.debug("TftpMetrics.compute: duration is %s" % self.duration)
|
||||
self.bps = (self.bytes * 8.0) / self.duration
|
||||
self.kbps = self.bps / 1024.0
|
||||
logger.debug("TftpMetrics.compute: kbps is %s" % self.kbps)
|
||||
dupcount = 0
|
||||
log.debug("TftpMetrics.compute: kbps is %s" % self.kbps)
|
||||
for key in self.dups:
|
||||
dupcount += self.dups[key]
|
||||
self.dupcount += self.dups[key]
|
||||
|
||||
def add_dup(self, blocknumber):
|
||||
"""This method adds a dup for a block number to the metrics."""
|
||||
log.debug("Recording a dup for block %d" % blocknumber)
|
||||
if self.dups.has_key(blocknumber):
|
||||
self.dups[pkt.blocknumber] += 1
|
||||
else:
|
||||
self.dups[pkt.blocknumber] = 1
|
||||
tftpassert(self.dups[pkt.blocknumber] < MAX_DUPS,
|
||||
"Max duplicates for block %d reached" % blocknumber)
|
||||
|
||||
###############################################################################
|
||||
# Context classes
|
||||
|
@ -40,16 +51,32 @@ class TftpMetrics(object):
|
|||
|
||||
class TftpContext(object):
|
||||
"""The base class of the contexts."""
|
||||
def __init__(self, host, port):
|
||||
|
||||
def __init__(self, host, port, timeout):
|
||||
"""Constructor for the base context, setting shared instance
|
||||
variables."""
|
||||
self.file_to_transfer = None
|
||||
self.fileobj = None
|
||||
self.options = None
|
||||
self.packethook = None
|
||||
self.sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
|
||||
self.sock.settimeout(timeout)
|
||||
self.state = None
|
||||
self.next_block = 0
|
||||
self.factory = TftpPacketFactory()
|
||||
# Note, setting the host will also set self.address, as it's a property.
|
||||
self.host = host
|
||||
self.port = port
|
||||
# The port associated with the TID
|
||||
self.tidport = None
|
||||
# Metrics
|
||||
self.metrics = TftpMetrics()
|
||||
# Flag when the transfer is pending completion.
|
||||
self.pending_complete = False
|
||||
|
||||
def checkTimeout(self, now):
|
||||
# FIXME
|
||||
pass
|
||||
|
||||
def start(self):
|
||||
return NotImplementedError, "Abstract method"
|
||||
|
@ -69,37 +96,9 @@ class TftpContext(object):
|
|||
|
||||
host = property(gethost, sethost)
|
||||
|
||||
def sendAck(self, blocknumber):
|
||||
"""This method sends an ack packet to the block number specified."""
|
||||
logger.info("sending ack to block %d" % blocknumber)
|
||||
ackpkt = TftpPacketACK()
|
||||
ackpkt.blocknumber = blocknumber
|
||||
self.sock.sendto(ackpkt.encode().buffer, (self.host, self.tidport))
|
||||
|
||||
def sendError(self, errorcode):
|
||||
"""This method uses the socket passed, and uses the errorcode to
|
||||
compose and send an error packet."""
|
||||
logger.debug("In sendError, being asked to send error %d" % errorcode)
|
||||
errpkt = TftpPacketERR()
|
||||
errpkt.errorcode = errorcode
|
||||
self.sock.sendto(errpkt.encode().buffer, (self.host, self.tidport))
|
||||
|
||||
class TftpContextClient(TftpContext):
|
||||
"""This class represents shared functionality by both the download and
|
||||
upload client contexts."""
|
||||
def __init__(self, host, port, filename, options, packethook, timeout):
|
||||
TftpContext.__init__(self, host, port)
|
||||
self.file_to_transfer = filename
|
||||
self.options = options
|
||||
self.packethook = packethook
|
||||
self.sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
|
||||
self.sock.settimeout(timeout)
|
||||
self.state = None
|
||||
self.next_block = 0
|
||||
|
||||
def setNextBlock(self, block):
|
||||
if block > 2 ** 16:
|
||||
logger.debug("block number rollover to 0 again")
|
||||
log.debug("block number rollover to 0 again")
|
||||
block = 0
|
||||
self.__eblock = block
|
||||
|
||||
|
@ -111,19 +110,21 @@ class TftpContextClient(TftpContext):
|
|||
def cycle(self):
|
||||
"""Here we wait for a response from the server after sending it
|
||||
something, and dispatch appropriate action to that response."""
|
||||
# FIXME: This won't work very well in a server context with multiple
|
||||
# sessions running.
|
||||
for i in range(TIMEOUT_RETRIES):
|
||||
logger.debug("in cycle, receive attempt %d" % i)
|
||||
log.debug("in cycle, receive attempt %d" % i)
|
||||
try:
|
||||
(buffer, (raddress, rport)) = self.sock.recvfrom(MAX_BLKSIZE)
|
||||
except socket.timeout, err:
|
||||
logger.warn("Timeout waiting for traffic, retrying...")
|
||||
log.warn("Timeout waiting for traffic, retrying...")
|
||||
continue
|
||||
break
|
||||
else:
|
||||
raise TftpException, "Hit max timeouts, giving up."
|
||||
|
||||
# Ok, we've received a packet. Log it.
|
||||
logger.debug("Received %d bytes from %s:%s"
|
||||
log.debug("Received %d bytes from %s:%s"
|
||||
% (len(buffer), raddress, rport))
|
||||
|
||||
# Decode it.
|
||||
|
@ -131,11 +132,11 @@ class TftpContextClient(TftpContext):
|
|||
|
||||
# Check for known "connection".
|
||||
if raddress != self.address:
|
||||
logger.warn("Received traffic from %s, expected host %s. Discarding"
|
||||
log.warn("Received traffic from %s, expected host %s. Discarding"
|
||||
% (raddress, self.host))
|
||||
|
||||
if self.tidport and self.tidport != rport:
|
||||
logger.warn("Received traffic from %s:%s but we're "
|
||||
log.warn("Received traffic from %s:%s but we're "
|
||||
"connected to %s:%s. Discarding."
|
||||
% (raddress, rport,
|
||||
self.host, self.tidport))
|
||||
|
@ -150,29 +151,66 @@ class TftpContextClient(TftpContext):
|
|||
# And handle it, possibly changing state.
|
||||
self.state = self.state.handle(recvpkt, raddress, rport)
|
||||
|
||||
class TftpContextClientUpload(TftpContextClient):
|
||||
class TftpContextServer(TftpContext):
|
||||
"""The context for the server."""
|
||||
def __init__(self, host, port, timeout, root, dyn_file_func):
|
||||
TftpContext.__init__(self,
|
||||
host,
|
||||
port,
|
||||
timeout)
|
||||
# At this point we have no idea if this is a download or an upload. We
|
||||
# need to let the start state determine that.
|
||||
self.state = TftpStateServerStart()
|
||||
self.root = root
|
||||
self.dyn_file_func = dyn_file_func
|
||||
|
||||
def start(self, buffer):
|
||||
"""Start the state cycle. Note that the server context receives an
|
||||
initial packet in its start method."""
|
||||
log.debug("TftpContextServer.start() - pkt = %s" % pkt)
|
||||
|
||||
self.metrics.start_time = time.time()
|
||||
log.debug("set metrics.start_time to %s" % self.metrics.start_time)
|
||||
|
||||
pkt = self.factory.parse(buffer)
|
||||
log.debug("TftpContextServer.start() - factory returned a %s" % pkt)
|
||||
|
||||
# Call handle once with the initial packet. This should put us into
|
||||
# the download or the upload state.
|
||||
self.state = self.state.handle(pkt,
|
||||
self.host,
|
||||
self.port)
|
||||
|
||||
try:
|
||||
while self.state:
|
||||
log.debug("state is %s" % self.state)
|
||||
self.cycle()
|
||||
finally:
|
||||
self.fileobj.close()
|
||||
|
||||
class TftpContextClientUpload(TftpContext):
|
||||
"""The upload context for the client during an upload."""
|
||||
def __init__(self, host, port, filename, input, options, packethook, timeout):
|
||||
TftpContextClient.__init__(self,
|
||||
host,
|
||||
port,
|
||||
filename,
|
||||
options,
|
||||
packethook,
|
||||
timeout)
|
||||
TftpContext.__init__(self,
|
||||
host,
|
||||
port,
|
||||
timeout)
|
||||
self.file_to_transfer = filename
|
||||
self.options = options
|
||||
self.packethook = packethook
|
||||
self.fileobj = open(input, "wb")
|
||||
|
||||
logger.debug("TftpContextClientUpload.__init__()")
|
||||
logger.debug("file_to_transfer = %s, options = %s" %
|
||||
log.debug("TftpContextClientUpload.__init__()")
|
||||
log.debug("file_to_transfer = %s, options = %s" %
|
||||
(self.file_to_transfer, self.options))
|
||||
|
||||
def start(self):
|
||||
logger.info("sending tftp upload request to %s" % self.host)
|
||||
logger.info(" filename -> %s" % self.file_to_transfer)
|
||||
logger.info(" options -> %s" % self.options)
|
||||
log.info("sending tftp upload request to %s" % self.host)
|
||||
log.info(" filename -> %s" % self.file_to_transfer)
|
||||
log.info(" options -> %s" % self.options)
|
||||
|
||||
self.metrics.start_time = time.time()
|
||||
logger.debug("set metrics.start_time to %s" % self.metrics.start_time)
|
||||
log.debug("set metrics.start_time to %s" % self.metrics.start_time)
|
||||
|
||||
# FIXME: put this in a sendWRQ method?
|
||||
pkt = TftpPacketWRQ()
|
||||
|
@ -186,7 +224,7 @@ class TftpContextClientUpload(TftpContextClient):
|
|||
|
||||
try:
|
||||
while self.state:
|
||||
logger.debug("state is %s" % self.state)
|
||||
log.debug("state is %s" % self.state)
|
||||
self.cycle()
|
||||
finally:
|
||||
self.fileobj.close()
|
||||
|
@ -194,32 +232,32 @@ class TftpContextClientUpload(TftpContextClient):
|
|||
def end(self):
|
||||
pass
|
||||
|
||||
class TftpContextClientDownload(TftpContextClient):
|
||||
class TftpContextClientDownload(TftpContext):
|
||||
"""The download context for the client during a download."""
|
||||
def __init__(self, host, port, filename, output, options, packethook, timeout):
|
||||
TftpContextClient.__init__(self,
|
||||
host,
|
||||
port,
|
||||
filename,
|
||||
options,
|
||||
packethook,
|
||||
timeout)
|
||||
TftpContext.__init__(self,
|
||||
host,
|
||||
port,
|
||||
filename,
|
||||
options,
|
||||
packethook,
|
||||
timeout)
|
||||
# FIXME - need to support alternate return formats than files?
|
||||
# File-like objects would be ideal, ala duck-typing.
|
||||
self.fileobj = open(output, "wb")
|
||||
|
||||
logger.debug("TftpContextClientDownload.__init__()")
|
||||
logger.debug("file_to_transfer = %s, options = %s" %
|
||||
log.debug("TftpContextClientDownload.__init__()")
|
||||
log.debug("file_to_transfer = %s, options = %s" %
|
||||
(self.file_to_transfer, self.options))
|
||||
|
||||
def start(self):
|
||||
"""Initiate the download."""
|
||||
logger.info("sending tftp download request to %s" % self.host)
|
||||
logger.info(" filename -> %s" % self.file_to_transfer)
|
||||
logger.info(" options -> %s" % self.options)
|
||||
log.info("sending tftp download request to %s" % self.host)
|
||||
log.info(" filename -> %s" % self.file_to_transfer)
|
||||
log.info(" options -> %s" % self.options)
|
||||
|
||||
self.metrics.start_time = time.time()
|
||||
logger.debug("set metrics.start_time to %s" % self.metrics.start_time)
|
||||
log.debug("set metrics.start_time to %s" % self.metrics.start_time)
|
||||
|
||||
# FIXME: put this in a sendRRQ method?
|
||||
pkt = TftpPacketRRQ()
|
||||
|
@ -233,7 +271,7 @@ class TftpContextClientDownload(TftpContextClient):
|
|||
|
||||
try:
|
||||
while self.state:
|
||||
logger.debug("state is %s" % self.state)
|
||||
log.debug("state is %s" % self.state)
|
||||
self.cycle()
|
||||
finally:
|
||||
self.fileobj.close()
|
||||
|
@ -241,7 +279,7 @@ class TftpContextClientDownload(TftpContextClient):
|
|||
def end(self):
|
||||
"""Finish up the context."""
|
||||
self.metrics.end_time = time.time()
|
||||
logger.debug("set metrics.end_time to %s" % self.metrics.end_time)
|
||||
log.debug("set metrics.end_time to %s" % self.metrics.end_time)
|
||||
self.metrics.compute()
|
||||
|
||||
|
||||
|
@ -268,235 +306,431 @@ class TftpState(object):
|
|||
options."""
|
||||
if pkt.options.keys() > 0:
|
||||
if pkt.match_options(self.context.options):
|
||||
logger.info("Successful negotiation of options")
|
||||
log.info("Successful negotiation of options")
|
||||
# Set options to OACK options
|
||||
self.context.options = pkt.options
|
||||
for key in self.context.options:
|
||||
logger.info(" %s = %s" % (key, self.context.options[key]))
|
||||
log.info(" %s = %s" % (key, self.context.options[key]))
|
||||
else:
|
||||
logger.error("failed to negotiate options")
|
||||
log.error("failed to negotiate options")
|
||||
raise TftpException, "Failed to negotiate options"
|
||||
else:
|
||||
raise TftpException, "No options found in OACK"
|
||||
|
||||
class TftpStateUpload(TftpState):
|
||||
"""A class holding common code for upload states."""
|
||||
def sendDat(self, resend=False):
|
||||
def returnSupportedOptions(self, options):
|
||||
"""This method takes a requested options list from a client, and
|
||||
returns the ones that are supported."""
|
||||
# We support the options blksize and tsize right now.
|
||||
# FIXME - put this somewhere else?
|
||||
accepted_options = {}
|
||||
for option in options:
|
||||
if option == 'blksize':
|
||||
# Make sure it's valid.
|
||||
if int(options[option]) > MAX_BLKSIZE:
|
||||
accepted_options[option] = MAX_BLKSIZE
|
||||
elif option == 'tsize':
|
||||
log.debug("tsize option is set")
|
||||
accepted_options['tsize'] = 1
|
||||
else:
|
||||
log.info("Dropping unsupported option '%s'" % option)
|
||||
return accepted_options
|
||||
|
||||
def serverInitial(self, pkt, raddress, rport):
|
||||
"""This method performs initial setup for a server context transfer,
|
||||
put here to refactor code out of the TftpStateServerRecvRRQ and
|
||||
TftpStateServerRecvWRQ classes, since their initial setup is
|
||||
identical. The method returns a boolean, sendoack, to indicate whether
|
||||
it is required to send an OACK to the client."""
|
||||
options = pkt.options
|
||||
sendoack = False
|
||||
if not options:
|
||||
log.debug("setting default options, blksize")
|
||||
# FIXME: put default options elsewhere
|
||||
self.context.options = { 'blksize': DEF_BLKSIZE }
|
||||
else:
|
||||
log.debug("options requested: %s" % options)
|
||||
self.context.options = self.returnSupportedOptions(options)
|
||||
sendoack = True
|
||||
|
||||
# FIXME - only octet mode is supported at this time.
|
||||
if pkt.mode != 'octet':
|
||||
self.sendError(TftpErrors.IllegalTftpOp)
|
||||
raise TftpException, \
|
||||
"Only octet transfers are supported at this time."
|
||||
|
||||
# test host/port of client end
|
||||
if self.context.host != raddress or self.context.port != rport:
|
||||
self.sendError(TftpErrors.UnknownTID)
|
||||
log.error("Expected traffic from %s:%s but received it "
|
||||
"from %s:%s instead."
|
||||
% (self.context.host,
|
||||
self.context.port,
|
||||
raddress,
|
||||
rport))
|
||||
# FIXME: increment an error count?
|
||||
# Return same state, we're still waiting for valid traffic.
|
||||
return self
|
||||
|
||||
log.debug("requested filename is %s" % pkt.filename)
|
||||
# There are no os.sep's allowed in the filename.
|
||||
# FIXME: Should we allow subdirectories?
|
||||
if pkt.filename.find(os.sep) >= 0:
|
||||
self.sendError(TftpErrors.IllegalTftpOp)
|
||||
raise TftpException, "%s found in filename, not permitted" % os.sep
|
||||
|
||||
self.context.file_to_transfer = pkt.filename
|
||||
|
||||
return sendoack
|
||||
|
||||
def sendDAT(self, resend=False):
|
||||
"""This method sends the next DAT packet based on the data in the
|
||||
context. It returns a boolean indicating whether the transfer is
|
||||
finished."""
|
||||
finished = False
|
||||
blocknumber = self.context.next_block
|
||||
if not resend:
|
||||
blksize = int(self.context.options['blksize'])
|
||||
buffer = self.context.fileobj.read(blksize)
|
||||
logger.debug("Read %d bytes into buffer" % len(buffer))
|
||||
log.debug("Read %d bytes into buffer" % len(buffer))
|
||||
if len(buffer) < blksize:
|
||||
logger.info("Reached EOF on file %s" % self.context.input)
|
||||
log.info("Reached EOF on file %s" % self.context.input)
|
||||
finished = True
|
||||
self.context.next_block += 1
|
||||
self.bytes += len(buffer)
|
||||
else:
|
||||
logger.warn("Resending block number %d" % blocknumber)
|
||||
log.warn("Resending block number %d" % blocknumber)
|
||||
dat = TftpPacketDAT()
|
||||
dat.data = buffer
|
||||
dat.blocknumber = blocknumber
|
||||
logger.debug("Sending DAT packet %d" % blocknumber)
|
||||
log.debug("Sending DAT packet %d" % blocknumber)
|
||||
self.context.sock.sendto(dat.encode().buffer,
|
||||
(self.context.host, self.context.port))
|
||||
if self.context.packethook:
|
||||
self.context.packethook(dat)
|
||||
return finished
|
||||
|
||||
class TftpStateDownload(TftpState):
|
||||
"""A class holding common code for download states."""
|
||||
def sendACK(self, blocknumber=None):
|
||||
"""This method sends an ack packet to the block number specified. If
|
||||
none is specified, it defaults to the next_block property in the
|
||||
parent context."""
|
||||
if not blocknumber:
|
||||
blocknumber = self.context.next_block
|
||||
log.info("sending ack to block %d" % blocknumber)
|
||||
ackpkt = TftpPacketACK()
|
||||
ackpkt.blocknumber = blocknumber
|
||||
self.context.sock.sendto(ackpkt.encode().buffer,
|
||||
(self.context.host,
|
||||
self.context.tidport))
|
||||
|
||||
def sendError(self, errorcode):
|
||||
"""This method uses the socket passed, and uses the errorcode to
|
||||
compose and send an error packet."""
|
||||
log.debug("In sendError, being asked to send error %d" % errorcode)
|
||||
errpkt = TftpPacketERR()
|
||||
errpkt.errorcode = errorcode
|
||||
self.context.sock.sendto(errpkt.encode().buffer,
|
||||
(self.context.host,
|
||||
self.context.tidport))
|
||||
|
||||
def sendOACK(self):
|
||||
"""This method sends an OACK packet with the options from the current
|
||||
context."""
|
||||
log.debug("In sendOACK with options %s" % options)
|
||||
pkt = TftpPacketOACK()
|
||||
pkt.options = self.options
|
||||
self.context.sock.sendto(pkt.encode().buffer,
|
||||
(self.context.host,
|
||||
self.context.tidport))
|
||||
|
||||
def handleDat(self, pkt):
|
||||
"""This method handles a DAT packet during a download."""
|
||||
logger.info("handling DAT packet - block %d" % pkt.blocknumber)
|
||||
logger.debug("expecting block %s" % self.context.next_block)
|
||||
"""This method handles a DAT packet during a client download, or a
|
||||
server upload."""
|
||||
log.info("handling DAT packet - block %d" % pkt.blocknumber)
|
||||
log.debug("expecting block %s" % self.context.next_block)
|
||||
if pkt.blocknumber == self.context.next_block:
|
||||
logger.debug("good, received block %d in sequence"
|
||||
log.debug("good, received block %d in sequence"
|
||||
% pkt.blocknumber)
|
||||
|
||||
self.context.sendAck(pkt.blocknumber)
|
||||
self.sendACK()
|
||||
self.context.next_block += 1
|
||||
|
||||
logger.debug("writing %d bytes to output file"
|
||||
log.debug("writing %d bytes to output file"
|
||||
% len(pkt.data))
|
||||
self.context.fileobj.write(pkt.data)
|
||||
self.context.metrics.bytes += len(pkt.data)
|
||||
# Check for end-of-file, any less than full data packet.
|
||||
if len(pkt.data) < int(self.context.options['blksize']):
|
||||
logger.info("end of file detected")
|
||||
log.info("end of file detected")
|
||||
return None
|
||||
|
||||
elif pkt.blocknumber < self.context.next_block:
|
||||
logger.warn("dropping duplicate block %d" % pkt.blocknumber)
|
||||
if self.context.metrics.dups.has_key(pkt.blocknumber):
|
||||
self.context.metrics.dups[pkt.blocknumber] += 1
|
||||
else:
|
||||
self.context.metrics.dups[pkt.blocknumber] = 1
|
||||
tftpassert(self.context.metrics.dups[pkt.blocknumber] < MAX_DUPS,
|
||||
"Max duplicates for block %d reached" % pkt.blocknumber)
|
||||
# FIXME: double-check sorceror's apprentice problem!
|
||||
logger.debug("ACKing block %d again, just in case" % pkt.blocknumber)
|
||||
self.context.sendAck(pkt.blocknumber)
|
||||
log.warn("dropping duplicate block %d" % pkt.blocknumber)
|
||||
self.context.metrics.add_dup(pkt.blocknumber)
|
||||
log.debug("ACKing block %d again, just in case" % pkt.blocknumber)
|
||||
self.sendACK(pkt.blocknumber)
|
||||
|
||||
else:
|
||||
# FIXME: should we be more tolerant and just discard instead?
|
||||
msg = "Whoa! Received future block %d but expected %d" \
|
||||
% (pkt.blocknumber, self.context.next_block)
|
||||
logger.error(msg)
|
||||
log.error(msg)
|
||||
raise TftpException, msg
|
||||
|
||||
# Default is to ack
|
||||
return TftpStateSentACK(self.context)
|
||||
return TftpStateExpectDAT(self.context)
|
||||
|
||||
class TftpStateSentWRQ(TftpStateUpload):
|
||||
"""Just sent an WRQ packet for an upload."""
|
||||
class TftpStateServerRecvRRQ(TftpState):
|
||||
"""This class represents the state of the TFTP server when it has just
|
||||
received an RRQ packet."""
|
||||
def handle(self, pkt, raddress, rport):
|
||||
"Handle an initial RRQ packet as a server."
|
||||
log.debug("In TftpStateServerRecvRRQ.handle")
|
||||
sendoack = self.serverInitial(pkt, raddress, rport)
|
||||
path = self.context.root + os.sep + self.context.file_to_transfer
|
||||
log.info("Opening file %s for reading" % path)
|
||||
if os.path.exists(path):
|
||||
# Note: Open in binary mode for win32 portability, since win32
|
||||
# blows.
|
||||
self.context.fileobj = open(path, "rb")
|
||||
elif self.dyn_file_func:
|
||||
log.debug("No such file %s but using dyn_file_func" % path)
|
||||
self.context.fileobj = \
|
||||
self.dyn_file_func(self.context.file_to_transfer)
|
||||
else:
|
||||
send.sendError(TftpErrors.FileNotFound)
|
||||
raise TftpException, "File not found: %s" % path
|
||||
|
||||
# Options negotiation.
|
||||
if sendoack:
|
||||
self.sendOACK()
|
||||
return TftpStateServerOACK(self.context)
|
||||
else:
|
||||
log.debug("No requested options, starting send...")
|
||||
self.context.pending_complete = self.sendDAT()
|
||||
return TftpStateExpectACK(self.context)
|
||||
|
||||
# Note, we don't have to check any other states in this method, that's
|
||||
# up to the caller.
|
||||
|
||||
class TftpStateServerRecvWRQ(TftpState):
|
||||
"""This class represents the state of the TFTP server when it has just
|
||||
received a WRQ packet."""
|
||||
def handle(self, pkt, raddress, rport):
|
||||
"Handle an initial WRQ packet as a server."
|
||||
log.debug("In TftpStateServerRecvWRQ.handle")
|
||||
sendoack = self.serverInitial(pkt, raddress, rport)
|
||||
path = self.context.root + os.sep + self.context.file_to_transfer
|
||||
log.info("Opening file %s for writing" % path)
|
||||
if os.path.exists(path):
|
||||
# FIXME: correct behavior?
|
||||
log.warn("File %s exists already, overwriting...")
|
||||
# FIXME: I think we should upload to a temp file and not overwrite the
|
||||
# existing file until the file is successfully uploaded.
|
||||
self.context.fileobj = open(path, "wb")
|
||||
|
||||
# Options negotiation.
|
||||
if sendoack:
|
||||
log.debug("Sending OACK to client")
|
||||
self.sendOACK()
|
||||
else:
|
||||
log.debug("No requested options, starting send...")
|
||||
self.sendACK()
|
||||
# We may have sent an OACK, but we're expecting a DAT as the response
|
||||
# to either the OACK or an ACK, so lets unconditionally use the
|
||||
# TftpStateExpectDAT state.
|
||||
return TftpStateExpectDAT(self.context)
|
||||
|
||||
# Note, we don't have to check any other states in this method, that's
|
||||
# up to the caller.
|
||||
|
||||
class TftpStateServerStart(TftpState):
|
||||
"""The start state for the server."""
|
||||
def handle(self, pkt, raddress, rport):
|
||||
"""Handle a packet we just received."""
|
||||
if not self.context.tidport:
|
||||
self.context.tidport = rport
|
||||
logger.debug("Set remote port for session to %s" % rport)
|
||||
|
||||
# If we're going to successfully transfer the file, then we should see
|
||||
# either an OACK for accepted options, or an ACK to ignore options.
|
||||
if isinstance(pkt, TftpPacketOACK):
|
||||
logger.info("received OACK from server")
|
||||
try:
|
||||
self.handleOACK(pkt)
|
||||
except TftpException, err:
|
||||
logger.error("failed to negotiate options")
|
||||
self.context.sendError(TftpErrors.FailedNegotiation)
|
||||
raise
|
||||
else:
|
||||
logger.debug("sending first DAT packet")
|
||||
fin = self.context.sendDat()
|
||||
if fin:
|
||||
logger.info("Add done")
|
||||
return None
|
||||
else:
|
||||
logger.debug("Changing state to TftpStateSentDAT")
|
||||
return TftpStateSentDAT(self.context)
|
||||
|
||||
elif isinstance(pkt, TftpPacketACK):
|
||||
logger.info("received ACK from server")
|
||||
logger.debug("apparently the server ignored our options")
|
||||
# The block number should be zero.
|
||||
if pkt.blocknumber == 0:
|
||||
logger.debug("ack blocknumber is zero as expected")
|
||||
logger.debug("sending first DAT packet")
|
||||
fin = self.context.sendDat()
|
||||
if fin:
|
||||
logger.info("Add done")
|
||||
return None
|
||||
else:
|
||||
logger.debug("Changing state to TftpStateSentDAT")
|
||||
return TftpStateSentDAT(self.context)
|
||||
else:
|
||||
logger.warn("discarding ACK to block %s" % pkt.blocknumber)
|
||||
logger.debug("still waiting for valid response from server")
|
||||
return self
|
||||
|
||||
elif isinstance(pkt, TftpPacketERR):
|
||||
self.context.sendError(TftpErrors.IllegalTftpOp)
|
||||
raise TftpException, "Received ERR from server: " + str(pkt)
|
||||
|
||||
elif isinstance(pkt, TftpPacketRRQ):
|
||||
self.context.sendError(TftpErrors.IllegalTftpOp)
|
||||
raise TftpException, "Received RRQ from server while in upload"
|
||||
|
||||
elif isinstance(pkt, TftpPacketDAT):
|
||||
self.context.sendError(TftpErrors.IllegalTftpOp)
|
||||
raise TftpException, "Received DAT from server while in upload"
|
||||
|
||||
log.debug("In TftpStateServerStart.handle")
|
||||
if isinstance(pkt, TftpPacketRRQ):
|
||||
log.debug("handling an RRQ packet")
|
||||
return TftpStateServerRecvRRQ(self.context).handle(pkt,
|
||||
raddress,
|
||||
rport)
|
||||
elif isinstance(pkt, TftpPacketWRQ):
|
||||
log.debug("handling a WRQ packet")
|
||||
return TftpStateServerRecvWRQ(self.context).handle(pkt,
|
||||
raddress,
|
||||
rport)
|
||||
else:
|
||||
self.context.sendError(TftpErrors.IllegalTftpOp)
|
||||
raise TftpException, "Received unknown packet type from server: " + str(pkt)
|
||||
self.sendError(TftpErrors.IllegalTftpOp)
|
||||
raise TftpException, \
|
||||
"Invalid packet to begin up/download: %s" % pkt
|
||||
|
||||
# By default, no state change.
|
||||
return self
|
||||
|
||||
class TftpStateSentDAT(TftpStateUpload):
|
||||
class TftpStateExpectACK(TftpState):
|
||||
"""This class represents the state of the transfer when a DAT was just
|
||||
sent, and we are waiting for an ACK from the server. This class is the
|
||||
same one used by the client during the upload, and the server during the
|
||||
download."""
|
||||
|
||||
class TftpStateSentRRQ(TftpStateDownload):
|
||||
"""Just sent an RRQ packet."""
|
||||
def handle(self, pkt, raddress, rport):
|
||||
"""Handle the packet in response to an RRQ to the server."""
|
||||
if not self.context.tidport:
|
||||
self.context.tidport = rport
|
||||
logger.debug("Set remote port for session to %s" % rport)
|
||||
"Handle a packet, hopefully an ACK since we just sent a DAT."
|
||||
if isinstance(pkt, TftpPacketACK):
|
||||
log.info("Received ACK for packet %d" % pkt.blocknumber)
|
||||
# Is this an ack to the one we just sent?
|
||||
if self.context.next_block == pkt.blocknumber:
|
||||
if self.context.pending_complete:
|
||||
log.info("Received ACK to final DAT, we're done.")
|
||||
return None
|
||||
else:
|
||||
log.debug("Good ACK, sending next DAT")
|
||||
self.context.pending_complete = self.sendDAT()
|
||||
|
||||
elif pkt.blocknumber < self.context.next_block:
|
||||
self.context.metrics.add_dup(pkt.blocknumber)
|
||||
|
||||
# Now check the packet type and dispatch it properly.
|
||||
if isinstance(pkt, TftpPacketOACK):
|
||||
logger.info("received OACK from server")
|
||||
try:
|
||||
self.handleOACK(pkt)
|
||||
except TftpException, err:
|
||||
logger.error("failed to negotiate options: %s" % str(err))
|
||||
self.context.sendError(TftpErrors.FailedNegotiation)
|
||||
raise
|
||||
else:
|
||||
logger.debug("sending ACK to OACK")
|
||||
|
||||
self.context.sendAck(blocknumber=0)
|
||||
|
||||
logger.debug("Changing state to TftpStateSentACK")
|
||||
return TftpStateSentACK(self.context)
|
||||
|
||||
elif isinstance(pkt, TftpPacketDAT):
|
||||
# If there are any options set, then the server didn't honour any
|
||||
# of them.
|
||||
logger.info("received DAT from server")
|
||||
if self.context.options:
|
||||
logger.info("server ignored options, falling back to defaults")
|
||||
self.context.options = { 'blksize': DEF_BLKSIZE }
|
||||
return self.handleDat(pkt)
|
||||
|
||||
# Every other packet type is a problem.
|
||||
elif isinstance(recvpkt, TftpPacketACK):
|
||||
# Umm, we ACK, the server doesn't.
|
||||
self.context.sendError(TftpErrors.IllegalTftpOp)
|
||||
raise TftpException, "Received ACK from server while in download"
|
||||
|
||||
elif isinstance(recvpkt, TftpPacketWRQ):
|
||||
self.context.sendError(TftpErrors.IllegalTftpOp)
|
||||
raise TftpException, "Received WRQ from server while in download"
|
||||
|
||||
elif isinstance(recvpkt, TftpPacketERR):
|
||||
self.context.sendError(TftpErrors.IllegalTftpOp)
|
||||
raise TftpException, "Received ERR from server: " + str(recvpkt)
|
||||
|
||||
log.warn("Oooh, time warp. Received ACK to packet we "
|
||||
"didn't send yet. Discarding.")
|
||||
self.context.metrics.errors += 1
|
||||
return self
|
||||
elif isinstance(pkt, TftpPacketERR):
|
||||
log.error("Received ERR packet from peer: %s" % str(pkt))
|
||||
raise TftpException, \
|
||||
"Received ERR packet from peer: %s" % str(pkt)
|
||||
else:
|
||||
self.context.sendError(TftpErrors.IllegalTftpOp)
|
||||
raise TftpException, "Received unknown packet type from server: " + str(recvpkt)
|
||||
log.warn("Discarding unsupported packet: %s" % str(pkt))
|
||||
return self
|
||||
|
||||
# By default, no state change.
|
||||
return self
|
||||
|
||||
class TftpStateSentACK(TftpStateDownload):
|
||||
class TftpStateExpectDAT(TftpState):
|
||||
"""Just sent an ACK packet. Waiting for DAT."""
|
||||
def handle(self, pkt, raddress, rport):
|
||||
"""Handle the packet in response to an ACK, which should be a DAT."""
|
||||
if isinstance(pkt, TftpPacketDAT):
|
||||
return self.handleDat(pkt)
|
||||
|
||||
# Every other packet type is a problem.
|
||||
elif isinstance(recvpkt, TftpPacketACK):
|
||||
# Umm, we ACK, you don't.
|
||||
self.sendError(TftpErrors.IllegalTftpOp)
|
||||
raise TftpException, "Received ACK from peer when expecting DAT"
|
||||
|
||||
elif isinstance(recvpkt, TftpPacketWRQ):
|
||||
self.sendError(TftpErrors.IllegalTftpOp)
|
||||
raise TftpException, "Received WRQ from peer when expecting DAT"
|
||||
|
||||
elif isinstance(recvpkt, TftpPacketERR):
|
||||
self.sendError(TftpErrors.IllegalTftpOp)
|
||||
raise TftpException, "Received ERR from peer: " + str(recvpkt)
|
||||
|
||||
else:
|
||||
self.sendError(TftpErrors.IllegalTftpOp)
|
||||
raise TftpException, "Received unknown packet type from peer: " + str(recvpkt)
|
||||
|
||||
class TftpStateSentWRQ(TftpState):
|
||||
"""Just sent an WRQ packet for an upload."""
|
||||
def handle(self, pkt, raddress, rport):
|
||||
"""Handle a packet we just received."""
|
||||
if not self.context.tidport:
|
||||
self.context.tidport = rport
|
||||
log.debug("Set remote port for session to %s" % rport)
|
||||
|
||||
# If we're going to successfully transfer the file, then we should see
|
||||
# either an OACK for accepted options, or an ACK to ignore options.
|
||||
if isinstance(pkt, TftpPacketOACK):
|
||||
log.info("received OACK from server")
|
||||
try:
|
||||
self.handleOACK(pkt)
|
||||
except TftpException, err:
|
||||
log.error("failed to negotiate options")
|
||||
self.sendError(TftpErrors.FailedNegotiation)
|
||||
raise
|
||||
else:
|
||||
log.debug("sending first DAT packet")
|
||||
self.context.pending_complete = self.sendDAT()
|
||||
log.debug("Changing state to TftpStateExpectACK")
|
||||
return TftpStateExpectACK(self.context)
|
||||
|
||||
elif isinstance(pkt, TftpPacketACK):
|
||||
log.info("received ACK from server")
|
||||
log.debug("apparently the server ignored our options")
|
||||
# The block number should be zero.
|
||||
if pkt.blocknumber == 0:
|
||||
log.debug("ack blocknumber is zero as expected")
|
||||
log.debug("sending first DAT packet")
|
||||
self.pending_complete = self.context.sendDAT()
|
||||
log.debug("Changing state to TftpStateExpectACK")
|
||||
return TftpStateExpectACK(self.context)
|
||||
else:
|
||||
log.warn("discarding ACK to block %s" % pkt.blocknumber)
|
||||
log.debug("still waiting for valid response from server")
|
||||
return self
|
||||
|
||||
elif isinstance(pkt, TftpPacketERR):
|
||||
self.sendError(TftpErrors.IllegalTftpOp)
|
||||
raise TftpException, "Received ERR from server: " + str(pkt)
|
||||
|
||||
elif isinstance(pkt, TftpPacketRRQ):
|
||||
self.sendError(TftpErrors.IllegalTftpOp)
|
||||
raise TftpException, "Received RRQ from server while in upload"
|
||||
|
||||
elif isinstance(pkt, TftpPacketDAT):
|
||||
self.sendError(TftpErrors.IllegalTftpOp)
|
||||
raise TftpException, "Received DAT from server while in upload"
|
||||
|
||||
else:
|
||||
self.sendError(TftpErrors.IllegalTftpOp)
|
||||
raise TftpException, "Received unknown packet type from server: " + str(pkt)
|
||||
|
||||
# By default, no state change.
|
||||
return self
|
||||
|
||||
class TftpStateSentRRQ(TftpState):
|
||||
"""Just sent an RRQ packet."""
|
||||
def handle(self, pkt, raddress, rport):
|
||||
"""Handle the packet in response to an RRQ to the server."""
|
||||
if not self.context.tidport:
|
||||
self.context.tidport = rport
|
||||
log.debug("Set remote port for session to %s" % rport)
|
||||
|
||||
# Now check the packet type and dispatch it properly.
|
||||
if isinstance(pkt, TftpPacketOACK):
|
||||
log.info("received OACK from server")
|
||||
try:
|
||||
self.handleOACK(pkt)
|
||||
except TftpException, err:
|
||||
log.error("failed to negotiate options: %s" % str(err))
|
||||
self.sendError(TftpErrors.FailedNegotiation)
|
||||
raise
|
||||
else:
|
||||
log.debug("sending ACK to OACK")
|
||||
|
||||
self.sendACK(blocknumber=0)
|
||||
|
||||
log.debug("Changing state to TftpStateExpectDAT")
|
||||
return TftpStateExpectDAT(self.context)
|
||||
|
||||
elif isinstance(pkt, TftpPacketDAT):
|
||||
# If there are any options set, then the server didn't honour any
|
||||
# of them.
|
||||
log.info("received DAT from server")
|
||||
if self.context.options:
|
||||
log.info("server ignored options, falling back to defaults")
|
||||
self.context.options = { 'blksize': DEF_BLKSIZE }
|
||||
return self.handleDat(pkt)
|
||||
|
||||
# Every other packet type is a problem.
|
||||
elif isinstance(recvpkt, TftpPacketACK):
|
||||
# Umm, we ACK, the server doesn't.
|
||||
self.context.sendError(TftpErrors.IllegalTftpOp)
|
||||
self.sendError(TftpErrors.IllegalTftpOp)
|
||||
raise TftpException, "Received ACK from server while in download"
|
||||
|
||||
elif isinstance(recvpkt, TftpPacketWRQ):
|
||||
self.context.sendError(TftpErrors.IllegalTftpOp)
|
||||
self.sendError(TftpErrors.IllegalTftpOp)
|
||||
raise TftpException, "Received WRQ from server while in download"
|
||||
|
||||
elif isinstance(recvpkt, TftpPacketERR):
|
||||
self.context.sendError(TftpErrors.IllegalTftpOp)
|
||||
self.sendError(TftpErrors.IllegalTftpOp)
|
||||
raise TftpException, "Received ERR from server: " + str(recvpkt)
|
||||
|
||||
else:
|
||||
self.context.sendError(TftpErrors.IllegalTftpOp)
|
||||
self.sendError(TftpErrors.IllegalTftpOp)
|
||||
raise TftpException, "Received unknown packet type from server: " + str(recvpkt)
|
||||
|
||||
# By default, no state change.
|
||||
return self
|
||||
|
|
Reference in a new issue