bscp.py 6.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177
  1. #!/usr/bin/python3
  2. # Copyright (C) 2012-2016
  3. #
  4. # * Volker Diels-Grabsch <v@njh.eu>
  5. # * art0int <zvn_mail@mail.ru>
  6. # * guillaume <guillaume@atto.be>
  7. #
  8. # Permission to use, copy, modify, and/or distribute this software for any
  9. # purpose with or without fee is hereby granted, provided that the above
  10. # copyright notice and this permission notice appear in all copies.
  11. #
  12. # THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
  13. # WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
  14. # MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR
  15. # ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
  16. # WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
  17. # ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF
  18. # OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
  19. import hashlib
  20. import os
  21. import os.path
  22. import struct
  23. import subprocess
  24. import sys
  25. # device_function: filters or transforms device names
  26. device_function = lambda dev : dev
  27. # remote_command: command to run bscp on the remote host
  28. remote_command = __file__
  29. def serve():
  30. global allowed_devices
  31. (size, blocksize, filename_len, hashname_len) = struct.unpack('<QQQQ', sys.stdin.buffer.read(8+8+8+8))
  32. filename = sys.stdin.buffer.read(filename_len).decode()
  33. hashname = sys.stdin.buffer.read(hashname_len).decode()
  34. filename = device_function(filename)
  35. if filename == None or len(filename) == 0:
  36. return
  37. if not os.path.exists(filename):
  38. # Create sparse file
  39. with open(filename, 'wb') as f:
  40. f.truncate(size)
  41. os.chmod(filename, 0o600)
  42. with open(filename, 'rb+') as f:
  43. f.seek(0, 2)
  44. sys.stdout.buffer.write(struct.pack('<Q', f.tell()))
  45. readremain = size
  46. rblocksize = blocksize
  47. f.seek(0)
  48. while True:
  49. if readremain <= blocksize:
  50. rblocksize = readremain
  51. block = f.read(rblocksize)
  52. if len(block) == 0:
  53. break
  54. digest = hashlib.new(hashname, block).digest()
  55. sys.stdout.buffer.write(digest)
  56. readremain -= rblocksize
  57. if readremain == 0:
  58. break
  59. sys.stdout.flush()
  60. while True:
  61. position_s = sys.stdin.buffer.read(8)
  62. if len(position_s) == 0:
  63. break
  64. (position,) = struct.unpack('<Q', position_s)
  65. block = sys.stdin.buffer.read(blocksize)
  66. f.seek(position)
  67. f.write(block)
  68. readremain = size
  69. rblocksize = blocksize
  70. hash_total = hashlib.new(hashname)
  71. f.seek(0)
  72. while True:
  73. if readremain <= blocksize:
  74. rblocksize = readremain
  75. block = f.read(rblocksize)
  76. if len(block) == 0:
  77. break
  78. hash_total.update(block)
  79. readremain -= rblocksize
  80. if readremain == 0:
  81. break
  82. sys.stdout.buffer.write(hash_total.digest())
  83. class IOCounter:
  84. def __init__(self, in_stream, out_stream):
  85. self.in_stream = in_stream
  86. self.out_stream = out_stream
  87. self.in_total = 0
  88. self.out_total = 0
  89. def read(self, size=None):
  90. if size is None:
  91. s = self.in_stream.read()
  92. else:
  93. s = self.in_stream.read(size)
  94. self.in_total += len(s)
  95. return s
  96. def write(self, s):
  97. self.out_stream.write(s)
  98. self.out_total += len(s)
  99. def flush(self):
  100. self.out_stream.flush()
  101. def bscp(local_filename, remote_host, remote_filename, blocksize, hashname):
  102. hash_total = hashlib.new(hashname)
  103. with open(local_filename, 'rb') as f:
  104. f.seek(0, 2)
  105. size = f.tell()
  106. f.seek(0)
  107. # Calculate number of blocks, including the last block which may be smaller
  108. blockcount = int((size + blocksize - 1) / blocksize)
  109. command = ('ssh', remote_host, '--', remote_command)
  110. p = subprocess.Popen(command, stdin=subprocess.PIPE, stdout=subprocess.PIPE, stderr=subprocess.STDOUT)
  111. io = IOCounter(p.stdout, p.stdin)
  112. remote_filename_bytes = remote_filename.encode()
  113. hashname_bytes = hashname.encode()
  114. io.write(struct.pack('<QQQQ', size, blocksize, len(remote_filename_bytes), len(hashname_bytes)))
  115. io.write(remote_filename_bytes)
  116. io.write(hashname_bytes)
  117. io.flush()
  118. (remote_size,) = struct.unpack('<Q', io.read(8))
  119. if remote_size < size:
  120. raise RuntimeError('Remote size less than local (local: %i, remote: %i)' % (size, remote_size))
  121. remote_digest_list = [io.read(hash_total.digest_size) for i in range(blockcount)]
  122. for remote_digest in remote_digest_list:
  123. position = f.tell()
  124. block = f.read(blocksize)
  125. hash_total.update(block)
  126. digest = hashlib.new(hashname, block).digest()
  127. if digest != remote_digest:
  128. try:
  129. io.write(struct.pack('<Q', position))
  130. io.write(block)
  131. except IOError:
  132. break
  133. io.flush()
  134. p.stdin.close()
  135. remote_digest_total = io.read()
  136. p.wait()
  137. if remote_digest_total != hash_total.digest():
  138. raise RuntimeError('Checksum mismatch after transfer')
  139. return (io.in_total, io.out_total, size)
  140. if __name__ == '__main__':
  141. try:
  142. local_filename = sys.argv[1]
  143. (remote_host, remote_filename) = sys.argv[2].split(':')
  144. if len(sys.argv) >= 4:
  145. blocksize = int(sys.argv[3])
  146. else:
  147. blocksize = 64 * 1024
  148. if len(sys.argv) >= 5:
  149. hashname = sys.argv[4]
  150. else:
  151. hashname = 'sha1'
  152. assert len(sys.argv) <= 5
  153. except:
  154. usage = 'bscp SRC HOST:DEST [BLOCKSIZE] [HASH]'
  155. sys.stderr.write('Usage:\n\n %s\n\n' % (usage,))
  156. sys.exit(1)
  157. (in_total, out_total, size) = bscp(local_filename, remote_host, remote_filename, blocksize, hashname)
  158. speedup = size * 1.0 / (in_total + out_total)
  159. sys.stderr.write('in=%i out=%i size=%i speedup=%.2f\n' % (in_total, out_total, size, speedup))