Restructured in preparation for tftp options

git-svn-id: https://tftpy.svn.sourceforge.net/svnroot/tftpy/trunk@10 63283fd4-ec1e-0410-9879-cb7f675518da
This commit is contained in:
msoulier 2006-10-04 17:32:05 +00:00
parent 2827cf1e8f
commit c11ac3a321

View file

@ -19,18 +19,6 @@ MAX_BLKSIZE = 65536
SOCK_TIMEOUT = 5
MAX_DUPS = 20
# Initialize the logger.
logging.basicConfig(
level=LOG_LEVEL,
format='%(asctime)s %(name)-12s %(levelname)-8s %(message)s',
datefmt='%m-%d %H:%M:%S')
logger = logging.getLogger('tftpy')
class TftpException(Exception):
"""This class is the parent class of all exceptions regarding the handling
of the TFTP protocol."""
pass
def tftpassert(condition, msg):
"""This function is a simple utility that will check the condition
passed for a false state. If it finds one, it throws a TftpException
@ -39,6 +27,29 @@ def tftpassert(condition, msg):
if not condition:
raise TftpException, msg
def setLogLevel(level=LOG_LEVEL):
"""This function is a utility function for setting the internal log level.
The log level defaults to logging.NOTSET, so unwanted output to stdout is not
created."""
global logger
# Initialize the logger.
logging.basicConfig(
level=level,
format='%(asctime)s %(name)-12s %(levelname)-8s %(message)s',
datefmt='%m-%d %H:%M:%S')
logger = logging.getLogger('tftpy')
# The logger used by this library. Feel free to clobber it with your own, if you like, as
# long as it conforms to Python's logging.
logger = None
# Set up the default logger.
setLogLevel()
class TftpException(Exception):
"""This class is the parent class of all exceptions regarding the handling
of the TFTP protocol."""
pass
class TftpPacket(object):
"""This class is the parent class of all tftp packet classes. It is an
abstract class, providing an interface, and should not be instantiated
@ -364,18 +375,55 @@ class TftpPacketFactory(object):
packet.buffer = buffer
return packet.decode()
class Tftp(object):
class TftpState(object):
"""This class represents a particular state for a TFTP Session. It encapsulates a
state, kind of like an enum. The states mean the following:
nil - Session not yet established
rrq - Just sent RRQ in a download, waiting for response
wrq - Just sent WRQ in an upload, waiting for response
dat - Transferring data
oack - Received oack, negotiating options
ack - Acknowledged oack, awaiting response
err - Fatal problems, giving up
fin - Transfer completed
"""
states = ['nil',
'rrq',
'wrq',
'dat',
'oack',
'ack',
'err',
'fin']
def __init__(self, state='nil'):
self.state = state
def getState(self):
return self.__state
def setState(self, state):
if state in TftpState.states:
self.__state = state
state = property(getState, setState)
class TftpSession(object):
"""This class is the base class for the tftp client and server. Any shared
code should be in this class."""
def __init__(self):
"Class constructor. Note that the state property must be a TftpState object."
self.options = None
self.state = TftpState()
self.dups = 0
self.errors = 0
class TftpClient(Tftp):
class TftpClient(TftpSession):
"""This class is an implementation of a tftp client."""
def __init__(self, host, port, options):
"""This constructor returns an instance of TftpClient, taking the
remote host, the remote port, and the filename to fetch."""
Tftp.__init__(self)
TftpSession.__init__(self)
self.host = host
self.port = port
self.options = options
@ -400,73 +448,80 @@ class TftpClient(Tftp):
pkt.filename = filename
pkt.mode = "octet" # FIXME - shouldn't hardcode this
sock.sendto(pkt.encode().buffer, (self.host, self.port))
self.state.state = 'rrq'
# FIXME - need to do option negotiation here
while True:
(buffer, (raddress, rport)) = sock.recvfrom(MAX_BLKSIZE)
recvpkt = tftp_factory.parse(buffer)
# Read the initial response datagram to see if we're in business.
(buffer, (raddress, rport)) = sock.recvfrom(MAX_BLKSIZE)
recvpkt = tftp_factory.parse(buffer)
while isinstance(recvpkt, TftpPacketDAT):
logger.debug("Received %d bytes from %s:%s" \
% (len(buffer), raddress, rport))
logger.debug("Received %d bytes from %s:%s"
% (len(buffer), raddress, rport))
# FIXME - check sender port and ip address
# FIXME - can we refactor this into below?
logger.debug("recvpkt.blocknumber = %d" % recvpkt.blocknumber)
logger.debug("curblock = %d" % curblock)
if recvpkt.blocknumber == curblock+1:
logger.debug("good, received block %d in sequence"
% recvpkt.blocknumber)
curblock += 1
# ACK the packet, and save the data.
logger.info("sending ACK to block %d" % curblock)
logger.debug("ip = %s, port = %s" % (self.host, self.port))
ackpkt = TftpPacketACK()
ackpkt.blocknumber = curblock
sock.sendto(ackpkt.encode().buffer, (self.host, self.port))
logger.debug("writing %d bytes to output file"
% len(recvpkt.data))
outputfile.write(recvpkt.data)
bytes += len(recvpkt.data)
# If there is a packethook defined, call it.
if packethook:
packethook(recvpkt)
# Check for end-of-file, any less than full data packet.
if len(recvpkt.data) < DEF_BLKSIZE:
logger.info("end of file detected")
break
if isinstance(recvpkt, TftpPacketDAT):
logger.debug("recvpkt.blocknumber = %d" % recvpkt.blocknumber)
logger.debug("curblock = %d" % curblock)
if recvpkt.blocknumber == curblock+1:
logger.debug("good, received block %d in sequence"
% recvpkt.blocknumber)
curblock += 1
# ACK the packet, and save the data.
logger.info("sending ACK to block %d" % curblock)
logger.debug("ip = %s, port = %s" % (self.host, self.port))
ackpkt = TftpPacketACK()
ackpkt.blocknumber = curblock
sock.sendto(ackpkt.encode().buffer, (self.host, self.port))
logger.debug("writing %d bytes to output file"
% len(recvpkt.data))
outputfile.write(recvpkt.data)
bytes += len(recvpkt.data)
# If there is a packethook defined, call it.
if packethook:
packethook(recvpkt)
# Check for end-of-file, any less than full data packet.
if len(recvpkt.data) < DEF_BLKSIZE:
logger.info("end of file detected")
break
elif recvpkt.blocknumber == curblock:
logger.warn("dropping duplicate block %d" % curblock)
if dups.has_key(curblock):
dups[curblock] += 1
else:
dups[curblock] = 1
tftpassert(dups[curblock] < MAX_DUPS,
"Max duplicates for block %d reached" % curblock)
logger.debug("ACKing block %d again, just in case" % curblock)
ackpkt = TftpPacketACK()
ackpkt.blocknumber = curblock
sock.sendto(ackpkt.encode().buffer, (self.host, self.port))
elif recvpkt.blocknumber == curblock:
logger.warn("dropping duplicate block %d" % curblock)
if dups.has_key(curblock):
dups[curblock] += 1
else:
dups[curblock] = 1
tftpassert(dups[curblock] < MAX_DUPS,
"Max duplicates for block %d reached" % curblock)
logger.debug("ACKing block %d again, just in case" % curblock)
ackpkt = TftpPacketACK()
ackpkt.blocknumber = curblock
sock.sendto(ackpkt.encode().buffer, (self.host, self.port))
msg = "Whoa! Received block %d but expected %d" % (recvpkt.blocknumber,
curblock+1)
logger.error(msg)
raise TftpException, msg
# Check other packet types.
elif isinstance(recvpkt, TftpPacketOACK):
tftpassert(False, "Options currently unsupported")
elif isinstance(recvpkt, TftpPacketACK):
# Umm, we ACK, the server doesn't.
tftpassert(False, "Received ACK from server while in download")
elif isinstance(recvpkt, TftpPacketERR):
tftpassert(False, "Received ERR from server: " + recvpkt)
elif isinstance(recvpkt, TftpPacketWRQ):
tftpassert(False, "Received WRQ from server: " + recvpkt)
else:
msg = "Whoa! Received block %d but expected %d" % (recvpkt.blocknumber,
curblock+1)
logger.error(msg)
raise TftpException, msg
tftpassert(False, "Received unknown packet type from server: "
+ recvpkt)
(buffer, (raddress, rport)) = sock.recvfrom(MAX_BLKSIZE)
logger.info("Received %d bytes from %s:%s" % (len(buffer),
raddress,
rport))
recvpkt = tftp_factory.parse(buffer)
# FIXME - check sender port and ip address
# end while
# Check for errors
if isinstance(recvpkt, TftpPacketERR):
logger.error("received ERR packet")
end_time = time.time()
duration = end_time - start_time