Merged upload patch.
commit
bd2e19529f
|
@ -17,6 +17,9 @@ def main():
|
|||
parser.add_option('-f',
|
||||
'--filename',
|
||||
help='filename to fetch')
|
||||
parser.add_option('-u',
|
||||
'--upload',
|
||||
help='filename to upload')
|
||||
parser.add_option('-b',
|
||||
'--blocksize',
|
||||
help='udp packet size to use (default: 512)',
|
||||
|
@ -24,6 +27,9 @@ def main():
|
|||
parser.add_option('-o',
|
||||
'--output',
|
||||
help='output file (default: same as requested filename)')
|
||||
parser.add_option('-i',
|
||||
'--input',
|
||||
help='input file (default: same as upload filename)')
|
||||
parser.add_option('-d',
|
||||
'--debug',
|
||||
action='store_true',
|
||||
|
@ -40,7 +46,7 @@ def main():
|
|||
default=False,
|
||||
help="ask client to send tsize option in download")
|
||||
options, args = parser.parse_args()
|
||||
if not options.host or not options.filename:
|
||||
if not options.host or (not options.filename and not options.upload):
|
||||
sys.stderr.write("Both the --host and --filename options "
|
||||
"are required.\n")
|
||||
parser.print_help()
|
||||
|
@ -52,9 +58,6 @@ def main():
|
|||
parser.print_help()
|
||||
sys.exit(1)
|
||||
|
||||
if not options.output:
|
||||
options.output = os.path.basename(options.filename)
|
||||
|
||||
class Progress(object):
|
||||
def __init__(self, out):
|
||||
self.progress = 0
|
||||
|
@ -62,7 +65,7 @@ def main():
|
|||
def progresshook(self, pkt):
|
||||
if isinstance(pkt, tftpy.TftpPacketDAT):
|
||||
self.progress += len(pkt.data)
|
||||
self.out("Downloaded %d bytes" % self.progress)
|
||||
self.out("Transferred %d bytes" % self.progress)
|
||||
elif isinstance(pkt, tftpy.TftpPacketOACK):
|
||||
self.out("Received OACK, options are: %s" % pkt.options)
|
||||
|
||||
|
@ -84,10 +87,18 @@ def main():
|
|||
tclient = tftpy.TftpClient(options.host,
|
||||
int(options.port),
|
||||
tftp_options)
|
||||
|
||||
tclient.download(options.filename,
|
||||
options.output,
|
||||
progresshook)
|
||||
if(options.filename):
|
||||
if not options.output:
|
||||
options.output = os.path.basename(options.filename)
|
||||
tclient.download(options.filename,
|
||||
options.output,
|
||||
progresshook)
|
||||
elif(options.upload):
|
||||
if not options.input:
|
||||
options.input = os.path.basename(options.upload)
|
||||
tclient.upload(options.upload,
|
||||
options.input,
|
||||
progresshook)
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
|
|
|
@ -11,7 +11,13 @@ class TftpClient(TftpSession):
|
|||
TftpSession.__init__(self)
|
||||
self.host = host
|
||||
self.iport = port
|
||||
self.filename = None
|
||||
self.options = options
|
||||
self.blocknumber = 0
|
||||
self.fileobj = None
|
||||
self.timesent = 0
|
||||
self.buffer = None
|
||||
self.bytes = 0
|
||||
if self.options.has_key('blksize'):
|
||||
size = self.options['blksize']
|
||||
tftpassert(types.IntType == type(size), "blksize must be an int")
|
||||
|
@ -20,21 +26,20 @@ class TftpClient(TftpSession):
|
|||
else:
|
||||
self.options['blksize'] = DEF_BLKSIZE
|
||||
# Support other options here? timeout time, retries, etc?
|
||||
|
||||
# The remote sending port, to identify the connection.
|
||||
self.port = None
|
||||
self.sock = None
|
||||
|
||||
|
||||
def gethost(self):
|
||||
"Simple getter method for use in a property."
|
||||
return self.__host
|
||||
|
||||
|
||||
def sethost(self, host):
|
||||
"""Setter method that also sets the address property as a result
|
||||
of the host that is set."""
|
||||
self.__host = host
|
||||
self.address = socket.gethostbyname(host)
|
||||
|
||||
|
||||
host = property(gethost, sethost)
|
||||
|
||||
def download(self, filename, output, packethook=None, timeout=SOCK_TIMEOUT):
|
||||
|
@ -49,12 +54,14 @@ class TftpClient(TftpSession):
|
|||
# Open the output file.
|
||||
# FIXME - need to support alternate return formats than files?
|
||||
# File-like objects would be ideal, ala duck-typing.
|
||||
outputfile = open(output, "wb")
|
||||
self.fileobj = open(output, "wb")
|
||||
recvpkt = None
|
||||
curblock = 0
|
||||
dups = {}
|
||||
start_time = time.time()
|
||||
bytes = 0
|
||||
self.bytes = 0
|
||||
|
||||
self.filename = filename
|
||||
|
||||
tftp_factory = TftpPacketFactory()
|
||||
self.sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
|
||||
|
@ -68,7 +75,7 @@ class TftpClient(TftpSession):
|
|||
pkt.options = self.options
|
||||
self.sock.sendto(pkt.encode().buffer, (self.host, self.iport))
|
||||
self.state.state = 'rrq'
|
||||
|
||||
|
||||
timeouts = 0
|
||||
while True:
|
||||
try:
|
||||
|
@ -83,9 +90,9 @@ class TftpClient(TftpSession):
|
|||
|
||||
recvpkt = tftp_factory.parse(buffer)
|
||||
|
||||
logger.debug("Received %d bytes from %s:%s"
|
||||
logger.debug("Received %d bytes from %s:%s"
|
||||
% (len(buffer), raddress, rport))
|
||||
|
||||
|
||||
# Check for known "connection".
|
||||
if raddress != self.address:
|
||||
logger.warn("Received traffic from %s, expected host %s. Discarding"
|
||||
|
@ -108,7 +115,7 @@ class TftpClient(TftpSession):
|
|||
if not self.port and self.state.state == 'rrq':
|
||||
self.port = rport
|
||||
logger.debug("Set remote port for session to %s" % rport)
|
||||
|
||||
|
||||
if isinstance(recvpkt, TftpPacketDAT):
|
||||
logger.debug("recvpkt.blocknumber = %d" % recvpkt.blocknumber)
|
||||
logger.debug("curblock = %d" % curblock)
|
||||
|
@ -117,22 +124,22 @@ class TftpClient(TftpSession):
|
|||
logger.debug("block number rollover to 0 again")
|
||||
expected_block = 0
|
||||
if recvpkt.blocknumber == expected_block:
|
||||
logger.debug("good, received block %d in sequence"
|
||||
logger.debug("good, received block %d in sequence"
|
||||
% recvpkt.blocknumber)
|
||||
curblock = expected_block
|
||||
|
||||
|
||||
|
||||
# ACK the packet, and save the data.
|
||||
logger.info("sending ACK to block %d" % curblock)
|
||||
logger.debug("ip = %s, port = %s" % (self.host, self.port))
|
||||
ackpkt = TftpPacketACK()
|
||||
ackpkt.blocknumber = curblock
|
||||
self.sock.sendto(ackpkt.encode().buffer, (self.host, self.port))
|
||||
|
||||
logger.debug("writing %d bytes to output file"
|
||||
|
||||
logger.debug("writing %d bytes to output file"
|
||||
% len(recvpkt.data))
|
||||
outputfile.write(recvpkt.data)
|
||||
bytes += len(recvpkt.data)
|
||||
self.fileobj.write(recvpkt.data)
|
||||
self.bytes += len(recvpkt.data)
|
||||
# Check for end-of-file, any less than full data packet.
|
||||
if len(recvpkt.data) < int(self.options['blksize']):
|
||||
logger.info("end of file detected")
|
||||
|
@ -152,7 +159,7 @@ class TftpClient(TftpSession):
|
|||
self.sock.sendto(ackpkt.encode().buffer, (self.host, self.port))
|
||||
|
||||
else:
|
||||
msg = "Whoa! Received block %d but expected %d" % (recvpkt.blocknumber,
|
||||
msg = "Whoa! Received block %d but expected %d" % (recvpkt.blocknumber,
|
||||
curblock+1)
|
||||
logger.error(msg)
|
||||
raise TftpException, msg
|
||||
|
@ -163,7 +170,7 @@ class TftpClient(TftpSession):
|
|||
self.errors += 1
|
||||
logger.error("Received OACK in state %s" % self.state.state)
|
||||
continue
|
||||
|
||||
|
||||
self.state.state = 'oack'
|
||||
logger.info("Received OACK from server.")
|
||||
if recvpkt.options.keys() > 0:
|
||||
|
@ -208,7 +215,7 @@ class TftpClient(TftpSession):
|
|||
|
||||
|
||||
# end while
|
||||
outputfile.close()
|
||||
self.fileobj.close()
|
||||
|
||||
end_time = time.time()
|
||||
duration = end_time - start_time
|
||||
|
@ -216,11 +223,196 @@ class TftpClient(TftpSession):
|
|||
logger.info("Duration too short, rate undetermined")
|
||||
else:
|
||||
logger.info('')
|
||||
logger.info("Downloaded %d bytes in %d seconds" % (bytes, duration))
|
||||
bps = (bytes * 8.0) / duration
|
||||
logger.info("Downloaded %d bytes in %d seconds" % (self.bytes, duration))
|
||||
bps = (self.bytes * 8.0) / duration
|
||||
kbps = bps / 1024.0
|
||||
logger.info("Average rate: %.2f kbps" % kbps)
|
||||
dupcount = 0
|
||||
for key in dups:
|
||||
dupcount += dups[key]
|
||||
logger.info("Received %d duplicate packets" % dupcount)
|
||||
|
||||
def upload(self, filename, input, packethook=None, timeout=SOCK_TIMEOUT):
|
||||
# Open the input file.
|
||||
self.fileobj = open(input, "rb")
|
||||
recvpkt = None
|
||||
curblock = 0
|
||||
start_time = time.time()
|
||||
self.bytes = 0
|
||||
|
||||
tftp_factory = TftpPacketFactory()
|
||||
self.sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
|
||||
self.sock.settimeout(timeout)
|
||||
|
||||
self.filename = filename
|
||||
|
||||
self.send_wrq()
|
||||
self.state.state = 'wrq'
|
||||
|
||||
timeouts = 0
|
||||
while True:
|
||||
try:
|
||||
(buffer, (raddress, rport)) = self.sock.recvfrom(MAX_BLKSIZE)
|
||||
except socket.timeout, err:
|
||||
timeouts += 1
|
||||
if timeouts >= TIMEOUT_RETRIES:
|
||||
raise TftpException, "Hit max timeouts, giving up."
|
||||
else:
|
||||
if self.state.state == 'dat' or self.state.state == 'fin':
|
||||
logger.debug("Timing out on DAT. Need to resend.")
|
||||
self.send_dat(packethook,resend=True)
|
||||
elif self.state.state == 'wrq':
|
||||
logger.debug("Timing out on WRQ.")
|
||||
self.send_wrq(resend=True)
|
||||
else:
|
||||
tftpassert(False,
|
||||
"Timing out in unsupported state %s" %
|
||||
self.state.state)
|
||||
continue
|
||||
|
||||
recvpkt = tftp_factory.parse(buffer)
|
||||
|
||||
logger.debug("Received %d bytes from %s:%s"
|
||||
% (len(buffer), raddress, rport))
|
||||
|
||||
# Check for known "connection".
|
||||
if raddress != self.address:
|
||||
logger.warn("Received traffic from %s, expected host %s. Discarding"
|
||||
% (raddress, self.host))
|
||||
continue
|
||||
if self.port and self.port != rport:
|
||||
logger.warn("Received traffic from %s:%s but we're "
|
||||
"connected to %s:%s. Discarding."
|
||||
% (raddress, rport,
|
||||
self.host, self.port))
|
||||
continue
|
||||
|
||||
if not self.port and self.state.state == 'wrq':
|
||||
self.port = rport
|
||||
logger.debug("Set remote port for session to %s" % rport)
|
||||
|
||||
# Next packet type
|
||||
if isinstance(recvpkt, TftpPacketACK):
|
||||
logger.debug("Received an ACK from the server.")
|
||||
# tftp on wrt54gl seems to answer with an ack to a wrq regardless
|
||||
# if we sent options.
|
||||
if recvpkt.blocknumber == 0 and self.state.state in ('oack','wrq'):
|
||||
logger.debug("Received ACK with 0 blocknumber, starting upload")
|
||||
self.state.state = 'dat'
|
||||
self.send_dat(packethook)
|
||||
else:
|
||||
if self.state.state == 'dat' or self.state.state == 'fin':
|
||||
if self.blocknumber == recvpkt.blocknumber:
|
||||
logger.info("Received ACK for block %d"
|
||||
% recvpkt.blocknumber)
|
||||
if self.state.state == 'fin':
|
||||
break
|
||||
else:
|
||||
self.send_dat(packethook)
|
||||
elif recvpkt.blocknumber < self.blocknumber:
|
||||
# Don't resend a DAT due to an old ACK. Fixes the
|
||||
# sorceror's apprentice problem.
|
||||
logger.warn("Received old ACK for block number %d"
|
||||
% recvpkt.blocknumber)
|
||||
else:
|
||||
logger.warn("Received ACK for block number "
|
||||
"%d, apparently from the future"
|
||||
% recvpkt.blocknumber)
|
||||
else:
|
||||
logger.error("Received ACK with block number %d "
|
||||
"while in state %s"
|
||||
% (recvpkt.blocknumber,
|
||||
self.state.state))
|
||||
|
||||
# Check other packet types.
|
||||
elif isinstance(recvpkt, TftpPacketOACK):
|
||||
if not self.state.state == 'wrq':
|
||||
self.errors += 1
|
||||
logger.error("Received OACK in state %s" % self.state.state)
|
||||
continue
|
||||
|
||||
self.state.state = 'oack'
|
||||
logger.info("Received OACK from server.")
|
||||
if recvpkt.options.keys() > 0:
|
||||
if recvpkt.match_options(self.options):
|
||||
logger.info("Successful negotiation of options")
|
||||
for key in self.options:
|
||||
logger.info(" %s = %s" % (key, self.options[key]))
|
||||
logger.debug("sending ACK to OACK")
|
||||
ackpkt = TftpPacketACK()
|
||||
ackpkt.blocknumber = 0
|
||||
self.sock.sendto(ackpkt.encode().buffer, (self.host, self.port))
|
||||
self.state.state = 'dat'
|
||||
self.send_dat(packethook)
|
||||
else:
|
||||
logger.error("failed to negotiate options")
|
||||
self.senderror(self.sock, TftpErrors.FailedNegotiation, self.host, self.port)
|
||||
self.state.state = 'err'
|
||||
raise TftpException, "Failed to negotiate options"
|
||||
|
||||
elif isinstance(recvpkt, TftpPacketERR):
|
||||
self.state.state = 'err'
|
||||
self.senderror(self.sock, TftpErrors.IllegalTftpOp, self.host, self.port)
|
||||
tftpassert(False, "Received ERR from server: " + str(recvpkt))
|
||||
|
||||
elif isinstance(recvpkt, TftpPacketWRQ):
|
||||
self.state.state = 'err'
|
||||
self.senderror(self.sock, TftpErrors.IllegalTftpOp, self.host, self.port)
|
||||
tftpassert(False, "Received WRQ from server: " + str(recvpkt))
|
||||
|
||||
else:
|
||||
self.state.state = 'err'
|
||||
self.senderror(self.sock, TftpErrors.IllegalTftpOp, self.host, self.port)
|
||||
tftpassert(False, "Received unknown packet type from server: "
|
||||
+ str(recvpkt))
|
||||
|
||||
|
||||
# end while
|
||||
self.fileobj.close()
|
||||
|
||||
end_time = time.time()
|
||||
duration = end_time - start_time
|
||||
if duration == 0:
|
||||
logger.info("Duration too short, rate undetermined")
|
||||
else:
|
||||
logger.info('')
|
||||
logger.info("Uploaded %d bytes in %d seconds" % (self.bytes, duration))
|
||||
bps = (self.bytes * 8.0) / duration
|
||||
kbps = bps / 1024.0
|
||||
logger.info("Average rate: %.2f kbps" % kbps)
|
||||
|
||||
def send_dat(self, packethook, resend=False):
|
||||
"""This method reads and sends a DAT packet based on what is in self.buffer."""
|
||||
if not resend:
|
||||
blksize = int(self.options['blksize'])
|
||||
self.buffer = self.fileobj.read(blksize)
|
||||
logger.debug("Read %d bytes into buffer" % len(self.buffer))
|
||||
if len(self.buffer) < blksize:
|
||||
logger.info("Reached EOF on file %s" % self.filename)
|
||||
self.state.state = 'fin'
|
||||
self.blocknumber += 1
|
||||
if self.blocknumber > 65535:
|
||||
logger.debug("Blocknumber rolled over to zero")
|
||||
self.blocknumber = 0
|
||||
self.bytes += len(self.buffer)
|
||||
else:
|
||||
logger.warn("Resending block number %d" % self.blocknumber)
|
||||
dat = TftpPacketDAT()
|
||||
dat.data = self.buffer
|
||||
dat.blocknumber = self.blocknumber
|
||||
logger.debug("Sending DAT packet %d" % self.blocknumber)
|
||||
self.sock.sendto(dat.encode().buffer, (self.host, self.port))
|
||||
self.timesent = time.time()
|
||||
if packethook:
|
||||
packethook(dat)
|
||||
|
||||
def send_wrq(self, resend=False):
|
||||
"""This method sends a wrq"""
|
||||
logger.info("Sending tftp upload request to %s" % self.host)
|
||||
logger.info(" filename -> %s" % self.filename)
|
||||
|
||||
wrq = TftpPacketWRQ()
|
||||
wrq.filename = self.filename
|
||||
wrq.mode = "octet" # FIXME - shouldn't hardcode this
|
||||
wrq.options = self.options
|
||||
self.sock.sendto(wrq.encode().buffer, (self.host, self.iport))
|
||||
|
|
Reference in New Issue