#!/usr/bin/env sh
exec poke -L "$0" "$@"
!#

/* prelinkr.pk - Prepare ELF64 DSOs to load to a given address.  */

/* Copyright (C) 2024 Free Software Foundation Inc.  */

/* Written by Jose E. Marchesi <jose.marchesi@oracle.com>
   Contributed by Oracle Inc.  */

/* This program is free software: you can redistribute it and/or modify
 * it under the terms of the GNU General Public License as published by
 * the Free Software Foundation, either version 3 of the License, or
 * (at your option) any later version.
 *
 * This program is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU General Public License for more details.
 *
 * You should have received a copy of the GNU General Public License
 * along with this program.  If not, see <http://www.gnu.org/licenses/>.
 */

/* This program is basically a Poke version of `prelink -r', written
   by Jakub Jelinek.  */

load argp;
load elf;

var verbose = 0;

fun read_load_address = (Elf64_File elf, Elf64_Addr addr) Elf64_Addr:
{
  var off = elf.vaddr_to_file_offset (addr);
  return Elf64_Addr @ off;
}

fun write_load_address = (Elf64_File elf, Elf64_Addr addr, Elf64_Addr val) void:
{
  var off = elf.vaddr_to_file_offset (addr);

  if (verbose)
    {
      var sec = elf.vaddr_to_sec (addr);

      assert (sec != -1);
      printf ("adjusting data %s(%i32d):%v[%v->%v]\n",
              elf.get_section_name (elf.shdr[sec].sh_name), sec,
              addr, Elf64_Addr @ off, val);
    }
  Elf64_Addr @ off = val;
}

fun adjust_section = (Elf64_File elf, int<32> idx,
                      Elf64_Addr start, Elf64_Addr adjust) void:
{
  var shdr = elf.shdr[idx];

  /* First adjust section contents. */
  if (shdr.sh_type == ELF_SHT_PROGBITS)
    {
      if (elf.get_section_name (shdr.sh_name) == ".stab")
        {
          raise Exception { code = EC_inval,
                            msg = "relocating .stab sections is not supported" };
        }
      else if (elf.get_section_name (shdr.sh_name) == ".debug_info")
        {
          raise Exception { code = EC_inval,
                            msg = "relocating .debug_info sections is not supported" };
        }
    }
  else if (shdr.sh_type in [ELF_SHT_SYMTAB, ELF_SHT_DYNSYM])
    {
      for (sym in Elf64_Sym [shdr.sh_size] @ shdr.sh_offset)
        {
          if ((sym.st_shndx == ELF_SHN_ABS
               && sym.st_value != 0#B
               && sym.st_info.st_type <= ELF_STT_FUNC)
              || (sym.st_shndx > ELF_SHN_UNDEF
                  && sym.st_shndx < elf.ehdr.e_shnum
                  && sym.st_info.st_type != ELF_STT_TLS
                  && elf.shdr[sym.st_shndx].sh_flags & (ELF_SHF_WRITE | ELF_SHF_ALLOC | ELF_SHF_EXECINSTR)))
            {
              sym.st_value += adjust;
              if (verbose)
                printf ("adjusting symbol %s st_value[%v->%v]\n",
                        elf.get_string (sym.st_name),
                        sym.st_value - adjust, sym.st_value);
            }
        }
    }
  else if (shdr.sh_type == ELF_SHT_DYNAMIC)
    {
      /* XXX non-strict due to a poke bug regarding unions below.  */
      for (dyn in Elf64_Dyn[shdr.sh_size] @ shdr.sh_offset)
        {
          if (elf_mach == ELF_EM_X86_64 && dyn.d_tag == ELF_DT_PLTGOT)
            {
              var sec = elf.vaddr_to_sec (dyn.d_data.d_ptr);
              if (sec != -1)
                {
                  var data = read_load_address (elf, dyn.d_data.d_ptr);
                  /* If .got.plt[0] points to _DYNAMIC, it needs to
                     be adjusted.  */
                  if (data == shdr.sh_addr && data >= start)
                    write_load_address (elf, dyn.d_data.d_ptr, data + adjust);

                  data = read_load_address (elf, dyn.d_data.d_ptr + 8#B);
                  /* If .got.plt[1] points to .plt + 0x16, it needs to
                     be adjusted.  */
                  if (data > 0#B && data >= start)
                    {
                      for (s in elf.shdr)
                        if (data == s.sh_addr + 0x16#B
                            && s.sh_type == ELF_SHT_PROGBITS
                            && elf.get_section_name (s.sh_name) == ".plt")
                          write_load_address (elf, dyn.d_data.d_ptr + 8#B, data + adjust);
                    }
                }
            }

          if (dyn.d_tag as uint<32> in [ELF_DT_REL, ELF_DT_RELA])
            {
              if (dyn.d_data.d_ptr != 0#B && dyn.d_data.d_ptr >= start)
                {
                  if (verbose)
                    printf ("adjusting dynamic tag d_ptr[%v->%v] with tag %v\n",
                            dyn.d_data.d_ptr, dyn.d_data.d_ptr + adjust, dyn.d_tag);
                  dyn.d_data.d_ptr += adjust;
                }
            }
          else if (!(dyn.d_data.d_ptr ?! E_elem))
            {
              if (dyn.d_data.d_ptr >= start)
                {
                  if (verbose)
                    printf ("adjusting dynamic tag d_ptr[%v->%v] with tag %v\n",
                            dyn.d_data.d_ptr, dyn.d_data.d_ptr + adjust, dyn.d_tag);
                  printf ("DYN %v\n", dyn);
                  dyn.d_data.d_ptr += adjust;
                  printf ("POST\n");
                }
            }
        }
    }
  else if (shdr.sh_type == ELF_SHT_REL)
    {
      for (rel in Elf64_Rel [shdr.sh_size] @ shdr.sh_offset
           where elf.vaddr_to_sec (rel.r_offset) != -1)
        {
          if (rel.r_offset >= start)
            rel.r_offset += adjust;
        }
    }
  else if (shdr.sh_type == ELF_SHT_RELA)
    {
      for (rela in Elf64_Rela[shdr.sh_size] @ shdr.sh_offset
           where elf.vaddr_to_sec (rela.r_offset) != -1)
        {
          if (elf_mach == ELF_EM_X86_64)
            {
              if (rela.r_info.r_type == ELF_R_X86_64_RELATIVE)
                {
                   if (rela.r_addend >= start)
                     {
                       if (verbose)
                         {
                           print "adjusting reloc R_X86_64_RELATIVE";
                           printf " addend[%v->%v]\n", rela.r_addend, rela.r_addend + adjust;
                         }
                       if (read_load_address (elf, rela.r_offset) == rela.r_addend)
                         write_load_address (elf, rela.r_offset, rela.r_addend + adjust);
                       rela.r_addend += adjust;
                     }
                }
              else if (rela.r_info.r_type == ELF_R_X86_64_IRELATIVE)
                {
                  if (rela.r_addend >= start)
                    {
                      if (verbose)
                        printf "adjusting reloc R_X86_64_IRELATIVE addend[%v->%v]\n",
                               rela.r_addend, rela.r_addend + adjust;
                      rela.r_addend += adjust;
                    }
                  var addr = read_load_address (elf, rela.r_offset);
                  if (addr >= start)
                    write_load_address (elf, rela.r_offset, addr + adjust);
                }
              else if (rela.r_info.r_type == ELF_R_X86_64_JUMP_SLOT)
                {
                  var addr = read_load_address (elf, rela.r_offset);
                  if (addr >= start)
                    write_load_address (elf, rela.r_offset, addr + adjust);
                }
            }

          if (rela.r_offset >= start)
            rela.r_offset += adjust;
        }
    }

  /* Now the section itself.  */
  if (shdr.sh_flags & (ELF_SHF_WRITE | ELF_SHF_ALLOC | ELF_SHF_EXECINSTR))
    {
      if (shdr.sh_addr >= start)
        {
          shdr.sh_addr += adjust;
          if (start > 0#B)
            {
              shdr.sh_offset += adjust;
              if (verbose)
                printf ("adjusting section [%i32d] %s sh_offset[%v->%v]\n",
                        idx, elf.get_section_name (shdr.sh_name),
                        shdr.sh_offset - adjust, shdr.sh_offset);
            }
          if (verbose)
            printf ("adjusting section [%i32d] %s sh_addr[%v->%v]\n",
                    idx, elf.get_section_name (shdr.sh_name),
                    shdr.sh_addr - adjust, shdr.sh_addr);
        }
    }
}

fun adjust_segment = (Elf64_File elf, int<32> idx,
                      Elf64_Addr start, Elf64_Addr adjust) void:
{
  var phdr = elf.phdr[idx];

  /* The STACK segment doesn't need to be adjusted, since all its
     addresses and offsets shall be zero.  */
  if (phdr.p_type == ELF_PT_GNU_STACK)
    return;

  fun check_alignment = (Elf64_Addr vaddr, Elf64_Addr offset) void:
  {
    if (phdr.p_type == ELF_PT_LOAD && ((vaddr - offset) % phdr.p_align) != 0#B)
      raise Exception { code = EC_inval,
                        msg = "specified address leads to unaligned PT_LOAD segment in DSO" };
  }

  fun note_adjust = (string field, Elf64_Addr from, Elf64_Addr to) void:
  {
    if (verbose)
      printf "adjusting segment [%i32d] %s[%v->%v]\n", idx, field, from, to;
  }

  if (start == 0#B)
    {
      check_alignment (phdr.p_vaddr + adjust, phdr.p_offset);
      phdr.p_vaddr += adjust;
      phdr.p_paddr += adjust;
      note_adjust ("p_paddr", phdr.p_paddr - adjust, phdr.p_paddr);
      note_adjust ("p_vaddr", phdr.p_vaddr - adjust, phdr.p_vaddr);
    }
  else if (start <= phdr.p_vaddr)
    {
      check_alignment (phdr.p_vaddr + adjust, phdr.p_offset + adjust);
      phdr.p_vaddr += adjust;
      phdr.p_paddr += adjust;
      phdr.p_offset += adjust;
      note_adjust ("p_paddr", phdr.p_paddr - adjust, phdr.p_paddr);
      note_adjust ("p_vaddr", phdr.p_vaddr - adjust, phdr.p_vaddr);
      note_adjust ("p_offset", phdr.p_offset - adjust, phdr.p_offset);
    }
  else if (start < phdr.p_vaddr + phdr.p_filesz)
    {
      phdr.p_filesz += adjust;
      phdr.p_memsz += adjust;
      note_adjust ("p_filesz", phdr.p_filesz - adjust, phdr.p_filesz);
      note_adjust ("p_memsz", phdr.p_memsz - adjust, phdr.p_memsz);
    }
  else if (start < phdr.p_vaddr + phdr.p_memsz)
    {
      phdr.p_memsz += adjust;
      note_adjust ("p_memsz", phdr.p_memsz - adjust, phdr.p_memsz);
    }
  else
    return;
}

fun adjust_file = (Elf64_File elf,
                   Elf64_Addr start, Elf64_Addr adjust) void:
{
  if (elf.ehdr.e_entry >= start)
    elf.ehdr.e_entry += adjust;

  for (var i = 0; i < elf.ehdr.e_phnum; ++i)
    adjust_segment (elf, i, start, adjust);
  for (var i = 0; i < elf.ehdr.e_shnum; ++i)
    adjust_section (elf, i, start, adjust);
}

/******************** Main program.  **************************/

var address = 0#B;

argv
  = argp_parse ("prelinkr", "", "Prepare ELF64 DSOs to load to a given address.",
                [Argp_Option { name = "a", long_name = "address",
                               arg_required = 1,
                               handler = lambda (string arg) void:
                                 {
                                   var base = 10;
                                   if (arg'length > 1 && arg[0:2] == "0x")
                                     {
                                       base = 16;
                                       arg = arg[2:];
                                     }
                                   address = (strtoi (arg, base).val)#B;
                                 }
                             },
                 Argp_Option { name = "v", long_name = "verbose",
                               handler = lambda (string arg) void: { verbose = 1; }
                             }],
                argv);

if (argv'length != 1)
  {
    print "Usage: prelinkr [-v|--verbose] [-a ADDRESS] FILE\n";
    exit (1);
  }

vm_set_obase (16);
vm_set_autoremap (0);
var file = argv[0];

try
  {
    var fd = open (file, IOS_F_READ|IOS_F_WRITE);
    var dso = Elf64_File @ fd : 0#B;

    /* We handle only DSOs.  */
    if (dso.ehdr.e_type != ELF_ET_DYN)
      {
        print "Invalid ELF file.  Expected a DSO.\n";
        exit (1);
      }

    /* We support x86_64 and aarch64.  */
    if (!(elf_mach as int<32> in [ELF_EM_X86_64, ELF_EM_AARCH64]))
      {
        print "Unsupported ELF architecture.\n";
        exit (1);
      }

   /* Determine the current base of the DSO and relocate
      if necessary.  */
    var elf_base = dso.get_load_base;
    if (address > elf_base)
      adjust_file (dso, 0#B, address - elf_base);
  }
catch (Exception e)
  {
    if (e.code == EC_constraint)
      print "error: invalid DSO or transformation leads to invalid ELF\n";
    if (e == E_io)
      printf ("error: could not read `%s'\n", file);
    else if (e == E_perm)
      printf ("error: not enough permissions to read `%s'\n", file);
    else if (e.code == EC_inval)
      printf ("error: %s\n", e.msg);
    else
      raise e;

    exit (1);
  }

/*
 * Local Variables:
 * mode: poke
 * End:
 */
