#!/usr/bin/env python import unittest import logging import os import sys import stat import re import subprocess import tempfile import shutil import lief from lief import Logger Logger.set_level(lief.LOGGING_LEVEL.INFO) from subprocess import Popen from unittest import TestCase from utils import get_sample, get_compiler CURRENT_DIRECTORY = os.path.dirname(os.path.abspath(__file__)) STUB = lief.parse(os.path.join(CURRENT_DIRECTORY, "hello_lief.bin")) LIBADD_C = """\ #include #include int add(int a, int b); int add(int a, int b) { printf("%d + %d = %d\\n", a, b, a + b); return a + b; } """ BINADD_C = """\ #include #include int add(int a, int b); int main(int argc, char **argv) { if (argc != 3) { printf("Usage: %s \\n", argv[0]); exit(-1); } int res = add(atoi(argv[1]), atoi(argv[2])); printf("From myLIb, a + b = %d\\n", res); return 0; } """ class LibAddSample(object): COUNT = 0 def __init__(self): self.logger = logging.getLogger(__name__) self.tmp_dir = tempfile.mkdtemp(suffix='_lief_sample_{:d}'.format(LibAddSample.COUNT)) self.logger.debug("temp dir: {}".format(self.tmp_dir)) LibAddSample.COUNT += 1 self.binadd_path = os.path.join(self.tmp_dir, "binadd.c") self.libadd_path = os.path.join(self.tmp_dir, "libadd.c") self.libadd_so = os.path.join(self.tmp_dir, "libadd.so") self.binadd_bin = os.path.join(self.tmp_dir, "binadd.bin") self.compiler = get_compiler() self.logger.debug("Compiler: {}".format(self.compiler)) with open(self.binadd_path, 'w') as f: f.write(BINADD_C) with open(self.libadd_path, 'w') as f: f.write(LIBADD_C) self._compile_libadd() self._compile_binadd() def _compile_libadd(self, extra_flags=[]): if os.path.isfile(self.libadd_so): os.remove(self.libadd_so) CC_FLAGS = ['-fPIC', '-shared', '-Wl,-soname,libadd.so'] + extra_flags cmd = [self.compiler, '-o', self.libadd_so] + CC_FLAGS + [self.libadd_path] self.logger.debug("Compile 'libadd' with: {}".format(" ".join(cmd))) p = Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT) stdout, _ = p.communicate() self.logger.debug(stdout) def _compile_binadd(self, extra_flags=[]): if os.path.isfile(self.binadd_bin): os.remove(self.binadd_bin) CC_FLAGS = ['-L', self.tmp_dir] + extra_flags cmd = [self.compiler, '-o', self.binadd_bin] + CC_FLAGS + [self.binadd_path, '-ladd'] self.logger.debug("Compile 'binadd' with: {}".format(" ".join(cmd))) p = Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT) stdout, _ = p.communicate() self.logger.debug(stdout) @property def libadd(self): return self.libadd_so @property def binadd(self): return self.binadd_bin @property def directory(self): return self.tmp_dir def remove(self): if os.path.isdir(self.directory): shutil.rmtree(self.directory) def __del__(self): self.remove() class TestDynamic(TestCase): def setUp(self): self.logger = logging.getLogger(__name__) @unittest.skipUnless(sys.platform.startswith("linux"), "requires Linux") def test_remove_library(self): sample = LibAddSample() libadd = lief.parse(sample.libadd) binadd = lief.parse(sample.binadd) libadd_needed = binadd.get_library("libadd.so") binadd -= libadd_needed self.assertFalse(binadd.has_library("libadd.so")) @unittest.skipUnless(sys.platform.startswith("linux"), "requires Linux") def test_remove_tag(self): sample = LibAddSample() libadd = lief.parse(sample.libadd) binadd = lief.parse(sample.binadd) self.logger.debug("BEFORE") list(map(lambda e : self.logger.debug(e), binadd.dynamic_entries)) self.logger.debug("") binadd -= lief.ELF.DYNAMIC_TAGS.NEEDED self.logger.debug("AFTER") list(map(lambda e : self.logger.debug(e), binadd.dynamic_entries)) self.logger.debug("") self.assertTrue(all(e.tag != lief.ELF.DYNAMIC_TAGS.NEEDED for e in binadd.dynamic_entries)) @unittest.skipUnless(sys.platform.startswith("linux"), "requires Linux") def test_runpath_api(self): sample = LibAddSample() libadd = lief.parse(sample.libadd) binadd = lief.parse(sample.binadd) rpath = lief.ELF.DynamicEntryRunPath() rpath = binadd.add(rpath) self.logger.debug(rpath) rpath += "/tmp" self.logger.debug(rpath) self.assertEqual(rpath.paths, ["/tmp"]) self.assertEqual(rpath.runpath, "/tmp") rpath.insert(0, "/foo") self.assertEqual(rpath.paths, ["/foo", "/tmp"]) self.assertEqual(rpath.runpath, "/foo:/tmp") rpath.paths = ["/foo", "/tmp", "/bar"] self.logger.debug(rpath) self.assertEqual(rpath.paths, ["/foo", "/tmp", "/bar"]) self.assertEqual(rpath.runpath, "/foo:/tmp:/bar") rpath -= "/tmp" self.logger.debug(rpath) self.assertEqual(rpath.runpath, "/foo:/bar") rpath.remove("/foo").remove("/bar") self.logger.debug(rpath) self.assertEqual(rpath.runpath, "") self.logger.debug(rpath) @unittest.skipUnless(sys.platform.startswith("linux"), "requires Linux") def test_change_libname(self): sample = LibAddSample() libadd = lief.parse(sample.libadd) binadd = lief.parse(sample.binadd) new_name = "libwhichhasalongverylongname.so" self.assertIn(lief.ELF.DYNAMIC_TAGS.SONAME, libadd) so_name = libadd[lief.ELF.DYNAMIC_TAGS.SONAME] self.logger.debug("DT_SONAME: {}".format(so_name.name)) so_name.name = new_name libfoo_path = os.path.join(sample.directory, new_name) self.logger.debug(libfoo_path) libadd.write(libfoo_path) libfoo = lief.parse(libfoo_path) new_so_name = libadd[lief.ELF.DYNAMIC_TAGS.SONAME] # Check builder did the job right self.assertEqual(new_so_name.name, new_name) libadd_needed = binadd.get_library("libadd.so") libadd_needed.name = new_name # Add a RPATH entry rpath = lief.ELF.DynamicEntryRunPath(sample.directory) rpath = binadd.add(rpath) self.logger.debug(rpath) new_binadd_path = os.path.join(sample.directory, "binadd_updated.bin") self.logger.debug(new_binadd_path) binadd.write(new_binadd_path) # Remove original binaries: os.remove(sample.libadd) os.remove(sample.binadd) # Run the new executable st = os.stat(libfoo_path) os.chmod(libfoo_path, st.st_mode | stat.S_IEXEC) st = os.stat(new_binadd_path) os.chmod(new_binadd_path, st.st_mode | stat.S_IEXEC) p = Popen([new_binadd_path, '1', '2'], stdout=subprocess.PIPE, stderr=subprocess.STDOUT) #env={"LD_LIBRARY_PATH": sample.directory}) # Shouldn't be needed if RPATH inject succeed stdout, _ = p.communicate() if p.returncode > 0: self.logger.fatal(stdout.decode("utf8")) self.assertEqual(p.returncode, 0) self.logger.debug(stdout.decode("utf8")) self.assertIsNotNone(re.search(r'From myLIb, a \+ b = 3', stdout.decode("utf8"))) sample.remove() if __name__ == '__main__': root_logger = logging.getLogger() root_logger.setLevel(logging.DEBUG) ch = logging.StreamHandler() ch.setLevel(logging.DEBUG) root_logger.addHandler(ch) unittest.main(verbosity=2)