/*
 *  This program is free software; you can redistribute it and/or modify
 *  it under the terms of the GNU General Public License as published by
 *  the Free Software Foundation; either version 2 of the License, or
 *  (at your option) any later version.
 *
 *  This program is distributed in the hope that it will be useful,
 *  but WITHOUT ANY WARRANTY; without even the implied warranty of
 *  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 *  GNU Library General Public License for more details.
 *
 *  You should have received a copy of the GNU General Public License
 *  along with this program; if not, write to the Free Software
 *  Foundation, Inc., 59 Temple Place - Suite 330, Boston, MA 02111-1307, USA.
 */
/*
Test-code for Kernel-Modules with GRUB - ELF-Relocation
	by Soeren Bleikertz <sb@osdev.de>
	
This is just for demonstration and learning.
I won't explain how relocation works. Please read the ELF-specifications!

http://sac.cc || http://osdev.de || http://soeren.geekgate.org

works more or less. not fully tested!
any questions or comments?
*/

#include <stdio.h>
#include <sys/types.h>
#include <sys/stat.h>
#include <unistd.h>
#include <fcntl.h>
#include <stdlib.h>
#include <string.h>
#include <elf.h>
#include <sys/mman.h>

//FATAL ERROR
#define FAT_ERROR(s) \
	do { \
		perror(s); \
		exit(-1); \
	} while(1);
	
#define ESUCCESS 0
#define EFAILURE -1
#define ELF_MAGIC "\x7f""ELF"

#define MAX_SYMS 255
	
#define ENTRY_FUNC "mod_init"
	
/* TEST-FUNCTION */
	

/* test.c - I use this small source for relocation-testing. just compile it with
	gcc -c test.c
	
int lala(int *a) {
	*a = 2;
	return 0;
}

static void mod_init(void)
{	
	char *bla = "MOD_INIT!!!!\n";
	foo();
	bar(bla);
}

void hoo(void) {
	int bla;
	bla = 2;
	bla *= 2;
}	
	
 EOF */
 
	
void foo(void)
{
	printf("\nfoo() said: Hello\n");
}
	
void bar(char *b)
{
	printf("\nbar() said: %s(%p)\n", b, b);
}


/* ENTRY-STUFF */

typedef struct entry {
	//Elf32_Shdr *shdr;
	ulong addr;
} entry_t;

entry_t main_entry;


/* SYMBOL-STUFF */
	
typedef struct sym_ent {
	char *name;
	ulong addr;
} sym_t;

// symbol table from module
sym_t sym_tab_mod[MAX_SYMS];
int sym_count;

// global symbol-table
sym_t sym_tab[] = {
	{ "foo", (ulong)foo },
	{ "bar", (ulong)bar } //test
};

// get addr for a symbol
int sym_lookup(char *sym, ulong *sym_val)
{
	int sym_nr = sizeof(sym_tab)/sizeof(sym_t);
	int i;
	
	for (i=0; i<sym_nr; i++) {
		if (!strcmp(sym_tab[i].name, sym)) {
			*sym_val = sym_tab[i].addr;
			return ESUCCESS;
		}
	}
	return EFAILURE;
}

//only local/global functions with size>0
int sym_valid_name(Elf32_Sym *symt)
{
	if ((symt->st_name != 0) &&
	((ELF32_ST_BIND(symt->st_info) == STB_LOCAL) || (ELF32_ST_BIND(symt->st_info) == STB_GLOBAL))
	&& (symt->st_size > 0) && (ELF32_ST_TYPE(symt->st_info) == STT_FUNC))
		return ESUCCESS;
	return EFAILURE;
}
	

// create symbol-table from mod
int sym_tab_create(char *data)
{
	Elf32_Ehdr *elfhdr = (Elf32_Ehdr*)data;
	int nr_entry, i, nr_shdr = elfhdr->e_shnum;
	Elf32_Shdr *shdr, *off_shdr = (Elf32_Shdr*)(data + elfhdr->e_shoff);
	Elf32_Sym *symt;
	char *strtab;
	ulong text_ofs;
	
	sym_count = 0;
	
	//search Symbol-Table
	for (i=0; i < nr_shdr; i++) {
		if (off_shdr[i].sh_type == SHT_SYMTAB)
			shdr = &off_shdr[i];
	}
	
	printf("Debug: symtab: %p\n", shdr->sh_offset);
	symt = (Elf32_Sym*)(data+shdr->sh_offset);
	nr_entry = shdr->sh_size/sizeof(Elf32_Sym);
	
	//search String-Table
	for (i=0; i < nr_shdr; i++) {
		if ((off_shdr[i].sh_type == SHT_STRTAB) &&
			!(off_shdr[i].sh_flags) && (elfhdr->e_shstrndx != i))
			shdr = &off_shdr[i];
	}
	printf("Debug: strtab: %p\n", shdr->sh_offset);
	strtab = (char*)(data+shdr->sh_offset);
	
	//search .text
	for (i=0;i<nr_shdr; i++) {
		if (!(off_shdr[i].sh_flags & 0x0004))
			continue;
		text_ofs = (ulong)off_shdr[i].sh_offset;
	}
	
	
	printf("Debug: .text ofs: %p\n", text_ofs);
	
	// create symbol-table of module
	for (i=0; i<nr_entry; i++) {
		if ((sym_valid_name(&symt[i]) == ESUCCESS) && (sym_count <= MAX_SYMS)) {
			sym_tab_mod[sym_count].name = &strtab[symt[i].st_name];
			sym_tab_mod[sym_count].addr = (ulong)(symt[i].st_value + data + text_ofs);
			printf("Debug: Symbol: %s (%p)\n",
				sym_tab_mod[sym_count], sym_tab_mod[sym_count].addr);
			sym_count++;
			}
	}

	return ESUCCESS;
}
	

/* KMOD-STUFF */

// validate ELF
int kmod_check(char *data)
{
	Elf32_Ehdr *elfhdr= (Elf32_Ehdr*)data;

	if (strncmp(&elfhdr->e_ident[EI_MAG0], ELF_MAGIC, 4) ||
		(elfhdr->e_ident[EI_CLASS] != ELFCLASS32) ||
		(elfhdr->e_ident[EI_DATA] != ELFDATA2LSB) ||
		(elfhdr->e_type != ET_REL) ||
		(elfhdr->e_machine != EM_386) ||
		(elfhdr->e_version != 1))
			return EFAILURE;
		
	return ESUCCESS;
}

int kmod_load(char *data, char *modname)
{
	if (kmod_check(data) == EFAILURE) {
		printf("Error: loading of '%s' failed!\n", modname);
		return EFAILURE;
	}
	printf("Debug: Valid ELF\n");

	return ESUCCESS;
}

int kmod_elf_sym(char *data, uint symidx, ulong *sym_val, ulong symtab_sect)
{
	Elf32_Ehdr *elfhdr;
	Elf32_Shdr *shdr, *sym_sec, *tmp;
	Elf32_Sym *symtab;
	char *sym, *strtab;
	int i;
	
	elfhdr = (Elf32_Ehdr*)data;
	
	shdr = (Elf32_Shdr*)(data + elfhdr->e_shoff);
	
	//search strtab
	for (i=0; i < elfhdr->e_shnum; i++) {
		if ((shdr[i].sh_type == SHT_STRTAB) && !(shdr[i].sh_flags) && (elfhdr->e_shstrndx != i))
			break;			
	}
	tmp = &shdr[i];
	
	strtab = (char*)(data+tmp->sh_offset);
	
	printf("Debug: StrTab-ofs: %p\n", tmp->sh_offset);
	
	sym_sec = (Elf32_Shdr*)&shdr[symtab_sect];
	
	printf("Debug: SymTab-Sect: %lu, SymTab-IDX: %d\n", symtab_sect, symidx);
	
	if (symidx > sym_sec->sh_size/sym_sec->sh_entsize)
		return EFAILURE;
	
	if (sym_sec->sh_type != SHT_SYMTAB) {
		printf("Error: Not a SymTab!\n");
		return EFAILURE;
	}
	
	symtab = (Elf32_Sym*)(data + sym_sec->sh_offset);
	
	printf("Debug: Symtab-ofs: %p\n", sym_sec->sh_offset);
	
	
	if (!symtab[symidx].st_shndx) {
		//external symbol
		printf("Debug: External Symbol\n");
		sym = &strtab[symtab[symidx].st_name];
		printf("Debug: Symbol: \"%s\"\n", sym);
		if (sym_lookup(sym, sym_val) == EFAILURE) {
			printf("Error: Unknown Symbol!\n");
			return EFAILURE;
		}
		printf("Debug: ext. Symbol-Addr: %p\n", *sym_val);
	} else {
		//internal symbol
		printf("Debug: Internal Symbol\n");
		shdr = (Elf32_Shdr*)(data + elfhdr->e_shoff + elfhdr->e_shentsize * symidx);
		*sym_val = symtab->st_value + (ulong)(data + shdr->sh_offset);
		printf("Debug: int. Symbol-Addr: %p, Symtab-Value: %d\n", *sym_val, symtab->st_value);	
	}
	return ESUCCESS;
}
	

int kmod_elf_reloc_do(char *data, Elf32_Rel *rel, Elf32_Shdr *shdr)
{
	ulong *rel_addr;
	Elf32_Shdr *rel_sect; //section for relocation
	Elf32_Ehdr *elfhdr;
	ulong sym_val;
	
	//ELf-hdr
	elfhdr = (Elf32_Ehdr*)data;
	
	
	//section for relocation
	rel_sect = (Elf32_Shdr*)(data + elfhdr->e_shoff + elfhdr->e_shentsize * shdr->sh_info);
	
	//relocation addr
	rel_addr = (ulong*)(data + rel_sect->sh_offset + rel->r_offset);
	printf("Debug: rel_offset: 0x%x, SYM: 0x%x, TYPE: 0x%x\n",
		rel->r_offset, ELF32_R_SYM(rel->r_info), ELF32_R_TYPE(rel->r_info));
	
	
	printf("Debug: sect_for_rel_ofs: %p, SymTab-Sect-IDX: %d\n", rel_sect->sh_offset, shdr->sh_link);
	
	// get addr of symbol
	if (kmod_elf_sym(data, ELF32_R_SYM(rel->r_info), &sym_val, shdr->sh_link) == EFAILURE)
		return EFAILURE;
	// omg..
	switch (ELF32_R_TYPE(rel->r_info)) {
		case R_386_32:
			*rel_addr = sym_val + *rel_addr;
			printf("Debug: R_386_32: rel_addr: %p(%p)\n", rel_addr, *rel_addr);
		break;
		case R_386_PC32:
			*rel_addr = sym_val + *rel_addr - (ulong)rel_addr;
			printf("Debug: R_386_PC32: rel_addr: %p(%p)\n", rel_addr, *rel_addr);
		break;
		default:
			printf("Error: wrong Reloc-Type\n");
			return EFAILURE;
	}
	
	return ESUCCESS;	
}


int kmod_elf_reloc(char *data)
{
	Elf32_Shdr *shdr;
	Elf32_Ehdr *ehdr;
	Elf32_Rel *rel;
	uint shdr_nr, i=0, shdr_ent_sz, rel_sz=0, j;
	ulong entry=0;
	char *bss;
	
	ehdr = (Elf32_Ehdr*)data;
	shdr_nr = ehdr->e_shnum;
	shdr_ent_sz = ehdr->e_shentsize;
	
	shdr = (Elf32_Shdr*)(data+ehdr->e_shoff);	
	
	// search BBS
	while((shdr[i].sh_type != SHT_NOBITS) && (i<shdr_nr))
		i++;
	if (shdr[i].sh_type != SHT_NOBITS)
		printf("No BSS found!\n");
	printf("Debug: Found BSS in Section %d\n", i);
	bss = malloc(shdr[i].sh_size);
	printf("Debug: new BSS: 0x%x\n", (ulong)bss - (ulong)data);
	shdr[i].sh_offset = (ulong)bss - (ulong)data;
	
	// search relocation-sections
	for(i=0; i<shdr_nr; i++, rel_sz=0) {
		if ((shdr[i].sh_type != SHT_RELA) && (shdr[i].sh_type != SHT_REL))
			continue;
		rel_sz = shdr[i].sh_entsize;
		rel = (Elf32_Rel*)(data+shdr[i].sh_offset);
		printf("Debug: rel_section: %d(%p), rel_ent_size: %d -> rel_type: %d, rel_ent_nr: %d\n",
			i, shdr[i].sh_offset, rel_sz, shdr[i].sh_type, shdr[i].sh_size/rel_sz);
		for(j=0;j<shdr[i].sh_size/rel_sz;j++) {
			printf("\nDebug: Relocation-Nr: %d/%d\n", j+1, shdr[i].sh_size/rel_sz);
			printf("Debug: rel: %p, shdr: %p\n",
				(ulong)&rel[j]-(ulong)data, (ulong)&shdr[i]-(ulong)data);
			// do relocation
			if(kmod_elf_reloc_do(data, &rel[j], &shdr[i]) == EFAILURE)
				return EFAILURE;
		}
	}
	
	// search for ENTRY_FUNC, set main-entry-addr
	for (i=0; i<sym_count; i++) {
		if (!strcmp(sym_tab_mod[i].name, ENTRY_FUNC)) {
			main_entry.addr = sym_tab_mod[i].addr;
			break;
		}
	}
	
	// ENTRY_FUNC not found, set main-entry-addr to beginning of .text
	if (main_entry.addr == 0) {
		printf("Debug: No Main-Entry! Set entry to beginning of .text!\n");
		for (i=0;i<shdr_nr; i++) {
			if (!(shdr[i].sh_flags & 0x0004))
				continue;
			main_entry.addr = (ulong)(data + shdr[i].sh_offset);
			//main_entry.shdr = &shdr[i];
			printf("Debug: Entry: %p\n", main_entry.addr);
			break;
		}
	}	
	
	return ESUCCESS;
}
	


/* OTHER STUFF */

int main(int argc, char **argv)
{
	ulong elf_sz, entry;
	char *elf_ptr;
	int fd;
	struct stat st;
	void (*ble)();
	
	if (argc != 2)
		return -1;
	
	//open file
	if ((fd = open(argv[1], O_RDONLY, 0)) == -1)
		FAT_ERROR("open failed");

	//getting file-infos
	if (fstat(fd, &st) == -1)
		FAT_ERROR("stat failed");

	elf_sz = st.st_size;
	if ((elf_ptr = malloc(elf_sz)) == NULL)
		FAT_ERROR("malloc() failed");
	//read file
	read(fd, elf_ptr, elf_sz);
	
	close(fd);
	
	printf("KMOD-Testing:\n");
	
	printf("Modul: %s:\n---\n", argv[1]);
	printf("[Load Mod]\n");
	kmod_load(elf_ptr, argv[1]);
	printf("> Done!\n\n");
	
	printf("[Create SymTab]\n");
	sym_tab_create(elf_ptr);
	printf("> Done!\n\n");
	
	//relocation
	bzero(&main_entry, sizeof(entry_t));
	printf("[Relocation]\n");
	if (kmod_elf_reloc(elf_ptr)==EFAILURE) {
		printf("Error: Relocation failed!\n");
		return EFAILURE;
	} else
		printf("> Done!\n\n");
	
	printf("[Start]\n");
	ble = (void*)main_entry.addr;
	printf("Debug: entry: %p\n", main_entry.addr);
	printf("<output>\n");
	ble();
	printf("</output\n");
	printf("> Done!\n");
	free(elf_ptr);
	return 0;
}

