First successful download with both client and server.

master
Michael P. Soulier 2009-08-16 19:44:57 -04:00
parent 62b22fb562
commit a6a18c178b
7 changed files with 120 additions and 70 deletions

2
.gitignore vendored
View File

@ -1 +1,3 @@
*.pyc *.pyc
*.swp
tags

View File

@ -21,7 +21,7 @@ def main():
'--upload', '--upload',
help='filename to upload') help='filename to upload')
parser.add_option('-b', parser.add_option('-b',
'--blocksize', '--blksize',
help='udp packet size to use (default: 512)', help='udp packet size to use (default: 512)',
default=512) default=512)
parser.add_option('-o', parser.add_option('-o',
@ -76,11 +76,11 @@ def main():
else: else:
tftpy.setLogLevel(logging.INFO) tftpy.setLogLevel(logging.INFO)
progresshook = Progress(tftpy.logger.info).progresshook progresshook = Progress(tftpy.log.info).progresshook
tftp_options = {} tftp_options = {}
if options.blocksize: if options.blksize:
tftp_options['blksize'] = int(options.blocksize) tftp_options['blksize'] = int(options.blksize)
if options.tsize: if options.tsize:
tftp_options['tsize'] = 0 tftp_options['tsize'] = 0
@ -103,6 +103,8 @@ def main():
except tftpy.TftpException, err: except tftpy.TftpException, err:
sys.stderr.write("%s\n" % str(err)) sys.stderr.write("%s\n" % str(err))
sys.exit(1) sys.exit(1)
except KeyboardInterrupt:
pass
if __name__ == '__main__': if __name__ == '__main__':
main() main()

View File

@ -20,8 +20,8 @@ def main():
parser.add_option('-r', parser.add_option('-r',
'--root', '--root',
type='string', type='string',
help='path to serve from (default: /tftpboot)', help='path to serve from',
default="/tftpboot") default=None)
parser.add_option('-d', parser.add_option('-d',
'--debug', '--debug',
action='store_true', action='store_true',
@ -34,9 +34,16 @@ def main():
else: else:
tftpy.setLogLevel(logging.INFO) tftpy.setLogLevel(logging.INFO)
if not options.root:
parser.print_help()
sys.exit(1)
server = tftpy.TftpServer(options.root) server = tftpy.TftpServer(options.root)
try: try:
server.listen(options.ip, options.port) server.listen(options.ip, options.port)
except tftpy.TftpException, err:
sys.stderr.write("%s\n" % str(err))
sys.exit(1)
except KeyboardInterrupt: except KeyboardInterrupt:
pass pass

View File

@ -35,6 +35,11 @@ class TftpClient(TftpSession):
SOCK_TIMEOUT setting, which is the amount of time that the client will SOCK_TIMEOUT setting, which is the amount of time that the client will
wait for a receive packet to arrive.""" wait for a receive packet to arrive."""
# We're downloading. # We're downloading.
log.debug("Creating download context with the following params:")
log.debug("host = %s, port = %s, filename = %s, output = %s"
% (self.host, self.iport, filename, output))
log.debug("options = %s, packethook = %s, timeout = %s"
% (self.options, packethook, timeout))
self.context = TftpContextClientDownload(self.host, self.context = TftpContextClientDownload(self.host,
self.iport, self.iport,
filename, filename,

View File

@ -4,22 +4,8 @@ from TftpShared import *
class TftpSession(object): class TftpSession(object):
"""This class is the base class for the tftp client and server. Any shared """This class is the base class for the tftp client and server. Any shared
code should be in this class.""" code should be in this class."""
# FIXME: do we need this anymore?
def __init__(self): pass
"""Class constructor."""
self.options = None
self.state = None
self.dups = 0
self.errors = 0
def senderror(self, sock, errorcode, address, port):
"""This method uses the socket passed, and uses the errorcode, address
and port to compose and send an error packet."""
log.debug("In senderror, being asked to send error %d to %s:%s"
% (errorcode, address, port))
errpkt = TftpPacketERR()
errpkt.errorcode = errorcode
sock.sendto(errpkt.encode().buffer, (address, port))
class TftpPacketWithOptions(object): class TftpPacketWithOptions(object):
"""This class exists to permit some TftpPacket subclasses to share code """This class exists to permit some TftpPacket subclasses to share code

View File

@ -3,6 +3,7 @@ import select
from TftpShared import * from TftpShared import *
from TftpPacketTypes import * from TftpPacketTypes import *
from TftpPacketFactory import * from TftpPacketFactory import *
from TftpStates import *
class TftpServer(TftpSession): class TftpServer(TftpSession):
"""This class implements a tftp server object.""" """This class implements a tftp server object."""
@ -19,9 +20,9 @@ class TftpServer(TftpSession):
# FIXME: What about multiple roots? # FIXME: What about multiple roots?
self.root = os.path.abspath(tftproot) self.root = os.path.abspath(tftproot)
self.dyn_file_func = dyn_file_func self.dyn_file_func = dyn_file_func
# A dict of handlers, where each session is keyed by a string like # A dict of sessions, where each session is keyed by a string like
# ip:tid for the remote end. # ip:tid for the remote end.
self.handlers = {} self.sessions = {}
if os.path.exists(self.root): if os.path.exists(self.root):
log.debug("tftproot %s does exist" % self.root) log.debug("tftproot %s does exist" % self.root)
@ -89,8 +90,8 @@ class TftpServer(TftpSession):
log.debug("Read %d bytes" % len(buffer)) log.debug("Read %d bytes" % len(buffer))
recvpkt = tftp_factory.parse(buffer) recvpkt = tftp_factory.parse(buffer)
# FIXME: Is this the best way to do a session key? What # Forge a session key based on the client's IP and port,
# about symmetric udp? # which should safely work through NAT.
key = "%s:%s" % (raddress, rport) key = "%s:%s" % (raddress, rport)
if not self.sessions.has_key(key): if not self.sessions.has_key(key):
@ -108,36 +109,37 @@ class TftpServer(TftpSession):
else: else:
# Must find the owner of this traffic. # Must find the owner of this traffic.
for key in self.session: for key in self.sessions:
if readysock == self.session[key].sock: if readysock == self.sessions[key].sock:
log.info("Matched input to session key %s"
% key)
try: try:
self.session[key].cycle() self.sessions[key].cycle()
if self.session[key].state == None: if self.sessions[key].state == None:
log.info("Successful transfer.") log.info("Successful transfer.")
deletion_list.append(key) deletion_list.append(key)
break
except TftpException, err: except TftpException, err:
deletion_list.append(key) deletion_list.append(key)
log.error("Fatal exception thrown from " log.error("Fatal exception thrown from "
"handler: %s" % str(err)) "session %s: %s"
% (key, str(err)))
break
else: else:
log.error("Can't find the owner for this packet. " log.error("Can't find the owner for this packet. "
"Discarding.") "Discarding.")
log.debug("Looping on all handlers to check for timeouts") log.debug("Looping on all sessions to check for timeouts")
now = time.time() now = time.time()
for key in self.sessions: for key in self.sessions:
try: try:
self.sessions[key].checkTimeout(now) self.sessions[key].checkTimeout(now)
except TftpException, err: except TftpException, err:
log.error("Fatal exception thrown from handler: %s" log.error(str(err))
% str(err))
deletion_list.append(key) deletion_list.append(key)
log.debug("Iterating deletion list.") log.debug("Iterating deletion list.")
for key in deletion_list: for key in deletion_list:
if self.sessions.has_key(key): if self.sessions.has_key(key):
log.debug("Deleting handler %s" % key) log.debug("Deleting session %s" % key)
del self.sessions[key] del self.sessions[key]
deletion_list = []

View File

@ -1,7 +1,7 @@
from TftpShared import * from TftpShared import *
from TftpPacketTypes import * from TftpPacketTypes import *
from TftpPacketFactory import * from TftpPacketFactory import *
import socket, time import socket, time, os
############################################################################### ###############################################################################
# Utility classes # Utility classes
@ -39,10 +39,10 @@ class TftpMetrics(object):
"""This method adds a dup for a block number to the metrics.""" """This method adds a dup for a block number to the metrics."""
log.debug("Recording a dup for block %d" % blocknumber) log.debug("Recording a dup for block %d" % blocknumber)
if self.dups.has_key(blocknumber): if self.dups.has_key(blocknumber):
self.dups[pkt.blocknumber] += 1 self.dups[blocknumber] += 1
else: else:
self.dups[pkt.blocknumber] = 1 self.dups[blocknumber] = 1
tftpassert(self.dups[pkt.blocknumber] < MAX_DUPS, tftpassert(self.dups[blocknumber] < MAX_DUPS,
"Max duplicates for block %d reached" % blocknumber) "Max duplicates for block %d reached" % blocknumber)
############################################################################### ###############################################################################
@ -73,10 +73,14 @@ class TftpContext(object):
self.metrics = TftpMetrics() self.metrics = TftpMetrics()
# Flag when the transfer is pending completion. # Flag when the transfer is pending completion.
self.pending_complete = False self.pending_complete = False
# Time when this context last received any traffic.
self.last_update = 0
def checkTimeout(self, now): def checkTimeout(self, now):
# FIXME """Compare current time with last_update time, and raise an exception
pass if we're over SOCK_TIMEOUT time."""
if now - self.last_update > SOCK_TIMEOUT:
raise TftpException, "Timeout waiting for traffic"
def start(self): def start(self):
return NotImplementedError, "Abstract method" return NotImplementedError, "Abstract method"
@ -126,6 +130,8 @@ class TftpContext(object):
# Ok, we've received a packet. Log it. # Ok, we've received a packet. Log it.
log.debug("Received %d bytes from %s:%s" log.debug("Received %d bytes from %s:%s"
% (len(buffer), raddress, rport)) % (len(buffer), raddress, rport))
# And update our last updated time.
self.last_update = time.time()
# Decode it. # Decode it.
recvpkt = self.factory.parse(buffer) recvpkt = self.factory.parse(buffer)
@ -160,17 +166,23 @@ class TftpContextServer(TftpContext):
timeout) timeout)
# At this point we have no idea if this is a download or an upload. We # At this point we have no idea if this is a download or an upload. We
# need to let the start state determine that. # need to let the start state determine that.
self.state = TftpStateServerStart() self.state = TftpStateServerStart(self)
self.root = root self.root = root
self.dyn_file_func = dyn_file_func self.dyn_file_func = dyn_file_func
# In a server, the tidport is the same as the port. This is also true
# with symmetric UDP, which we haven't implemented yet.
self.tidport = port
def start(self, buffer): def start(self, buffer):
"""Start the state cycle. Note that the server context receives an """Start the state cycle. Note that the server context receives an
initial packet in its start method.""" initial packet in its start method. Also note that the server does not
log.debug("TftpContextServer.start() - pkt = %s" % pkt) loop on cycle(), as it expects the TftpServer object to manage
that."""
log.debug("In TftpContextServer.start")
self.metrics.start_time = time.time() self.metrics.start_time = time.time()
log.debug("set metrics.start_time to %s" % self.metrics.start_time) log.debug("set metrics.start_time to %s" % self.metrics.start_time)
# And update our last updated time.
self.last_update = time.time()
pkt = self.factory.parse(buffer) pkt = self.factory.parse(buffer)
log.debug("TftpContextServer.start() - factory returned a %s" % pkt) log.debug("TftpContextServer.start() - factory returned a %s" % pkt)
@ -181,16 +193,19 @@ class TftpContextServer(TftpContext):
self.host, self.host,
self.port) self.port)
try: # FIXME
while self.state: # How do we ensure that the server closes files, even on error?
log.debug("state is %s" % self.state)
self.cycle()
finally:
self.fileobj.close()
class TftpContextClientUpload(TftpContext): class TftpContextClientUpload(TftpContext):
"""The upload context for the client during an upload.""" """The upload context for the client during an upload."""
def __init__(self, host, port, filename, input, options, packethook, timeout): def __init__(self,
host,
port,
filename,
input,
options,
packethook,
timeout):
TftpContext.__init__(self, TftpContext.__init__(self,
host, host,
port, port,
@ -234,14 +249,22 @@ class TftpContextClientUpload(TftpContext):
class TftpContextClientDownload(TftpContext): class TftpContextClientDownload(TftpContext):
"""The download context for the client during a download.""" """The download context for the client during a download."""
def __init__(self, host, port, filename, output, options, packethook, timeout): def __init__(self,
host,
port,
filename,
output,
options,
packethook,
timeout):
TftpContext.__init__(self, TftpContext.__init__(self,
host, host,
port, port,
filename,
options,
packethook,
timeout) timeout)
# FIXME: should we refactor setting of these params?
self.file_to_transfer = filename
self.options = options
self.packethook = packethook
# FIXME - need to support alternate return formats than files? # FIXME - need to support alternate return formats than files?
# File-like objects would be ideal, ala duck-typing. # File-like objects would be ideal, ala duck-typing.
self.fileobj = open(output, "wb") self.fileobj = open(output, "wb")
@ -327,12 +350,21 @@ class TftpState(object):
if option == 'blksize': if option == 'blksize':
# Make sure it's valid. # Make sure it's valid.
if int(options[option]) > MAX_BLKSIZE: if int(options[option]) > MAX_BLKSIZE:
log.info("Client requested blksize greater than %d "
"setting to maximum" % MAX_BLKSIZE)
accepted_options[option] = MAX_BLKSIZE accepted_options[option] = MAX_BLKSIZE
elif option == 'tsize': elif int(options[option]) < MIN_BLKSIZE:
log.debug("tsize option is set") log.info("Client requested blksize less than %d "
accepted_options['tsize'] = 1 "setting to minimum" % MIN_BLKSIZE)
accepted_options[option] = MIN_BLKSIZE
else: else:
log.info("Dropping unsupported option '%s'" % option) accepted_options[option] = options[option]
elif option == 'tsize':
log.debug("tsize option is set")
accepted_options['tsize'] = 1
else:
log.info("Dropping unsupported option '%s'" % option)
log.debug("returning these accepted options: %s" % accepted_options)
return accepted_options return accepted_options
def serverInitial(self, pkt, raddress, rport): def serverInitial(self, pkt, raddress, rport):
@ -388,15 +420,16 @@ class TftpState(object):
finished.""" finished."""
finished = False finished = False
blocknumber = self.context.next_block blocknumber = self.context.next_block
tftpassert( blocknumber > 0, "There is no block zero!" )
if not resend: if not resend:
blksize = int(self.context.options['blksize']) blksize = int(self.context.options['blksize'])
buffer = self.context.fileobj.read(blksize) buffer = self.context.fileobj.read(blksize)
log.debug("Read %d bytes into buffer" % len(buffer)) log.debug("Read %d bytes into buffer" % len(buffer))
if len(buffer) < blksize: if len(buffer) < blksize:
log.info("Reached EOF on file %s" % self.context.input) log.info("Reached EOF on file %s"
% self.context.file_to_transfer)
finished = True finished = True
self.context.next_block += 1 self.context.metrics.bytes += len(buffer)
self.bytes += len(buffer)
else: else:
log.warn("Resending block number %d" % blocknumber) log.warn("Resending block number %d" % blocknumber)
dat = TftpPacketDAT() dat = TftpPacketDAT()
@ -413,7 +446,8 @@ class TftpState(object):
"""This method sends an ack packet to the block number specified. If """This method sends an ack packet to the block number specified. If
none is specified, it defaults to the next_block property in the none is specified, it defaults to the next_block property in the
parent context.""" parent context."""
if not blocknumber: log.debug("in sendACK, blocknumber is %s" % blocknumber)
if blocknumber is None:
blocknumber = self.context.next_block blocknumber = self.context.next_block
log.info("sending ack to block %d" % blocknumber) log.info("sending ack to block %d" % blocknumber)
ackpkt = TftpPacketACK() ackpkt = TftpPacketACK()
@ -435,9 +469,9 @@ class TftpState(object):
def sendOACK(self): def sendOACK(self):
"""This method sends an OACK packet with the options from the current """This method sends an OACK packet with the options from the current
context.""" context."""
log.debug("In sendOACK with options %s" % options) log.debug("In sendOACK with options %s" % self.context.options)
pkt = TftpPacketOACK() pkt = TftpPacketOACK()
pkt.options = self.options pkt.options = self.context.options
self.context.sock.sendto(pkt.encode().buffer, self.context.sock.sendto(pkt.encode().buffer,
(self.context.host, (self.context.host,
self.context.tidport)) self.context.tidport))
@ -464,6 +498,10 @@ class TftpState(object):
return None return None
elif pkt.blocknumber < self.context.next_block: elif pkt.blocknumber < self.context.next_block:
if pkt.blocknumber == 0:
log.warn("There is no block zero!")
self.sendError(TftpErrors.IllegalTftpOp)
raise TftpException, "There is no block zero!"
log.warn("dropping duplicate block %d" % pkt.blocknumber) log.warn("dropping duplicate block %d" % pkt.blocknumber)
self.context.metrics.add_dup(pkt.blocknumber) self.context.metrics.add_dup(pkt.blocknumber)
log.debug("ACKing block %d again, just in case" % pkt.blocknumber) log.debug("ACKing block %d again, just in case" % pkt.blocknumber)
@ -502,12 +540,17 @@ class TftpStateServerRecvRRQ(TftpState):
# Options negotiation. # Options negotiation.
if sendoack: if sendoack:
# Note, next_block is 0 here since that's the proper
# acknowledgement to an OACK.
# FIXME: perhaps we do need a TftpStateExpectOACK class...
self.sendOACK() self.sendOACK()
return TftpStateServerOACK(self.context)
else: else:
self.context.next_block = 1
log.debug("No requested options, starting send...") log.debug("No requested options, starting send...")
self.context.pending_complete = self.sendDAT() self.context.pending_complete = self.sendDAT()
return TftpStateExpectACK(self.context) # Note, we expect an ack regardless of whether we sent a DAT or an
# OACK.
return TftpStateExpectACK(self.context)
# Note, we don't have to check any other states in this method, that's # Note, we don't have to check any other states in this method, that's
# up to the caller. # up to the caller.
@ -579,6 +622,9 @@ class TftpStateExpectACK(TftpState):
return None return None
else: else:
log.debug("Good ACK, sending next DAT") log.debug("Good ACK, sending next DAT")
self.context.next_block += 1
log.debug("Incremented next_block to %d"
% (self.context.next_block))
self.context.pending_complete = self.sendDAT() self.context.pending_complete = self.sendDAT()
elif pkt.blocknumber < self.context.next_block: elif pkt.blocknumber < self.context.next_block: