/******************************************************************************
 * kexec.c
 *
 * Support of kexec (reboot locally into new mini-os kernel).
 *
 * Copyright (c) 2024, Juergen Gross, SUSE Linux GmbH
 *
 * Permission is hereby granted, free of charge, to any person obtaining a copy
 * of this software and associated documentation files (the "Software"), to
 * deal in the Software without restriction, including without limitation the
 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
 * sell copies of the Software, and to permit persons to whom the Software is
 * furnished to do so, subject to the following conditions:
 *
 * The above copyright notice and this permission notice shall be included in
 * all copies or substantial portions of the Software.
 *
 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
 * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
 * DEALINGS IN THE SOFTWARE.
 */

#ifdef CONFIG_KEXEC

#include <mini-os/os.h>
#include <mini-os/lib.h>
#include <mini-os/e820.h>
#include <mini-os/err.h>
#include <mini-os/kexec.h>

#include <xen/elfnote.h>
#include <xen/arch-x86/hvm/start_info.h>

static unsigned long kernel_phys_entry = ~0UL;

/*
 * Final stage of kexec. Copies all data to the final destinations, zeroes
 * .bss and activates new kernel.
 * Must be called with interrupts off. Stack, code and data must be
 * accessible via identity mapped virtual addresses (virt == phys). Copying
 * and zeroing is done using virtual addresses.
 * No relocations inside the function are allowed, as it is copied to an
 * allocated page before being executed.
 */
static void __attribute__((__section__(".text.kexec")))
    kexec_final(struct kexec_action *actions, unsigned long real)
{
    char *src, *dest;
    unsigned int a, cnt;

    for ( a = 0; ; a++ )
    {
        switch ( actions[a].action )
        {
        case KEXEC_COPY:
            dest = actions[a].dest;
            src = actions[a].src;
            for ( cnt = 0; cnt < actions[a].len; cnt++ )
                *dest++ = *src++;
            break;

        case KEXEC_ZERO:
            dest = actions[a].dest;
            for ( cnt = 0; cnt < actions[a].len; cnt++ )
                *dest++ = 0;
            break;

        case KEXEC_CALL:
            asm("movl %0, %%ebx\n\t"
                "movl %1, %%edi\n\t"
                "jmp *%2"
                : :"m" (actions[a].src), "m" (actions[a].dest), "m" (real));
            break;
        }
    }
}

#define KEXEC_STACK_LONGS  8
static unsigned long __attribute__((__section__(".data.kexec")))
    kexec_stack[KEXEC_STACK_LONGS];

static unsigned long get_kexec_addr(void *kexec_page, void *addr)
{
    unsigned long off = (unsigned long)addr - (unsigned long)_kexec_start;

    return (unsigned long)kexec_page + off;
}

void do_kexec(void *kexec_page)
{
    unsigned long actions;
    unsigned long stack;
    unsigned long final;
    unsigned long phys;

    actions = get_kexec_addr(kexec_page, kexec_actions);
    stack = get_kexec_addr(kexec_page, kexec_stack + KEXEC_STACK_LONGS);
    final = get_kexec_addr(kexec_page, kexec_final);
    phys = get_kexec_addr(kexec_page, kexec_phys);

    memcpy(kexec_page, _kexec_start, KEXEC_SECSIZE);
    asm("cli\n\t"
        "mov %0, %%"ASM_SP"\n\t"
        "mov %1, %%"ASM_ARG1"\n\t"
        "mov %2, %%"ASM_ARG2"\n\t"
        "jmp *%3"
        : :"m" (stack), "m" (actions), "m" (phys), "m" (final));
}

bool kexec_chk_arch(elf_ehdr *ehdr)
{
    return ehdr->e32.e_machine == EM_386 || ehdr->e32.e_machine == EM_X86_64;
}

static unsigned int note_data_sz(unsigned int sz)
{
    return (sz + 3) & ~3;
}

static void read_note_entry(elf_ehdr *ehdr, void *start, unsigned int len)
{
    elf_note *note = start;
    unsigned int off, note_len, namesz, descsz;
    char *val;

    for ( off = 0; off < len; off += note_len )
    {
        namesz = note_data_sz(note_val(ehdr, note, namesz));
        descsz = note_data_sz(note_val(ehdr, note, descsz));
        val = note_val(ehdr, note, data);
        note_len = val - (char *)note + namesz + descsz;

        if ( !strncmp(val, "Xen", namesz) &&
             note_val(ehdr, note, type) == XEN_ELFNOTE_PHYS32_ENTRY )
        {
            val += namesz;
            switch ( note_val(ehdr, note, descsz) )
            {
            case 1:
                kernel_phys_entry = *(uint8_t *)val;
                return;
            case 2:
                kernel_phys_entry = *(uint16_t *)val;
                return;
            case 4:
                kernel_phys_entry = *(uint32_t *)val;
                return;
            case 8:
                kernel_phys_entry = *(uint64_t *)val;
                return;
            default:
                break;
            }
        }

        note = elf_ptr_add(note, note_len);
    }
}

int kexec_arch_analyze_phdr(elf_ehdr *ehdr, elf_phdr *phdr)
{
    void *notes_start;
    unsigned int notes_len;

    if ( phdr_val(ehdr, phdr, p_type) != PT_NOTE || kernel_phys_entry != ~0UL )
        return 0;

    notes_start = elf_ptr_add(ehdr, phdr_val(ehdr, phdr, p_offset));
    notes_len = phdr_val(ehdr, phdr, p_filesz);
    read_note_entry(ehdr, notes_start, notes_len);

    return 0;
}

int kexec_arch_analyze_shdr(elf_ehdr *ehdr, elf_shdr *shdr)
{
    void *notes_start;
    unsigned int notes_len;

    if ( shdr_val(ehdr, shdr, sh_type) != SHT_NOTE ||
         kernel_phys_entry != ~0UL )
        return 0;

    notes_start = elf_ptr_add(ehdr, shdr_val(ehdr, shdr, sh_offset));
    notes_len = shdr_val(ehdr, shdr, sh_size);
    read_note_entry(ehdr, notes_start, notes_len);

    return 0;
}

bool kexec_arch_need_analyze_shdrs(void)
{
    return kernel_phys_entry == ~0UL;
}

static unsigned long kexec_param_pa;
static unsigned int kexec_param_size;
static unsigned long kexec_param_mem;

static struct kexec_module *kexec_check_module(void)
{
    unsigned long mod_size;
    unsigned long mod;
    struct kexec_module *module_ptr;

    mod = get_module(&mod_size);
    if ( !mod )
        return NULL;
    /* Size must be a multiple of PAGE_SIZE. */
    if ( mod_size & ~PAGE_MASK )
        return NULL;

    /* Kexec module description is at start of the last page of the module. */
    module_ptr = (void *)(mod + mod_size - (unsigned long)PAGE_SIZE);

    /* Check eye catcher. */
    if ( memcmp(module_ptr->eye_catcher, KEXECMOD_EYECATCHER,
                sizeof(module_ptr->eye_catcher)) )
        return NULL;
    if ( module_ptr->n_pages != (mod_size >> PAGE_SHIFT) - 1 )
        return NULL;

    kexec_mod_start = mod;

    return module_ptr;
}

static void get_mod_addr(unsigned long from, unsigned long to)
{
    unsigned long size = PFN_PHYS(CONFIG_KEXEC_MODULE_PAGES);

    if ( to - from >= size && to - size > kexec_mod_start )
        kexec_mod_start = to - size;
}

#define min(a, b) ((a) < (b) ? (a) : (b))
void kexec_module(unsigned long start_pfn, unsigned long max_pfn)
{
    unsigned int i;
    char *rec_end;

    /* Reuse already existing kexec module. */
    mod_ptr = kexec_check_module();
    if ( !mod_ptr && CONFIG_KEXEC_MODULE_PAGES )
    {
        max_pfn = min(max_pfn, PHYS_PFN(0xffffffff));

        iterate_memory_range(PFN_PHYS(start_pfn), PFN_PHYS(max_pfn),
                             get_mod_addr);
        BUG_ON(!kexec_mod_start);

        mod_ptr = (void *)(kexec_mod_start +
                           ((CONFIG_KEXEC_MODULE_PAGES - 1) << PAGE_SHIFT));
        memset(mod_ptr, 0, PAGE_SIZE);
        memcpy(mod_ptr->eye_catcher, KEXECMOD_EYECATCHER,
               sizeof(mod_ptr->eye_catcher));
        mod_ptr->n_pages = CONFIG_KEXEC_MODULE_PAGES - 1;
        memset(mod_ptr->pg2rec, KEXECMOD_PG_FREE, mod_ptr->n_pages);
        mod_ptr->n_records = 16;
        mod_ptr->recs_off = sizeof(struct kexec_module) +
                            mod_ptr->n_pages + (mod_ptr->n_pages & 1);

        set_reserved_range(kexec_mod_start, (unsigned long)mod_ptr + PAGE_SIZE);
    }

    mod_recs = (void *)((unsigned long)mod_ptr + mod_ptr->recs_off);
    mod_rec_start = (char *)(mod_recs + mod_ptr->n_records);
    mod_rec_end = mod_rec_start;
    for ( i = 0; i < mod_ptr->n_records; i++ )
    {
        if ( mod_recs[i].type == KEXECMOD_REC_NONE )
            continue;
        rec_end = (char *)mod_ptr + mod_recs[i].offset + mod_recs[i].size;
        if ( mod_rec_end < rec_end )
            mod_rec_end = rec_end;
    }
}

void kexec_set_param_loc(const char *cmdline)
{
    kexec_param_size = sizeof(struct hvm_start_info);
    kexec_param_size += e820_entries * sizeof(struct hvm_memmap_table_entry);
    if ( mod_ptr )
        kexec_param_size += sizeof(struct hvm_modlist_entry);
    kexec_param_size += strlen(cmdline) + 1;

    kexec_last_addr = (kexec_last_addr + 7) & ~7UL;
    kexec_param_pa = kexec_last_addr;
    kexec_last_addr += kexec_param_size;
    kexec_last_addr = round_pgup(kexec_last_addr);
}

int kexec_get_entry(const char *cmdline)
{
    void *next;
    struct hvm_start_info *info;
    struct hvm_memmap_table_entry *mmap;
    struct hvm_modlist_entry *mod;
    unsigned int order;
    unsigned int i;

    if ( kernel_phys_entry == ~0UL )
        return ENOEXEC;

    order = get_order(kexec_param_size);

    kexec_param_mem = alloc_pages(order);
    if ( !kexec_param_mem )
        return ENOMEM;

    next = (void *)kexec_param_mem;

    info = next;
    memset(info, 0, sizeof(*info));
    info->magic = XEN_HVM_START_MAGIC_VALUE;
    info->version = 1;
    next = info + 1;

    mmap = next;
    info->memmap_paddr = kexec_param_pa + (unsigned long)next - kexec_param_mem;
    for ( i = 0; i < e820_entries; i++ )
    {
        if ( e820_map[i].type == E820_TYPE_SOFT_RESERVED )
            continue;
        mmap->addr = e820_map[i].addr;
        mmap->size = e820_map[i].size;
        mmap->type = e820_map[i].type;
        mmap++;
    }
    info->memmap_entries = mmap - (struct hvm_memmap_table_entry *)next;
    next = mmap;

    if ( mod_ptr )
    {
        mod = next;
        memset(mod, 0, sizeof(*mod));
        info->nr_modules = 1;
        info->modlist_paddr = kexec_param_pa +
                              (unsigned long)next - kexec_param_mem;
        mod->paddr = kexec_mod_start;
        mod->size = PFN_PHYS(mod_ptr->n_pages + 1);
        mod->cmdline_paddr = 0;
        next = mod + 1;
    }

    info->cmdline_paddr = kexec_param_pa + (unsigned long)next - kexec_param_mem;
    strcpy(next, cmdline);

    if ( kexec_add_action(KEXEC_COPY, to_virt(kexec_param_pa), info,
                          kexec_param_size) )
        return ENOSPC;

    /* The call of the new kernel happens via the physical address! */
    if ( kexec_add_action(KEXEC_CALL, (void *)kernel_phys_entry,
                          (void *)kexec_param_pa, 0) )
        return ENOSPC;

    return 0;
}

void kexec_get_entry_undo(void)
{
    if ( kexec_param_mem )
    {
        free_pages((void *)kexec_param_mem, get_order(kexec_param_size));
        kexec_param_mem = 0;
    }
}
#endif /* CONFIG_KEXEC */
