WIP: MIC: VmTests: Add SSH connection

This commit is contained in:
Chris Gunn 2024-11-19 16:40:36 -08:00
Родитель b005e8b6b2
Коммит 4747e29bf2
6 изменённых файлов: 303 добавлений и 2 удалений

Просмотреть файл

@ -1,5 +1,6 @@
docker == 7.1.0
libvirt-python == 10.9.0
paramiko == 3.5.0
pytest == 8.3.3
PyYAML == 6.0.2
types-docker == 7.1.0.20240827

Просмотреть файл

@ -6,4 +6,5 @@ black == 24.8.0
flake8 == 7.1.0
isort == 5.13.2
mypy == 1.13.0
types-paramiko == 3.5.0.20240928
types-PyYAML == 6.0.12.20240917

Просмотреть файл

@ -11,7 +11,8 @@ from docker import DockerClient
from .conftest import TEST_CONFIGS_DIR
from .utils.imagecustomizer import run_image_customizer
from .utils.libvirt_utils import VmSpec, create_libvirt_domain_xml
from .utils.libvirt_utils import VmSpec, create_libvirt_domain_xml, get_vm_ip_address
from .utils.ssh_client import SshClient
def test_no_change(
@ -73,3 +74,13 @@ def test_no_change(
# Start the VM.
domain.resume()
vm_ip_address = get_vm_ip_address(libvirt_conn, vm_name, timeout=15)
# All the VM are know to be new.
# So, it is fine using an empty known_hosts file.
ssh_known_hosts_path = test_temp_dir.joinpath("known_hosts")
open(ssh_known_hosts_path, 'w').close()
with SshClient(vm_ip_address, key_path=ssh_private_key_path, known_hosts_path=ssh_known_hosts_path) as vm_ssh:
vm_ssh.run("cat /etc/os-release")

Просмотреть файл

@ -1,6 +1,9 @@
import time
import xml.etree.ElementTree as ET # noqa: N817
from pathlib import Path
from typing import Dict
from typing import Dict, Optional
import libvirt # type: ignore
class VmSpec:
@ -150,3 +153,42 @@ def _gen_disk_device_name(prefix: str, next_disk_indexes: Dict[str, int]) -> str
case _:
return f"{prefix}{disk_index}"
# Wait for the VM to boot and then get the IP address.
def get_vm_ip_address(
libvirt_conn: libvirt.virConnect,
vm_name: str,
timeout: float,
) -> str:
timeout_time = time.time() + timeout
while True:
addr = try_get_vm_ip_address(libvirt_conn, vm_name)
if addr:
return addr
if time.time() > timeout_time:
raise Exception(f"No IP addresses found for '{vm_name}'. OS might have failed to boot.")
# Try to get the IP address of the VM.
def try_get_vm_ip_address(
libvirt_conn: libvirt.virConnect,
vm_name: str,
) -> Optional[str]:
domain = libvirt_conn.lookupByName(vm_name)
# Acquire IP address from libvirt's DHCP server.
interfaces = domain.interfaceAddresses(libvirt.VIR_DOMAIN_INTERFACE_ADDRESSES_SRC_LEASE)
if len(interfaces) < 1:
return None
interface_name = next(iter(interfaces))
addrs = interfaces[interface_name]["addrs"]
if len(addrs) < 1:
return None
addr = addrs[0]["addr"]
assert isinstance(addr, str)
return addr

Просмотреть файл

@ -0,0 +1,246 @@
import logging
import shlex
import time
from io import StringIO
from pathlib import Path
from threading import Thread
from typing import Any, Dict, List, Optional, Union
from paramiko import SSHClient, AutoAddPolicy
from paramiko.channel import ChannelFile, ChannelStderrFile
class SshExecutableResult:
def __init__(
self,
stdout: str,
stderr: str,
exit_code: Optional[int],
cmd: Union[str, List[str]],
elapsed: float,
is_timeout: bool,
) -> None:
self.stdout = stdout
self.stderr = stderr
self.exit_code = exit_code
self.cmd = cmd
self.elapsed = elapsed
self.is_timeout = is_timeout
def check_exit_code(self) -> None:
if self.is_timeout:
raise Exception("SSH process timed out")
elif self.exit_code is not None and self.exit_code != 0:
raise Exception(f"SSH process failed with exit code: {self.exit_code}")
class _SshChannelFileReader:
def __init__(self, channel_file: ChannelFile, log_level: int, log_name: str) -> None:
self._channel_file = channel_file
self._log_level = log_level
self._log_name = log_name
self._output: Optional[str] = None
self._thread: Thread = Thread(target=self._read_thread)
self._thread.start()
def close(self) -> None:
self._thread.join()
def __enter__(self) -> "_SshChannelFileReader":
return self
def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
self.close()
def wait_for_output(self) -> str:
self._thread.join()
assert self._output is not None
return self._output
def _read_thread(self) -> None:
log_enabled = logging.getLogger().isEnabledFor(self._log_level)
with StringIO() as output:
while True:
# Read output one line at a time.
line = self._channel_file.readline()
if not line:
break
# Store the line.
output.write(line)
# Log the line.
if log_enabled:
line_strip_newline = line[:-1] if line.endswith("\n") else line
logging.log(self._log_level, "%s: %s", self._log_name, line_strip_newline)
self._channel_file.close()
self._output = output.getvalue()
class SshProcess:
def __init__(
self,
cmd: str,
stdout: ChannelFile,
stderr: ChannelStderrFile,
stdout_log_level: int,
stderr_log_level: int,
) -> None:
self.cmd = cmd
self._channel = stdout.channel
self._result: Optional[SshExecutableResult] = None
self._start_time = time.monotonic()
chanid = self._channel.chanid
logging.debug("[ssh][%d][cmd]: %s", chanid, cmd)
self._stdout_reader = _SshChannelFileReader(stdout, stdout_log_level, f"[ssh][{chanid}][stdout]")
self._stderr_reader = _SshChannelFileReader(stderr, stderr_log_level, f"[ssh][{chanid}][stderr]")
def close(self) -> None:
self._channel.close()
self._stdout_reader.close()
self._stderr_reader.close()
def __enter__(self) -> "SshProcess":
return self
def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
self.close()
def wait(
self,
timeout: float = 600,
) -> SshExecutableResult:
result = self._result
if result is None:
# Wait for the process to exit.
completed = self._channel.status_event.wait(timeout)
if completed:
exit_code = self._channel.recv_exit_status()
else:
# Close channel.
self._channel.close()
# Get the process's output.
stdout = self._stdout_reader.wait_for_output()
stderr = self._stderr_reader.wait_for_output()
elapsed_time = time.monotonic() - self._start_time
logging.debug(
"[ssh][%d][cmd]: execution time: %f, exit code: %d", self._channel.chanid, elapsed_time, exit_code
)
result = SshExecutableResult(stdout, stderr, exit_code, self.cmd, elapsed_time, not completed)
self._result = result
return result
class SshClient:
def __init__(
self,
hostname: str,
port: int = 22,
username: Optional[str] = None,
key_path: Optional[Path] = None,
gateway: "Optional[SshClient]" = None,
known_hosts_path: Optional[Path] = None,
) -> None:
self.ssh_client: SSHClient
# Handle gateway.
# (That is, proxying an SSH connection through another SSH connection.)
sock = None
if gateway:
gateway_transport = gateway.ssh_client.get_transport()
assert gateway_transport
sock = gateway_transport.open_channel("direct-tcpip", (hostname, port), ("", 0))
self.ssh_client = SSHClient()
# Handle known hosts.
self.ssh_client.set_missing_host_key_policy(AutoAddPolicy)
if known_hosts_path:
self.ssh_client.load_host_keys(str(known_hosts_path))
else:
self.ssh_client.load_system_host_keys()
key_filename = None if key_path is None else str(key_path.absolute())
# Open SSH connection.
self.ssh_client.connect(hostname=hostname, port=port, username=username, key_filename=key_filename, sock=sock)
def close(self) -> None:
self.ssh_client.close()
def __enter__(self) -> "SshClient":
return self
def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
self.close()
def run(
self,
cmd: str,
shell: bool = False,
cwd: Optional[Path] = None,
env: Optional[Dict[str, str]] = None,
stdout_log_level: int = logging.DEBUG,
stderr_log_level: int = logging.DEBUG,
timeout: float = 600,
) -> SshExecutableResult:
with self.popen(
cmd,
shell=shell,
cwd=cwd,
env=env,
stdout_log_level=stdout_log_level,
stderr_log_level=stderr_log_level,
) as process:
return process.wait(
timeout=timeout,
)
def popen(
self,
cmd: Union[str, List[str]],
shell: bool = False,
cwd: Optional[Path] = None,
env: Optional[Dict[str, str]] = None,
stdout_log_level: int = logging.DEBUG,
stderr_log_level: int = logging.DEBUG,
) -> SshProcess:
if isinstance(cmd, list):
cmd = shlex.join(cmd)
elif not shell:
# SSH runs all commands in shell sessions.
# So, to remove shell symantics, use shlex to escape all the shell symbols.
cmd = shlex.join(shlex.split(cmd))
if cwd is not None:
cmd = f"cd {shlex.quote(str(cwd))}; {cmd}"
stdin, stdout, stderr = self.ssh_client.exec_command(cmd, environment=env)
stdin.close()
return SshProcess(cmd, stdout, stderr, stdout_log_level, stderr_log_level)
def put_file(self, local_path: Path, node_path: Path) -> None:
with self.ssh_client.open_sftp() as sftp:
sftp.put(str(local_path), str(node_path))
def get_file(self, node_path: Path, local_path: Path) -> None:
with self.ssh_client.open_sftp() as sftp:
sftp.get(str(node_path), str(local_path))

Просмотреть файл