Restructured in preparation for tftp options
git-svn-id: https://tftpy.svn.sourceforge.net/svnroot/tftpy/trunk@10 63283fd4-ec1e-0410-9879-cb7f675518da
This commit is contained in:
parent
2827cf1e8f
commit
c11ac3a321
1 changed files with 127 additions and 72 deletions
199
lib/tftpy.py
199
lib/tftpy.py
|
@ -19,18 +19,6 @@ MAX_BLKSIZE = 65536
|
|||
SOCK_TIMEOUT = 5
|
||||
MAX_DUPS = 20
|
||||
|
||||
# Initialize the logger.
|
||||
logging.basicConfig(
|
||||
level=LOG_LEVEL,
|
||||
format='%(asctime)s %(name)-12s %(levelname)-8s %(message)s',
|
||||
datefmt='%m-%d %H:%M:%S')
|
||||
logger = logging.getLogger('tftpy')
|
||||
|
||||
class TftpException(Exception):
|
||||
"""This class is the parent class of all exceptions regarding the handling
|
||||
of the TFTP protocol."""
|
||||
pass
|
||||
|
||||
def tftpassert(condition, msg):
|
||||
"""This function is a simple utility that will check the condition
|
||||
passed for a false state. If it finds one, it throws a TftpException
|
||||
|
@ -39,6 +27,29 @@ def tftpassert(condition, msg):
|
|||
if not condition:
|
||||
raise TftpException, msg
|
||||
|
||||
def setLogLevel(level=LOG_LEVEL):
|
||||
"""This function is a utility function for setting the internal log level.
|
||||
The log level defaults to logging.NOTSET, so unwanted output to stdout is not
|
||||
created."""
|
||||
global logger
|
||||
# Initialize the logger.
|
||||
logging.basicConfig(
|
||||
level=level,
|
||||
format='%(asctime)s %(name)-12s %(levelname)-8s %(message)s',
|
||||
datefmt='%m-%d %H:%M:%S')
|
||||
logger = logging.getLogger('tftpy')
|
||||
|
||||
# The logger used by this library. Feel free to clobber it with your own, if you like, as
|
||||
# long as it conforms to Python's logging.
|
||||
logger = None
|
||||
# Set up the default logger.
|
||||
setLogLevel()
|
||||
|
||||
class TftpException(Exception):
|
||||
"""This class is the parent class of all exceptions regarding the handling
|
||||
of the TFTP protocol."""
|
||||
pass
|
||||
|
||||
class TftpPacket(object):
|
||||
"""This class is the parent class of all tftp packet classes. It is an
|
||||
abstract class, providing an interface, and should not be instantiated
|
||||
|
@ -364,18 +375,55 @@ class TftpPacketFactory(object):
|
|||
packet.buffer = buffer
|
||||
return packet.decode()
|
||||
|
||||
class Tftp(object):
|
||||
class TftpState(object):
|
||||
"""This class represents a particular state for a TFTP Session. It encapsulates a
|
||||
state, kind of like an enum. The states mean the following:
|
||||
nil - Session not yet established
|
||||
rrq - Just sent RRQ in a download, waiting for response
|
||||
wrq - Just sent WRQ in an upload, waiting for response
|
||||
dat - Transferring data
|
||||
oack - Received oack, negotiating options
|
||||
ack - Acknowledged oack, awaiting response
|
||||
err - Fatal problems, giving up
|
||||
fin - Transfer completed
|
||||
"""
|
||||
states = ['nil',
|
||||
'rrq',
|
||||
'wrq',
|
||||
'dat',
|
||||
'oack',
|
||||
'ack',
|
||||
'err',
|
||||
'fin']
|
||||
|
||||
def __init__(self, state='nil'):
|
||||
self.state = state
|
||||
|
||||
def getState(self):
|
||||
return self.__state
|
||||
|
||||
def setState(self, state):
|
||||
if state in TftpState.states:
|
||||
self.__state = state
|
||||
|
||||
state = property(getState, setState)
|
||||
|
||||
class TftpSession(object):
|
||||
"""This class is the base class for the tftp client and server. Any shared
|
||||
code should be in this class."""
|
||||
def __init__(self):
|
||||
"Class constructor. Note that the state property must be a TftpState object."
|
||||
self.options = None
|
||||
self.state = TftpState()
|
||||
self.dups = 0
|
||||
self.errors = 0
|
||||
|
||||
class TftpClient(Tftp):
|
||||
class TftpClient(TftpSession):
|
||||
"""This class is an implementation of a tftp client."""
|
||||
def __init__(self, host, port, options):
|
||||
"""This constructor returns an instance of TftpClient, taking the
|
||||
remote host, the remote port, and the filename to fetch."""
|
||||
Tftp.__init__(self)
|
||||
TftpSession.__init__(self)
|
||||
self.host = host
|
||||
self.port = port
|
||||
self.options = options
|
||||
|
@ -400,73 +448,80 @@ class TftpClient(Tftp):
|
|||
pkt.filename = filename
|
||||
pkt.mode = "octet" # FIXME - shouldn't hardcode this
|
||||
sock.sendto(pkt.encode().buffer, (self.host, self.port))
|
||||
self.state.state = 'rrq'
|
||||
|
||||
# FIXME - need to do option negotiation here
|
||||
while True:
|
||||
(buffer, (raddress, rport)) = sock.recvfrom(MAX_BLKSIZE)
|
||||
recvpkt = tftp_factory.parse(buffer)
|
||||
|
||||
# Read the initial response datagram to see if we're in business.
|
||||
(buffer, (raddress, rport)) = sock.recvfrom(MAX_BLKSIZE)
|
||||
recvpkt = tftp_factory.parse(buffer)
|
||||
|
||||
while isinstance(recvpkt, TftpPacketDAT):
|
||||
logger.debug("Received %d bytes from %s:%s" \
|
||||
% (len(buffer), raddress, rport))
|
||||
logger.debug("Received %d bytes from %s:%s"
|
||||
% (len(buffer), raddress, rport))
|
||||
# FIXME - check sender port and ip address
|
||||
# FIXME - can we refactor this into below?
|
||||
logger.debug("recvpkt.blocknumber = %d" % recvpkt.blocknumber)
|
||||
logger.debug("curblock = %d" % curblock)
|
||||
if recvpkt.blocknumber == curblock+1:
|
||||
logger.debug("good, received block %d in sequence"
|
||||
% recvpkt.blocknumber)
|
||||
curblock += 1
|
||||
# ACK the packet, and save the data.
|
||||
logger.info("sending ACK to block %d" % curblock)
|
||||
logger.debug("ip = %s, port = %s" % (self.host, self.port))
|
||||
ackpkt = TftpPacketACK()
|
||||
ackpkt.blocknumber = curblock
|
||||
sock.sendto(ackpkt.encode().buffer, (self.host, self.port))
|
||||
|
||||
logger.debug("writing %d bytes to output file"
|
||||
% len(recvpkt.data))
|
||||
outputfile.write(recvpkt.data)
|
||||
bytes += len(recvpkt.data)
|
||||
# If there is a packethook defined, call it.
|
||||
if packethook:
|
||||
packethook(recvpkt)
|
||||
# Check for end-of-file, any less than full data packet.
|
||||
if len(recvpkt.data) < DEF_BLKSIZE:
|
||||
logger.info("end of file detected")
|
||||
break
|
||||
if isinstance(recvpkt, TftpPacketDAT):
|
||||
logger.debug("recvpkt.blocknumber = %d" % recvpkt.blocknumber)
|
||||
logger.debug("curblock = %d" % curblock)
|
||||
if recvpkt.blocknumber == curblock+1:
|
||||
logger.debug("good, received block %d in sequence"
|
||||
% recvpkt.blocknumber)
|
||||
curblock += 1
|
||||
# ACK the packet, and save the data.
|
||||
logger.info("sending ACK to block %d" % curblock)
|
||||
logger.debug("ip = %s, port = %s" % (self.host, self.port))
|
||||
ackpkt = TftpPacketACK()
|
||||
ackpkt.blocknumber = curblock
|
||||
sock.sendto(ackpkt.encode().buffer, (self.host, self.port))
|
||||
|
||||
logger.debug("writing %d bytes to output file"
|
||||
% len(recvpkt.data))
|
||||
outputfile.write(recvpkt.data)
|
||||
bytes += len(recvpkt.data)
|
||||
# If there is a packethook defined, call it.
|
||||
if packethook:
|
||||
packethook(recvpkt)
|
||||
# Check for end-of-file, any less than full data packet.
|
||||
if len(recvpkt.data) < DEF_BLKSIZE:
|
||||
logger.info("end of file detected")
|
||||
break
|
||||
|
||||
elif recvpkt.blocknumber == curblock:
|
||||
logger.warn("dropping duplicate block %d" % curblock)
|
||||
if dups.has_key(curblock):
|
||||
dups[curblock] += 1
|
||||
else:
|
||||
dups[curblock] = 1
|
||||
tftpassert(dups[curblock] < MAX_DUPS,
|
||||
"Max duplicates for block %d reached" % curblock)
|
||||
logger.debug("ACKing block %d again, just in case" % curblock)
|
||||
ackpkt = TftpPacketACK()
|
||||
ackpkt.blocknumber = curblock
|
||||
sock.sendto(ackpkt.encode().buffer, (self.host, self.port))
|
||||
|
||||
elif recvpkt.blocknumber == curblock:
|
||||
logger.warn("dropping duplicate block %d" % curblock)
|
||||
if dups.has_key(curblock):
|
||||
dups[curblock] += 1
|
||||
else:
|
||||
dups[curblock] = 1
|
||||
tftpassert(dups[curblock] < MAX_DUPS,
|
||||
"Max duplicates for block %d reached" % curblock)
|
||||
logger.debug("ACKing block %d again, just in case" % curblock)
|
||||
ackpkt = TftpPacketACK()
|
||||
ackpkt.blocknumber = curblock
|
||||
sock.sendto(ackpkt.encode().buffer, (self.host, self.port))
|
||||
msg = "Whoa! Received block %d but expected %d" % (recvpkt.blocknumber,
|
||||
curblock+1)
|
||||
logger.error(msg)
|
||||
raise TftpException, msg
|
||||
|
||||
# Check other packet types.
|
||||
elif isinstance(recvpkt, TftpPacketOACK):
|
||||
tftpassert(False, "Options currently unsupported")
|
||||
|
||||
elif isinstance(recvpkt, TftpPacketACK):
|
||||
# Umm, we ACK, the server doesn't.
|
||||
tftpassert(False, "Received ACK from server while in download")
|
||||
|
||||
elif isinstance(recvpkt, TftpPacketERR):
|
||||
tftpassert(False, "Received ERR from server: " + recvpkt)
|
||||
|
||||
elif isinstance(recvpkt, TftpPacketWRQ):
|
||||
tftpassert(False, "Received WRQ from server: " + recvpkt)
|
||||
|
||||
else:
|
||||
msg = "Whoa! Received block %d but expected %d" % (recvpkt.blocknumber,
|
||||
curblock+1)
|
||||
logger.error(msg)
|
||||
raise TftpException, msg
|
||||
tftpassert(False, "Received unknown packet type from server: "
|
||||
+ recvpkt)
|
||||
|
||||
(buffer, (raddress, rport)) = sock.recvfrom(MAX_BLKSIZE)
|
||||
logger.info("Received %d bytes from %s:%s" % (len(buffer),
|
||||
raddress,
|
||||
rport))
|
||||
recvpkt = tftp_factory.parse(buffer)
|
||||
# FIXME - check sender port and ip address
|
||||
|
||||
# end while
|
||||
# Check for errors
|
||||
if isinstance(recvpkt, TftpPacketERR):
|
||||
logger.error("received ERR packet")
|
||||
|
||||
end_time = time.time()
|
||||
duration = end_time - start_time
|
||||
|
|
Reference in a new issue