diff --git a/tftpy/TftpStates.py b/tftpy/TftpStates.py index c9f20b5..3098e45 100644 --- a/tftpy/TftpStates.py +++ b/tftpy/TftpStates.py @@ -24,9 +24,6 @@ class TftpState(object): file object is required, since in tftp there's always a file involved.""" self.context = context - # This variable is used to store the absolute path to the file being - # managed. Currently only used by the server. - self.full_path = None def handle(self, pkt, raddress, rport): """An abstract method for handling a packet. It is expected to return @@ -76,64 +73,6 @@ class TftpState(object): log.debug("Returning these accepted options: %s" % accepted_options) return accepted_options - def serverInitial(self, pkt, raddress, rport): - """This method performs initial setup for a server context transfer, - put here to refactor code out of the TftpStateServerRecvRRQ and - TftpStateServerRecvWRQ classes, since their initial setup is - identical. The method returns a boolean, sendoack, to indicate whether - it is required to send an OACK to the client.""" - options = pkt.options - sendoack = False - if not self.context.tidport: - self.context.tidport = rport - log.info("Setting tidport to %s" % rport) - - log.debug("Setting default options, blksize") - self.context.options = { 'blksize': DEF_BLKSIZE } - - if options: - log.debug("Options requested: %s" % options) - supported_options = self.returnSupportedOptions(options) - self.context.options.update(supported_options) - sendoack = True - - # FIXME - only octet mode is supported at this time. - if pkt.mode != 'octet': - self.sendError(TftpErrors.IllegalTftpOp) - raise TftpException, \ - "Only octet transfers are supported at this time." - - # test host/port of client end - if self.context.host != raddress or self.context.port != rport: - self.sendError(TftpErrors.UnknownTID) - log.error("Expected traffic from %s:%s but received it " - "from %s:%s instead." - % (self.context.host, - self.context.port, - raddress, - rport)) - # FIXME: increment an error count? - # Return same state, we're still waiting for valid traffic. - return self - - log.debug("Requested filename is %s" % pkt.filename) - - # Make sure that the path to the file is contained in the server's - # root directory. - full_path = os.path.join(self.context.root, pkt.filename) - self.full_path = os.path.abspath(full_path) - log.debug("full_path is %s" % full_path) - if self.context.root == full_path[:len(self.context.root)]: - log.info("requested file is in the server root - good") - else: - log.warn("requested file is not within the server root - bad") - self.sendError(TftpErrors.IllegalTftpOp) - raise TftpException, "bad file path" - - self.context.file_to_transfer = pkt.filename - - return sendoack - def sendDAT(self): """This method sends the next DAT packet based on the data in the context. It returns a boolean indicating whether the transfer is @@ -261,7 +200,76 @@ class TftpState(object): # Default is to ack return TftpStateExpectDAT(self.context) -class TftpStateServerRecvRRQ(TftpState): +class TftpServerState(TftpState): + """The base class for server states.""" + + def __init__(self, context): + TftpState.__init__(self, context) + + # This variable is used to store the absolute path to the file being + # managed. + self.full_path = None + + def serverInitial(self, pkt, raddress, rport): + """This method performs initial setup for a server context transfer, + put here to refactor code out of the TftpStateServerRecvRRQ and + TftpStateServerRecvWRQ classes, since their initial setup is + identical. The method returns a boolean, sendoack, to indicate whether + it is required to send an OACK to the client.""" + options = pkt.options + sendoack = False + if not self.context.tidport: + self.context.tidport = rport + log.info("Setting tidport to %s" % rport) + + log.debug("Setting default options, blksize") + self.context.options = { 'blksize': DEF_BLKSIZE } + + if options: + log.debug("Options requested: %s" % options) + supported_options = self.returnSupportedOptions(options) + self.context.options.update(supported_options) + sendoack = True + + # FIXME - only octet mode is supported at this time. + if pkt.mode != 'octet': + self.sendError(TftpErrors.IllegalTftpOp) + raise TftpException, \ + "Only octet transfers are supported at this time." + + # test host/port of client end + if self.context.host != raddress or self.context.port != rport: + self.sendError(TftpErrors.UnknownTID) + log.error("Expected traffic from %s:%s but received it " + "from %s:%s instead." + % (self.context.host, + self.context.port, + raddress, + rport)) + # FIXME: increment an error count? + # Return same state, we're still waiting for valid traffic. + return self + + log.debug("Requested filename is %s" % pkt.filename) + + # Make sure that the path to the file is contained in the server's + # root directory. + full_path = os.path.join(self.context.root, pkt.filename) + self.full_path = os.path.abspath(full_path) + log.debug("full_path is %s" % full_path) + if self.context.root == full_path[:len(self.context.root)]: + log.info("requested file is in the server root - good") + else: + log.warn("requested file is not within the server root - bad") + self.sendError(TftpErrors.IllegalTftpOp) + raise TftpException, "bad file path" + + self.context.file_to_transfer = pkt.filename + + return sendoack + + +class TftpStateServerRecvRRQ(TftpServerState): """This class represents the state of the TFTP server when it has just received an RRQ packet.""" def handle(self, pkt, raddress, rport): @@ -306,7 +314,7 @@ class TftpStateServerRecvRRQ(TftpState): # Note, we don't have to check any other states in this method, that's # up to the caller. -class TftpStateServerRecvWRQ(TftpState): +class TftpStateServerRecvWRQ(TftpServerState): """This class represents the state of the TFTP server when it has just received a WRQ packet.""" def make_subdirs(self):