From a6a18c178b4b60d49baa42c55fb7948d955de263 Mon Sep 17 00:00:00 2001 From: "Michael P. Soulier" Date: Sun, 16 Aug 2009 19:44:57 -0400 Subject: [PATCH] First successful download with both client and server. --- .gitignore | 2 + bin/tftpy_client.py | 10 ++-- bin/tftpy_server.py | 11 +++- tftpy/TftpClient.py | 5 ++ tftpy/TftpPacketTypes.py | 18 +------ tftpy/TftpServer.py | 32 +++++------ tftpy/TftpStates.py | 112 +++++++++++++++++++++++++++------------ 7 files changed, 120 insertions(+), 70 deletions(-) diff --git a/.gitignore b/.gitignore index 0d20b64..6a211b7 100644 --- a/.gitignore +++ b/.gitignore @@ -1 +1,3 @@ *.pyc +*.swp +tags diff --git a/bin/tftpy_client.py b/bin/tftpy_client.py index 4671f6f..ac09397 100755 --- a/bin/tftpy_client.py +++ b/bin/tftpy_client.py @@ -21,7 +21,7 @@ def main(): '--upload', help='filename to upload') parser.add_option('-b', - '--blocksize', + '--blksize', help='udp packet size to use (default: 512)', default=512) parser.add_option('-o', @@ -76,11 +76,11 @@ def main(): else: tftpy.setLogLevel(logging.INFO) - progresshook = Progress(tftpy.logger.info).progresshook + progresshook = Progress(tftpy.log.info).progresshook tftp_options = {} - if options.blocksize: - tftp_options['blksize'] = int(options.blocksize) + if options.blksize: + tftp_options['blksize'] = int(options.blksize) if options.tsize: tftp_options['tsize'] = 0 @@ -103,6 +103,8 @@ def main(): except tftpy.TftpException, err: sys.stderr.write("%s\n" % str(err)) sys.exit(1) + except KeyboardInterrupt: + pass if __name__ == '__main__': main() diff --git a/bin/tftpy_server.py b/bin/tftpy_server.py index d424deb..b8aec50 100755 --- a/bin/tftpy_server.py +++ b/bin/tftpy_server.py @@ -20,8 +20,8 @@ def main(): parser.add_option('-r', '--root', type='string', - help='path to serve from (default: /tftpboot)', - default="/tftpboot") + help='path to serve from', + default=None) parser.add_option('-d', '--debug', action='store_true', @@ -34,9 +34,16 @@ def main(): else: tftpy.setLogLevel(logging.INFO) + if not options.root: + parser.print_help() + sys.exit(1) + server = tftpy.TftpServer(options.root) try: server.listen(options.ip, options.port) + except tftpy.TftpException, err: + sys.stderr.write("%s\n" % str(err)) + sys.exit(1) except KeyboardInterrupt: pass diff --git a/tftpy/TftpClient.py b/tftpy/TftpClient.py index 89843f1..aacdb2a 100644 --- a/tftpy/TftpClient.py +++ b/tftpy/TftpClient.py @@ -35,6 +35,11 @@ class TftpClient(TftpSession): SOCK_TIMEOUT setting, which is the amount of time that the client will wait for a receive packet to arrive.""" # We're downloading. + log.debug("Creating download context with the following params:") + log.debug("host = %s, port = %s, filename = %s, output = %s" + % (self.host, self.iport, filename, output)) + log.debug("options = %s, packethook = %s, timeout = %s" + % (self.options, packethook, timeout)) self.context = TftpContextClientDownload(self.host, self.iport, filename, diff --git a/tftpy/TftpPacketTypes.py b/tftpy/TftpPacketTypes.py index b9328c5..76fd06e 100644 --- a/tftpy/TftpPacketTypes.py +++ b/tftpy/TftpPacketTypes.py @@ -4,22 +4,8 @@ from TftpShared import * class TftpSession(object): """This class is the base class for the tftp client and server. Any shared code should be in this class.""" - - def __init__(self): - """Class constructor.""" - self.options = None - self.state = None - self.dups = 0 - self.errors = 0 - - def senderror(self, sock, errorcode, address, port): - """This method uses the socket passed, and uses the errorcode, address - and port to compose and send an error packet.""" - log.debug("In senderror, being asked to send error %d to %s:%s" - % (errorcode, address, port)) - errpkt = TftpPacketERR() - errpkt.errorcode = errorcode - sock.sendto(errpkt.encode().buffer, (address, port)) + # FIXME: do we need this anymore? + pass class TftpPacketWithOptions(object): """This class exists to permit some TftpPacket subclasses to share code diff --git a/tftpy/TftpServer.py b/tftpy/TftpServer.py index ad781a2..d2353a9 100644 --- a/tftpy/TftpServer.py +++ b/tftpy/TftpServer.py @@ -3,6 +3,7 @@ import select from TftpShared import * from TftpPacketTypes import * from TftpPacketFactory import * +from TftpStates import * class TftpServer(TftpSession): """This class implements a tftp server object.""" @@ -19,9 +20,9 @@ class TftpServer(TftpSession): # FIXME: What about multiple roots? self.root = os.path.abspath(tftproot) self.dyn_file_func = dyn_file_func - # A dict of handlers, where each session is keyed by a string like + # A dict of sessions, where each session is keyed by a string like # ip:tid for the remote end. - self.handlers = {} + self.sessions = {} if os.path.exists(self.root): log.debug("tftproot %s does exist" % self.root) @@ -89,8 +90,8 @@ class TftpServer(TftpSession): log.debug("Read %d bytes" % len(buffer)) recvpkt = tftp_factory.parse(buffer) - # FIXME: Is this the best way to do a session key? What - # about symmetric udp? + # Forge a session key based on the client's IP and port, + # which should safely work through NAT. key = "%s:%s" % (raddress, rport) if not self.sessions.has_key(key): @@ -108,36 +109,37 @@ class TftpServer(TftpSession): else: # Must find the owner of this traffic. - for key in self.session: - if readysock == self.session[key].sock: + for key in self.sessions: + if readysock == self.sessions[key].sock: + log.info("Matched input to session key %s" + % key) try: - self.session[key].cycle() - if self.session[key].state == None: + self.sessions[key].cycle() + if self.sessions[key].state == None: log.info("Successful transfer.") deletion_list.append(key) - break except TftpException, err: deletion_list.append(key) log.error("Fatal exception thrown from " - "handler: %s" % str(err)) + "session %s: %s" + % (key, str(err))) + break else: log.error("Can't find the owner for this packet. " "Discarding.") - log.debug("Looping on all handlers to check for timeouts") + log.debug("Looping on all sessions to check for timeouts") now = time.time() for key in self.sessions: try: self.sessions[key].checkTimeout(now) except TftpException, err: - log.error("Fatal exception thrown from handler: %s" - % str(err)) + log.error(str(err)) deletion_list.append(key) log.debug("Iterating deletion list.") for key in deletion_list: if self.sessions.has_key(key): - log.debug("Deleting handler %s" % key) + log.debug("Deleting session %s" % key) del self.sessions[key] - deletion_list = [] diff --git a/tftpy/TftpStates.py b/tftpy/TftpStates.py index 3d71e16..b3d5205 100644 --- a/tftpy/TftpStates.py +++ b/tftpy/TftpStates.py @@ -1,7 +1,7 @@ from TftpShared import * from TftpPacketTypes import * from TftpPacketFactory import * -import socket, time +import socket, time, os ############################################################################### # Utility classes @@ -39,10 +39,10 @@ class TftpMetrics(object): """This method adds a dup for a block number to the metrics.""" log.debug("Recording a dup for block %d" % blocknumber) if self.dups.has_key(blocknumber): - self.dups[pkt.blocknumber] += 1 + self.dups[blocknumber] += 1 else: - self.dups[pkt.blocknumber] = 1 - tftpassert(self.dups[pkt.blocknumber] < MAX_DUPS, + self.dups[blocknumber] = 1 + tftpassert(self.dups[blocknumber] < MAX_DUPS, "Max duplicates for block %d reached" % blocknumber) ############################################################################### @@ -73,10 +73,14 @@ class TftpContext(object): self.metrics = TftpMetrics() # Flag when the transfer is pending completion. self.pending_complete = False + # Time when this context last received any traffic. + self.last_update = 0 def checkTimeout(self, now): - # FIXME - pass + """Compare current time with last_update time, and raise an exception + if we're over SOCK_TIMEOUT time.""" + if now - self.last_update > SOCK_TIMEOUT: + raise TftpException, "Timeout waiting for traffic" def start(self): return NotImplementedError, "Abstract method" @@ -126,6 +130,8 @@ class TftpContext(object): # Ok, we've received a packet. Log it. log.debug("Received %d bytes from %s:%s" % (len(buffer), raddress, rport)) + # And update our last updated time. + self.last_update = time.time() # Decode it. recvpkt = self.factory.parse(buffer) @@ -160,17 +166,23 @@ class TftpContextServer(TftpContext): timeout) # At this point we have no idea if this is a download or an upload. We # need to let the start state determine that. - self.state = TftpStateServerStart() + self.state = TftpStateServerStart(self) self.root = root self.dyn_file_func = dyn_file_func + # In a server, the tidport is the same as the port. This is also true + # with symmetric UDP, which we haven't implemented yet. + self.tidport = port def start(self, buffer): """Start the state cycle. Note that the server context receives an - initial packet in its start method.""" - log.debug("TftpContextServer.start() - pkt = %s" % pkt) - + initial packet in its start method. Also note that the server does not + loop on cycle(), as it expects the TftpServer object to manage + that.""" + log.debug("In TftpContextServer.start") self.metrics.start_time = time.time() log.debug("set metrics.start_time to %s" % self.metrics.start_time) + # And update our last updated time. + self.last_update = time.time() pkt = self.factory.parse(buffer) log.debug("TftpContextServer.start() - factory returned a %s" % pkt) @@ -181,16 +193,19 @@ class TftpContextServer(TftpContext): self.host, self.port) - try: - while self.state: - log.debug("state is %s" % self.state) - self.cycle() - finally: - self.fileobj.close() + # FIXME + # How do we ensure that the server closes files, even on error? class TftpContextClientUpload(TftpContext): """The upload context for the client during an upload.""" - def __init__(self, host, port, filename, input, options, packethook, timeout): + def __init__(self, + host, + port, + filename, + input, + options, + packethook, + timeout): TftpContext.__init__(self, host, port, @@ -234,14 +249,22 @@ class TftpContextClientUpload(TftpContext): class TftpContextClientDownload(TftpContext): """The download context for the client during a download.""" - def __init__(self, host, port, filename, output, options, packethook, timeout): + def __init__(self, + host, + port, + filename, + output, + options, + packethook, + timeout): TftpContext.__init__(self, host, port, - filename, - options, - packethook, timeout) + # FIXME: should we refactor setting of these params? + self.file_to_transfer = filename + self.options = options + self.packethook = packethook # FIXME - need to support alternate return formats than files? # File-like objects would be ideal, ala duck-typing. self.fileobj = open(output, "wb") @@ -327,12 +350,21 @@ class TftpState(object): if option == 'blksize': # Make sure it's valid. if int(options[option]) > MAX_BLKSIZE: + log.info("Client requested blksize greater than %d " + "setting to maximum" % MAX_BLKSIZE) accepted_options[option] = MAX_BLKSIZE - elif option == 'tsize': - log.debug("tsize option is set") - accepted_options['tsize'] = 1 + elif int(options[option]) < MIN_BLKSIZE: + log.info("Client requested blksize less than %d " + "setting to minimum" % MIN_BLKSIZE) + accepted_options[option] = MIN_BLKSIZE else: - log.info("Dropping unsupported option '%s'" % option) + accepted_options[option] = options[option] + elif option == 'tsize': + log.debug("tsize option is set") + accepted_options['tsize'] = 1 + else: + log.info("Dropping unsupported option '%s'" % option) + log.debug("returning these accepted options: %s" % accepted_options) return accepted_options def serverInitial(self, pkt, raddress, rport): @@ -388,15 +420,16 @@ class TftpState(object): finished.""" finished = False blocknumber = self.context.next_block + tftpassert( blocknumber > 0, "There is no block zero!" ) if not resend: blksize = int(self.context.options['blksize']) buffer = self.context.fileobj.read(blksize) log.debug("Read %d bytes into buffer" % len(buffer)) if len(buffer) < blksize: - log.info("Reached EOF on file %s" % self.context.input) + log.info("Reached EOF on file %s" + % self.context.file_to_transfer) finished = True - self.context.next_block += 1 - self.bytes += len(buffer) + self.context.metrics.bytes += len(buffer) else: log.warn("Resending block number %d" % blocknumber) dat = TftpPacketDAT() @@ -413,7 +446,8 @@ class TftpState(object): """This method sends an ack packet to the block number specified. If none is specified, it defaults to the next_block property in the parent context.""" - if not blocknumber: + log.debug("in sendACK, blocknumber is %s" % blocknumber) + if blocknumber is None: blocknumber = self.context.next_block log.info("sending ack to block %d" % blocknumber) ackpkt = TftpPacketACK() @@ -435,9 +469,9 @@ class TftpState(object): def sendOACK(self): """This method sends an OACK packet with the options from the current context.""" - log.debug("In sendOACK with options %s" % options) + log.debug("In sendOACK with options %s" % self.context.options) pkt = TftpPacketOACK() - pkt.options = self.options + pkt.options = self.context.options self.context.sock.sendto(pkt.encode().buffer, (self.context.host, self.context.tidport)) @@ -464,6 +498,10 @@ class TftpState(object): return None elif pkt.blocknumber < self.context.next_block: + if pkt.blocknumber == 0: + log.warn("There is no block zero!") + self.sendError(TftpErrors.IllegalTftpOp) + raise TftpException, "There is no block zero!" log.warn("dropping duplicate block %d" % pkt.blocknumber) self.context.metrics.add_dup(pkt.blocknumber) log.debug("ACKing block %d again, just in case" % pkt.blocknumber) @@ -502,12 +540,17 @@ class TftpStateServerRecvRRQ(TftpState): # Options negotiation. if sendoack: + # Note, next_block is 0 here since that's the proper + # acknowledgement to an OACK. + # FIXME: perhaps we do need a TftpStateExpectOACK class... self.sendOACK() - return TftpStateServerOACK(self.context) else: + self.context.next_block = 1 log.debug("No requested options, starting send...") self.context.pending_complete = self.sendDAT() - return TftpStateExpectACK(self.context) + # Note, we expect an ack regardless of whether we sent a DAT or an + # OACK. + return TftpStateExpectACK(self.context) # Note, we don't have to check any other states in this method, that's # up to the caller. @@ -579,6 +622,9 @@ class TftpStateExpectACK(TftpState): return None else: log.debug("Good ACK, sending next DAT") + self.context.next_block += 1 + log.debug("Incremented next_block to %d" + % (self.context.next_block)) self.context.pending_complete = self.sendDAT() elif pkt.blocknumber < self.context.next_block: