Skip to content

Commit 28cc8dc

Browse files
authored
_put_file: upload file atomically (#129)
1 parent e76d09f commit 28cc8dc

File tree

3 files changed

+31
-12
lines changed

3 files changed

+31
-12
lines changed

dvc_ssh/__init__.py

Lines changed: 2 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -105,22 +105,13 @@ def _prepare_credentials(self, **config):
105105
@wrap_prop(threading.Lock())
106106
@cached_property
107107
def fs(self):
108-
from sshfs import SSHFileSystem as _SSHFileSystem
108+
from . import spec
109109

110-
return _SSHFileSystem(**self.fs_args)
110+
return spec.SSHFileSystem(**self.fs_args)
111111

112112
# Ensure that if an interrupt happens during the transfer, we don't
113113
# pollute the cache.
114114

115115
def upload_fobj(self, fobj, to_info, **kwargs):
116116
with as_atomic(self, to_info) as tmp_file:
117117
super().upload_fobj(fobj, tmp_file, **kwargs)
118-
119-
def put_file(
120-
self,
121-
from_file,
122-
to_info,
123-
**kwargs,
124-
):
125-
with as_atomic(self, to_info) as tmp_file:
126-
super().put_file(from_file, tmp_file, **kwargs)

dvc_ssh/client.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ async def public_key_auth_requested(
7575
pubkey: Optional[SSHKey] = read_public_key(pubkey_to_load)
7676
except (OSError, KeyImportError):
7777
pubkey = None
78-
return SSHLocalKeyPair(key, pubkey)
78+
return SSHLocalKeyPair(key, pubkey, cert=None, enc_key=None)
7979
return None
8080

8181
async def _read_private_key_interactive(self, path: "FilePath") -> "SSHKey":

dvc_ssh/spec.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
import posixpath
2+
from secrets import token_urlsafe
3+
4+
from sshfs import SSHFileSystem as _SSHFileSystem
5+
6+
7+
def tmp_fname(prefix: str = "") -> str:
8+
"""Temporary name for a partial download"""
9+
return f"{prefix}.{token_urlsafe(16)}.tmp"
10+
11+
12+
class SSHFileSystem(_SSHFileSystem):
13+
async def _put_file(self, lpath, rpath, *args, **kwargs):
14+
parent = posixpath.dirname(rpath)
15+
tmp_info = posixpath.join(parent, tmp_fname())
16+
try:
17+
await super()._put_file(lpath, tmp_info, *args, **kwargs)
18+
except BaseException:
19+
# Handle stuff like KeyboardInterrupt
20+
# as well as other errors that might
21+
# arise during file transfer.
22+
try:
23+
await self._rm_file(tmp_info)
24+
except FileNotFoundError:
25+
pass
26+
raise
27+
else:
28+
await self._mv(tmp_info, rpath)

0 commit comments

Comments
 (0)