diff --git a/tftpy/TftpClient.py b/tftpy/TftpClient.py index 4d9b5f7..da35d05 100644 --- a/tftpy/TftpClient.py +++ b/tftpy/TftpClient.py @@ -1,7 +1,7 @@ import time, types from TftpShared import * from TftpPacketFactory import * -from TftpStates import TftpContextClientDownload +from TftpStates import TftpContextClientDownload, TftpContextClientUpload class TftpClient(TftpSession): """This class is an implementation of a tftp client. Once instantiated, a @@ -63,185 +63,26 @@ class TftpClient(TftpSession): # Open the input file. # FIXME: As of the state machine, this is now broken. Need to # implement with new state machine. - self.fileobj = open(input, "rb") - recvpkt = None - curblock = 0 - start_time = time.time() - self.bytes = 0 + self.context = TftpContextClientUpload(self.host, + self.iport, + filename, + input, + self.options, + packethook, + timeout) + self.context.start() + # Upload happens here + self.context.end() - tftp_factory = TftpPacketFactory() - self.sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) - self.sock.settimeout(timeout) + metrics = self.context.metrics - self.filename = filename - - self.send_wrq() - self.state.state = 'wrq' - - timeouts = 0 - while True: - try: - (buffer, (raddress, rport)) = self.sock.recvfrom(MAX_BLKSIZE) - except socket.timeout, err: - timeouts += 1 - if timeouts >= TIMEOUT_RETRIES: - raise TftpException, "Hit max timeouts, giving up." - else: - if self.state.state == 'dat' or self.state.state == 'fin': - logger.debug("Timing out on DAT. Need to resend.") - self.send_dat(packethook,resend=True) - elif self.state.state == 'wrq': - logger.debug("Timing out on WRQ.") - self.send_wrq(resend=True) - else: - tftpassert(False, - "Timing out in unsupported state %s" % - self.state.state) - continue - - recvpkt = tftp_factory.parse(buffer) - - logger.debug("Received %d bytes from %s:%s" - % (len(buffer), raddress, rport)) - - # Check for known "connection". - if raddress != self.address: - logger.warn("Received traffic from %s, expected host %s. Discarding" - % (raddress, self.host)) - continue - if self.port and self.port != rport: - logger.warn("Received traffic from %s:%s but we're " - "connected to %s:%s. Discarding." - % (raddress, rport, - self.host, self.port)) - continue - - if not self.port and self.state.state == 'wrq': - self.port = rport - logger.debug("Set remote port for session to %s" % rport) - - # Next packet type - if isinstance(recvpkt, TftpPacketACK): - logger.debug("Received an ACK from the server.") - # tftp on wrt54gl seems to answer with an ack to a wrq regardless - # if we sent options. - if recvpkt.blocknumber == 0 and self.state.state in ('oack','wrq'): - logger.debug("Received ACK with 0 blocknumber, starting upload") - self.state.state = 'dat' - self.send_dat(packethook) - else: - if self.state.state == 'dat' or self.state.state == 'fin': - if self.blocknumber == recvpkt.blocknumber: - logger.info("Received ACK for block %d" - % recvpkt.blocknumber) - if self.state.state == 'fin': - break - else: - self.send_dat(packethook) - elif recvpkt.blocknumber < self.blocknumber: - # Don't resend a DAT due to an old ACK. Fixes the - # sorceror's apprentice problem. - logger.warn("Received old ACK for block number %d" - % recvpkt.blocknumber) - else: - logger.warn("Received ACK for block number " - "%d, apparently from the future" - % recvpkt.blocknumber) - else: - logger.error("Received ACK with block number %d " - "while in state %s" - % (recvpkt.blocknumber, - self.state.state)) - - # Check other packet types. - elif isinstance(recvpkt, TftpPacketOACK): - if not self.state.state == 'wrq': - self.errors += 1 - logger.error("Received OACK in state %s" % self.state.state) - continue - - self.state.state = 'oack' - logger.info("Received OACK from server.") - if recvpkt.options.keys() > 0: - if recvpkt.match_options(self.options): - logger.info("Successful negotiation of options") - for key in self.options: - logger.info(" %s = %s" % (key, self.options[key])) - logger.debug("sending ACK to OACK") - ackpkt = TftpPacketACK() - ackpkt.blocknumber = 0 - self.sock.sendto(ackpkt.encode().buffer, (self.host, self.port)) - self.state.state = 'dat' - self.send_dat(packethook) - else: - logger.error("failed to negotiate options") - self.senderror(self.sock, TftpErrors.FailedNegotiation, self.host, self.port) - self.state.state = 'err' - raise TftpException, "Failed to negotiate options" - - elif isinstance(recvpkt, TftpPacketERR): - self.state.state = 'err' - self.senderror(self.sock, TftpErrors.IllegalTftpOp, self.host, self.port) - tftpassert(False, "Received ERR from server: " + str(recvpkt)) - - elif isinstance(recvpkt, TftpPacketWRQ): - self.state.state = 'err' - self.senderror(self.sock, TftpErrors.IllegalTftpOp, self.host, self.port) - tftpassert(False, "Received WRQ from server: " + str(recvpkt)) - - else: - self.state.state = 'err' - self.senderror(self.sock, TftpErrors.IllegalTftpOp, self.host, self.port) - tftpassert(False, "Received unknown packet type from server: " - + str(recvpkt)) - - - # end while - self.fileobj.close() - - end_time = time.time() - duration = end_time - start_time - if duration == 0: + # FIXME: Should we output this? Shouldn't we let the client control + # output? This should be in the sample client, but not in the download + # call. + if metrics.duration == 0: logger.info("Duration too short, rate undetermined") else: logger.info('') - logger.info("Uploaded %d bytes in %d seconds" % (self.bytes, duration)) - bps = (self.bytes * 8.0) / duration - kbps = bps / 1024.0 - logger.info("Average rate: %.2f kbps" % kbps) - - def send_dat(self, packethook, resend=False): - """This method reads and sends a DAT packet based on what is in self.buffer.""" - if not resend: - blksize = int(self.options['blksize']) - self.buffer = self.fileobj.read(blksize) - logger.debug("Read %d bytes into buffer" % len(self.buffer)) - if len(self.buffer) < blksize: - logger.info("Reached EOF on file %s" % self.filename) - self.state.state = 'fin' - self.blocknumber += 1 - if self.blocknumber > 65535: - logger.debug("Blocknumber rolled over to zero") - self.blocknumber = 0 - self.bytes += len(self.buffer) - else: - logger.warn("Resending block number %d" % self.blocknumber) - dat = TftpPacketDAT() - dat.data = self.buffer - dat.blocknumber = self.blocknumber - logger.debug("Sending DAT packet %d" % self.blocknumber) - self.sock.sendto(dat.encode().buffer, (self.host, self.port)) - self.timesent = time.time() - if packethook: - packethook(dat) - - def send_wrq(self, resend=False): - """This method sends a wrq""" - logger.info("Sending tftp upload request to %s" % self.host) - logger.info(" filename -> %s" % self.filename) - - wrq = TftpPacketWRQ() - wrq.filename = self.filename - wrq.mode = "octet" # FIXME - shouldn't hardcode this - wrq.options = self.options - self.sock.sendto(wrq.encode().buffer, (self.host, self.iport)) + logger.info("Downloaded %.2f bytes in %.2f seconds" % (metrics.bytes, metrics.duration)) + logger.info("Average rate: %.2f kbps" % metrics.kbps) + logger.info("Received %d duplicate packets" % metrics.dupcount) \ No newline at end of file diff --git a/tftpy/TftpStates.py b/tftpy/TftpStates.py index 398f137..88c4fa1 100644 --- a/tftpy/TftpStates.py +++ b/tftpy/TftpStates.py @@ -56,17 +56,17 @@ class TftpContext(object): def end(self): return NotImplementedError, "Abstract method" - + def gethost(self): "Simple getter method for use in a property." return self.__host - + def sethost(self, host): """Setter method that also sets the address property as a result of the host that is set.""" self.__host = host self.address = socket.gethostbyname(host) - + host = property(gethost, sethost) def sendAck(self, blocknumber): @@ -76,83 +76,37 @@ class TftpContext(object): ackpkt.blocknumber = blocknumber self.sock.sendto(ackpkt.encode().buffer, (self.host, self.tidport)) - def senderror(self, errorcode): + def sendError(self, errorcode): """This method uses the socket passed, and uses the errorcode to compose and send an error packet.""" - logger.debug("In senderror, being asked to send error %d" % errorcode) + logger.debug("In sendError, being asked to send error %d" % errorcode) errpkt = TftpPacketERR() errpkt.errorcode = errorcode - sock.sendto(errpkt.encode().buffer, (self.host, self.tidport)) + self.sock.sendto(errpkt.encode().buffer, (self.host, self.tidport)) -class TftpContextServerDownload(TftpContext): - """The download context for the server during a download.""" - pass - -class TftpContextClientDownload(TftpContext): - """The download context for the client during a download.""" - def __init__(self, host, port, filename, output, options, packethook, timeout): +class TftpContextClient(TftpContext): + """This class represents shared functionality by both the download and + upload client contexts.""" + def __init__(self, host, port, filename, options, packethook, timeout): TftpContext.__init__(self, host, port) - # FIXME - need to support alternate return formats than files? - # File-like objects would be ideal, ala duck-typing. - self.requested_file = filename - self.fileobj = open(output, "wb") + self.file_to_transfer = filename self.options = options self.packethook = packethook self.sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) self.sock.settimeout(timeout) - self.state = None - self.expected_block = 0 + self.next_block = 0 - ############################ - # Logging - ############################ - logger.debug("TftpContextClientDownload.__init__()") - logger.debug("requested_file = %s, options = %s" % - (self.requested_file, self.options)) - - def setExpectedBlock(self, block): + def setNextBlock(self, block): if block > 2 ** 16: logger.debug("block number rollover to 0 again") block = 0 self.__eblock = block - def getExpectedBlock(self): + def getNextBlock(self): return self.__eblock - expected_block = property(getExpectedBlock, setExpectedBlock) - - def start(self): - """Initiate the download.""" - logger.info("sending tftp download request to %s" % self.host) - logger.info(" filename -> %s" % self.requested_file) - logger.info(" options -> %s" % self.options) - - self.metrics.start_time = time.time() - logger.debug("set metrics.start_time to %s" % self.metrics.start_time) - - # FIXME: put this in a sendRRQ method? - pkt = TftpPacketRRQ() - pkt.filename = self.requested_file - pkt.mode = "octet" # FIXME - shouldn't hardcode this - pkt.options = self.options - self.sock.sendto(pkt.encode().buffer, (self.host, self.port)) - self.expected_block = 1 - - self.state = TftpStateSentRRQ(self) - - try: - while self.state: - logger.debug("state is %s" % self.state) - self.cycle() - finally: - self.fileobj.close() - - def end(self): - """Finish up the context.""" - self.metrics.end_time = time.time() - logger.debug("set metrics.end_time to %s" % self.metrics.end_time) - self.metrics.compute() + next_block = property(getNextBlock, setNextBlock) def cycle(self): """Here we wait for a response from the server after sending it @@ -169,7 +123,7 @@ class TftpContextClientDownload(TftpContext): raise TftpException, "Hit max timeouts, giving up." # Ok, we've received a packet. Log it. - logger.debug("Received %d bytes from %s:%s" + logger.debug("Received %d bytes from %s:%s" % (len(buffer), raddress, rport)) # Decode it. @@ -196,6 +150,101 @@ class TftpContextClientDownload(TftpContext): # And handle it, possibly changing state. self.state = self.state.handle(recvpkt, raddress, rport) +class TftpContextClientUpload(TftpContextClient): + """The upload context for the client during an upload.""" + def __init__(self, host, port, filename, input, options, packethook, timeout): + TftpContextClient.__init__(self, + host, + port, + filename, + options, + packethook, + timeout) + self.fileobj = open(input, "wb") + + logger.debug("TftpContextClientUpload.__init__()") + logger.debug("file_to_transfer = %s, options = %s" % + (self.file_to_transfer, self.options)) + + def start(self): + logger.info("sending tftp upload request to %s" % self.host) + logger.info(" filename -> %s" % self.file_to_transfer) + logger.info(" options -> %s" % self.options) + + self.metrics.start_time = time.time() + logger.debug("set metrics.start_time to %s" % self.metrics.start_time) + + # FIXME: put this in a sendWRQ method? + pkt = TftpPacketWRQ() + pkt.filename = self.file_to_transfer + pkt.mode = "octet" # FIXME - shouldn't hardcode this + pkt.options = self.options + self.sock.sendto(pkt.encode().buffer, (self.host, self.port)) + self.next_block = 1 + + self.state = TftpStateSentWRQ(self) + + try: + while self.state: + logger.debug("state is %s" % self.state) + self.cycle() + finally: + self.fileobj.close() + + def end(self): + pass + +class TftpContextClientDownload(TftpContextClient): + """The download context for the client during a download.""" + def __init__(self, host, port, filename, output, options, packethook, timeout): + TftpContextClient.__init__(self, + host, + port, + filename, + options, + packethook, + timeout) + # FIXME - need to support alternate return formats than files? + # File-like objects would be ideal, ala duck-typing. + self.fileobj = open(output, "wb") + + logger.debug("TftpContextClientDownload.__init__()") + logger.debug("file_to_transfer = %s, options = %s" % + (self.file_to_transfer, self.options)) + + def start(self): + """Initiate the download.""" + logger.info("sending tftp download request to %s" % self.host) + logger.info(" filename -> %s" % self.file_to_transfer) + logger.info(" options -> %s" % self.options) + + self.metrics.start_time = time.time() + logger.debug("set metrics.start_time to %s" % self.metrics.start_time) + + # FIXME: put this in a sendRRQ method? + pkt = TftpPacketRRQ() + pkt.filename = self.file_to_transfer + pkt.mode = "octet" # FIXME - shouldn't hardcode this + pkt.options = self.options + self.sock.sendto(pkt.encode().buffer, (self.host, self.port)) + self.next_block = 1 + + self.state = TftpStateSentRRQ(self) + + try: + while self.state: + logger.debug("state is %s" % self.state) + self.cycle() + finally: + self.fileobj.close() + + def end(self): + """Finish up the context.""" + self.metrics.end_time = time.time() + logger.debug("set metrics.end_time to %s" % self.metrics.end_time) + self.metrics.compute() + + ############################################################################### # State classes ############################################################################### @@ -214,20 +263,62 @@ class TftpState(object): a TftpState object, either itself or a new state.""" raise NotImplementedError, "Abstract method" + def handleOACK(self, pkt): + """This method handles an OACK from the server, syncing any accepted + options.""" + if pkt.options.keys() > 0: + if pkt.match_options(self.context.options): + logger.info("Successful negotiation of options") + # Set options to OACK options + self.context.options = pkt.options + for key in self.context.options: + logger.info(" %s = %s" % (key, self.context.options[key])) + else: + logger.error("failed to negotiate options") + raise TftpException, "Failed to negotiate options" + else: + raise TftpException, "No options found in OACK" + +class TftpStateUpload(TftpState): + """A class holding common code for upload states.""" + def sendDat(self, resend=False): + finished = False + blocknumber = self.context.next_block + if not resend: + blksize = int(self.context.options['blksize']) + buffer = self.context.fileobj.read(blksize) + logger.debug("Read %d bytes into buffer" % len(buffer)) + if len(buffer) < blksize: + logger.info("Reached EOF on file %s" % self.context.input) + finished = True + self.context.next_block += 1 + self.bytes += len(buffer) + else: + logger.warn("Resending block number %d" % blocknumber) + dat = TftpPacketDAT() + dat.data = buffer + dat.blocknumber = blocknumber + logger.debug("Sending DAT packet %d" % blocknumber) + self.context.sock.sendto(dat.encode().buffer, + (self.context.host, self.context.port)) + if self.context.packethook: + self.context.packethook(dat) + return finished + class TftpStateDownload(TftpState): """A class holding common code for download states.""" def handleDat(self, pkt): """This method handles a DAT packet during a download.""" logger.info("handling DAT packet - block %d" % pkt.blocknumber) - logger.debug("expecting block %s" % self.context.expected_block) - if pkt.blocknumber == self.context.expected_block: - logger.debug("good, received block %d in sequence" + logger.debug("expecting block %s" % self.context.next_block) + if pkt.blocknumber == self.context.next_block: + logger.debug("good, received block %d in sequence" % pkt.blocknumber) - - self.context.sendAck(pkt.blocknumber) - self.context.expected_block += 1 - logger.debug("writing %d bytes to output file" + self.context.sendAck(pkt.blocknumber) + self.context.next_block += 1 + + logger.debug("writing %d bytes to output file" % len(pkt.data)) self.context.fileobj.write(pkt.data) self.context.metrics.bytes += len(pkt.data) @@ -236,7 +327,7 @@ class TftpStateDownload(TftpState): logger.info("end of file detected") return None - elif pkt.blocknumber < self.context.expected_block: + elif pkt.blocknumber < self.context.next_block: logger.warn("dropping duplicate block %d" % pkt.blocknumber) if self.context.metrics.dups.has_key(pkt.blocknumber): self.context.metrics.dups[pkt.blocknumber] += 1 @@ -251,16 +342,87 @@ class TftpStateDownload(TftpState): else: # FIXME: should we be more tolerant and just discard instead? msg = "Whoa! Received future block %d but expected %d" \ - % (pkt.blocknumber, self.context.expected_block) + % (pkt.blocknumber, self.context.next_block) logger.error(msg) raise TftpException, msg # Default is to ack return TftpStateSentACK(self.context) +class TftpStateSentWRQ(TftpStateUpload): + """Just sent an WRQ packet for an upload.""" + def handle(self, pkt, raddress, rport): + """Handle a packet we just received.""" + if not self.context.tidport: + self.context.tidport = rport + logger.debug("Set remote port for session to %s" % rport) + + # If we're going to successfully transfer the file, then we should see + # either an OACK for accepted options, or an ACK to ignore options. + if isinstance(pkt, TftpPacketOACK): + logger.info("received OACK from server") + try: + self.handleOACK(pkt) + except TftpException, err: + logger.error("failed to negotiate options") + self.context.sendError(TftpErrors.FailedNegotiation) + raise + else: + logger.debug("sending first DAT packet") + fin = self.context.sendDat() + if fin: + logger.info("Add done") + return None + else: + logger.debug("Changing state to TftpStateSentDAT") + return TftpStateSentDAT(self.context) + + elif isinstance(pkt, TftpPacketACK): + logger.info("received ACK from server") + logger.debug("apparently the server ignored our options") + # The block number should be zero. + if pkt.blocknumber == 0: + logger.debug("ack blocknumber is zero as expected") + logger.debug("sending first DAT packet") + fin = self.context.sendDat() + if fin: + logger.info("Add done") + return None + else: + logger.debug("Changing state to TftpStateSentDAT") + return TftpStateSentDAT(self.context) + else: + logger.warn("discarding ACK to block %s" % pkt.blocknumber) + logger.debug("still waiting for valid response from server") + return self + + elif isinstance(pkt, TftpPacketERR): + self.context.sendError(TftpErrors.IllegalTftpOp) + raise TftpException, "Received ERR from server: " + str(pkt) + + elif isinstance(pkt, TftpPacketRRQ): + self.context.sendError(TftpErrors.IllegalTftpOp) + raise TftpException, "Received RRQ from server while in upload" + + elif isinstance(pkt, TftpPacketDAT): + self.context.sendError(TftpErrors.IllegalTftpOp) + raise TftpException, "Received DAT from server while in upload" + + else: + self.context.sendError(TftpErrors.IllegalTftpOp) + raise TftpException, "Received unknown packet type from server: " + str(pkt) + + # By default, no state change. + return self + +class TftpStateSentDAT(TftpStateUpload): + """This class represents the state of the transfer when a DAT was just + sent, and we are waiting for an ACK from the server. This class is the + same one used by the client during the upload, and the server during the + download.""" + class TftpStateSentRRQ(TftpStateDownload): """Just sent an RRQ packet.""" - def handle(self, pkt, raddress, rport): """Handle the packet in response to an RRQ to the server.""" if not self.context.tidport: @@ -269,24 +431,20 @@ class TftpStateSentRRQ(TftpStateDownload): # Now check the packet type and dispatch it properly. if isinstance(pkt, TftpPacketOACK): - logger.info("received OACK from server.") - if pkt.options.keys() > 0: - if pkt.match_options(self.context.options): - logger.info("Successful negotiation of options") - # Set options to OACK options - self.context.options = pkt.options - for key in self.context.options: - logger.info(" %s = %s" % (key, self.context.options[key])) - logger.debug("sending ACK to OACK") + logger.info("received OACK from server") + try: + self.handleOACK(pkt) + except TftpException, err: + logger.error("failed to negotiate options: %s" % str(err)) + self.context.sendError(TftpErrors.FailedNegotiation) + raise + else: + logger.debug("sending ACK to OACK") - self.context.sendAck(blocknumber=0) + self.context.sendAck(blocknumber=0) - logger.debug("Changing state to TftpStateSentACK") - return TftpStateSentACK(self.context) - else: - logger.error("failed to negotiate options") - self.senderror(self.sock, TftpErrors.FailedNegotiation, self.host, self.port) - raise TftpException, "Failed to negotiate options" + logger.debug("Changing state to TftpStateSentACK") + return TftpStateSentACK(self.context) elif isinstance(pkt, TftpPacketDAT): # If there are any options set, then the server didn't honour any @@ -300,19 +458,19 @@ class TftpStateSentRRQ(TftpStateDownload): # Every other packet type is a problem. elif isinstance(recvpkt, TftpPacketACK): # Umm, we ACK, the server doesn't. - self.senderror(self.sock, TftpErrors.IllegalTftpOp) + self.context.sendError(TftpErrors.IllegalTftpOp) raise TftpException, "Received ACK from server while in download" elif isinstance(recvpkt, TftpPacketWRQ): - self.senderror(self.sock, TftpErrors.IllegalTftpOp) + self.context.sendError(TftpErrors.IllegalTftpOp) raise TftpException, "Received WRQ from server while in download" elif isinstance(recvpkt, TftpPacketERR): - self.senderror(self.sock, TftpErrors.IllegalTftpOp) + self.context.sendError(TftpErrors.IllegalTftpOp) raise TftpException, "Received ERR from server: " + str(recvpkt) else: - self.senderror(self.sock, TftpErrors.IllegalTftpOp) + self.context.sendError(TftpErrors.IllegalTftpOp) raise TftpException, "Received unknown packet type from server: " + str(recvpkt) # By default, no state change. @@ -328,17 +486,17 @@ class TftpStateSentACK(TftpStateDownload): # Every other packet type is a problem. elif isinstance(recvpkt, TftpPacketACK): # Umm, we ACK, the server doesn't. - self.senderror(self.sock, TftpErrors.IllegalTftpOp) + self.context.sendError(TftpErrors.IllegalTftpOp) raise TftpException, "Received ACK from server while in download" elif isinstance(recvpkt, TftpPacketWRQ): - self.senderror(self.sock, TftpErrors.IllegalTftpOp) + self.context.sendError(TftpErrors.IllegalTftpOp) raise TftpException, "Received WRQ from server while in download" elif isinstance(recvpkt, TftpPacketERR): - self.senderror(self.sock, TftpErrors.IllegalTftpOp) + self.context.sendError(TftpErrors.IllegalTftpOp) raise TftpException, "Received ERR from server: " + str(recvpkt) else: - self.senderror(self.sock, TftpErrors.IllegalTftpOp) + self.context.sendError(TftpErrors.IllegalTftpOp) raise TftpException, "Received unknown packet type from server: " + str(recvpkt)