Skip to content

Commit

Permalink
f-sa
Browse files Browse the repository at this point in the history
  • Loading branch information
unkcpz committed Oct 17, 2023
1 parent 1184f99 commit 8ae9ec1
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 39 deletions.
64 changes: 28 additions & 36 deletions aiidalab_widgets_base/computational_resources.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,15 @@
import copy
import enum
import os
import re
import subprocess
import threading
from collections import namedtuple
from pathlib import Path
from uuid import UUID

import ipywidgets as ipw
import jinja2
import pexpect
import shortuuid
import traitlets as tl
Expand All @@ -18,7 +20,7 @@
from aiida.transports.plugins import ssh as aiida_ssh_plugin
from humanfriendly import InvalidSize, parse_size
from IPython.display import clear_output, display
from jinja2 import Environment, meta
from jinja2 import meta as jinja2_meta

from .databases import ComputationalResourcesDatabaseWidget
from .utils import MessageLevel, StatusHTML, wrap_message
Expand Down Expand Up @@ -236,16 +238,17 @@ class SshComputerSetup(ipw.VBox):
message = tl.Unicode()
password_message = tl.Unicode("The passwordless enabling log.")

def __init__(self, ssh_folder=None, **kwargs):
def __init__(self, ssh_folder: Path | None = None, **kwargs):
"""Setup a passwordless access to a computer."""
# ssh folder init
if ssh_folder is None:
ssh_folder = Path.home() / ".ssh"
if not ssh_folder.exists():
ssh_folder.mkdir()
ssh_folder.chmod(0o700)

self.ssh_folder = ssh_folder
if not ssh_folder.exists():
ssh_folder.mkdir()
ssh_folder.chmod(0o700)

self._ssh_folder = ssh_folder

self._ssh_connection_message = None
self._password_message = ipw.HTML()
Expand Down Expand Up @@ -277,11 +280,11 @@ def __init__(self, ssh_folder=None, **kwargs):
)

# Username.
self.username = ipw.Text(description="username:", layout=LAYOUT, style=STYLE)
self.username = ipw.Text(description="Username:", layout=LAYOUT, style=STYLE)

# Port.
self.port = ipw.IntText(
description="port:",
description="Port:",
value=22,
layout=LAYOUT,
style=STYLE,
Expand Down Expand Up @@ -357,7 +360,7 @@ def _ssh_keygen(self):
"Generating SSH key pair.",
MessageLevel.SUCCESS,
)
fpath = Path.home() / ".ssh" / "id_rsa"
fpath = self._ssh_folder / "id_rsa"
keygen_cmd = [
"ssh-keygen",
"-f",
Expand Down Expand Up @@ -393,7 +396,7 @@ def _can_login(self):

def _is_in_config(self):
"""Check if the SSH config file contains host information."""
config_path = self.ssh_folder / "config"
config_path = self._ssh_folder / "config"
if not config_path.exists():
return False
sshcfg = aiida_ssh_plugin.parse_sshconfig(self.hostname.value)
Expand All @@ -406,7 +409,7 @@ def _is_in_config(self):

def _write_ssh_config(self, private_key_abs_fname=None):
"""Put host information into the config file."""
config_path = self.ssh_folder / "config"
config_path = self._ssh_folder / "config"

self.message = wrap_message(
f"Adding {self.hostname.value} section to {config_path}",
Expand Down Expand Up @@ -483,12 +486,10 @@ def _on_setup_ssh_button_pressed(self, _=None):

# if the private key filename is exist, generate random string and append to filename subfix
# then override current name.
if filename in [str(p.name) for p in Path(self.ssh_folder).iterdir()]:
private_key_fname = str(
Path(self.ssh_folder) / f"{filename}-{shortuuid.uuid()}"
)
if filename in [str(p.name) for p in Path(self._ssh_folder).iterdir()]:
private_key_fpath = self._ssh_folder / f"{filename}-{shortuuid.uuid()}"

self._add_private_key(private_key_fname, private_key_content)
self._add_private_key(private_key_fpath, private_key_content)

# TODO(danielhollas): I am not sure this is correct. What if the user wants
# to overwrite the private key? Or any other config? The configuration would never be written.
Expand Down Expand Up @@ -600,7 +601,7 @@ def _on_verification_mode_change(self, change):
if self._verification_mode.value == "private_key":
display(self._inp_private_key)
elif self._verification_mode.value == "public_key":
public_key = Path.home() / ".ssh" / "id_rsa.pub"
public_key = self._ssh_folder / "id_rsa.pub"
if public_key.exists():
display(
ipw.HTML(
Expand All @@ -610,7 +611,7 @@ def _on_verification_mode_change(self, change):
)

@property
def _private_key(self):
def _private_key(self) -> tuple[str | None, bytes | None]:
"""Unwrap private key file and setting filename and file content."""
if self._inp_private_key.value:
(fname, _value), *_ = self._inp_private_key.value.items()
Expand All @@ -621,17 +622,10 @@ def _private_key(self):
return None, None

@staticmethod
def _add_private_key(private_key_fname, private_key_content):
"""
param private_key_fname: string
param private_key_content: bytes
"""
fpath = Path.home() / ".ssh" / private_key_fname
fpath.write_bytes(private_key_content)

fpath.chmod(0o600)

return fpath
def _add_private_key(private_key_fpath: Path, private_key_content: bytes):
"""Write private key to the private key file in the ssh folder."""
private_key_fpath.write_bytes(private_key_content)
private_key_fpath.chmod(0o600)

def _reset(self):
self.hostname.value = ""
Expand Down Expand Up @@ -1248,8 +1242,6 @@ def _observe_code_setup(self, _=None):
try:
self.default_calc_job_plugin.value = value
except tl.TraitError:
import re

# If is a template then don't raise the error message.
if not re.match(r".*{{.+}}.*", value):
self.message = wrap_message(
Expand Down Expand Up @@ -1431,11 +1423,11 @@ def _render(self):
self._help_text.value = f"""<div>{tooltip}</div>"""

for line_key, line_str in self.templates.items():
env = Environment()
env = jinja2.Environment()
parsed_content = env.parse(line_str)

# vars is a set of variables in the template
line_vars = meta.find_undeclared_variables(parsed_content)
line_vars = jinja2_meta.find_undeclared_variables(parsed_content)

# Create a widget for each variable.
# The var is the name in a template string
Expand Down Expand Up @@ -1518,7 +1510,7 @@ def _on_template_variable_filled(self, change):
}

# re-render the template
env = Environment()
env = jinja2.Environment()
filled_str = env.from_string(line.str).render(**inp_dict)

# Update the filled template.
Expand Down Expand Up @@ -1890,9 +1882,9 @@ def _fill_template(self):
filled_templates = copy.deepcopy(w_tmp.filled_templates)

for k, v in w_tmp.filled_templates.items():
env = Environment()
env = jinja2.Environment()
parsed_content = env.parse(v)
vs = meta.find_undeclared_variables(parsed_content)
vs = jinja2_meta.find_undeclared_variables(parsed_content)

# No variables in the template, all filled.
if len(vs) == 0:
Expand Down
7 changes: 4 additions & 3 deletions tests/test_computational_resources.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,9 +52,10 @@ def test_ssh_computer_setup_widget(monkeypatch, tmp_path):
widget._ssh_keygen()

# Create non-default private key file.
fpath = widget._add_private_key("my_key_name", b"my_key_content")
assert fpath.exists()
with open(fpath) as f:
private_key_path = tmp_path / ".ssh" / "my_key_name"
widget._add_private_key(private_key_path, b"my_key_content")
assert private_key_path.exists()
with open(private_key_path) as f:
assert f.read() == "my_key_content"

# set private key with same name to trigger the rename operation
Expand Down

0 comments on commit 8ae9ec1

Please sign in to comment.