From add444006ca53d1469ef4f333e9bbbaea05a8ac1 Mon Sep 17 00:00:00 2001 From: "Michael P. Soulier" Date: Sat, 23 Jul 2011 23:20:53 -0400 Subject: [PATCH] Fixes issue #23, breaking up TftpStates into TftpStates and TftpContexts. --- tftpy/TftpClient.py | 5 +- tftpy/TftpContexts.py | 385 ++++++++++++++++++++++++++++++++++++++++++ tftpy/TftpServer.py | 4 +- tftpy/TftpStates.py | 376 +---------------------------------------- tftpy/__init__.py | 2 + 5 files changed, 394 insertions(+), 378 deletions(-) create mode 100644 tftpy/TftpContexts.py diff --git a/tftpy/TftpClient.py b/tftpy/TftpClient.py index f9250bf..b8cfa34 100644 --- a/tftpy/TftpClient.py +++ b/tftpy/TftpClient.py @@ -4,8 +4,9 @@ performed via a standard logging object set in TftpShared.""" import time, types from TftpShared import * -from TftpPacketFactory import * -from TftpStates import TftpContextClientDownload, TftpContextClientUpload +from TftpPacketTypes import * +from TftpPacketFactory import TftpPacketFactory +from TftpContexts import TftpContextClientDownload, TftpContextClientUpload class TftpClient(TftpSession): """This class is an implementation of a tftp client. Once instantiated, a diff --git a/tftpy/TftpContexts.py b/tftpy/TftpContexts.py new file mode 100644 index 0000000..b317c5e --- /dev/null +++ b/tftpy/TftpContexts.py @@ -0,0 +1,385 @@ +"""This module implements all contexts for state handling during uploads and +downloads, the main interface to which being the TftpContext base class. + +The concept is simple. Each context object represents a single upload or +download, and the state object in the context object represents the current +state of that transfer. The state object has a handle() method that expects +the next packet in the transfer, and returns a state object until the transfer +is complete, at which point it returns None. That is, unless there is a fatal +error, in which case a TftpException is returned instead.""" + +from TftpShared import * +from TftpPacketTypes import * +from TftpPacketFactory import TftpPacketFactory +from TftpStates import * +import socket, time, sys + +############################################################################### +# Utility classes +############################################################################### + +class TftpMetrics(object): + """A class representing metrics of the transfer.""" + def __init__(self): + # Bytes transferred + self.bytes = 0 + # Bytes re-sent + self.resent_bytes = 0 + # Duplicate packets received + self.dups = {} + self.dupcount = 0 + # Times + self.start_time = 0 + self.end_time = 0 + self.duration = 0 + # Rates + self.bps = 0 + self.kbps = 0 + # Generic errors + self.errors = 0 + + def compute(self): + # Compute transfer time + self.duration = self.end_time - self.start_time + if self.duration == 0: + self.duration = 1 + log.debug("TftpMetrics.compute: duration is %s" % self.duration) + self.bps = (self.bytes * 8.0) / self.duration + self.kbps = self.bps / 1024.0 + log.debug("TftpMetrics.compute: kbps is %s" % self.kbps) + for key in self.dups: + self.dupcount += self.dups[key] + + def add_dup(self, blocknumber): + """This method adds a dup for a block number to the metrics.""" + log.debug("Recording a dup for block %d" % blocknumber) + if self.dups.has_key(blocknumber): + self.dups[blocknumber] += 1 + else: + self.dups[blocknumber] = 1 + tftpassert(self.dups[blocknumber] < MAX_DUPS, + "Max duplicates for block %d reached" % blocknumber) + +############################################################################### +# Context classes +############################################################################### + +class TftpContext(object): + """The base class of the contexts.""" + + def __init__(self, host, port, timeout, dyn_file_func=None): + """Constructor for the base context, setting shared instance + variables.""" + self.file_to_transfer = None + self.fileobj = None + self.options = None + self.packethook = None + self.sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) + self.sock.settimeout(timeout) + self.timeout = timeout + self.state = None + self.next_block = 0 + self.factory = TftpPacketFactory() + # Note, setting the host will also set self.address, as it's a property. + self.host = host + self.port = port + # The port associated with the TID + self.tidport = None + # Metrics + self.metrics = TftpMetrics() + # Fluag when the transfer is pending completion. + self.pending_complete = False + # Time when this context last received any traffic. + # FIXME: does this belong in metrics? + self.last_update = 0 + # The last packet we sent, if applicable, to make resending easy. + self.last_pkt = None + self.dyn_file_func = dyn_file_func + # Count the number of retry attempts. + self.retry_count = 0 + + def getBlocksize(self): + """Fetch the current blocksize for this session.""" + return int(self.options.get('blksize', 512)) + + def __del__(self): + """Simple destructor to try to call housekeeping in the end method if + not called explicitely. Leaking file descriptors is not a good + thing.""" + self.end() + + def checkTimeout(self, now): + """Compare current time with last_update time, and raise an exception + if we're over the timeout time.""" + log.debug("checking for timeout on session %s" % self) + if now - self.last_update > self.timeout: + raise TftpTimeout, "Timeout waiting for traffic" + + def start(self): + raise NotImplementedError, "Abstract method" + + def end(self): + """Perform session cleanup, since the end method should always be + called explicitely by the calling code, this works better than the + destructor.""" + log.debug("in TftpContext.end") + if self.fileobj is not None and not self.fileobj.closed: + log.debug("self.fileobj is open - closing") + self.fileobj.close() + + def gethost(self): + "Simple getter method for use in a property." + return self.__host + + def sethost(self, host): + """Setter method that also sets the address property as a result + of the host that is set.""" + self.__host = host + self.address = socket.gethostbyname(host) + + host = property(gethost, sethost) + + def setNextBlock(self, block): + if block >= 2 ** 16: + log.debug("Block number rollover to 0 again") + block = 0 + self.__eblock = block + + def getNextBlock(self): + return self.__eblock + + next_block = property(getNextBlock, setNextBlock) + + def cycle(self): + """Here we wait for a response from the server after sending it + something, and dispatch appropriate action to that response.""" + try: + (buffer, (raddress, rport)) = self.sock.recvfrom(MAX_BLKSIZE) + except socket.timeout, err: + log.warn("Timeout waiting for traffic, retrying...") + raise TftpTimeout, "Timed-out waiting for traffic" + + # Ok, we've received a packet. Log it. + log.debug("Received %d bytes from %s:%s" + % (len(buffer), raddress, rport)) + # And update our last updated time. + self.last_update = time.time() + + # Decode it. + recvpkt = self.factory.parse(buffer) + + # Check for known "connection". + if raddress != self.address: + log.warn("Received traffic from %s, expected host %s. Discarding" + % (raddress, self.host)) + + if self.tidport and self.tidport != rport: + log.warn("Received traffic from %s:%s but we're " + "connected to %s:%s. Discarding." + % (raddress, rport, + self.host, self.tidport)) + + # If there is a packethook defined, call it. We unconditionally + # pass all packets, it's up to the client to screen out different + # kinds of packets. This way, the client is privy to things like + # negotiated options. + if self.packethook: + self.packethook(recvpkt) + + # And handle it, possibly changing state. + self.state = self.state.handle(recvpkt, raddress, rport) + # If we didn't throw any exceptions here, reset the retry_count to + # zero. + self.retry_count = 0 + +class TftpContextServer(TftpContext): + """The context for the server.""" + def __init__(self, host, port, timeout, root, dyn_file_func=None): + TftpContext.__init__(self, + host, + port, + timeout, + dyn_file_func + ) + # At this point we have no idea if this is a download or an upload. We + # need to let the start state determine that. + self.state = TftpStateServerStart(self) + self.root = root + self.dyn_file_func = dyn_file_func + + def __str__(self): + return "%s:%s %s" % (self.host, self.port, self.state) + + def start(self, buffer): + """Start the state cycle. Note that the server context receives an + initial packet in its start method. Also note that the server does not + loop on cycle(), as it expects the TftpServer object to manage + that.""" + log.debug("In TftpContextServer.start") + self.metrics.start_time = time.time() + log.debug("Set metrics.start_time to %s" % self.metrics.start_time) + # And update our last updated time. + self.last_update = time.time() + + pkt = self.factory.parse(buffer) + log.debug("TftpContextServer.start() - factory returned a %s" % pkt) + + # Call handle once with the initial packet. This should put us into + # the download or the upload state. + self.state = self.state.handle(pkt, + self.host, + self.port) + + def end(self): + """Finish up the context.""" + TftpContext.end(self) + self.metrics.end_time = time.time() + log.debug("Set metrics.end_time to %s" % self.metrics.end_time) + self.metrics.compute() + +class TftpContextClientUpload(TftpContext): + """The upload context for the client during an upload. + Note: If input is a hyphen, then we will use stdin.""" + def __init__(self, + host, + port, + filename, + input, + options, + packethook, + timeout): + TftpContext.__init__(self, + host, + port, + timeout) + self.file_to_transfer = filename + self.options = options + self.packethook = packethook + if input == '-': + self.fileobj = sys.stdin + else: + self.fileobj = open(input, "rb") + + log.debug("TftpContextClientUpload.__init__()") + log.debug("file_to_transfer = %s, options = %s" % + (self.file_to_transfer, self.options)) + + def __str__(self): + return "%s:%s %s" % (self.host, self.port, self.state) + + def start(self): + log.info("Sending tftp upload request to %s" % self.host) + log.info(" filename -> %s" % self.file_to_transfer) + log.info(" options -> %s" % self.options) + + self.metrics.start_time = time.time() + log.debug("Set metrics.start_time to %s" % self.metrics.start_time) + + # FIXME: put this in a sendWRQ method? + pkt = TftpPacketWRQ() + pkt.filename = self.file_to_transfer + pkt.mode = "octet" # FIXME - shouldn't hardcode this + pkt.options = self.options + self.sock.sendto(pkt.encode().buffer, (self.host, self.port)) + self.next_block = 1 + self.last_pkt = pkt + # FIXME: should we centralize sendto operations so we can refactor all + # saving of the packet to the last_pkt field? + + self.state = TftpStateSentWRQ(self) + + while self.state: + try: + log.debug("State is %s" % self.state) + self.cycle() + except TftpTimeout, err: + log.error(str(err)) + self.retry_count += 1 + if self.retry_count >= TIMEOUT_RETRIES: + log.debug("hit max retries, giving up") + raise + else: + log.warn("resending last packet") + self.state.resendLast() + + def end(self): + """Finish up the context.""" + TftpContext.end(self) + self.metrics.end_time = time.time() + log.debug("Set metrics.end_time to %s" % self.metrics.end_time) + self.metrics.compute() + +class TftpContextClientDownload(TftpContext): + """The download context for the client during a download. + Note: If output is a hyphen, then the output will be sent to stdout.""" + def __init__(self, + host, + port, + filename, + output, + options, + packethook, + timeout): + TftpContext.__init__(self, + host, + port, + timeout) + # FIXME: should we refactor setting of these params? + self.file_to_transfer = filename + self.options = options + self.packethook = packethook + # FIXME - need to support alternate return formats than files? + # File-like objects would be ideal, ala duck-typing. + # If the filename is -, then use stdout + if output == '-': + self.fileobj = sys.stdout + else: + self.fileobj = open(output, "wb") + + log.debug("TftpContextClientDownload.__init__()") + log.debug("file_to_transfer = %s, options = %s" % + (self.file_to_transfer, self.options)) + + def __str__(self): + return "%s:%s %s" % (self.host, self.port, self.state) + + def start(self): + """Initiate the download.""" + log.info("Sending tftp download request to %s" % self.host) + log.info(" filename -> %s" % self.file_to_transfer) + log.info(" options -> %s" % self.options) + + self.metrics.start_time = time.time() + log.debug("Set metrics.start_time to %s" % self.metrics.start_time) + + # FIXME: put this in a sendRRQ method? + pkt = TftpPacketRRQ() + pkt.filename = self.file_to_transfer + pkt.mode = "octet" # FIXME - shouldn't hardcode this + pkt.options = self.options + self.sock.sendto(pkt.encode().buffer, (self.host, self.port)) + self.next_block = 1 + self.last_pkt = pkt + + self.state = TftpStateSentRRQ(self) + + while self.state: + try: + log.debug("State is %s" % self.state) + self.cycle() + except TftpTimeout, err: + log.error(str(err)) + self.retry_count += 1 + if self.retry_count >= TIMEOUT_RETRIES: + log.debug("hit max retries, giving up") + raise + else: + log.warn("resending last packet") + self.state.resendLast() + + def end(self): + """Finish up the context.""" + TftpContext.end(self) + self.metrics.end_time = time.time() + log.debug("Set metrics.end_time to %s" % self.metrics.end_time) + self.metrics.compute() diff --git a/tftpy/TftpServer.py b/tftpy/TftpServer.py index 46c662b..1efaba2 100644 --- a/tftpy/TftpServer.py +++ b/tftpy/TftpServer.py @@ -7,8 +7,8 @@ import socket, os, re, time, random import select from TftpShared import * from TftpPacketTypes import * -from TftpPacketFactory import * -from TftpStates import * +from TftpPacketFactory import TftpPacketFactory +from TftpContexts import TftpContextServer class TftpServer(TftpSession): """This class implements a tftp server object. Run the listen() method to diff --git a/tftpy/TftpStates.py b/tftpy/TftpStates.py index f3467f8..1e903e7 100644 --- a/tftpy/TftpStates.py +++ b/tftpy/TftpStates.py @@ -1,6 +1,5 @@ """This module implements all state handling during uploads and downloads, the -main interface to which being the TftpContext base class and the TftpState -base class. +main interface to which being the TftpState base class. The concept is simple. Each context object represents a single upload or download, and the state object in the context object represents the current @@ -11,378 +10,7 @@ error, in which case a TftpException is returned instead.""" from TftpShared import * from TftpPacketTypes import * -from TftpPacketFactory import * -import socket, time, os, sys - -############################################################################### -# Utility classes -############################################################################### - -class TftpMetrics(object): - """A class representing metrics of the transfer.""" - def __init__(self): - # Bytes transferred - self.bytes = 0 - # Bytes re-sent - self.resent_bytes = 0 - # Duplicate packets received - self.dups = {} - self.dupcount = 0 - # Times - self.start_time = 0 - self.end_time = 0 - self.duration = 0 - # Rates - self.bps = 0 - self.kbps = 0 - # Generic errors - self.errors = 0 - - def compute(self): - # Compute transfer time - self.duration = self.end_time - self.start_time - if self.duration == 0: - self.duration = 1 - log.debug("TftpMetrics.compute: duration is %s" % self.duration) - self.bps = (self.bytes * 8.0) / self.duration - self.kbps = self.bps / 1024.0 - log.debug("TftpMetrics.compute: kbps is %s" % self.kbps) - for key in self.dups: - self.dupcount += self.dups[key] - - def add_dup(self, blocknumber): - """This method adds a dup for a block number to the metrics.""" - log.debug("Recording a dup for block %d" % blocknumber) - if self.dups.has_key(blocknumber): - self.dups[blocknumber] += 1 - else: - self.dups[blocknumber] = 1 - tftpassert(self.dups[blocknumber] < MAX_DUPS, - "Max duplicates for block %d reached" % blocknumber) - -############################################################################### -# Context classes -############################################################################### - -class TftpContext(object): - """The base class of the contexts.""" - - def __init__(self, host, port, timeout, dyn_file_func=None): - """Constructor for the base context, setting shared instance - variables.""" - self.file_to_transfer = None - self.fileobj = None - self.options = None - self.packethook = None - self.sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) - self.sock.settimeout(timeout) - self.timeout = timeout - self.state = None - self.next_block = 0 - self.factory = TftpPacketFactory() - # Note, setting the host will also set self.address, as it's a property. - self.host = host - self.port = port - # The port associated with the TID - self.tidport = None - # Metrics - self.metrics = TftpMetrics() - # Fluag when the transfer is pending completion. - self.pending_complete = False - # Time when this context last received any traffic. - # FIXME: does this belong in metrics? - self.last_update = 0 - # The last packet we sent, if applicable, to make resending easy. - self.last_pkt = None - self.dyn_file_func = dyn_file_func - # Count the number of retry attempts. - self.retry_count = 0 - - def getBlocksize(self): - """Fetch the current blocksize for this session.""" - return int(self.options.get('blksize', 512)) - - def __del__(self): - """Simple destructor to try to call housekeeping in the end method if - not called explicitely. Leaking file descriptors is not a good - thing.""" - self.end() - - def checkTimeout(self, now): - """Compare current time with last_update time, and raise an exception - if we're over the timeout time.""" - log.debug("checking for timeout on session %s" % self) - if now - self.last_update > self.timeout: - raise TftpTimeout, "Timeout waiting for traffic" - - def start(self): - raise NotImplementedError, "Abstract method" - - def end(self): - """Perform session cleanup, since the end method should always be - called explicitely by the calling code, this works better than the - destructor.""" - log.debug("in TftpContext.end") - if self.fileobj is not None and not self.fileobj.closed: - log.debug("self.fileobj is open - closing") - self.fileobj.close() - - def gethost(self): - "Simple getter method for use in a property." - return self.__host - - def sethost(self, host): - """Setter method that also sets the address property as a result - of the host that is set.""" - self.__host = host - self.address = socket.gethostbyname(host) - - host = property(gethost, sethost) - - def setNextBlock(self, block): - if block >= 2 ** 16: - log.debug("Block number rollover to 0 again") - block = 0 - self.__eblock = block - - def getNextBlock(self): - return self.__eblock - - next_block = property(getNextBlock, setNextBlock) - - def cycle(self): - """Here we wait for a response from the server after sending it - something, and dispatch appropriate action to that response.""" - try: - (buffer, (raddress, rport)) = self.sock.recvfrom(MAX_BLKSIZE) - except socket.timeout, err: - log.warn("Timeout waiting for traffic, retrying...") - raise TftpTimeout, "Timed-out waiting for traffic" - - # Ok, we've received a packet. Log it. - log.debug("Received %d bytes from %s:%s" - % (len(buffer), raddress, rport)) - # And update our last updated time. - self.last_update = time.time() - - # Decode it. - recvpkt = self.factory.parse(buffer) - - # Check for known "connection". - if raddress != self.address: - log.warn("Received traffic from %s, expected host %s. Discarding" - % (raddress, self.host)) - - if self.tidport and self.tidport != rport: - log.warn("Received traffic from %s:%s but we're " - "connected to %s:%s. Discarding." - % (raddress, rport, - self.host, self.tidport)) - - # If there is a packethook defined, call it. We unconditionally - # pass all packets, it's up to the client to screen out different - # kinds of packets. This way, the client is privy to things like - # negotiated options. - if self.packethook: - self.packethook(recvpkt) - - # And handle it, possibly changing state. - self.state = self.state.handle(recvpkt, raddress, rport) - # If we didn't throw any exceptions here, reset the retry_count to - # zero. - self.retry_count = 0 - -class TftpContextServer(TftpContext): - """The context for the server.""" - def __init__(self, host, port, timeout, root, dyn_file_func=None): - TftpContext.__init__(self, - host, - port, - timeout, - dyn_file_func - ) - # At this point we have no idea if this is a download or an upload. We - # need to let the start state determine that. - self.state = TftpStateServerStart(self) - self.root = root - self.dyn_file_func = dyn_file_func - - def __str__(self): - return "%s:%s %s" % (self.host, self.port, self.state) - - def start(self, buffer): - """Start the state cycle. Note that the server context receives an - initial packet in its start method. Also note that the server does not - loop on cycle(), as it expects the TftpServer object to manage - that.""" - log.debug("In TftpContextServer.start") - self.metrics.start_time = time.time() - log.debug("Set metrics.start_time to %s" % self.metrics.start_time) - # And update our last updated time. - self.last_update = time.time() - - pkt = self.factory.parse(buffer) - log.debug("TftpContextServer.start() - factory returned a %s" % pkt) - - # Call handle once with the initial packet. This should put us into - # the download or the upload state. - self.state = self.state.handle(pkt, - self.host, - self.port) - - def end(self): - """Finish up the context.""" - TftpContext.end(self) - self.metrics.end_time = time.time() - log.debug("Set metrics.end_time to %s" % self.metrics.end_time) - self.metrics.compute() - -class TftpContextClientUpload(TftpContext): - """The upload context for the client during an upload. - Note: If input is a hyphen, then we will use stdin.""" - def __init__(self, - host, - port, - filename, - input, - options, - packethook, - timeout): - TftpContext.__init__(self, - host, - port, - timeout) - self.file_to_transfer = filename - self.options = options - self.packethook = packethook - if input == '-': - self.fileobj = sys.stdin - else: - self.fileobj = open(input, "rb") - - log.debug("TftpContextClientUpload.__init__()") - log.debug("file_to_transfer = %s, options = %s" % - (self.file_to_transfer, self.options)) - - def __str__(self): - return "%s:%s %s" % (self.host, self.port, self.state) - - def start(self): - log.info("Sending tftp upload request to %s" % self.host) - log.info(" filename -> %s" % self.file_to_transfer) - log.info(" options -> %s" % self.options) - - self.metrics.start_time = time.time() - log.debug("Set metrics.start_time to %s" % self.metrics.start_time) - - # FIXME: put this in a sendWRQ method? - pkt = TftpPacketWRQ() - pkt.filename = self.file_to_transfer - pkt.mode = "octet" # FIXME - shouldn't hardcode this - pkt.options = self.options - self.sock.sendto(pkt.encode().buffer, (self.host, self.port)) - self.next_block = 1 - self.last_pkt = pkt - # FIXME: should we centralize sendto operations so we can refactor all - # saving of the packet to the last_pkt field? - - self.state = TftpStateSentWRQ(self) - - while self.state: - try: - log.debug("State is %s" % self.state) - self.cycle() - except TftpTimeout, err: - log.error(str(err)) - self.retry_count += 1 - if self.retry_count >= TIMEOUT_RETRIES: - log.debug("hit max retries, giving up") - raise - else: - log.warn("resending last packet") - self.state.resendLast() - - def end(self): - """Finish up the context.""" - TftpContext.end(self) - self.metrics.end_time = time.time() - log.debug("Set metrics.end_time to %s" % self.metrics.end_time) - self.metrics.compute() - -class TftpContextClientDownload(TftpContext): - """The download context for the client during a download. - Note: If output is a hyphen, then the output will be sent to stdout.""" - def __init__(self, - host, - port, - filename, - output, - options, - packethook, - timeout): - TftpContext.__init__(self, - host, - port, - timeout) - # FIXME: should we refactor setting of these params? - self.file_to_transfer = filename - self.options = options - self.packethook = packethook - # FIXME - need to support alternate return formats than files? - # File-like objects would be ideal, ala duck-typing. - # If the filename is -, then use stdout - if output == '-': - self.fileobj = sys.stdout - else: - self.fileobj = open(output, "wb") - - log.debug("TftpContextClientDownload.__init__()") - log.debug("file_to_transfer = %s, options = %s" % - (self.file_to_transfer, self.options)) - - def __str__(self): - return "%s:%s %s" % (self.host, self.port, self.state) - - def start(self): - """Initiate the download.""" - log.info("Sending tftp download request to %s" % self.host) - log.info(" filename -> %s" % self.file_to_transfer) - log.info(" options -> %s" % self.options) - - self.metrics.start_time = time.time() - log.debug("Set metrics.start_time to %s" % self.metrics.start_time) - - # FIXME: put this in a sendRRQ method? - pkt = TftpPacketRRQ() - pkt.filename = self.file_to_transfer - pkt.mode = "octet" # FIXME - shouldn't hardcode this - pkt.options = self.options - self.sock.sendto(pkt.encode().buffer, (self.host, self.port)) - self.next_block = 1 - self.last_pkt = pkt - - self.state = TftpStateSentRRQ(self) - - while self.state: - try: - log.debug("State is %s" % self.state) - self.cycle() - except TftpTimeout, err: - log.error(str(err)) - self.retry_count += 1 - if self.retry_count >= TIMEOUT_RETRIES: - log.debug("hit max retries, giving up") - raise - else: - log.warn("resending last packet") - self.state.resendLast() - - def end(self): - """Finish up the context.""" - TftpContext.end(self) - self.metrics.end_time = time.time() - log.debug("Set metrics.end_time to %s" % self.metrics.end_time) - self.metrics.compute() +import os ############################################################################### # State classes diff --git a/tftpy/__init__.py b/tftpy/__init__.py index b1400df..e8ef87f 100644 --- a/tftpy/__init__.py +++ b/tftpy/__init__.py @@ -20,3 +20,5 @@ from TftpPacketTypes import * from TftpPacketFactory import * from TftpClient import * from TftpServer import * +from TftpContexts import * +from TftpStates import *