From 62b22fb562eff64a6d6bb6c1a1a3c194d668d9a1 Mon Sep 17 00:00:00 2001 From: "Michael P. Soulier" Date: Sat, 15 Aug 2009 22:36:58 -0400 Subject: [PATCH] Did some rework for the state machine in a server context. Removed the handler framework in favour of a TftpContextServer used as the session. --- tftpy/TftpClient.py | 20 +- tftpy/TftpPacketFactory.py | 4 +- tftpy/TftpPacketTypes.py | 88 ++--- tftpy/TftpServer.py | 408 +++------------------- tftpy/TftpShared.py | 6 +- tftpy/TftpStates.py | 674 +++++++++++++++++++++++++------------ 6 files changed, 564 insertions(+), 636 deletions(-) diff --git a/tftpy/TftpClient.py b/tftpy/TftpClient.py index da35d05..89843f1 100644 --- a/tftpy/TftpClient.py +++ b/tftpy/TftpClient.py @@ -52,12 +52,12 @@ class TftpClient(TftpSession): # output? This should be in the sample client, but not in the download # call. if metrics.duration == 0: - logger.info("Duration too short, rate undetermined") + log.info("Duration too short, rate undetermined") else: - logger.info('') - logger.info("Downloaded %.2f bytes in %.2f seconds" % (metrics.bytes, metrics.duration)) - logger.info("Average rate: %.2f kbps" % metrics.kbps) - logger.info("Received %d duplicate packets" % metrics.dupcount) + log.info('') + log.info("Downloaded %.2f bytes in %.2f seconds" % (metrics.bytes, metrics.duration)) + log.info("Average rate: %.2f kbps" % metrics.kbps) + log.info("Received %d duplicate packets" % metrics.dupcount) def upload(self, filename, input, packethook=None, timeout=SOCK_TIMEOUT): # Open the input file. @@ -80,9 +80,9 @@ class TftpClient(TftpSession): # output? This should be in the sample client, but not in the download # call. if metrics.duration == 0: - logger.info("Duration too short, rate undetermined") + log.info("Duration too short, rate undetermined") else: - logger.info('') - logger.info("Downloaded %.2f bytes in %.2f seconds" % (metrics.bytes, metrics.duration)) - logger.info("Average rate: %.2f kbps" % metrics.kbps) - logger.info("Received %d duplicate packets" % metrics.dupcount) \ No newline at end of file + log.info('') + log.info("Downloaded %.2f bytes in %.2f seconds" % (metrics.bytes, metrics.duration)) + log.info("Average rate: %.2f kbps" % metrics.kbps) + log.info("Received %d duplicate packets" % metrics.dupcount) diff --git a/tftpy/TftpPacketFactory.py b/tftpy/TftpPacketFactory.py index 642b4d8..3f287de 100644 --- a/tftpy/TftpPacketFactory.py +++ b/tftpy/TftpPacketFactory.py @@ -19,9 +19,9 @@ class TftpPacketFactory(object): """This method is used to parse an existing datagram into its corresponding TftpPacket object. The buffer is the raw bytes off of the network.""" - logger.debug("parsing a %d byte packet" % len(buffer)) + log.debug("parsing a %d byte packet" % len(buffer)) (opcode,) = struct.unpack("!H", buffer[:2]) - logger.debug("opcode is %d" % opcode) + log.debug("opcode is %d" % opcode) packet = self.__create(opcode) packet.buffer = buffer return packet.decode() diff --git a/tftpy/TftpPacketTypes.py b/tftpy/TftpPacketTypes.py index e269deb..b9328c5 100644 --- a/tftpy/TftpPacketTypes.py +++ b/tftpy/TftpPacketTypes.py @@ -15,7 +15,7 @@ class TftpSession(object): def senderror(self, sock, errorcode, address, port): """This method uses the socket passed, and uses the errorcode, address and port to compose and send an error packet.""" - logger.debug("In senderror, being asked to send error %d to %s:%s" + log.debug("In senderror, being asked to send error %d to %s:%s" % (errorcode, address, port)) errpkt = TftpPacketERR() errpkt.errorcode = errorcode @@ -27,23 +27,23 @@ class TftpPacketWithOptions(object): goal is just to share code here, and not cause diamond inheritance.""" def __init__(self): - self.options = [] + self.options = {} def setoptions(self, options): - logger.debug("in TftpPacketWithOptions.setoptions") - logger.debug("options: " + str(options)) + log.debug("in TftpPacketWithOptions.setoptions") + log.debug("options: " + str(options)) myoptions = {} for key in options: newkey = str(key) myoptions[newkey] = str(options[key]) - logger.debug("populated myoptions with %s = %s" + log.debug("populated myoptions with %s = %s" % (newkey, myoptions[newkey])) - logger.debug("setting options hash to: " + str(myoptions)) + log.debug("setting options hash to: " + str(myoptions)) self._options = myoptions def getoptions(self): - logger.debug("in TftpPacketWithOptions.getoptions") + log.debug("in TftpPacketWithOptions.getoptions") return self._options # Set up getter and setter on options to ensure that they are the proper @@ -59,19 +59,19 @@ class TftpPacketWithOptions(object): format = "!" options = {} - logger.debug("decode_options: buffer is: " + repr(buffer)) - logger.debug("size of buffer is %d bytes" % len(buffer)) + log.debug("decode_options: buffer is: " + repr(buffer)) + log.debug("size of buffer is %d bytes" % len(buffer)) if len(buffer) == 0: - logger.debug("size of buffer is zero, returning empty hash") + log.debug("size of buffer is zero, returning empty hash") return {} # Count the nulls in the buffer. Each one terminates a string. - logger.debug("about to iterate options buffer counting nulls") + log.debug("about to iterate options buffer counting nulls") length = 0 for c in buffer: - #logger.debug("iterating this byte: " + repr(c)) + #log.debug("iterating this byte: " + repr(c)) if ord(c) == 0: - logger.debug("found a null at length %d" % length) + log.debug("found a null at length %d" % length) if length > 0: format += "%dsx" % length length = -1 @@ -79,14 +79,14 @@ class TftpPacketWithOptions(object): raise TftpException, "Invalid options in buffer" length += 1 - logger.debug("about to unpack, format is: %s" % format) + log.debug("about to unpack, format is: %s" % format) mystruct = struct.unpack(format, buffer) tftpassert(len(mystruct) % 2 == 0, "packet with odd number of option/value pairs") for i in range(0, len(mystruct), 2): - logger.debug("setting option %s to %s" % (mystruct[i], mystruct[i+1])) + log.debug("setting option %s to %s" % (mystruct[i], mystruct[i+1])) options[mystruct[i]] = mystruct[i+1] return options @@ -134,10 +134,10 @@ class TftpPacketInitial(TftpPacket, TftpPacketWithOptions): ptype = None if self.opcode == 1: ptype = "RRQ" else: ptype = "WRQ" - logger.debug("Encoding %s packet, filename = %s, mode = %s" + log.debug("Encoding %s packet, filename = %s, mode = %s" % (ptype, self.filename, self.mode)) for key in self.options: - logger.debug(" Option %s = %s" % (key, self.options[key])) + log.debug(" Option %s = %s" % (key, self.options[key])) format = "!H" format += "%dsx" % len(self.filename) @@ -148,7 +148,7 @@ class TftpPacketInitial(TftpPacket, TftpPacketWithOptions): # Add options. options_list = [] if self.options.keys() > 0: - logger.debug("there are options to encode") + log.debug("there are options to encode") for key in self.options: # Populate the option name format += "%dsx" % len(key) @@ -157,9 +157,9 @@ class TftpPacketInitial(TftpPacket, TftpPacketWithOptions): format += "%dsx" % len(str(self.options[key])) options_list.append(str(self.options[key])) - logger.debug("format is %s" % format) - logger.debug("options_list is %s" % options_list) - logger.debug("size of struct is %d" % struct.calcsize(format)) + log.debug("format is %s" % format) + log.debug("options_list is %s" % options_list) + log.debug("size of struct is %d" % struct.calcsize(format)) self.buffer = struct.pack(format, self.opcode, @@ -167,7 +167,7 @@ class TftpPacketInitial(TftpPacket, TftpPacketWithOptions): self.mode, *options_list) - logger.debug("buffer is " + repr(self.buffer)) + log.debug("buffer is " + repr(self.buffer)) return self def decode(self): @@ -177,13 +177,13 @@ class TftpPacketInitial(TftpPacket, TftpPacketWithOptions): nulls = 0 format = "" nulls = length = tlength = 0 - logger.debug("in decode: about to iterate buffer counting nulls") + log.debug("in decode: about to iterate buffer counting nulls") subbuf = self.buffer[2:] for c in subbuf: - logger.debug("iterating this byte: " + repr(c)) + log.debug("iterating this byte: " + repr(c)) if ord(c) == 0: nulls += 1 - logger.debug("found a null at length %d, now have %d" + log.debug("found a null at length %d, now have %d" % (length, nulls)) format += "%dsx" % length length = -1 @@ -193,17 +193,17 @@ class TftpPacketInitial(TftpPacket, TftpPacketWithOptions): length += 1 tlength += 1 - logger.debug("hopefully found end of mode at length %d" % tlength) + log.debug("hopefully found end of mode at length %d" % tlength) # length should now be the end of the mode. tftpassert(nulls == 2, "malformed packet") shortbuf = subbuf[:tlength+1] - logger.debug("about to unpack buffer with format: %s" % format) - logger.debug("unpacking buffer: " + repr(shortbuf)) + log.debug("about to unpack buffer with format: %s" % format) + log.debug("unpacking buffer: " + repr(shortbuf)) mystruct = struct.unpack(format, shortbuf) tftpassert(len(mystruct) == 2, "malformed packet") - logger.debug("setting filename to %s" % mystruct[0]) - logger.debug("setting mode to %s" % mystruct[1]) + log.debug("setting filename to %s" % mystruct[0]) + log.debug("setting mode to %s" % mystruct[1]) self.filename = mystruct[0] self.mode = mystruct[1] @@ -269,7 +269,7 @@ DATA | 03 | Block # | Data | """Encode the DAT packet. This method populates self.buffer, and returns self for easy method chaining.""" if len(self.data) == 0: - logger.debug("Encoding an empty DAT packet") + log.debug("Encoding an empty DAT packet") format = "!HH%ds" % len(self.data) self.buffer = struct.pack(format, self.opcode, @@ -283,12 +283,12 @@ DATA | 03 | Block # | Data | # We know the first 2 bytes are the opcode. The second two are the # block number. (self.blocknumber,) = struct.unpack("!H", self.buffer[2:4]) - logger.debug("decoding DAT packet, block number %d" % self.blocknumber) - logger.debug("should be %d bytes in the packet total" + log.debug("decoding DAT packet, block number %d" % self.blocknumber) + log.debug("should be %d bytes in the packet total" % len(self.buffer)) # Everything else is data. self.data = self.buffer[4:] - logger.debug("found %d bytes of data" + log.debug("found %d bytes of data" % len(self.data)) return self @@ -308,14 +308,14 @@ ACK | 04 | Block # | return 'ACK packet: block %d' % self.blocknumber def encode(self): - logger.debug("encoding ACK: opcode = %d, block = %d" + log.debug("encoding ACK: opcode = %d, block = %d" % (self.opcode, self.blocknumber)) self.buffer = struct.pack("!HH", self.opcode, self.blocknumber) return self def decode(self): self.opcode, self.blocknumber = struct.unpack("!HH", self.buffer) - logger.debug("decoded ACK packet: opcode = %d, block = %d" + log.debug("decoded ACK packet: opcode = %d, block = %d" % (self.opcode, self.blocknumber)) return self @@ -365,7 +365,7 @@ ERROR | 05 | ErrorCode | ErrMsg | 0 | """Encode the DAT packet based on instance variables, populating self.buffer, returning self.""" format = "!HH%dsx" % len(self.errmsgs[self.errorcode]) - logger.debug("encoding ERR packet with format %s" % format) + log.debug("encoding ERR packet with format %s" % format) self.buffer = struct.pack(format, self.opcode, self.errorcode, @@ -375,13 +375,13 @@ ERROR | 05 | ErrorCode | ErrMsg | 0 | def decode(self): "Decode self.buffer, populating instance variables and return self." tftpassert(len(self.buffer) > 4, "malformed ERR packet, too short") - logger.debug("Decoding ERR packet, length %s bytes" % + log.debug("Decoding ERR packet, length %s bytes" % len(self.buffer)) format = "!HH%dsx" % (len(self.buffer) - 5) - logger.debug("Decoding ERR packet with format: %s" % format) + log.debug("Decoding ERR packet with format: %s" % format) self.opcode, self.errorcode, self.errmsg = struct.unpack(format, self.buffer) - logger.error("ERR packet - errorcode: %d, message: %s" + log.error("ERR packet - errorcode: %d, message: %s" % (self.errorcode, self.errmsg)) return self @@ -402,10 +402,10 @@ class TftpPacketOACK(TftpPacket, TftpPacketWithOptions): def encode(self): format = "!H" # opcode options_list = [] - logger.debug("in TftpPacketOACK.encode") + log.debug("in TftpPacketOACK.encode") for key in self.options: - logger.debug("looping on option key %s" % key) - logger.debug("value is %s" % self.options[key]) + log.debug("looping on option key %s" % key) + log.debug("value is %s" % self.options[key]) format += "%dsx" % len(key) format += "%dsx" % len(self.options[key]) options_list.append(key) @@ -429,7 +429,7 @@ class TftpPacketOACK(TftpPacket, TftpPacketWithOptions): # We can accept anything between the min and max values. size = self.options[name] if size >= MIN_BLKSIZE and size <= MAX_BLKSIZE: - logger.debug("negotiated blksize of %d bytes" % size) + log.debug("negotiated blksize of %d bytes" % size) options[blksize] = size else: raise TftpException, "Unsupported option: %s" % name diff --git a/tftpy/TftpServer.py b/tftpy/TftpServer.py index e846979..ad781a2 100644 --- a/tftpy/TftpServer.py +++ b/tftpy/TftpServer.py @@ -1,4 +1,5 @@ import socket, os, re, time, random +import select from TftpShared import * from TftpPacketTypes import * from TftpPacketFactory import * @@ -15,26 +16,27 @@ class TftpServer(TftpSession): self.listenip = None self.listenport = None self.sock = None + # FIXME: What about multiple roots? self.root = os.path.abspath(tftproot) - self.dynfunc = dyn_file_func + self.dyn_file_func = dyn_file_func # A dict of handlers, where each session is keyed by a string like # ip:tid for the remote end. self.handlers = {} if os.path.exists(self.root): - logger.debug("tftproot %s does exist" % self.root) + log.debug("tftproot %s does exist" % self.root) if not os.path.isdir(self.root): raise TftpException, "The tftproot must be a directory." else: - logger.debug("tftproot %s is a directory" % self.root) + log.debug("tftproot %s is a directory" % self.root) if os.access(self.root, os.R_OK): - logger.debug("tftproot %s is readable" % self.root) + log.debug("tftproot %s is readable" % self.root) else: raise TftpException, "The tftproot must be readable" if os.access(self.root, os.W_OK): - logger.debug("tftproot %s is writable" % self.root) + log.debug("tftproot %s is writable" % self.root) else: - logger.warning("The tftproot %s is not writable" % self.root) + log.warning("The tftproot %s is not writable" % self.root) else: raise TftpException, "The tftproot does not exist." @@ -45,14 +47,12 @@ class TftpServer(TftpSession): """Start a server listening on the supplied interface and port. This defaults to INADDR_ANY (all interfaces) and UDP port 69. You can also supply a different socket timeout value, if desired.""" - import select - tftp_factory = TftpPacketFactory() # Don't use new 2.5 ternary operator yet # listenip = listenip if listenip else '0.0.0.0' if not listenip: listenip = '0.0.0.0' - logger.info("Server requested on ip %s, port %s" + log.info("Server requested on ip %s, port %s" % (listenip, listenport)) try: # FIXME - sockets should be non-blocking? @@ -62,388 +62,82 @@ class TftpServer(TftpSession): # Reraise it for now. raise - logger.info("Starting receive loop...") + log.info("Starting receive loop...") while True: # Build the inputlist array of sockets to select() on. inputlist = [] inputlist.append(self.sock) - for key in self.handlers: - inputlist.append(self.handlers[key].sock) + for key in self.sessions: + inputlist.append(self.sessions[key].sock) # Block until some socket has input on it. - logger.debug("Performing select on this inputlist: %s" % inputlist) + log.debug("Performing select on this inputlist: %s" % inputlist) readyinput, readyoutput, readyspecial = select.select(inputlist, [], [], SOCK_TIMEOUT) - #(buffer, (raddress, rport)) = self.sock.recvfrom(MAX_BLKSIZE) - #recvpkt = tftp_factory.parse(buffer) - #key = "%s:%s" % (raddress, rport) - deletion_list = [] + # Handle the available data, if any. Maybe we timed-out. for readysock in readyinput: + # Is the traffic on the main server socket? ie. new session? if readysock == self.sock: - logger.debug("Data ready on our main socket") + log.debug("Data ready on our main socket") buffer, (raddress, rport) = self.sock.recvfrom(MAX_BLKSIZE) - logger.debug("Read %d bytes" % len(buffer)) + + log.debug("Read %d bytes" % len(buffer)) + recvpkt = tftp_factory.parse(buffer) + # FIXME: Is this the best way to do a session key? What + # about symmetric udp? key = "%s:%s" % (raddress, rport) - if isinstance(recvpkt, TftpPacketRRQ): - logger.debug("RRQ packet from %s:%s" % (raddress, rport)) - if not self.handlers.has_key(key): - try: - logger.debug("New download request, session key = %s" - % key) - self.handlers[key] = TftpServerHandler(key, - 'rrq', - self.root, - listenip, - tftp_factory, - self.dynfunc) - self.handlers[key].handle((recvpkt, raddress, rport)) - except TftpException, err: - logger.error("Fatal exception thrown from handler: %s" - % str(err)) - logger.debug("Deleting handler: %s" % key) - deletion_list.append(key) - - else: - logger.warn("Received RRQ for existing session!") - self.senderror(self.sock, - TftpErrors.IllegalTftpOp, - raddress, - rport) - continue - - elif isinstance(recvpkt, TftpPacketWRQ): - logger.error("Write requests not implemented at this time.") - self.senderror(self.sock, - TftpErrors.IllegalTftpOp, - raddress, - rport) - continue + if not self.sessions.has_key(key): + log.debug("Creating new server context for " + "session key = %s" % key) + self.sessions[key] = TftpContextServer(raddress, + rport, + timeout, + self.root, + self.dyn_file_func) + self.sessions[key].start(buffer) else: - # FIXME - this will have to change if we do symmetric UDP - logger.error("Should only receive RRQ or WRQ packets " - "on main listen port. Received %s" % recvpkt) - self.senderror(self.sock, - TftpErrors.IllegalTftpOp, - raddress, - rport) - continue + log.warn("received traffic on main socket for " + "existing session??") else: - for key in self.handlers: - if readysock == self.handlers[key].sock: - # FIXME - violating DRY principle with above code + # Must find the owner of this traffic. + for key in self.session: + if readysock == self.session[key].sock: try: - self.handlers[key].handle() + self.session[key].cycle() + if self.session[key].state == None: + log.info("Successful transfer.") + deletion_list.append(key) break except TftpException, err: deletion_list.append(key) - if self.handlers[key].state.state == 'fin': - logger.info("Successful transfer.") - break - else: - logger.error("Fatal exception thrown from handler: %s" - % str(err)) + log.error("Fatal exception thrown from " + "handler: %s" % str(err)) else: - logger.error("Can't find the owner for this packet. Discarding.") + log.error("Can't find the owner for this packet. " + "Discarding.") - logger.debug("Looping on all handlers to check for timeouts") + log.debug("Looping on all handlers to check for timeouts") now = time.time() - for key in self.handlers: + for key in self.sessions: try: - self.handlers[key].check_timeout(now) + self.sessions[key].checkTimeout(now) except TftpException, err: - logger.error("Fatal exception thrown from handler: %s" + log.error("Fatal exception thrown from handler: %s" % str(err)) deletion_list.append(key) - logger.debug("Iterating deletion list.") + log.debug("Iterating deletion list.") for key in deletion_list: - if self.handlers.has_key(key): - logger.debug("Deleting handler %s" % key) - del self.handlers[key] + if self.sessions.has_key(key): + log.debug("Deleting handler %s" % key) + del self.sessions[key] deletion_list = [] - -class TftpServerHandler(TftpSession): - """This class implements a handler for a given server session, handling - the work for one download.""" - - def __init__(self, key, state, root, listenip, factory, dyn_file_func): - TftpSession.__init__(self) - logger.info("Starting new handler. Key %s." % key) - self.key = key - self.host, self.port = self.key.split(':') - self.port = int(self.port) - self.listenip = listenip - # Note, correct state here is important as it tells the handler whether it's - # handling a download or an upload. - self.state = state - self.root = root - self.mode = None - self.filename = None - self.sock = False - self.options = { 'blksize': DEF_BLKSIZE } - self.blocknumber = 0 - self.buffer = None - self.fileobj = None - self.timesent = 0 - self.timeouts = 0 - self.tftp_factory = factory - self.dynfunc = dyn_file_func - count = 0 - while not self.sock: - self.sock = self.gensock(listenip) - count += 1 - if count > 10: - raise TftpException, "Failed to bind this handler to any port" - - def check_timeout(self, now): - """This method checks to see if we've timed-out waiting for traffic - from the client.""" - if self.timesent: - if now - self.timesent > SOCK_TIMEOUT: - self.timeout() - - def timeout(self): - """This method handles a timeout condition.""" - logger.debug("Handling timeout for handler %s" % self.key) - self.timeouts += 1 - if self.timeouts > TIMEOUT_RETRIES: - raise TftpException, "Hit max retries, giving up." - - if self.state.state == 'dat' or self.state.state == 'fin': - logger.debug("Timing out on DAT. Need to resend.") - self.send_dat(resend=True) - elif self.state.state == 'oack': - logger.debug("Timing out on OACK. Need to resend.") - self.send_oack() - else: - tftpassert(False, - "Timing out in unsupported state %s" % - self.state.state) - - def gensock(self, listenip): - """This method generates a new UDP socket, whose listening port must - be randomly generated, and not conflict with any already in use. For - now, let the OS do this.""" - random.seed() - port = random.randrange(1025, 65536) - # FIXME - sockets should be non-blocking? - sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) - logger.debug("Trying a handler socket on port %d" % port) - try: - sock.bind((listenip, port)) - return sock - except socket.error, err: - if err[0] == 98: - logger.warn("Handler %s, port %d was already taken" % (self.key, port)) - return False - else: - raise - - def handle(self, pkttuple=None): - """This method informs a handler instance that it has data waiting on - its socket that it must read and process.""" - recvpkt = raddress = rport = None - if pkttuple: - logger.debug("Handed pkt %s for handler %s" % (recvpkt, self.key)) - recvpkt, raddress, rport = pkttuple - else: - logger.debug("Data ready for handler %s" % self.key) - buffer, (raddress, rport) = self.sock.recvfrom(MAX_BLKSIZE) - logger.debug("Read %d bytes" % len(buffer)) - recvpkt = self.tftp_factory.parse(buffer) - - # FIXME - refactor into another method, this is too big - if isinstance(recvpkt, TftpPacketRRQ): - logger.debug("Handler %s received RRQ packet" % self.key) - logger.debug("Requested file is %s, mode is %s" % (recvpkt.filename, - recvpkt.mode)) - # FIXME - only octet mode is supported at this time. - if recvpkt.mode != 'octet': - self.senderror(self.sock, - TftpErrors.IllegalTftpOp, - raddress, - rport) - raise TftpException, "Unsupported mode: %s" % recvpkt.mode - - # test host/port of client end - if self.host != raddress or self.port != rport: - self.senderror(self.sock, - TftpErrors.UnknownTID, - raddress, - rport) - logger.error("Expected traffic from %s:%s but received it " - "from %s:%s instead." - % (self.host, self.port, raddress, rport)) - self.errors += 1 - return - - if self.state.state == 'rrq': - logger.debug("Received RRQ. Composing response.") - self.filename = self.root + os.sep + recvpkt.filename - logger.debug("The path to the desired file is %s" % - self.filename) - self.filename = os.path.abspath(self.filename) - logger.debug("The absolute path is %s" % self.filename) - # Security check. Make sure it's prefixed by the tftproot. - if self.filename.find(self.root) == 0: - logger.debug("The path appears to be safe: %s" % - self.filename) - else: - logger.error("Insecure path: %s" % self.filename) - self.errors += 1 - self.senderror(self.sock, - TftpErrors.AccessViolation, - raddress, - rport) - raise TftpException, "Insecure path: %s" % self.filename - - # Does the file exist? - if(os.path.exists(self.filename) or not self.dynfunc is None): - logger.debug("File %s exists." % self.filename) - - # Check options. Currently we only support the blksize - # option. - if recvpkt.options.has_key('blksize'): - logger.debug("RRQ includes a blksize option") - blksize = int(recvpkt.options['blksize']) - # Delete the option now that it's handled. - del recvpkt.options['blksize'] - if blksize >= MIN_BLKSIZE and blksize <= MAX_BLKSIZE: - logger.info("Client requested blksize = %d" - % blksize) - self.options['blksize'] = blksize - else: - logger.warning("Client %s requested invalid " - "blocksize %d, responding with default" - % (self.key, blksize)) - self.options['blksize'] = DEF_BLKSIZE - - if recvpkt.options.has_key('tsize'): - logger.info('RRQ includes tsize option') - self.options['tsize'] = os.stat(self.filename).st_size - # Delete the option now that it's handled. - del recvpkt.options['tsize'] - - if len(recvpkt.options.keys()) > 0: - logger.warning("Client %s requested unsupported options: %s" - % (self.key, recvpkt.options)) - - if self.options: - logger.info("Options requested, sending OACK") - self.send_oack() - else: - logger.debug("Client %s requested no options." - % self.key) - self.start_download() - - else: - logger.error("Requested file %s does not exist." % - self.filename) - self.senderror(self.sock, - TftpErrors.FileNotFound, - raddress, - rport) - raise TftpException, "Requested file not found: %s" % self.filename - - else: - # We're receiving an RRQ when we're not expecting one. - logger.error("Received an RRQ in handler %s " - "but we're in state %s" % (self.key, self.state)) - self.errors += 1 - - # Next packet type - elif isinstance(recvpkt, TftpPacketACK): - logger.debug("Received an ACK from the client.") - if recvpkt.blocknumber == 0 and self.state.state == 'oack': - logger.debug("Received ACK with 0 blocknumber, starting download") - self.start_download() - else: - if self.state.state == 'dat' or self.state.state == 'fin': - if self.blocknumber == recvpkt.blocknumber: - logger.debug("Received ACK for block %d" - % recvpkt.blocknumber) - if self.state.state == 'fin': - raise TftpException, "Successful transfer." - else: - self.send_dat() - elif recvpkt.blocknumber < self.blocknumber: - # Don't resend a DAT due to an old ACK. Fixes the - # sorceror's apprentice problem. - logger.warn("Received old ACK for block number %d" - % recvpkt.blocknumber) - else: - logger.warn("Received ACK for block number " - "%d, apparently from the future" - % recvpkt.blocknumber) - else: - logger.error("Received ACK with block number %d " - "while in state %s" - % (recvpkt.blocknumber, - self.state.state)) - - elif isinstance(recvpkt, TftpPacketERR): - logger.error("Received error packet from client: %s" % recvpkt) - self.state.state = 'err' - raise TftpException, "Received error from client" - - # Handle other packet types. - else: - logger.error("Received packet %s while handling a download" - % recvpkt) - self.senderror(self.sock, - TftpErrors.IllegalTftpOp, - self.host, - self.port) - raise TftpException, "Invalid packet received during download" - - def start_download(self): - """This method opens self.filename, stores the resulting file object - in self.fileobj, and calls send_dat().""" - self.state.state = 'dat' - if os.path.exists(self.filename): - self.fileobj = open(self.filename, "rb") - else: - self.fileobj = self.dynfunc(self.filename) - self.send_dat() - - def send_dat(self, resend=False): - """This method reads sends a DAT packet based on what is in self.buffer.""" - if not resend: - blksize = int(self.options['blksize']) - self.buffer = self.fileobj.read(blksize) - logger.debug("Read %d bytes into buffer" % len(self.buffer)) - if len(self.buffer) < blksize: - logger.info("Reached EOF on file %s" % self.filename) - self.state.state = 'fin' - self.blocknumber += 1 - if self.blocknumber > 65535: - logger.debug("Blocknumber rolled over to zero") - self.blocknumber = 0 - else: - logger.warn("Resending block number %d" % self.blocknumber) - dat = TftpPacketDAT() - dat.data = self.buffer - dat.blocknumber = self.blocknumber - logger.debug("Sending DAT packet %d" % self.blocknumber) - self.sock.sendto(dat.encode().buffer, (self.host, self.port)) - self.timesent = time.time() - - # FIXME - should these be factored-out into the session class? - def send_oack(self): - """This method sends an OACK packet based on current params.""" - logger.debug("Composing and sending OACK packet") - oack = TftpPacketOACK() - oack.options = self.options - self.sock.sendto(oack.encode().buffer, - (self.host, self.port)) - self.timesent = time.time() - self.state.state = 'oack' diff --git a/tftpy/TftpShared.py b/tftpy/TftpShared.py index 95172c3..bb95ad4 100644 --- a/tftpy/TftpShared.py +++ b/tftpy/TftpShared.py @@ -17,7 +17,7 @@ DEF_TFTP_PORT = 69 logging.basicConfig() # The logger used by this library. Feel free to clobber it with your own, if you like, as # long as it conforms to Python's logging. -logger = logging.getLogger('tftpy') +log = logging.getLogger('tftpy') def tftpassert(condition, msg): """This function is a simple utility that will check the condition @@ -31,8 +31,8 @@ def setLogLevel(level): """This function is a utility function for setting the internal log level. The log level defaults to logging.NOTSET, so unwanted output to stdout is not created.""" - global logger - logger.setLevel(level) + global log + log.setLevel(level) class TftpErrors(object): """This class is a convenience for defining the common tftp error codes, diff --git a/tftpy/TftpStates.py b/tftpy/TftpStates.py index 88c4fa1..3d71e16 100644 --- a/tftpy/TftpStates.py +++ b/tftpy/TftpStates.py @@ -22,17 +22,28 @@ class TftpMetrics(object): # Rates self.bps = 0 self.kbps = 0 + # Generic errors + self.errors = 0 def compute(self): # Compute transfer time self.duration = self.end_time - self.start_time - logger.debug("TftpMetrics.compute: duration is %s" % self.duration) + log.debug("TftpMetrics.compute: duration is %s" % self.duration) self.bps = (self.bytes * 8.0) / self.duration self.kbps = self.bps / 1024.0 - logger.debug("TftpMetrics.compute: kbps is %s" % self.kbps) - dupcount = 0 + log.debug("TftpMetrics.compute: kbps is %s" % self.kbps) for key in self.dups: - dupcount += self.dups[key] + self.dupcount += self.dups[key] + + def add_dup(self, blocknumber): + """This method adds a dup for a block number to the metrics.""" + log.debug("Recording a dup for block %d" % blocknumber) + if self.dups.has_key(blocknumber): + self.dups[pkt.blocknumber] += 1 + else: + self.dups[pkt.blocknumber] = 1 + tftpassert(self.dups[pkt.blocknumber] < MAX_DUPS, + "Max duplicates for block %d reached" % blocknumber) ############################################################################### # Context classes @@ -40,16 +51,32 @@ class TftpMetrics(object): class TftpContext(object): """The base class of the contexts.""" - def __init__(self, host, port): + + def __init__(self, host, port, timeout): """Constructor for the base context, setting shared instance variables.""" + self.file_to_transfer = None + self.fileobj = None + self.options = None + self.packethook = None + self.sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) + self.sock.settimeout(timeout) + self.state = None + self.next_block = 0 self.factory = TftpPacketFactory() + # Note, setting the host will also set self.address, as it's a property. self.host = host self.port = port # The port associated with the TID self.tidport = None # Metrics self.metrics = TftpMetrics() + # Flag when the transfer is pending completion. + self.pending_complete = False + + def checkTimeout(self, now): + # FIXME + pass def start(self): return NotImplementedError, "Abstract method" @@ -69,37 +96,9 @@ class TftpContext(object): host = property(gethost, sethost) - def sendAck(self, blocknumber): - """This method sends an ack packet to the block number specified.""" - logger.info("sending ack to block %d" % blocknumber) - ackpkt = TftpPacketACK() - ackpkt.blocknumber = blocknumber - self.sock.sendto(ackpkt.encode().buffer, (self.host, self.tidport)) - - def sendError(self, errorcode): - """This method uses the socket passed, and uses the errorcode to - compose and send an error packet.""" - logger.debug("In sendError, being asked to send error %d" % errorcode) - errpkt = TftpPacketERR() - errpkt.errorcode = errorcode - self.sock.sendto(errpkt.encode().buffer, (self.host, self.tidport)) - -class TftpContextClient(TftpContext): - """This class represents shared functionality by both the download and - upload client contexts.""" - def __init__(self, host, port, filename, options, packethook, timeout): - TftpContext.__init__(self, host, port) - self.file_to_transfer = filename - self.options = options - self.packethook = packethook - self.sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) - self.sock.settimeout(timeout) - self.state = None - self.next_block = 0 - def setNextBlock(self, block): if block > 2 ** 16: - logger.debug("block number rollover to 0 again") + log.debug("block number rollover to 0 again") block = 0 self.__eblock = block @@ -111,19 +110,21 @@ class TftpContextClient(TftpContext): def cycle(self): """Here we wait for a response from the server after sending it something, and dispatch appropriate action to that response.""" + # FIXME: This won't work very well in a server context with multiple + # sessions running. for i in range(TIMEOUT_RETRIES): - logger.debug("in cycle, receive attempt %d" % i) + log.debug("in cycle, receive attempt %d" % i) try: (buffer, (raddress, rport)) = self.sock.recvfrom(MAX_BLKSIZE) except socket.timeout, err: - logger.warn("Timeout waiting for traffic, retrying...") + log.warn("Timeout waiting for traffic, retrying...") continue break else: raise TftpException, "Hit max timeouts, giving up." # Ok, we've received a packet. Log it. - logger.debug("Received %d bytes from %s:%s" + log.debug("Received %d bytes from %s:%s" % (len(buffer), raddress, rport)) # Decode it. @@ -131,11 +132,11 @@ class TftpContextClient(TftpContext): # Check for known "connection". if raddress != self.address: - logger.warn("Received traffic from %s, expected host %s. Discarding" + log.warn("Received traffic from %s, expected host %s. Discarding" % (raddress, self.host)) if self.tidport and self.tidport != rport: - logger.warn("Received traffic from %s:%s but we're " + log.warn("Received traffic from %s:%s but we're " "connected to %s:%s. Discarding." % (raddress, rport, self.host, self.tidport)) @@ -150,29 +151,66 @@ class TftpContextClient(TftpContext): # And handle it, possibly changing state. self.state = self.state.handle(recvpkt, raddress, rport) -class TftpContextClientUpload(TftpContextClient): +class TftpContextServer(TftpContext): + """The context for the server.""" + def __init__(self, host, port, timeout, root, dyn_file_func): + TftpContext.__init__(self, + host, + port, + timeout) + # At this point we have no idea if this is a download or an upload. We + # need to let the start state determine that. + self.state = TftpStateServerStart() + self.root = root + self.dyn_file_func = dyn_file_func + + def start(self, buffer): + """Start the state cycle. Note that the server context receives an + initial packet in its start method.""" + log.debug("TftpContextServer.start() - pkt = %s" % pkt) + + self.metrics.start_time = time.time() + log.debug("set metrics.start_time to %s" % self.metrics.start_time) + + pkt = self.factory.parse(buffer) + log.debug("TftpContextServer.start() - factory returned a %s" % pkt) + + # Call handle once with the initial packet. This should put us into + # the download or the upload state. + self.state = self.state.handle(pkt, + self.host, + self.port) + + try: + while self.state: + log.debug("state is %s" % self.state) + self.cycle() + finally: + self.fileobj.close() + +class TftpContextClientUpload(TftpContext): """The upload context for the client during an upload.""" def __init__(self, host, port, filename, input, options, packethook, timeout): - TftpContextClient.__init__(self, - host, - port, - filename, - options, - packethook, - timeout) + TftpContext.__init__(self, + host, + port, + timeout) + self.file_to_transfer = filename + self.options = options + self.packethook = packethook self.fileobj = open(input, "wb") - logger.debug("TftpContextClientUpload.__init__()") - logger.debug("file_to_transfer = %s, options = %s" % + log.debug("TftpContextClientUpload.__init__()") + log.debug("file_to_transfer = %s, options = %s" % (self.file_to_transfer, self.options)) def start(self): - logger.info("sending tftp upload request to %s" % self.host) - logger.info(" filename -> %s" % self.file_to_transfer) - logger.info(" options -> %s" % self.options) + log.info("sending tftp upload request to %s" % self.host) + log.info(" filename -> %s" % self.file_to_transfer) + log.info(" options -> %s" % self.options) self.metrics.start_time = time.time() - logger.debug("set metrics.start_time to %s" % self.metrics.start_time) + log.debug("set metrics.start_time to %s" % self.metrics.start_time) # FIXME: put this in a sendWRQ method? pkt = TftpPacketWRQ() @@ -186,7 +224,7 @@ class TftpContextClientUpload(TftpContextClient): try: while self.state: - logger.debug("state is %s" % self.state) + log.debug("state is %s" % self.state) self.cycle() finally: self.fileobj.close() @@ -194,32 +232,32 @@ class TftpContextClientUpload(TftpContextClient): def end(self): pass -class TftpContextClientDownload(TftpContextClient): +class TftpContextClientDownload(TftpContext): """The download context for the client during a download.""" def __init__(self, host, port, filename, output, options, packethook, timeout): - TftpContextClient.__init__(self, - host, - port, - filename, - options, - packethook, - timeout) + TftpContext.__init__(self, + host, + port, + filename, + options, + packethook, + timeout) # FIXME - need to support alternate return formats than files? # File-like objects would be ideal, ala duck-typing. self.fileobj = open(output, "wb") - logger.debug("TftpContextClientDownload.__init__()") - logger.debug("file_to_transfer = %s, options = %s" % + log.debug("TftpContextClientDownload.__init__()") + log.debug("file_to_transfer = %s, options = %s" % (self.file_to_transfer, self.options)) def start(self): """Initiate the download.""" - logger.info("sending tftp download request to %s" % self.host) - logger.info(" filename -> %s" % self.file_to_transfer) - logger.info(" options -> %s" % self.options) + log.info("sending tftp download request to %s" % self.host) + log.info(" filename -> %s" % self.file_to_transfer) + log.info(" options -> %s" % self.options) self.metrics.start_time = time.time() - logger.debug("set metrics.start_time to %s" % self.metrics.start_time) + log.debug("set metrics.start_time to %s" % self.metrics.start_time) # FIXME: put this in a sendRRQ method? pkt = TftpPacketRRQ() @@ -233,7 +271,7 @@ class TftpContextClientDownload(TftpContextClient): try: while self.state: - logger.debug("state is %s" % self.state) + log.debug("state is %s" % self.state) self.cycle() finally: self.fileobj.close() @@ -241,7 +279,7 @@ class TftpContextClientDownload(TftpContextClient): def end(self): """Finish up the context.""" self.metrics.end_time = time.time() - logger.debug("set metrics.end_time to %s" % self.metrics.end_time) + log.debug("set metrics.end_time to %s" % self.metrics.end_time) self.metrics.compute() @@ -268,235 +306,431 @@ class TftpState(object): options.""" if pkt.options.keys() > 0: if pkt.match_options(self.context.options): - logger.info("Successful negotiation of options") + log.info("Successful negotiation of options") # Set options to OACK options self.context.options = pkt.options for key in self.context.options: - logger.info(" %s = %s" % (key, self.context.options[key])) + log.info(" %s = %s" % (key, self.context.options[key])) else: - logger.error("failed to negotiate options") + log.error("failed to negotiate options") raise TftpException, "Failed to negotiate options" else: raise TftpException, "No options found in OACK" -class TftpStateUpload(TftpState): - """A class holding common code for upload states.""" - def sendDat(self, resend=False): + def returnSupportedOptions(self, options): + """This method takes a requested options list from a client, and + returns the ones that are supported.""" + # We support the options blksize and tsize right now. + # FIXME - put this somewhere else? + accepted_options = {} + for option in options: + if option == 'blksize': + # Make sure it's valid. + if int(options[option]) > MAX_BLKSIZE: + accepted_options[option] = MAX_BLKSIZE + elif option == 'tsize': + log.debug("tsize option is set") + accepted_options['tsize'] = 1 + else: + log.info("Dropping unsupported option '%s'" % option) + return accepted_options + + def serverInitial(self, pkt, raddress, rport): + """This method performs initial setup for a server context transfer, + put here to refactor code out of the TftpStateServerRecvRRQ and + TftpStateServerRecvWRQ classes, since their initial setup is + identical. The method returns a boolean, sendoack, to indicate whether + it is required to send an OACK to the client.""" + options = pkt.options + sendoack = False + if not options: + log.debug("setting default options, blksize") + # FIXME: put default options elsewhere + self.context.options = { 'blksize': DEF_BLKSIZE } + else: + log.debug("options requested: %s" % options) + self.context.options = self.returnSupportedOptions(options) + sendoack = True + + # FIXME - only octet mode is supported at this time. + if pkt.mode != 'octet': + self.sendError(TftpErrors.IllegalTftpOp) + raise TftpException, \ + "Only octet transfers are supported at this time." + + # test host/port of client end + if self.context.host != raddress or self.context.port != rport: + self.sendError(TftpErrors.UnknownTID) + log.error("Expected traffic from %s:%s but received it " + "from %s:%s instead." + % (self.context.host, + self.context.port, + raddress, + rport)) + # FIXME: increment an error count? + # Return same state, we're still waiting for valid traffic. + return self + + log.debug("requested filename is %s" % pkt.filename) + # There are no os.sep's allowed in the filename. + # FIXME: Should we allow subdirectories? + if pkt.filename.find(os.sep) >= 0: + self.sendError(TftpErrors.IllegalTftpOp) + raise TftpException, "%s found in filename, not permitted" % os.sep + + self.context.file_to_transfer = pkt.filename + + return sendoack + + def sendDAT(self, resend=False): + """This method sends the next DAT packet based on the data in the + context. It returns a boolean indicating whether the transfer is + finished.""" finished = False blocknumber = self.context.next_block if not resend: blksize = int(self.context.options['blksize']) buffer = self.context.fileobj.read(blksize) - logger.debug("Read %d bytes into buffer" % len(buffer)) + log.debug("Read %d bytes into buffer" % len(buffer)) if len(buffer) < blksize: - logger.info("Reached EOF on file %s" % self.context.input) + log.info("Reached EOF on file %s" % self.context.input) finished = True self.context.next_block += 1 self.bytes += len(buffer) else: - logger.warn("Resending block number %d" % blocknumber) + log.warn("Resending block number %d" % blocknumber) dat = TftpPacketDAT() dat.data = buffer dat.blocknumber = blocknumber - logger.debug("Sending DAT packet %d" % blocknumber) + log.debug("Sending DAT packet %d" % blocknumber) self.context.sock.sendto(dat.encode().buffer, (self.context.host, self.context.port)) if self.context.packethook: self.context.packethook(dat) return finished -class TftpStateDownload(TftpState): - """A class holding common code for download states.""" + def sendACK(self, blocknumber=None): + """This method sends an ack packet to the block number specified. If + none is specified, it defaults to the next_block property in the + parent context.""" + if not blocknumber: + blocknumber = self.context.next_block + log.info("sending ack to block %d" % blocknumber) + ackpkt = TftpPacketACK() + ackpkt.blocknumber = blocknumber + self.context.sock.sendto(ackpkt.encode().buffer, + (self.context.host, + self.context.tidport)) + + def sendError(self, errorcode): + """This method uses the socket passed, and uses the errorcode to + compose and send an error packet.""" + log.debug("In sendError, being asked to send error %d" % errorcode) + errpkt = TftpPacketERR() + errpkt.errorcode = errorcode + self.context.sock.sendto(errpkt.encode().buffer, + (self.context.host, + self.context.tidport)) + + def sendOACK(self): + """This method sends an OACK packet with the options from the current + context.""" + log.debug("In sendOACK with options %s" % options) + pkt = TftpPacketOACK() + pkt.options = self.options + self.context.sock.sendto(pkt.encode().buffer, + (self.context.host, + self.context.tidport)) + def handleDat(self, pkt): - """This method handles a DAT packet during a download.""" - logger.info("handling DAT packet - block %d" % pkt.blocknumber) - logger.debug("expecting block %s" % self.context.next_block) + """This method handles a DAT packet during a client download, or a + server upload.""" + log.info("handling DAT packet - block %d" % pkt.blocknumber) + log.debug("expecting block %s" % self.context.next_block) if pkt.blocknumber == self.context.next_block: - logger.debug("good, received block %d in sequence" + log.debug("good, received block %d in sequence" % pkt.blocknumber) - self.context.sendAck(pkt.blocknumber) + self.sendACK() self.context.next_block += 1 - logger.debug("writing %d bytes to output file" + log.debug("writing %d bytes to output file" % len(pkt.data)) self.context.fileobj.write(pkt.data) self.context.metrics.bytes += len(pkt.data) # Check for end-of-file, any less than full data packet. if len(pkt.data) < int(self.context.options['blksize']): - logger.info("end of file detected") + log.info("end of file detected") return None elif pkt.blocknumber < self.context.next_block: - logger.warn("dropping duplicate block %d" % pkt.blocknumber) - if self.context.metrics.dups.has_key(pkt.blocknumber): - self.context.metrics.dups[pkt.blocknumber] += 1 - else: - self.context.metrics.dups[pkt.blocknumber] = 1 - tftpassert(self.context.metrics.dups[pkt.blocknumber] < MAX_DUPS, - "Max duplicates for block %d reached" % pkt.blocknumber) - # FIXME: double-check sorceror's apprentice problem! - logger.debug("ACKing block %d again, just in case" % pkt.blocknumber) - self.context.sendAck(pkt.blocknumber) + log.warn("dropping duplicate block %d" % pkt.blocknumber) + self.context.metrics.add_dup(pkt.blocknumber) + log.debug("ACKing block %d again, just in case" % pkt.blocknumber) + self.sendACK(pkt.blocknumber) else: # FIXME: should we be more tolerant and just discard instead? msg = "Whoa! Received future block %d but expected %d" \ % (pkt.blocknumber, self.context.next_block) - logger.error(msg) + log.error(msg) raise TftpException, msg # Default is to ack - return TftpStateSentACK(self.context) + return TftpStateExpectDAT(self.context) -class TftpStateSentWRQ(TftpStateUpload): - """Just sent an WRQ packet for an upload.""" +class TftpStateServerRecvRRQ(TftpState): + """This class represents the state of the TFTP server when it has just + received an RRQ packet.""" + def handle(self, pkt, raddress, rport): + "Handle an initial RRQ packet as a server." + log.debug("In TftpStateServerRecvRRQ.handle") + sendoack = self.serverInitial(pkt, raddress, rport) + path = self.context.root + os.sep + self.context.file_to_transfer + log.info("Opening file %s for reading" % path) + if os.path.exists(path): + # Note: Open in binary mode for win32 portability, since win32 + # blows. + self.context.fileobj = open(path, "rb") + elif self.dyn_file_func: + log.debug("No such file %s but using dyn_file_func" % path) + self.context.fileobj = \ + self.dyn_file_func(self.context.file_to_transfer) + else: + send.sendError(TftpErrors.FileNotFound) + raise TftpException, "File not found: %s" % path + + # Options negotiation. + if sendoack: + self.sendOACK() + return TftpStateServerOACK(self.context) + else: + log.debug("No requested options, starting send...") + self.context.pending_complete = self.sendDAT() + return TftpStateExpectACK(self.context) + + # Note, we don't have to check any other states in this method, that's + # up to the caller. + +class TftpStateServerRecvWRQ(TftpState): + """This class represents the state of the TFTP server when it has just + received a WRQ packet.""" + def handle(self, pkt, raddress, rport): + "Handle an initial WRQ packet as a server." + log.debug("In TftpStateServerRecvWRQ.handle") + sendoack = self.serverInitial(pkt, raddress, rport) + path = self.context.root + os.sep + self.context.file_to_transfer + log.info("Opening file %s for writing" % path) + if os.path.exists(path): + # FIXME: correct behavior? + log.warn("File %s exists already, overwriting...") + # FIXME: I think we should upload to a temp file and not overwrite the + # existing file until the file is successfully uploaded. + self.context.fileobj = open(path, "wb") + + # Options negotiation. + if sendoack: + log.debug("Sending OACK to client") + self.sendOACK() + else: + log.debug("No requested options, starting send...") + self.sendACK() + # We may have sent an OACK, but we're expecting a DAT as the response + # to either the OACK or an ACK, so lets unconditionally use the + # TftpStateExpectDAT state. + return TftpStateExpectDAT(self.context) + + # Note, we don't have to check any other states in this method, that's + # up to the caller. + +class TftpStateServerStart(TftpState): + """The start state for the server.""" def handle(self, pkt, raddress, rport): """Handle a packet we just received.""" - if not self.context.tidport: - self.context.tidport = rport - logger.debug("Set remote port for session to %s" % rport) - - # If we're going to successfully transfer the file, then we should see - # either an OACK for accepted options, or an ACK to ignore options. - if isinstance(pkt, TftpPacketOACK): - logger.info("received OACK from server") - try: - self.handleOACK(pkt) - except TftpException, err: - logger.error("failed to negotiate options") - self.context.sendError(TftpErrors.FailedNegotiation) - raise - else: - logger.debug("sending first DAT packet") - fin = self.context.sendDat() - if fin: - logger.info("Add done") - return None - else: - logger.debug("Changing state to TftpStateSentDAT") - return TftpStateSentDAT(self.context) - - elif isinstance(pkt, TftpPacketACK): - logger.info("received ACK from server") - logger.debug("apparently the server ignored our options") - # The block number should be zero. - if pkt.blocknumber == 0: - logger.debug("ack blocknumber is zero as expected") - logger.debug("sending first DAT packet") - fin = self.context.sendDat() - if fin: - logger.info("Add done") - return None - else: - logger.debug("Changing state to TftpStateSentDAT") - return TftpStateSentDAT(self.context) - else: - logger.warn("discarding ACK to block %s" % pkt.blocknumber) - logger.debug("still waiting for valid response from server") - return self - - elif isinstance(pkt, TftpPacketERR): - self.context.sendError(TftpErrors.IllegalTftpOp) - raise TftpException, "Received ERR from server: " + str(pkt) - - elif isinstance(pkt, TftpPacketRRQ): - self.context.sendError(TftpErrors.IllegalTftpOp) - raise TftpException, "Received RRQ from server while in upload" - - elif isinstance(pkt, TftpPacketDAT): - self.context.sendError(TftpErrors.IllegalTftpOp) - raise TftpException, "Received DAT from server while in upload" - + log.debug("In TftpStateServerStart.handle") + if isinstance(pkt, TftpPacketRRQ): + log.debug("handling an RRQ packet") + return TftpStateServerRecvRRQ(self.context).handle(pkt, + raddress, + rport) + elif isinstance(pkt, TftpPacketWRQ): + log.debug("handling a WRQ packet") + return TftpStateServerRecvWRQ(self.context).handle(pkt, + raddress, + rport) else: - self.context.sendError(TftpErrors.IllegalTftpOp) - raise TftpException, "Received unknown packet type from server: " + str(pkt) + self.sendError(TftpErrors.IllegalTftpOp) + raise TftpException, \ + "Invalid packet to begin up/download: %s" % pkt - # By default, no state change. - return self - -class TftpStateSentDAT(TftpStateUpload): +class TftpStateExpectACK(TftpState): """This class represents the state of the transfer when a DAT was just sent, and we are waiting for an ACK from the server. This class is the same one used by the client during the upload, and the server during the download.""" - -class TftpStateSentRRQ(TftpStateDownload): - """Just sent an RRQ packet.""" def handle(self, pkt, raddress, rport): - """Handle the packet in response to an RRQ to the server.""" - if not self.context.tidport: - self.context.tidport = rport - logger.debug("Set remote port for session to %s" % rport) + "Handle a packet, hopefully an ACK since we just sent a DAT." + if isinstance(pkt, TftpPacketACK): + log.info("Received ACK for packet %d" % pkt.blocknumber) + # Is this an ack to the one we just sent? + if self.context.next_block == pkt.blocknumber: + if self.context.pending_complete: + log.info("Received ACK to final DAT, we're done.") + return None + else: + log.debug("Good ACK, sending next DAT") + self.context.pending_complete = self.sendDAT() + + elif pkt.blocknumber < self.context.next_block: + self.context.metrics.add_dup(pkt.blocknumber) - # Now check the packet type and dispatch it properly. - if isinstance(pkt, TftpPacketOACK): - logger.info("received OACK from server") - try: - self.handleOACK(pkt) - except TftpException, err: - logger.error("failed to negotiate options: %s" % str(err)) - self.context.sendError(TftpErrors.FailedNegotiation) - raise else: - logger.debug("sending ACK to OACK") - - self.context.sendAck(blocknumber=0) - - logger.debug("Changing state to TftpStateSentACK") - return TftpStateSentACK(self.context) - - elif isinstance(pkt, TftpPacketDAT): - # If there are any options set, then the server didn't honour any - # of them. - logger.info("received DAT from server") - if self.context.options: - logger.info("server ignored options, falling back to defaults") - self.context.options = { 'blksize': DEF_BLKSIZE } - return self.handleDat(pkt) - - # Every other packet type is a problem. - elif isinstance(recvpkt, TftpPacketACK): - # Umm, we ACK, the server doesn't. - self.context.sendError(TftpErrors.IllegalTftpOp) - raise TftpException, "Received ACK from server while in download" - - elif isinstance(recvpkt, TftpPacketWRQ): - self.context.sendError(TftpErrors.IllegalTftpOp) - raise TftpException, "Received WRQ from server while in download" - - elif isinstance(recvpkt, TftpPacketERR): - self.context.sendError(TftpErrors.IllegalTftpOp) - raise TftpException, "Received ERR from server: " + str(recvpkt) - + log.warn("Oooh, time warp. Received ACK to packet we " + "didn't send yet. Discarding.") + self.context.metrics.errors += 1 + return self + elif isinstance(pkt, TftpPacketERR): + log.error("Received ERR packet from peer: %s" % str(pkt)) + raise TftpException, \ + "Received ERR packet from peer: %s" % str(pkt) else: - self.context.sendError(TftpErrors.IllegalTftpOp) - raise TftpException, "Received unknown packet type from server: " + str(recvpkt) + log.warn("Discarding unsupported packet: %s" % str(pkt)) + return self - # By default, no state change. - return self - -class TftpStateSentACK(TftpStateDownload): +class TftpStateExpectDAT(TftpState): """Just sent an ACK packet. Waiting for DAT.""" def handle(self, pkt, raddress, rport): """Handle the packet in response to an ACK, which should be a DAT.""" if isinstance(pkt, TftpPacketDAT): return self.handleDat(pkt) + # Every other packet type is a problem. + elif isinstance(recvpkt, TftpPacketACK): + # Umm, we ACK, you don't. + self.sendError(TftpErrors.IllegalTftpOp) + raise TftpException, "Received ACK from peer when expecting DAT" + + elif isinstance(recvpkt, TftpPacketWRQ): + self.sendError(TftpErrors.IllegalTftpOp) + raise TftpException, "Received WRQ from peer when expecting DAT" + + elif isinstance(recvpkt, TftpPacketERR): + self.sendError(TftpErrors.IllegalTftpOp) + raise TftpException, "Received ERR from peer: " + str(recvpkt) + + else: + self.sendError(TftpErrors.IllegalTftpOp) + raise TftpException, "Received unknown packet type from peer: " + str(recvpkt) + +class TftpStateSentWRQ(TftpState): + """Just sent an WRQ packet for an upload.""" + def handle(self, pkt, raddress, rport): + """Handle a packet we just received.""" + if not self.context.tidport: + self.context.tidport = rport + log.debug("Set remote port for session to %s" % rport) + + # If we're going to successfully transfer the file, then we should see + # either an OACK for accepted options, or an ACK to ignore options. + if isinstance(pkt, TftpPacketOACK): + log.info("received OACK from server") + try: + self.handleOACK(pkt) + except TftpException, err: + log.error("failed to negotiate options") + self.sendError(TftpErrors.FailedNegotiation) + raise + else: + log.debug("sending first DAT packet") + self.context.pending_complete = self.sendDAT() + log.debug("Changing state to TftpStateExpectACK") + return TftpStateExpectACK(self.context) + + elif isinstance(pkt, TftpPacketACK): + log.info("received ACK from server") + log.debug("apparently the server ignored our options") + # The block number should be zero. + if pkt.blocknumber == 0: + log.debug("ack blocknumber is zero as expected") + log.debug("sending first DAT packet") + self.pending_complete = self.context.sendDAT() + log.debug("Changing state to TftpStateExpectACK") + return TftpStateExpectACK(self.context) + else: + log.warn("discarding ACK to block %s" % pkt.blocknumber) + log.debug("still waiting for valid response from server") + return self + + elif isinstance(pkt, TftpPacketERR): + self.sendError(TftpErrors.IllegalTftpOp) + raise TftpException, "Received ERR from server: " + str(pkt) + + elif isinstance(pkt, TftpPacketRRQ): + self.sendError(TftpErrors.IllegalTftpOp) + raise TftpException, "Received RRQ from server while in upload" + + elif isinstance(pkt, TftpPacketDAT): + self.sendError(TftpErrors.IllegalTftpOp) + raise TftpException, "Received DAT from server while in upload" + + else: + self.sendError(TftpErrors.IllegalTftpOp) + raise TftpException, "Received unknown packet type from server: " + str(pkt) + + # By default, no state change. + return self + +class TftpStateSentRRQ(TftpState): + """Just sent an RRQ packet.""" + def handle(self, pkt, raddress, rport): + """Handle the packet in response to an RRQ to the server.""" + if not self.context.tidport: + self.context.tidport = rport + log.debug("Set remote port for session to %s" % rport) + + # Now check the packet type and dispatch it properly. + if isinstance(pkt, TftpPacketOACK): + log.info("received OACK from server") + try: + self.handleOACK(pkt) + except TftpException, err: + log.error("failed to negotiate options: %s" % str(err)) + self.sendError(TftpErrors.FailedNegotiation) + raise + else: + log.debug("sending ACK to OACK") + + self.sendACK(blocknumber=0) + + log.debug("Changing state to TftpStateExpectDAT") + return TftpStateExpectDAT(self.context) + + elif isinstance(pkt, TftpPacketDAT): + # If there are any options set, then the server didn't honour any + # of them. + log.info("received DAT from server") + if self.context.options: + log.info("server ignored options, falling back to defaults") + self.context.options = { 'blksize': DEF_BLKSIZE } + return self.handleDat(pkt) + # Every other packet type is a problem. elif isinstance(recvpkt, TftpPacketACK): # Umm, we ACK, the server doesn't. - self.context.sendError(TftpErrors.IllegalTftpOp) + self.sendError(TftpErrors.IllegalTftpOp) raise TftpException, "Received ACK from server while in download" elif isinstance(recvpkt, TftpPacketWRQ): - self.context.sendError(TftpErrors.IllegalTftpOp) + self.sendError(TftpErrors.IllegalTftpOp) raise TftpException, "Received WRQ from server while in download" elif isinstance(recvpkt, TftpPacketERR): - self.context.sendError(TftpErrors.IllegalTftpOp) + self.sendError(TftpErrors.IllegalTftpOp) raise TftpException, "Received ERR from server: " + str(recvpkt) else: - self.context.sendError(TftpErrors.IllegalTftpOp) + self.sendError(TftpErrors.IllegalTftpOp) raise TftpException, "Received unknown packet type from server: " + str(recvpkt) + + # By default, no state change. + return self