diff --git a/mkosi/__init__.py b/mkosi/__init__.py index 55c8ccb9b..35a18fb90 100644 --- a/mkosi/__init__.py +++ b/mkosi/__init__.py @@ -5,7 +5,6 @@ import datetime import functools import hashlib -import io import itertools import json import logging @@ -93,6 +92,7 @@ copy_ephemeral, finalize_credentials, finalize_kernel_command_line_extra, + join_initrds, run_qemu, run_ssh, start_journal_remote, @@ -138,7 +138,6 @@ make_executable, one_zero, read_env_file, - round_up, scopedenv, ) from mkosi.versioncomp import GenericVersion @@ -1500,25 +1499,6 @@ def find_devicetree(context: Context, kver: str) -> Path: die(f"Requested devicetree {context.config.devicetree} not found") -def join_initrds(initrds: Sequence[Path], output: Path) -> Path: - assert initrds - - if len(initrds) == 1: - shutil.copy2(initrds[0], output) - return output - - seq = io.BytesIO() - for p in initrds: - initrd = p.read_bytes() - n = len(initrd) - padding = b"\0" * (round_up(n, 4) - n) # pad to 32 bit alignment - seq.write(initrd) - seq.write(padding) - - output.write_bytes(seq.getbuffer()) - return output - - def want_signed_pcrs(config: Config) -> bool: return config.sign_expected_pcr == ConfigFeature.enabled or ( config.sign_expected_pcr == ConfigFeature.auto @@ -2338,7 +2318,7 @@ def copy_initrd(context: Context) -> None: if context.config.kernel_modules_initrd: kver = next(gen_kernel_images(context))[0] initrds += [build_kernel_modules_initrd(context, kver)] - join_initrds(initrds, context.staging / context.config.output_split_initrd) + join_initrds(context.config, initrds, context.staging / context.config.output_split_initrd) break diff --git a/mkosi/qemu.py b/mkosi/qemu.py index bef8d7c66..84bfad818 100644 --- a/mkosi/qemu.py +++ b/mkosi/qemu.py @@ -7,6 +7,7 @@ import errno import fcntl import hashlib +import io import json import logging import os @@ -638,6 +639,25 @@ def rm() -> None: fork_and_wait(rm) +def join_initrds(config: Config, initrds: Sequence[Path], output: Path) -> Path: + assert initrds + + if len(initrds) == 1: + copy_tree(initrds[0], output, sandbox=config.sandbox) + return output + + seq = io.BytesIO() + for p in initrds: + initrd = p.read_bytes() + n = len(initrd) + padding = b"\0" * (round_up(n, 4) - n) # pad to 32 bit alignment + seq.write(initrd) + seq.write(padding) + + output.write_bytes(seq.getbuffer()) + return output + + def qemu_version(config: Config, binary: Path) -> GenericVersion: return GenericVersion( run( @@ -1343,9 +1363,14 @@ def add_virtiofs_mount( kernel and KernelType.identify(config, kernel) != KernelType.uki and "-initrd" not in args.cmdline - and (config.output_dir_or_cwd() / config.output_split_initrd).exists() ): - cmdline += ["-initrd", config.output_dir_or_cwd() / config.output_split_initrd] + if (config.output_dir_or_cwd() / config.output_split_initrd).exists(): + cmdline += ["-initrd", config.output_dir_or_cwd() / config.output_split_initrd] + elif config.initrds: + initrd = config.output_dir_or_cwd() / f"initrd-{uuid.uuid4().hex}" + join_initrds(config, config.initrds, initrd) + stack.callback(lambda: initrd.unlink()) + cmdline += ["-initrd", fname] if config.output_format in (OutputFormat.disk, OutputFormat.esp): direct = fname.stat().st_size % resource.getpagesize() == 0