Fixing up some of the upload code.
parent
5ee5f63f9b
commit
03e4e74829
|
@ -1,7 +1,7 @@
|
||||||
import time, types
|
import time, types
|
||||||
from TftpShared import *
|
from TftpShared import *
|
||||||
from TftpPacketFactory import *
|
from TftpPacketFactory import *
|
||||||
from TftpStates import TftpContextClientDownload
|
from TftpStates import TftpContextClientDownload, TftpContextClientUpload
|
||||||
|
|
||||||
class TftpClient(TftpSession):
|
class TftpClient(TftpSession):
|
||||||
"""This class is an implementation of a tftp client. Once instantiated, a
|
"""This class is an implementation of a tftp client. Once instantiated, a
|
||||||
|
@ -63,185 +63,26 @@ class TftpClient(TftpSession):
|
||||||
# Open the input file.
|
# Open the input file.
|
||||||
# FIXME: As of the state machine, this is now broken. Need to
|
# FIXME: As of the state machine, this is now broken. Need to
|
||||||
# implement with new state machine.
|
# implement with new state machine.
|
||||||
self.fileobj = open(input, "rb")
|
self.context = TftpContextClientUpload(self.host,
|
||||||
recvpkt = None
|
self.iport,
|
||||||
curblock = 0
|
filename,
|
||||||
start_time = time.time()
|
input,
|
||||||
self.bytes = 0
|
self.options,
|
||||||
|
packethook,
|
||||||
|
timeout)
|
||||||
|
self.context.start()
|
||||||
|
# Upload happens here
|
||||||
|
self.context.end()
|
||||||
|
|
||||||
tftp_factory = TftpPacketFactory()
|
metrics = self.context.metrics
|
||||||
self.sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
|
|
||||||
self.sock.settimeout(timeout)
|
|
||||||
|
|
||||||
self.filename = filename
|
# FIXME: Should we output this? Shouldn't we let the client control
|
||||||
|
# output? This should be in the sample client, but not in the download
|
||||||
self.send_wrq()
|
# call.
|
||||||
self.state.state = 'wrq'
|
if metrics.duration == 0:
|
||||||
|
|
||||||
timeouts = 0
|
|
||||||
while True:
|
|
||||||
try:
|
|
||||||
(buffer, (raddress, rport)) = self.sock.recvfrom(MAX_BLKSIZE)
|
|
||||||
except socket.timeout, err:
|
|
||||||
timeouts += 1
|
|
||||||
if timeouts >= TIMEOUT_RETRIES:
|
|
||||||
raise TftpException, "Hit max timeouts, giving up."
|
|
||||||
else:
|
|
||||||
if self.state.state == 'dat' or self.state.state == 'fin':
|
|
||||||
logger.debug("Timing out on DAT. Need to resend.")
|
|
||||||
self.send_dat(packethook,resend=True)
|
|
||||||
elif self.state.state == 'wrq':
|
|
||||||
logger.debug("Timing out on WRQ.")
|
|
||||||
self.send_wrq(resend=True)
|
|
||||||
else:
|
|
||||||
tftpassert(False,
|
|
||||||
"Timing out in unsupported state %s" %
|
|
||||||
self.state.state)
|
|
||||||
continue
|
|
||||||
|
|
||||||
recvpkt = tftp_factory.parse(buffer)
|
|
||||||
|
|
||||||
logger.debug("Received %d bytes from %s:%s"
|
|
||||||
% (len(buffer), raddress, rport))
|
|
||||||
|
|
||||||
# Check for known "connection".
|
|
||||||
if raddress != self.address:
|
|
||||||
logger.warn("Received traffic from %s, expected host %s. Discarding"
|
|
||||||
% (raddress, self.host))
|
|
||||||
continue
|
|
||||||
if self.port and self.port != rport:
|
|
||||||
logger.warn("Received traffic from %s:%s but we're "
|
|
||||||
"connected to %s:%s. Discarding."
|
|
||||||
% (raddress, rport,
|
|
||||||
self.host, self.port))
|
|
||||||
continue
|
|
||||||
|
|
||||||
if not self.port and self.state.state == 'wrq':
|
|
||||||
self.port = rport
|
|
||||||
logger.debug("Set remote port for session to %s" % rport)
|
|
||||||
|
|
||||||
# Next packet type
|
|
||||||
if isinstance(recvpkt, TftpPacketACK):
|
|
||||||
logger.debug("Received an ACK from the server.")
|
|
||||||
# tftp on wrt54gl seems to answer with an ack to a wrq regardless
|
|
||||||
# if we sent options.
|
|
||||||
if recvpkt.blocknumber == 0 and self.state.state in ('oack','wrq'):
|
|
||||||
logger.debug("Received ACK with 0 blocknumber, starting upload")
|
|
||||||
self.state.state = 'dat'
|
|
||||||
self.send_dat(packethook)
|
|
||||||
else:
|
|
||||||
if self.state.state == 'dat' or self.state.state == 'fin':
|
|
||||||
if self.blocknumber == recvpkt.blocknumber:
|
|
||||||
logger.info("Received ACK for block %d"
|
|
||||||
% recvpkt.blocknumber)
|
|
||||||
if self.state.state == 'fin':
|
|
||||||
break
|
|
||||||
else:
|
|
||||||
self.send_dat(packethook)
|
|
||||||
elif recvpkt.blocknumber < self.blocknumber:
|
|
||||||
# Don't resend a DAT due to an old ACK. Fixes the
|
|
||||||
# sorceror's apprentice problem.
|
|
||||||
logger.warn("Received old ACK for block number %d"
|
|
||||||
% recvpkt.blocknumber)
|
|
||||||
else:
|
|
||||||
logger.warn("Received ACK for block number "
|
|
||||||
"%d, apparently from the future"
|
|
||||||
% recvpkt.blocknumber)
|
|
||||||
else:
|
|
||||||
logger.error("Received ACK with block number %d "
|
|
||||||
"while in state %s"
|
|
||||||
% (recvpkt.blocknumber,
|
|
||||||
self.state.state))
|
|
||||||
|
|
||||||
# Check other packet types.
|
|
||||||
elif isinstance(recvpkt, TftpPacketOACK):
|
|
||||||
if not self.state.state == 'wrq':
|
|
||||||
self.errors += 1
|
|
||||||
logger.error("Received OACK in state %s" % self.state.state)
|
|
||||||
continue
|
|
||||||
|
|
||||||
self.state.state = 'oack'
|
|
||||||
logger.info("Received OACK from server.")
|
|
||||||
if recvpkt.options.keys() > 0:
|
|
||||||
if recvpkt.match_options(self.options):
|
|
||||||
logger.info("Successful negotiation of options")
|
|
||||||
for key in self.options:
|
|
||||||
logger.info(" %s = %s" % (key, self.options[key]))
|
|
||||||
logger.debug("sending ACK to OACK")
|
|
||||||
ackpkt = TftpPacketACK()
|
|
||||||
ackpkt.blocknumber = 0
|
|
||||||
self.sock.sendto(ackpkt.encode().buffer, (self.host, self.port))
|
|
||||||
self.state.state = 'dat'
|
|
||||||
self.send_dat(packethook)
|
|
||||||
else:
|
|
||||||
logger.error("failed to negotiate options")
|
|
||||||
self.senderror(self.sock, TftpErrors.FailedNegotiation, self.host, self.port)
|
|
||||||
self.state.state = 'err'
|
|
||||||
raise TftpException, "Failed to negotiate options"
|
|
||||||
|
|
||||||
elif isinstance(recvpkt, TftpPacketERR):
|
|
||||||
self.state.state = 'err'
|
|
||||||
self.senderror(self.sock, TftpErrors.IllegalTftpOp, self.host, self.port)
|
|
||||||
tftpassert(False, "Received ERR from server: " + str(recvpkt))
|
|
||||||
|
|
||||||
elif isinstance(recvpkt, TftpPacketWRQ):
|
|
||||||
self.state.state = 'err'
|
|
||||||
self.senderror(self.sock, TftpErrors.IllegalTftpOp, self.host, self.port)
|
|
||||||
tftpassert(False, "Received WRQ from server: " + str(recvpkt))
|
|
||||||
|
|
||||||
else:
|
|
||||||
self.state.state = 'err'
|
|
||||||
self.senderror(self.sock, TftpErrors.IllegalTftpOp, self.host, self.port)
|
|
||||||
tftpassert(False, "Received unknown packet type from server: "
|
|
||||||
+ str(recvpkt))
|
|
||||||
|
|
||||||
|
|
||||||
# end while
|
|
||||||
self.fileobj.close()
|
|
||||||
|
|
||||||
end_time = time.time()
|
|
||||||
duration = end_time - start_time
|
|
||||||
if duration == 0:
|
|
||||||
logger.info("Duration too short, rate undetermined")
|
logger.info("Duration too short, rate undetermined")
|
||||||
else:
|
else:
|
||||||
logger.info('')
|
logger.info('')
|
||||||
logger.info("Uploaded %d bytes in %d seconds" % (self.bytes, duration))
|
logger.info("Downloaded %.2f bytes in %.2f seconds" % (metrics.bytes, metrics.duration))
|
||||||
bps = (self.bytes * 8.0) / duration
|
logger.info("Average rate: %.2f kbps" % metrics.kbps)
|
||||||
kbps = bps / 1024.0
|
logger.info("Received %d duplicate packets" % metrics.dupcount)
|
||||||
logger.info("Average rate: %.2f kbps" % kbps)
|
|
||||||
|
|
||||||
def send_dat(self, packethook, resend=False):
|
|
||||||
"""This method reads and sends a DAT packet based on what is in self.buffer."""
|
|
||||||
if not resend:
|
|
||||||
blksize = int(self.options['blksize'])
|
|
||||||
self.buffer = self.fileobj.read(blksize)
|
|
||||||
logger.debug("Read %d bytes into buffer" % len(self.buffer))
|
|
||||||
if len(self.buffer) < blksize:
|
|
||||||
logger.info("Reached EOF on file %s" % self.filename)
|
|
||||||
self.state.state = 'fin'
|
|
||||||
self.blocknumber += 1
|
|
||||||
if self.blocknumber > 65535:
|
|
||||||
logger.debug("Blocknumber rolled over to zero")
|
|
||||||
self.blocknumber = 0
|
|
||||||
self.bytes += len(self.buffer)
|
|
||||||
else:
|
|
||||||
logger.warn("Resending block number %d" % self.blocknumber)
|
|
||||||
dat = TftpPacketDAT()
|
|
||||||
dat.data = self.buffer
|
|
||||||
dat.blocknumber = self.blocknumber
|
|
||||||
logger.debug("Sending DAT packet %d" % self.blocknumber)
|
|
||||||
self.sock.sendto(dat.encode().buffer, (self.host, self.port))
|
|
||||||
self.timesent = time.time()
|
|
||||||
if packethook:
|
|
||||||
packethook(dat)
|
|
||||||
|
|
||||||
def send_wrq(self, resend=False):
|
|
||||||
"""This method sends a wrq"""
|
|
||||||
logger.info("Sending tftp upload request to %s" % self.host)
|
|
||||||
logger.info(" filename -> %s" % self.filename)
|
|
||||||
|
|
||||||
wrq = TftpPacketWRQ()
|
|
||||||
wrq.filename = self.filename
|
|
||||||
wrq.mode = "octet" # FIXME - shouldn't hardcode this
|
|
||||||
wrq.options = self.options
|
|
||||||
self.sock.sendto(wrq.encode().buffer, (self.host, self.iport))
|
|
|
@ -76,83 +76,37 @@ class TftpContext(object):
|
||||||
ackpkt.blocknumber = blocknumber
|
ackpkt.blocknumber = blocknumber
|
||||||
self.sock.sendto(ackpkt.encode().buffer, (self.host, self.tidport))
|
self.sock.sendto(ackpkt.encode().buffer, (self.host, self.tidport))
|
||||||
|
|
||||||
def senderror(self, errorcode):
|
def sendError(self, errorcode):
|
||||||
"""This method uses the socket passed, and uses the errorcode to
|
"""This method uses the socket passed, and uses the errorcode to
|
||||||
compose and send an error packet."""
|
compose and send an error packet."""
|
||||||
logger.debug("In senderror, being asked to send error %d" % errorcode)
|
logger.debug("In sendError, being asked to send error %d" % errorcode)
|
||||||
errpkt = TftpPacketERR()
|
errpkt = TftpPacketERR()
|
||||||
errpkt.errorcode = errorcode
|
errpkt.errorcode = errorcode
|
||||||
sock.sendto(errpkt.encode().buffer, (self.host, self.tidport))
|
self.sock.sendto(errpkt.encode().buffer, (self.host, self.tidport))
|
||||||
|
|
||||||
class TftpContextServerDownload(TftpContext):
|
class TftpContextClient(TftpContext):
|
||||||
"""The download context for the server during a download."""
|
"""This class represents shared functionality by both the download and
|
||||||
pass
|
upload client contexts."""
|
||||||
|
def __init__(self, host, port, filename, options, packethook, timeout):
|
||||||
class TftpContextClientDownload(TftpContext):
|
|
||||||
"""The download context for the client during a download."""
|
|
||||||
def __init__(self, host, port, filename, output, options, packethook, timeout):
|
|
||||||
TftpContext.__init__(self, host, port)
|
TftpContext.__init__(self, host, port)
|
||||||
# FIXME - need to support alternate return formats than files?
|
self.file_to_transfer = filename
|
||||||
# File-like objects would be ideal, ala duck-typing.
|
|
||||||
self.requested_file = filename
|
|
||||||
self.fileobj = open(output, "wb")
|
|
||||||
self.options = options
|
self.options = options
|
||||||
self.packethook = packethook
|
self.packethook = packethook
|
||||||
self.sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
|
self.sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
|
||||||
self.sock.settimeout(timeout)
|
self.sock.settimeout(timeout)
|
||||||
|
|
||||||
self.state = None
|
self.state = None
|
||||||
self.expected_block = 0
|
self.next_block = 0
|
||||||
|
|
||||||
############################
|
def setNextBlock(self, block):
|
||||||
# Logging
|
|
||||||
############################
|
|
||||||
logger.debug("TftpContextClientDownload.__init__()")
|
|
||||||
logger.debug("requested_file = %s, options = %s" %
|
|
||||||
(self.requested_file, self.options))
|
|
||||||
|
|
||||||
def setExpectedBlock(self, block):
|
|
||||||
if block > 2 ** 16:
|
if block > 2 ** 16:
|
||||||
logger.debug("block number rollover to 0 again")
|
logger.debug("block number rollover to 0 again")
|
||||||
block = 0
|
block = 0
|
||||||
self.__eblock = block
|
self.__eblock = block
|
||||||
|
|
||||||
def getExpectedBlock(self):
|
def getNextBlock(self):
|
||||||
return self.__eblock
|
return self.__eblock
|
||||||
|
|
||||||
expected_block = property(getExpectedBlock, setExpectedBlock)
|
next_block = property(getNextBlock, setNextBlock)
|
||||||
|
|
||||||
def start(self):
|
|
||||||
"""Initiate the download."""
|
|
||||||
logger.info("sending tftp download request to %s" % self.host)
|
|
||||||
logger.info(" filename -> %s" % self.requested_file)
|
|
||||||
logger.info(" options -> %s" % self.options)
|
|
||||||
|
|
||||||
self.metrics.start_time = time.time()
|
|
||||||
logger.debug("set metrics.start_time to %s" % self.metrics.start_time)
|
|
||||||
|
|
||||||
# FIXME: put this in a sendRRQ method?
|
|
||||||
pkt = TftpPacketRRQ()
|
|
||||||
pkt.filename = self.requested_file
|
|
||||||
pkt.mode = "octet" # FIXME - shouldn't hardcode this
|
|
||||||
pkt.options = self.options
|
|
||||||
self.sock.sendto(pkt.encode().buffer, (self.host, self.port))
|
|
||||||
self.expected_block = 1
|
|
||||||
|
|
||||||
self.state = TftpStateSentRRQ(self)
|
|
||||||
|
|
||||||
try:
|
|
||||||
while self.state:
|
|
||||||
logger.debug("state is %s" % self.state)
|
|
||||||
self.cycle()
|
|
||||||
finally:
|
|
||||||
self.fileobj.close()
|
|
||||||
|
|
||||||
def end(self):
|
|
||||||
"""Finish up the context."""
|
|
||||||
self.metrics.end_time = time.time()
|
|
||||||
logger.debug("set metrics.end_time to %s" % self.metrics.end_time)
|
|
||||||
self.metrics.compute()
|
|
||||||
|
|
||||||
def cycle(self):
|
def cycle(self):
|
||||||
"""Here we wait for a response from the server after sending it
|
"""Here we wait for a response from the server after sending it
|
||||||
|
@ -196,6 +150,101 @@ class TftpContextClientDownload(TftpContext):
|
||||||
# And handle it, possibly changing state.
|
# And handle it, possibly changing state.
|
||||||
self.state = self.state.handle(recvpkt, raddress, rport)
|
self.state = self.state.handle(recvpkt, raddress, rport)
|
||||||
|
|
||||||
|
class TftpContextClientUpload(TftpContextClient):
|
||||||
|
"""The upload context for the client during an upload."""
|
||||||
|
def __init__(self, host, port, filename, input, options, packethook, timeout):
|
||||||
|
TftpContextClient.__init__(self,
|
||||||
|
host,
|
||||||
|
port,
|
||||||
|
filename,
|
||||||
|
options,
|
||||||
|
packethook,
|
||||||
|
timeout)
|
||||||
|
self.fileobj = open(input, "wb")
|
||||||
|
|
||||||
|
logger.debug("TftpContextClientUpload.__init__()")
|
||||||
|
logger.debug("file_to_transfer = %s, options = %s" %
|
||||||
|
(self.file_to_transfer, self.options))
|
||||||
|
|
||||||
|
def start(self):
|
||||||
|
logger.info("sending tftp upload request to %s" % self.host)
|
||||||
|
logger.info(" filename -> %s" % self.file_to_transfer)
|
||||||
|
logger.info(" options -> %s" % self.options)
|
||||||
|
|
||||||
|
self.metrics.start_time = time.time()
|
||||||
|
logger.debug("set metrics.start_time to %s" % self.metrics.start_time)
|
||||||
|
|
||||||
|
# FIXME: put this in a sendWRQ method?
|
||||||
|
pkt = TftpPacketWRQ()
|
||||||
|
pkt.filename = self.file_to_transfer
|
||||||
|
pkt.mode = "octet" # FIXME - shouldn't hardcode this
|
||||||
|
pkt.options = self.options
|
||||||
|
self.sock.sendto(pkt.encode().buffer, (self.host, self.port))
|
||||||
|
self.next_block = 1
|
||||||
|
|
||||||
|
self.state = TftpStateSentWRQ(self)
|
||||||
|
|
||||||
|
try:
|
||||||
|
while self.state:
|
||||||
|
logger.debug("state is %s" % self.state)
|
||||||
|
self.cycle()
|
||||||
|
finally:
|
||||||
|
self.fileobj.close()
|
||||||
|
|
||||||
|
def end(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
class TftpContextClientDownload(TftpContextClient):
|
||||||
|
"""The download context for the client during a download."""
|
||||||
|
def __init__(self, host, port, filename, output, options, packethook, timeout):
|
||||||
|
TftpContextClient.__init__(self,
|
||||||
|
host,
|
||||||
|
port,
|
||||||
|
filename,
|
||||||
|
options,
|
||||||
|
packethook,
|
||||||
|
timeout)
|
||||||
|
# FIXME - need to support alternate return formats than files?
|
||||||
|
# File-like objects would be ideal, ala duck-typing.
|
||||||
|
self.fileobj = open(output, "wb")
|
||||||
|
|
||||||
|
logger.debug("TftpContextClientDownload.__init__()")
|
||||||
|
logger.debug("file_to_transfer = %s, options = %s" %
|
||||||
|
(self.file_to_transfer, self.options))
|
||||||
|
|
||||||
|
def start(self):
|
||||||
|
"""Initiate the download."""
|
||||||
|
logger.info("sending tftp download request to %s" % self.host)
|
||||||
|
logger.info(" filename -> %s" % self.file_to_transfer)
|
||||||
|
logger.info(" options -> %s" % self.options)
|
||||||
|
|
||||||
|
self.metrics.start_time = time.time()
|
||||||
|
logger.debug("set metrics.start_time to %s" % self.metrics.start_time)
|
||||||
|
|
||||||
|
# FIXME: put this in a sendRRQ method?
|
||||||
|
pkt = TftpPacketRRQ()
|
||||||
|
pkt.filename = self.file_to_transfer
|
||||||
|
pkt.mode = "octet" # FIXME - shouldn't hardcode this
|
||||||
|
pkt.options = self.options
|
||||||
|
self.sock.sendto(pkt.encode().buffer, (self.host, self.port))
|
||||||
|
self.next_block = 1
|
||||||
|
|
||||||
|
self.state = TftpStateSentRRQ(self)
|
||||||
|
|
||||||
|
try:
|
||||||
|
while self.state:
|
||||||
|
logger.debug("state is %s" % self.state)
|
||||||
|
self.cycle()
|
||||||
|
finally:
|
||||||
|
self.fileobj.close()
|
||||||
|
|
||||||
|
def end(self):
|
||||||
|
"""Finish up the context."""
|
||||||
|
self.metrics.end_time = time.time()
|
||||||
|
logger.debug("set metrics.end_time to %s" % self.metrics.end_time)
|
||||||
|
self.metrics.compute()
|
||||||
|
|
||||||
|
|
||||||
###############################################################################
|
###############################################################################
|
||||||
# State classes
|
# State classes
|
||||||
###############################################################################
|
###############################################################################
|
||||||
|
@ -214,18 +263,60 @@ class TftpState(object):
|
||||||
a TftpState object, either itself or a new state."""
|
a TftpState object, either itself or a new state."""
|
||||||
raise NotImplementedError, "Abstract method"
|
raise NotImplementedError, "Abstract method"
|
||||||
|
|
||||||
|
def handleOACK(self, pkt):
|
||||||
|
"""This method handles an OACK from the server, syncing any accepted
|
||||||
|
options."""
|
||||||
|
if pkt.options.keys() > 0:
|
||||||
|
if pkt.match_options(self.context.options):
|
||||||
|
logger.info("Successful negotiation of options")
|
||||||
|
# Set options to OACK options
|
||||||
|
self.context.options = pkt.options
|
||||||
|
for key in self.context.options:
|
||||||
|
logger.info(" %s = %s" % (key, self.context.options[key]))
|
||||||
|
else:
|
||||||
|
logger.error("failed to negotiate options")
|
||||||
|
raise TftpException, "Failed to negotiate options"
|
||||||
|
else:
|
||||||
|
raise TftpException, "No options found in OACK"
|
||||||
|
|
||||||
|
class TftpStateUpload(TftpState):
|
||||||
|
"""A class holding common code for upload states."""
|
||||||
|
def sendDat(self, resend=False):
|
||||||
|
finished = False
|
||||||
|
blocknumber = self.context.next_block
|
||||||
|
if not resend:
|
||||||
|
blksize = int(self.context.options['blksize'])
|
||||||
|
buffer = self.context.fileobj.read(blksize)
|
||||||
|
logger.debug("Read %d bytes into buffer" % len(buffer))
|
||||||
|
if len(buffer) < blksize:
|
||||||
|
logger.info("Reached EOF on file %s" % self.context.input)
|
||||||
|
finished = True
|
||||||
|
self.context.next_block += 1
|
||||||
|
self.bytes += len(buffer)
|
||||||
|
else:
|
||||||
|
logger.warn("Resending block number %d" % blocknumber)
|
||||||
|
dat = TftpPacketDAT()
|
||||||
|
dat.data = buffer
|
||||||
|
dat.blocknumber = blocknumber
|
||||||
|
logger.debug("Sending DAT packet %d" % blocknumber)
|
||||||
|
self.context.sock.sendto(dat.encode().buffer,
|
||||||
|
(self.context.host, self.context.port))
|
||||||
|
if self.context.packethook:
|
||||||
|
self.context.packethook(dat)
|
||||||
|
return finished
|
||||||
|
|
||||||
class TftpStateDownload(TftpState):
|
class TftpStateDownload(TftpState):
|
||||||
"""A class holding common code for download states."""
|
"""A class holding common code for download states."""
|
||||||
def handleDat(self, pkt):
|
def handleDat(self, pkt):
|
||||||
"""This method handles a DAT packet during a download."""
|
"""This method handles a DAT packet during a download."""
|
||||||
logger.info("handling DAT packet - block %d" % pkt.blocknumber)
|
logger.info("handling DAT packet - block %d" % pkt.blocknumber)
|
||||||
logger.debug("expecting block %s" % self.context.expected_block)
|
logger.debug("expecting block %s" % self.context.next_block)
|
||||||
if pkt.blocknumber == self.context.expected_block:
|
if pkt.blocknumber == self.context.next_block:
|
||||||
logger.debug("good, received block %d in sequence"
|
logger.debug("good, received block %d in sequence"
|
||||||
% pkt.blocknumber)
|
% pkt.blocknumber)
|
||||||
|
|
||||||
self.context.sendAck(pkt.blocknumber)
|
self.context.sendAck(pkt.blocknumber)
|
||||||
self.context.expected_block += 1
|
self.context.next_block += 1
|
||||||
|
|
||||||
logger.debug("writing %d bytes to output file"
|
logger.debug("writing %d bytes to output file"
|
||||||
% len(pkt.data))
|
% len(pkt.data))
|
||||||
|
@ -236,7 +327,7 @@ class TftpStateDownload(TftpState):
|
||||||
logger.info("end of file detected")
|
logger.info("end of file detected")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
elif pkt.blocknumber < self.context.expected_block:
|
elif pkt.blocknumber < self.context.next_block:
|
||||||
logger.warn("dropping duplicate block %d" % pkt.blocknumber)
|
logger.warn("dropping duplicate block %d" % pkt.blocknumber)
|
||||||
if self.context.metrics.dups.has_key(pkt.blocknumber):
|
if self.context.metrics.dups.has_key(pkt.blocknumber):
|
||||||
self.context.metrics.dups[pkt.blocknumber] += 1
|
self.context.metrics.dups[pkt.blocknumber] += 1
|
||||||
|
@ -251,16 +342,87 @@ class TftpStateDownload(TftpState):
|
||||||
else:
|
else:
|
||||||
# FIXME: should we be more tolerant and just discard instead?
|
# FIXME: should we be more tolerant and just discard instead?
|
||||||
msg = "Whoa! Received future block %d but expected %d" \
|
msg = "Whoa! Received future block %d but expected %d" \
|
||||||
% (pkt.blocknumber, self.context.expected_block)
|
% (pkt.blocknumber, self.context.next_block)
|
||||||
logger.error(msg)
|
logger.error(msg)
|
||||||
raise TftpException, msg
|
raise TftpException, msg
|
||||||
|
|
||||||
# Default is to ack
|
# Default is to ack
|
||||||
return TftpStateSentACK(self.context)
|
return TftpStateSentACK(self.context)
|
||||||
|
|
||||||
|
class TftpStateSentWRQ(TftpStateUpload):
|
||||||
|
"""Just sent an WRQ packet for an upload."""
|
||||||
|
def handle(self, pkt, raddress, rport):
|
||||||
|
"""Handle a packet we just received."""
|
||||||
|
if not self.context.tidport:
|
||||||
|
self.context.tidport = rport
|
||||||
|
logger.debug("Set remote port for session to %s" % rport)
|
||||||
|
|
||||||
|
# If we're going to successfully transfer the file, then we should see
|
||||||
|
# either an OACK for accepted options, or an ACK to ignore options.
|
||||||
|
if isinstance(pkt, TftpPacketOACK):
|
||||||
|
logger.info("received OACK from server")
|
||||||
|
try:
|
||||||
|
self.handleOACK(pkt)
|
||||||
|
except TftpException, err:
|
||||||
|
logger.error("failed to negotiate options")
|
||||||
|
self.context.sendError(TftpErrors.FailedNegotiation)
|
||||||
|
raise
|
||||||
|
else:
|
||||||
|
logger.debug("sending first DAT packet")
|
||||||
|
fin = self.context.sendDat()
|
||||||
|
if fin:
|
||||||
|
logger.info("Add done")
|
||||||
|
return None
|
||||||
|
else:
|
||||||
|
logger.debug("Changing state to TftpStateSentDAT")
|
||||||
|
return TftpStateSentDAT(self.context)
|
||||||
|
|
||||||
|
elif isinstance(pkt, TftpPacketACK):
|
||||||
|
logger.info("received ACK from server")
|
||||||
|
logger.debug("apparently the server ignored our options")
|
||||||
|
# The block number should be zero.
|
||||||
|
if pkt.blocknumber == 0:
|
||||||
|
logger.debug("ack blocknumber is zero as expected")
|
||||||
|
logger.debug("sending first DAT packet")
|
||||||
|
fin = self.context.sendDat()
|
||||||
|
if fin:
|
||||||
|
logger.info("Add done")
|
||||||
|
return None
|
||||||
|
else:
|
||||||
|
logger.debug("Changing state to TftpStateSentDAT")
|
||||||
|
return TftpStateSentDAT(self.context)
|
||||||
|
else:
|
||||||
|
logger.warn("discarding ACK to block %s" % pkt.blocknumber)
|
||||||
|
logger.debug("still waiting for valid response from server")
|
||||||
|
return self
|
||||||
|
|
||||||
|
elif isinstance(pkt, TftpPacketERR):
|
||||||
|
self.context.sendError(TftpErrors.IllegalTftpOp)
|
||||||
|
raise TftpException, "Received ERR from server: " + str(pkt)
|
||||||
|
|
||||||
|
elif isinstance(pkt, TftpPacketRRQ):
|
||||||
|
self.context.sendError(TftpErrors.IllegalTftpOp)
|
||||||
|
raise TftpException, "Received RRQ from server while in upload"
|
||||||
|
|
||||||
|
elif isinstance(pkt, TftpPacketDAT):
|
||||||
|
self.context.sendError(TftpErrors.IllegalTftpOp)
|
||||||
|
raise TftpException, "Received DAT from server while in upload"
|
||||||
|
|
||||||
|
else:
|
||||||
|
self.context.sendError(TftpErrors.IllegalTftpOp)
|
||||||
|
raise TftpException, "Received unknown packet type from server: " + str(pkt)
|
||||||
|
|
||||||
|
# By default, no state change.
|
||||||
|
return self
|
||||||
|
|
||||||
|
class TftpStateSentDAT(TftpStateUpload):
|
||||||
|
"""This class represents the state of the transfer when a DAT was just
|
||||||
|
sent, and we are waiting for an ACK from the server. This class is the
|
||||||
|
same one used by the client during the upload, and the server during the
|
||||||
|
download."""
|
||||||
|
|
||||||
class TftpStateSentRRQ(TftpStateDownload):
|
class TftpStateSentRRQ(TftpStateDownload):
|
||||||
"""Just sent an RRQ packet."""
|
"""Just sent an RRQ packet."""
|
||||||
|
|
||||||
def handle(self, pkt, raddress, rport):
|
def handle(self, pkt, raddress, rport):
|
||||||
"""Handle the packet in response to an RRQ to the server."""
|
"""Handle the packet in response to an RRQ to the server."""
|
||||||
if not self.context.tidport:
|
if not self.context.tidport:
|
||||||
|
@ -269,24 +431,20 @@ class TftpStateSentRRQ(TftpStateDownload):
|
||||||
|
|
||||||
# Now check the packet type and dispatch it properly.
|
# Now check the packet type and dispatch it properly.
|
||||||
if isinstance(pkt, TftpPacketOACK):
|
if isinstance(pkt, TftpPacketOACK):
|
||||||
logger.info("received OACK from server.")
|
logger.info("received OACK from server")
|
||||||
if pkt.options.keys() > 0:
|
try:
|
||||||
if pkt.match_options(self.context.options):
|
self.handleOACK(pkt)
|
||||||
logger.info("Successful negotiation of options")
|
except TftpException, err:
|
||||||
# Set options to OACK options
|
logger.error("failed to negotiate options: %s" % str(err))
|
||||||
self.context.options = pkt.options
|
self.context.sendError(TftpErrors.FailedNegotiation)
|
||||||
for key in self.context.options:
|
raise
|
||||||
logger.info(" %s = %s" % (key, self.context.options[key]))
|
else:
|
||||||
logger.debug("sending ACK to OACK")
|
logger.debug("sending ACK to OACK")
|
||||||
|
|
||||||
self.context.sendAck(blocknumber=0)
|
self.context.sendAck(blocknumber=0)
|
||||||
|
|
||||||
logger.debug("Changing state to TftpStateSentACK")
|
logger.debug("Changing state to TftpStateSentACK")
|
||||||
return TftpStateSentACK(self.context)
|
return TftpStateSentACK(self.context)
|
||||||
else:
|
|
||||||
logger.error("failed to negotiate options")
|
|
||||||
self.senderror(self.sock, TftpErrors.FailedNegotiation, self.host, self.port)
|
|
||||||
raise TftpException, "Failed to negotiate options"
|
|
||||||
|
|
||||||
elif isinstance(pkt, TftpPacketDAT):
|
elif isinstance(pkt, TftpPacketDAT):
|
||||||
# If there are any options set, then the server didn't honour any
|
# If there are any options set, then the server didn't honour any
|
||||||
|
@ -300,19 +458,19 @@ class TftpStateSentRRQ(TftpStateDownload):
|
||||||
# Every other packet type is a problem.
|
# Every other packet type is a problem.
|
||||||
elif isinstance(recvpkt, TftpPacketACK):
|
elif isinstance(recvpkt, TftpPacketACK):
|
||||||
# Umm, we ACK, the server doesn't.
|
# Umm, we ACK, the server doesn't.
|
||||||
self.senderror(self.sock, TftpErrors.IllegalTftpOp)
|
self.context.sendError(TftpErrors.IllegalTftpOp)
|
||||||
raise TftpException, "Received ACK from server while in download"
|
raise TftpException, "Received ACK from server while in download"
|
||||||
|
|
||||||
elif isinstance(recvpkt, TftpPacketWRQ):
|
elif isinstance(recvpkt, TftpPacketWRQ):
|
||||||
self.senderror(self.sock, TftpErrors.IllegalTftpOp)
|
self.context.sendError(TftpErrors.IllegalTftpOp)
|
||||||
raise TftpException, "Received WRQ from server while in download"
|
raise TftpException, "Received WRQ from server while in download"
|
||||||
|
|
||||||
elif isinstance(recvpkt, TftpPacketERR):
|
elif isinstance(recvpkt, TftpPacketERR):
|
||||||
self.senderror(self.sock, TftpErrors.IllegalTftpOp)
|
self.context.sendError(TftpErrors.IllegalTftpOp)
|
||||||
raise TftpException, "Received ERR from server: " + str(recvpkt)
|
raise TftpException, "Received ERR from server: " + str(recvpkt)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
self.senderror(self.sock, TftpErrors.IllegalTftpOp)
|
self.context.sendError(TftpErrors.IllegalTftpOp)
|
||||||
raise TftpException, "Received unknown packet type from server: " + str(recvpkt)
|
raise TftpException, "Received unknown packet type from server: " + str(recvpkt)
|
||||||
|
|
||||||
# By default, no state change.
|
# By default, no state change.
|
||||||
|
@ -328,17 +486,17 @@ class TftpStateSentACK(TftpStateDownload):
|
||||||
# Every other packet type is a problem.
|
# Every other packet type is a problem.
|
||||||
elif isinstance(recvpkt, TftpPacketACK):
|
elif isinstance(recvpkt, TftpPacketACK):
|
||||||
# Umm, we ACK, the server doesn't.
|
# Umm, we ACK, the server doesn't.
|
||||||
self.senderror(self.sock, TftpErrors.IllegalTftpOp)
|
self.context.sendError(TftpErrors.IllegalTftpOp)
|
||||||
raise TftpException, "Received ACK from server while in download"
|
raise TftpException, "Received ACK from server while in download"
|
||||||
|
|
||||||
elif isinstance(recvpkt, TftpPacketWRQ):
|
elif isinstance(recvpkt, TftpPacketWRQ):
|
||||||
self.senderror(self.sock, TftpErrors.IllegalTftpOp)
|
self.context.sendError(TftpErrors.IllegalTftpOp)
|
||||||
raise TftpException, "Received WRQ from server while in download"
|
raise TftpException, "Received WRQ from server while in download"
|
||||||
|
|
||||||
elif isinstance(recvpkt, TftpPacketERR):
|
elif isinstance(recvpkt, TftpPacketERR):
|
||||||
self.senderror(self.sock, TftpErrors.IllegalTftpOp)
|
self.context.sendError(TftpErrors.IllegalTftpOp)
|
||||||
raise TftpException, "Received ERR from server: " + str(recvpkt)
|
raise TftpException, "Received ERR from server: " + str(recvpkt)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
self.senderror(self.sock, TftpErrors.IllegalTftpOp)
|
self.context.sendError(TftpErrors.IllegalTftpOp)
|
||||||
raise TftpException, "Received unknown packet type from server: " + str(recvpkt)
|
raise TftpException, "Received unknown packet type from server: " + str(recvpkt)
|
||||||
|
|
Reference in New Issue