WIP: MIC: VmTests: Add SSH connection
This commit is contained in:
Родитель
b005e8b6b2
Коммит
4747e29bf2
|
@ -1,5 +1,6 @@
|
|||
docker == 7.1.0
|
||||
libvirt-python == 10.9.0
|
||||
paramiko == 3.5.0
|
||||
pytest == 8.3.3
|
||||
PyYAML == 6.0.2
|
||||
types-docker == 7.1.0.20240827
|
||||
|
|
|
@ -6,4 +6,5 @@ black == 24.8.0
|
|||
flake8 == 7.1.0
|
||||
isort == 5.13.2
|
||||
mypy == 1.13.0
|
||||
types-paramiko == 3.5.0.20240928
|
||||
types-PyYAML == 6.0.12.20240917
|
||||
|
|
|
@ -11,7 +11,8 @@ from docker import DockerClient
|
|||
|
||||
from .conftest import TEST_CONFIGS_DIR
|
||||
from .utils.imagecustomizer import run_image_customizer
|
||||
from .utils.libvirt_utils import VmSpec, create_libvirt_domain_xml
|
||||
from .utils.libvirt_utils import VmSpec, create_libvirt_domain_xml, get_vm_ip_address
|
||||
from .utils.ssh_client import SshClient
|
||||
|
||||
|
||||
def test_no_change(
|
||||
|
@ -73,3 +74,13 @@ def test_no_change(
|
|||
|
||||
# Start the VM.
|
||||
domain.resume()
|
||||
|
||||
vm_ip_address = get_vm_ip_address(libvirt_conn, vm_name, timeout=15)
|
||||
|
||||
# All the VM are know to be new.
|
||||
# So, it is fine using an empty known_hosts file.
|
||||
ssh_known_hosts_path = test_temp_dir.joinpath("known_hosts")
|
||||
open(ssh_known_hosts_path, 'w').close()
|
||||
|
||||
with SshClient(vm_ip_address, key_path=ssh_private_key_path, known_hosts_path=ssh_known_hosts_path) as vm_ssh:
|
||||
vm_ssh.run("cat /etc/os-release")
|
||||
|
|
|
@ -1,6 +1,9 @@
|
|||
import time
|
||||
import xml.etree.ElementTree as ET # noqa: N817
|
||||
from pathlib import Path
|
||||
from typing import Dict
|
||||
from typing import Dict, Optional
|
||||
|
||||
import libvirt # type: ignore
|
||||
|
||||
|
||||
class VmSpec:
|
||||
|
@ -150,3 +153,42 @@ def _gen_disk_device_name(prefix: str, next_disk_indexes: Dict[str, int]) -> str
|
|||
|
||||
case _:
|
||||
return f"{prefix}{disk_index}"
|
||||
|
||||
|
||||
# Wait for the VM to boot and then get the IP address.
|
||||
def get_vm_ip_address(
|
||||
libvirt_conn: libvirt.virConnect,
|
||||
vm_name: str,
|
||||
timeout: float,
|
||||
) -> str:
|
||||
timeout_time = time.time() + timeout
|
||||
|
||||
while True:
|
||||
addr = try_get_vm_ip_address(libvirt_conn, vm_name)
|
||||
if addr:
|
||||
return addr
|
||||
|
||||
if time.time() > timeout_time:
|
||||
raise Exception(f"No IP addresses found for '{vm_name}'. OS might have failed to boot.")
|
||||
|
||||
|
||||
# Try to get the IP address of the VM.
|
||||
def try_get_vm_ip_address(
|
||||
libvirt_conn: libvirt.virConnect,
|
||||
vm_name: str,
|
||||
) -> Optional[str]:
|
||||
domain = libvirt_conn.lookupByName(vm_name)
|
||||
|
||||
# Acquire IP address from libvirt's DHCP server.
|
||||
interfaces = domain.interfaceAddresses(libvirt.VIR_DOMAIN_INTERFACE_ADDRESSES_SRC_LEASE)
|
||||
if len(interfaces) < 1:
|
||||
return None
|
||||
|
||||
interface_name = next(iter(interfaces))
|
||||
addrs = interfaces[interface_name]["addrs"]
|
||||
if len(addrs) < 1:
|
||||
return None
|
||||
|
||||
addr = addrs[0]["addr"]
|
||||
assert isinstance(addr, str)
|
||||
return addr
|
||||
|
|
|
@ -0,0 +1,246 @@
|
|||
import logging
|
||||
import shlex
|
||||
import time
|
||||
from io import StringIO
|
||||
from pathlib import Path
|
||||
from threading import Thread
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
from paramiko import SSHClient, AutoAddPolicy
|
||||
from paramiko.channel import ChannelFile, ChannelStderrFile
|
||||
|
||||
|
||||
class SshExecutableResult:
|
||||
def __init__(
|
||||
self,
|
||||
stdout: str,
|
||||
stderr: str,
|
||||
exit_code: Optional[int],
|
||||
cmd: Union[str, List[str]],
|
||||
elapsed: float,
|
||||
is_timeout: bool,
|
||||
) -> None:
|
||||
self.stdout = stdout
|
||||
self.stderr = stderr
|
||||
self.exit_code = exit_code
|
||||
self.cmd = cmd
|
||||
self.elapsed = elapsed
|
||||
self.is_timeout = is_timeout
|
||||
|
||||
def check_exit_code(self) -> None:
|
||||
if self.is_timeout:
|
||||
raise Exception("SSH process timed out")
|
||||
|
||||
elif self.exit_code is not None and self.exit_code != 0:
|
||||
raise Exception(f"SSH process failed with exit code: {self.exit_code}")
|
||||
|
||||
|
||||
class _SshChannelFileReader:
|
||||
def __init__(self, channel_file: ChannelFile, log_level: int, log_name: str) -> None:
|
||||
self._channel_file = channel_file
|
||||
self._log_level = log_level
|
||||
self._log_name = log_name
|
||||
self._output: Optional[str] = None
|
||||
|
||||
self._thread: Thread = Thread(target=self._read_thread)
|
||||
self._thread.start()
|
||||
|
||||
def close(self) -> None:
|
||||
self._thread.join()
|
||||
|
||||
def __enter__(self) -> "_SshChannelFileReader":
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
|
||||
self.close()
|
||||
|
||||
def wait_for_output(self) -> str:
|
||||
self._thread.join()
|
||||
|
||||
assert self._output is not None
|
||||
return self._output
|
||||
|
||||
def _read_thread(self) -> None:
|
||||
log_enabled = logging.getLogger().isEnabledFor(self._log_level)
|
||||
|
||||
with StringIO() as output:
|
||||
while True:
|
||||
# Read output one line at a time.
|
||||
line = self._channel_file.readline()
|
||||
if not line:
|
||||
break
|
||||
|
||||
# Store the line.
|
||||
output.write(line)
|
||||
|
||||
# Log the line.
|
||||
if log_enabled:
|
||||
line_strip_newline = line[:-1] if line.endswith("\n") else line
|
||||
logging.log(self._log_level, "%s: %s", self._log_name, line_strip_newline)
|
||||
|
||||
self._channel_file.close()
|
||||
self._output = output.getvalue()
|
||||
|
||||
|
||||
class SshProcess:
|
||||
def __init__(
|
||||
self,
|
||||
cmd: str,
|
||||
stdout: ChannelFile,
|
||||
stderr: ChannelStderrFile,
|
||||
stdout_log_level: int,
|
||||
stderr_log_level: int,
|
||||
) -> None:
|
||||
self.cmd = cmd
|
||||
self._channel = stdout.channel
|
||||
self._result: Optional[SshExecutableResult] = None
|
||||
|
||||
self._start_time = time.monotonic()
|
||||
|
||||
chanid = self._channel.chanid
|
||||
|
||||
logging.debug("[ssh][%d][cmd]: %s", chanid, cmd)
|
||||
|
||||
self._stdout_reader = _SshChannelFileReader(stdout, stdout_log_level, f"[ssh][{chanid}][stdout]")
|
||||
self._stderr_reader = _SshChannelFileReader(stderr, stderr_log_level, f"[ssh][{chanid}][stderr]")
|
||||
|
||||
def close(self) -> None:
|
||||
self._channel.close()
|
||||
self._stdout_reader.close()
|
||||
self._stderr_reader.close()
|
||||
|
||||
def __enter__(self) -> "SshProcess":
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
|
||||
self.close()
|
||||
|
||||
def wait(
|
||||
self,
|
||||
timeout: float = 600,
|
||||
) -> SshExecutableResult:
|
||||
result = self._result
|
||||
if result is None:
|
||||
# Wait for the process to exit.
|
||||
completed = self._channel.status_event.wait(timeout)
|
||||
|
||||
if completed:
|
||||
exit_code = self._channel.recv_exit_status()
|
||||
|
||||
else:
|
||||
# Close channel.
|
||||
self._channel.close()
|
||||
|
||||
# Get the process's output.
|
||||
stdout = self._stdout_reader.wait_for_output()
|
||||
stderr = self._stderr_reader.wait_for_output()
|
||||
|
||||
elapsed_time = time.monotonic() - self._start_time
|
||||
|
||||
logging.debug(
|
||||
"[ssh][%d][cmd]: execution time: %f, exit code: %d", self._channel.chanid, elapsed_time, exit_code
|
||||
)
|
||||
|
||||
result = SshExecutableResult(stdout, stderr, exit_code, self.cmd, elapsed_time, not completed)
|
||||
self._result = result
|
||||
|
||||
return result
|
||||
|
||||
|
||||
class SshClient:
|
||||
def __init__(
|
||||
self,
|
||||
hostname: str,
|
||||
port: int = 22,
|
||||
username: Optional[str] = None,
|
||||
key_path: Optional[Path] = None,
|
||||
gateway: "Optional[SshClient]" = None,
|
||||
known_hosts_path: Optional[Path] = None,
|
||||
) -> None:
|
||||
self.ssh_client: SSHClient
|
||||
|
||||
# Handle gateway.
|
||||
# (That is, proxying an SSH connection through another SSH connection.)
|
||||
sock = None
|
||||
if gateway:
|
||||
gateway_transport = gateway.ssh_client.get_transport()
|
||||
assert gateway_transport
|
||||
sock = gateway_transport.open_channel("direct-tcpip", (hostname, port), ("", 0))
|
||||
|
||||
self.ssh_client = SSHClient()
|
||||
|
||||
# Handle known hosts.
|
||||
self.ssh_client.set_missing_host_key_policy(AutoAddPolicy)
|
||||
if known_hosts_path:
|
||||
self.ssh_client.load_host_keys(str(known_hosts_path))
|
||||
else:
|
||||
self.ssh_client.load_system_host_keys()
|
||||
|
||||
key_filename = None if key_path is None else str(key_path.absolute())
|
||||
|
||||
# Open SSH connection.
|
||||
self.ssh_client.connect(hostname=hostname, port=port, username=username, key_filename=key_filename, sock=sock)
|
||||
|
||||
def close(self) -> None:
|
||||
self.ssh_client.close()
|
||||
|
||||
def __enter__(self) -> "SshClient":
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
|
||||
self.close()
|
||||
|
||||
def run(
|
||||
self,
|
||||
cmd: str,
|
||||
shell: bool = False,
|
||||
cwd: Optional[Path] = None,
|
||||
env: Optional[Dict[str, str]] = None,
|
||||
stdout_log_level: int = logging.DEBUG,
|
||||
stderr_log_level: int = logging.DEBUG,
|
||||
timeout: float = 600,
|
||||
) -> SshExecutableResult:
|
||||
with self.popen(
|
||||
cmd,
|
||||
shell=shell,
|
||||
cwd=cwd,
|
||||
env=env,
|
||||
stdout_log_level=stdout_log_level,
|
||||
stderr_log_level=stderr_log_level,
|
||||
) as process:
|
||||
return process.wait(
|
||||
timeout=timeout,
|
||||
)
|
||||
|
||||
def popen(
|
||||
self,
|
||||
cmd: Union[str, List[str]],
|
||||
shell: bool = False,
|
||||
cwd: Optional[Path] = None,
|
||||
env: Optional[Dict[str, str]] = None,
|
||||
stdout_log_level: int = logging.DEBUG,
|
||||
stderr_log_level: int = logging.DEBUG,
|
||||
) -> SshProcess:
|
||||
if isinstance(cmd, list):
|
||||
cmd = shlex.join(cmd)
|
||||
|
||||
elif not shell:
|
||||
# SSH runs all commands in shell sessions.
|
||||
# So, to remove shell symantics, use shlex to escape all the shell symbols.
|
||||
cmd = shlex.join(shlex.split(cmd))
|
||||
|
||||
if cwd is not None:
|
||||
cmd = f"cd {shlex.quote(str(cwd))}; {cmd}"
|
||||
|
||||
stdin, stdout, stderr = self.ssh_client.exec_command(cmd, environment=env)
|
||||
stdin.close()
|
||||
|
||||
return SshProcess(cmd, stdout, stderr, stdout_log_level, stderr_log_level)
|
||||
|
||||
def put_file(self, local_path: Path, node_path: Path) -> None:
|
||||
with self.ssh_client.open_sftp() as sftp:
|
||||
sftp.put(str(local_path), str(node_path))
|
||||
|
||||
def get_file(self, node_path: Path, local_path: Path) -> None:
|
||||
with self.ssh_client.open_sftp() as sftp:
|
||||
sftp.get(str(node_path), str(local_path))
|
Загрузка…
Ссылка в новой задаче