Started overhaul of state machine.

master
Michael P. Soulier 2009-04-08 23:29:43 -04:00
parent 41bf3a25e6
commit e7a63bbbc2
5 changed files with 349 additions and 223 deletions

View File

@ -1,6 +1,7 @@
import socket, time, types
import time, types
from TftpShared import *
from TftpPacketFactory import *
from TftpStates import TftpContextClientDownload
class TftpClient(TftpSession):
"""This class is an implementation of a tftp client. Once instantiated, a
@ -9,6 +10,7 @@ class TftpClient(TftpSession):
"""This constructor returns an instance of TftpClient, taking the
remote host, the remote port, and the filename to fetch."""
TftpSession.__init__(self)
self.context = None
self.host = host
self.iport = port
self.filename = None
@ -51,192 +53,30 @@ class TftpClient(TftpSession):
object. The timeout parameter may be used to override the default
SOCK_TIMEOUT setting, which is the amount of time that the client will
wait for a receive packet to arrive."""
# Open the output file.
# FIXME - need to support alternate return formats than files?
# File-like objects would be ideal, ala duck-typing.
self.fileobj = open(output, "wb")
recvpkt = None
curblock = 0
dups = {}
start_time = time.time()
self.bytes = 0
# We're downloading.
self.context = TftpContextClientDownload(self.host,
self.iport,
filename,
output,
self.options,
packethook,
timeout)
self.context.start()
# Download happens here
self.context.end()
self.filename = filename
metrics = self.context.metrics
tftp_factory = TftpPacketFactory()
self.sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
self.sock.settimeout(timeout)
logger.info("Sending tftp download request to %s" % self.host)
logger.info(" filename -> %s" % filename)
pkt = TftpPacketRRQ()
pkt.filename = filename
pkt.mode = "octet" # FIXME - shouldn't hardcode this
pkt.options = self.options
self.sock.sendto(pkt.encode().buffer, (self.host, self.iport))
self.state.state = 'rrq'
timeouts = 0
while True:
try:
(buffer, (raddress, rport)) = self.sock.recvfrom(MAX_BLKSIZE)
except socket.timeout, err:
timeouts += 1
if timeouts >= TIMEOUT_RETRIES:
raise TftpException, "Hit max timeouts, giving up."
else:
logger.warn("Timeout waiting for traffic, retrying...")
continue
recvpkt = tftp_factory.parse(buffer)
logger.debug("Received %d bytes from %s:%s"
% (len(buffer), raddress, rport))
# Check for known "connection".
if raddress != self.address:
logger.warn("Received traffic from %s, expected host %s. Discarding"
% (raddress, self.host))
continue
if self.port and self.port != rport:
logger.warn("Received traffic from %s:%s but we're "
"connected to %s:%s. Discarding."
% (raddress, rport,
self.host, self.port))
continue
# If there is a packethook defined, call it. We unconditionally
# pass all packets, it's up to the client to screen out different
# kinds of packets. This way, the client is privy to things like
# negotiated options.
if packethook:
packethook(recvpkt)
if not self.port and self.state.state == 'rrq':
self.port = rport
logger.debug("Set remote port for session to %s" % rport)
if isinstance(recvpkt, TftpPacketDAT):
logger.debug("recvpkt.blocknumber = %d" % recvpkt.blocknumber)
logger.debug("curblock = %d" % curblock)
if self.state.state == 'rrq' and self.options:
logger.info("no OACK, our options were ignored")
self.options = { 'blksize': DEF_BLKSIZE }
self.state.state = 'ack'
expected_block = curblock + 1
if expected_block > 65535:
logger.debug("block number rollover to 0 again")
expected_block = 0
if recvpkt.blocknumber == expected_block:
logger.debug("good, received block %d in sequence"
% recvpkt.blocknumber)
curblock = expected_block
# ACK the packet, and save the data.
logger.info("sending ACK to block %d" % curblock)
logger.debug("ip = %s, port = %s" % (self.host, self.port))
ackpkt = TftpPacketACK()
ackpkt.blocknumber = curblock
self.sock.sendto(ackpkt.encode().buffer, (self.host, self.port))
logger.debug("writing %d bytes to output file"
% len(recvpkt.data))
self.fileobj.write(recvpkt.data)
self.bytes += len(recvpkt.data)
# Check for end-of-file, any less than full data packet.
if len(recvpkt.data) < int(self.options['blksize']):
logger.info("end of file detected")
break
elif recvpkt.blocknumber == curblock:
logger.warn("dropping duplicate block %d" % curblock)
if dups.has_key(curblock):
dups[curblock] += 1
else:
dups[curblock] = 1
tftpassert(dups[curblock] < MAX_DUPS,
"Max duplicates for block %d reached" % curblock)
logger.debug("ACKing block %d again, just in case" % curblock)
ackpkt = TftpPacketACK()
ackpkt.blocknumber = curblock
self.sock.sendto(ackpkt.encode().buffer, (self.host, self.port))
else:
msg = "Whoa! Received block %d but expected %d" % (recvpkt.blocknumber,
curblock+1)
logger.error(msg)
raise TftpException, msg
# Check other packet types.
elif isinstance(recvpkt, TftpPacketOACK):
if not self.state.state == 'rrq':
self.errors += 1
logger.error("Received OACK in state %s" % self.state.state)
continue
self.state.state = 'oack'
logger.info("Received OACK from server.")
if recvpkt.options.keys() > 0:
if recvpkt.match_options(self.options):
logger.info("Successful negotiation of options")
# Set options to OACK options
self.options = recvpkt.options
for key in self.options:
logger.info(" %s = %s" % (key, self.options[key]))
logger.debug("sending ACK to OACK")
ackpkt = TftpPacketACK()
ackpkt.blocknumber = 0
self.sock.sendto(ackpkt.encode().buffer, (self.host, self.port))
self.state.state = 'ack'
else:
logger.error("failed to negotiate options")
self.senderror(self.sock, TftpErrors.FailedNegotiation, self.host, self.port)
self.state.state = 'err'
raise TftpException, "Failed to negotiate options"
elif isinstance(recvpkt, TftpPacketACK):
# Umm, we ACK, the server doesn't.
self.state.state = 'err'
self.senderror(self.sock, TftpErrors.IllegalTftpOp, self.host, self.port)
tftpassert(False, "Received ACK from server while in download")
elif isinstance(recvpkt, TftpPacketERR):
self.state.state = 'err'
self.senderror(self.sock, TftpErrors.IllegalTftpOp, self.host, self.port)
tftpassert(False, "Received ERR from server: " + str(recvpkt))
elif isinstance(recvpkt, TftpPacketWRQ):
self.state.state = 'err'
self.senderror(self.sock, TftpErrors.IllegalTftpOp, self.host, self.port)
tftpassert(False, "Received WRQ from server: " + str(recvpkt))
else:
self.state.state = 'err'
self.senderror(self.sock, TftpErrors.IllegalTftpOp, self.host, self.port)
tftpassert(False, "Received unknown packet type from server: "
+ str(recvpkt))
# end while
self.fileobj.close()
end_time = time.time()
duration = end_time - start_time
if duration == 0:
# FIXME: Should we output this? Shouldn't we let the client control
# output? This should be in the sample client, but not in the download
# call.
if metrics.duration == 0:
logger.info("Duration too short, rate undetermined")
else:
logger.info('')
logger.info("Downloaded %d bytes in %d seconds" % (self.bytes, duration))
bps = (self.bytes * 8.0) / duration
kbps = bps / 1024.0
logger.info("Average rate: %.2f kbps" % kbps)
dupcount = 0
for key in dups:
dupcount += dups[key]
logger.info("Received %d duplicate packets" % dupcount)
logger.info("Downloaded %d bytes in %d seconds" % (metrics.bytes, metrics.duration))
logger.info("Average rate: %.2f kbps" % metrics.kbps)
logger.info("Received %d duplicate packets" % metrics.dupcount)
def upload(self, filename, input, packethook=None, timeout=SOCK_TIMEOUT):
# Open the input file.

View File

@ -6,10 +6,9 @@ class TftpSession(object):
code should be in this class."""
def __init__(self):
"""Class constructor. Note that the state property must be a TftpState
object."""
"""Class constructor."""
self.options = None
self.state = TftpState()
self.state = None
self.dups = 0
self.errors = 0

View File

@ -95,7 +95,7 @@ class TftpServer(TftpSession):
logger.debug("New download request, session key = %s"
% key)
self.handlers[key] = TftpServerHandler(key,
TftpState('rrq'),
'rrq',
self.root,
listenip,
tftp_factory)

View File

@ -51,40 +51,3 @@ class TftpException(Exception):
"""This class is the parent class of all exceptions regarding the handling
of the TFTP protocol."""
pass
class TftpState(object):
"""This class represents a particular state for a TFTP Session. It encapsulates a
state, kind of like an enum. The states mean the following:
nil - Client/Server - Session not yet established
rrq - Client - Just sent RRQ in a download, waiting for response
Server - Just received an RRQ
wrq - Client - Just sent WRQ in an upload, waiting for response
Server - Just received a WRQ
dat - Client/Server - Transferring data
oack - Client - Just received oack
Server - Just sent OACK
ack - Client - Acknowledged oack, awaiting response
Server - Just received ACK to OACK
err - Client/Server - Fatal problems, giving up
fin - Client/Server - Transfer completed
"""
states = ['nil',
'rrq',
'wrq',
'dat',
'oack',
'ack',
'err',
'fin']
def __init__(self, state='nil'):
self.state = state
def getState(self):
return self.__state
def setState(self, state):
if state in TftpState.states:
self.__state = state
state = property(getState, setState)

324
tftpy/TftpStates.py Normal file
View File

@ -0,0 +1,324 @@
from TftpShared import *
from TftpPacketTypes import *
from TftpPacketFactory import *
import socket, time
###############################################################################
# Utility classes
###############################################################################
class TftpMetrics(object):
"""A class representing metrics of the transfer."""
def __init__(self):
# Bytes transferred
self.bytes = 0
# Duplicate packets received
self.dups = {}
self.dupcount = 0
# Times
self.start_time = 0
self.end_time = 0
self.duration = 0
# Rates
self.bps = 0
self.kbps = 0
def compute(self):
# Compute transfer time
self.duration = int(self.end_time - self.start_time)
self.bps = (metrics.bytes * 8.0) / metrics.duration
self.kbps = bps / 1024.0
for key in self.dups:
dupcount += metrics.dups[key]
###############################################################################
# Context classes
###############################################################################
class TftpContext(object):
"""The base class of the contexts."""
def __init__(self, host, port):
"""Constructor for the base context, setting shared instance
variables."""
self.factory = TftpPacketFactory()
self.host = host
self.port = port
# The port associated with the TID
self.tidport = None
# Metrics
self.metrics = TftpMetrics()
def start(self):
return NotImplementedError, "Abstract method"
def end(self):
return NotImplementedError, "Abstract method"
def gethost(self):
"Simple getter method for use in a property."
return self.__host
def sethost(self, host):
"""Setter method that also sets the address property as a result
of the host that is set."""
self.__host = host
self.address = socket.gethostbyname(host)
host = property(gethost, sethost)
def sendAck(self, blocknumber):
"""This method sends an ack packet to the block number specified."""
logger.info("sending ack to block %d" % blocknumber)
ackpkt = TftpPacketACK()
ackpkt.blocknumber = 0
self.sock.sendto(ackpkt.encode().buffer, (self.host, self.port))
def senderror(self, errorcode):
"""This method uses the socket passed, and uses the errorcode to
compose and send an error packet."""
logger.debug("In senderror, being asked to send error %d" % errorcode)
errpkt = TftpPacketERR()
errpkt.errorcode = errorcode
sock.sendto(errpkt.encode().buffer, (self.host, self.tidport))
class TftpContextServerDownload(TftpContext):
"""The download context for the server during a download."""
pass
class TftpContextClientDownload(TftpContext):
"""The download context for the client during a download."""
def __init__(self, host, port, filename, output, options, packethook, timeout):
TftpContext.__init__(self, host, port)
# Open the output file.
# FIXME - need to support alternate return formats than files?
# File-like objects would be ideal, ala duck-typing.
self.requested_file = filename
self.fileobj = open(output, "wb")
self.options = options
self.packethook = packethook
self.sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
self.sock.settimeout(timeout)
self.state = None
self.expected_block = 0
def setExpectedBlock(self, block):
if block > 2 ** 16:
logger.debug("block number rollover to 0 again")
block = 0
self.__eblock = block
def getExpectedBlock(self):
return self.__eblock
expected_block = property(getExpectedBlock, setExpectedBlock)
def start(self):
"""Initiate the download."""
logger.info("Sending tftp download request to %s" % self.host)
logger.info(" filename -> %s" % self.requested_file)
self.metrics.start_time = time.time()
# FIXME: put this in a sendRRQ method?
pkt = TftpPacketRRQ()
pkt.filename = self.requested_file
pkt.mode = "octet" # FIXME - shouldn't hardcode this
pkt.options = self.options
self.sock.sendto(pkt.encode().buffer, (self.host, self.port))
self.expected_block = 1
self.state = TftpStateSentRRQ(self)
try:
while self.state:
self.cycle()
finally:
self.fileobj.close()
def end(self):
"""Finish up the context."""
self.metrics.end_time = time.time()
self.metrics.compute()
def cycle(self):
"""Here we wait for a response from the server after sending it
something, and dispatch appropriate action to that response."""
for i in range(TIMEOUT_RETRIES):
try:
(buffer, (raddress, rport)) = self.sock.recvfrom(MAX_BLKSIZE)
except socket.timeout, err:
logger.warn("Timeout waiting for traffic, retrying...")
continue
break
else:
raise TftpException, "Hit max timeouts, giving up."
# Ok, we've received a packet. Decode it.
recvpkt = self.factory.parse(buffer)
# Log it.
logger.debug("Received %d bytes from %s:%s"
% (len(buffer), raddress, rport))
# Check for known "connection".
if raddress != self.address:
logger.warn("Received traffic from %s, expected host %s. Discarding"
% (raddress, self.host))
if self.port and self.port != rport:
logger.warn("Received traffic from %s:%s but we're "
"connected to %s:%s. Discarding."
% (raddress, rport,
self.host, self.port))
# If there is a packethook defined, call it. We unconditionally
# pass all packets, it's up to the client to screen out different
# kinds of packets. This way, the client is privy to things like
# negotiated options.
if self.packethook:
self.packethook(recvpkt)
# And handle it, possibly changing state.
self.state = self.state.handle(recvpkt, raddress, rport)
###############################################################################
# State classes
###############################################################################
class TftpState(object):
"""The base class for the states."""
def __init__(self, context):
"""Constructor for setting up common instance variables. The involved
file object is required, since in tftp there's always a file
involved."""
self.context = context
def handle(self, pkt, raddress, rport):
"""An abstract method for handling a packet. It is expected to return
a TftpState object, either itself or a new state."""
raise NotImplementedError, "Abstract method"
class TftpStateDownload(TftpState):
"""A class holding common code for download states."""
def handleDat(self, pkt):
"""This method handles a DAT packet during a download."""
logger.info("handling DAT packet - block %d" % pkt.blocknumber)
logger.debug("expecting block %s" % self.expected_block)
if pkt.blocknumber == self.expected_block:
logger.debug("good, received block %d in sequence"
% pkt.blocknumber)
self.context.sendAck(pkt.blocknumber)
self.expected_block += 1
logger.debug("writing %d bytes to output file"
% len(pkt.data))
self.context.fileobj.write(pkt.data)
self.context.metrics.bytes += len(pkt.data)
# Check for end-of-file, any less than full data packet.
if len(pkt.data) < int(self.context.options['blksize']):
logger.info("end of file detected")
return None
elif pkt.blocknumber == curblock:
logger.warn("dropping duplicate block %d" % pkt.blocknumber)
if self.context.metrics.dups.has_key(curblock):
self.context.metrics.dups[pkt.blocknumber] += 1
else:
self.context.metrics.dups[pkt.blocknumber] = 1
tftpassert(self.context.metrics.dups[curblock] < MAX_DUPS,
"Max duplicates for block %d reached" % curblock)
# FIXME: double-check sorceror's apprentice problem!
logger.debug("ACKing block %d again, just in case" % curblock)
self.context.sendAck(pkt.blocknumber)
else:
# FIXME: should we be more tolerant and just discard instead?
msg = "Whoa! Received block %d but expected %d" % (pkt.blocknumber,
self.expected_block)
logger.error(msg)
raise TftpException, msg
# Default is to ack
return TftpStateSentACK(self.context)
class TftpStateSentRRQ(TftpStateDownload):
"""Just sent an RRQ packet."""
def handle(self, pkt, raddress, rport):
"""Handle the packet in response to an RRQ to the server."""
if not self.tidport:
self.tidport = rport
logger.debug("Set remote port for session to %s" % rport)
# Now check the packet type and dispatch it properly.
if isinstance(pkt, TftpPacketOACK):
logger.info("Received OACK from server.")
if pkt.options.keys() > 0:
if pkt.match_options(self.options):
logger.info("Successful negotiation of options")
# Set options to OACK options
self.options = pkt.options
for key in self.options:
logger.info(" %s = %s" % (key, self.options[key]))
logger.debug("sending ACK to OACK")
self.context.sendAck(blocknumber=0)
logger.debug("Changing state to TftpStateSentACK")
return TftpStateSentACK(self.context)
else:
logger.error("failed to negotiate options")
self.senderror(self.sock, TftpErrors.FailedNegotiation, self.host, self.port)
raise TftpException, "Failed to negotiate options"
elif isinstance(pkt, TftpPacketDAT):
return self.handleDat(pkt)
# Every other packet type is a problem.
elif isinstance(recvpkt, TftpPacketACK):
# Umm, we ACK, the server doesn't.
self.senderror(self.sock, TftpErrors.IllegalTftpOp)
raise TftpException, "Received ACK from server while in download"
elif isinstance(recvpkt, TftpPacketWRQ):
self.senderror(self.sock, TftpErrors.IllegalTftpOp)
raise TftpException, "Received WRQ from server while in download"
elif isinstance(recvpkt, TftpPacketERR):
self.senderror(self.sock, TftpErrors.IllegalTftpOp)
raise TftpException, "Received ERR from server: " + str(recvpkt)
else:
self.senderror(self.sock, TftpErrors.IllegalTftpOp)
raise TftpException, "Received unknown packet type from server: " + str(recvpkt)
# By default, no state change.
return self
class TftpStateSentACK(TftpState):
"""Just sent an ACK packet. Waiting for DAT."""
def handle(self, pkt, raddress, rport):
"""Handle the packet in response to an ACK, which should be a DAT."""
if isinstance(pkt, TftpPacketDAT):
return self.handleDat(pkt)
# Every other packet type is a problem.
elif isinstance(recvpkt, TftpPacketACK):
# Umm, we ACK, the server doesn't.
self.senderror(self.sock, TftpErrors.IllegalTftpOp)
raise TftpException, "Received ACK from server while in download"
elif isinstance(recvpkt, TftpPacketWRQ):
self.senderror(self.sock, TftpErrors.IllegalTftpOp)
raise TftpException, "Received WRQ from server while in download"
elif isinstance(recvpkt, TftpPacketERR):
self.senderror(self.sock, TftpErrors.IllegalTftpOp)
raise TftpException, "Received ERR from server: " + str(recvpkt)
else:
self.senderror(self.sock, TftpErrors.IllegalTftpOp)
raise TftpException, "Received unknown packet type from server: " + str(recvpkt)