WUMSLoader/wumsloader/src/module/RelocationUtils.cpp

203 lines
7.9 KiB
C++

#include "RelocationUtils.h"
#include "ElfUtils.h"
#include "ImportRPLInformation.h"
#include "globals.h"
#include "module/ModuleContainer.h"
#include "module/RelocationData.h"
#include "utils/OnLeavingScope.h"
#include "utils/StringTools.h"
#include "utils/logger.h"
#include "utils/memory.h"
#include <wums/defines/relocation_defines.h>
#include <coreinit/debug.h>
#include <coreinit/dynload.h>
#include <algorithm>
#include <map>
#include <span>
#include <string>
#include <vector>
#include <cstdint>
#include <malloc.h>
namespace WUMSLoader::Modules::RelocationUtils {
namespace {
OSDynLoad_Error CustomDynLoadAlloc(int32_t size, int32_t align, void **outAddr) {
if (!outAddr) {
return OS_DYNLOAD_INVALID_ALLOCATOR_PTR;
}
if (align < 4) {
align = 4;
}
if (!(*outAddr = memalign(align, size))) {
return OS_DYNLOAD_OUT_OF_MEMORY;
}
// keep track of allocated memory to clean it up in case the RPLs won't get unloaded properly
gAllocatedAddresses.push_back(*outAddr);
return OS_DYNLOAD_OK;
}
void CustomDynLoadFree(void *addr) {
free(addr);
// Remove from list
if (const auto it = std::ranges::find(gAllocatedAddresses, addr); it != gAllocatedAddresses.end()) {
gAllocatedAddresses.erase(it);
}
}
bool doRelocation(const std::vector<RelocationData> &relocData,
const std::map<std::string, const ModuleContainer *, std::less<>> &moduleMap,
const ExternalRPLLoadingStrategy rplLoadingStrategy,
std::span<relocation_trampoline_entry_t> trampData,
std::map<std::string, OSDynLoad_Module, std::less<>> &usedRPls) {
for (const auto &curReloc : relocData) {
const auto &functionName = curReloc.getName();
const std::string_view rplName = curReloc.getImportRPLInformation().getRPLName();
uint32_t functionAddress = 0;
if (auto it = moduleMap.find(rplName); it != moduleMap.end()) {
const auto *module = it->second;
bool found = false;
for (const auto &exportData : module->getLinkInformation().getExportDataList()) {
if (functionName == exportData.getName()) {
functionAddress = reinterpret_cast<uint32_t>(exportData.getAddress());
found = true;
break;
}
}
if (!found) {
DEBUG_FUNCTION_LINE_ERR("Failed to find export %.*s of module: %.*s",
static_cast<int>(functionName.length()), functionName.data(),
static_cast<int>(rplName.length()), rplName.data());
return false;
}
}
if (functionAddress == 0) {
if (functionName == "MEMAllocFromDefaultHeap") {
functionAddress = reinterpret_cast<uint32_t>(&MEMAlloc);
} else if (functionName == "MEMAllocFromDefaultHeapEx") {
functionAddress = reinterpret_cast<uint32_t>(&MEMAllocEx);
} else if (functionName == "MEMFreeToDefaultHeap") {
functionAddress = reinterpret_cast<uint32_t>(&MEMFree);
}
}
if (functionAddress == 0) {
const int32_t isData = curReloc.getImportRPLInformation().isData();
OSDynLoad_Module rplHandle = nullptr;
if (auto rplIt = usedRPls.find(rplName); rplIt == usedRPls.end()) {
OSDynLoad_Module tmp = nullptr;
if (OSDynLoad_IsModuleLoaded(rplName.data(), &tmp) != OS_DYNLOAD_OK || tmp == nullptr) {
if (rplLoadingStrategy == ExternalRPLLoadingStrategy::IGNORE_EXTERNAL_RPLS) {
DEBUG_FUNCTION_LINE_ERR("Relocation requires .rpl which isn't loaded and loading is not allowed");
return false;
}
}
if (OSDynLoad_Acquire(rplName.data(), &rplHandle) != OS_DYNLOAD_OK) {
DEBUG_FUNCTION_LINE_ERR("Failed to acquire %.*s", (int) rplName.length(), rplName.data());
return false;
}
usedRPls.emplace(std::string(rplName), rplHandle);
} else {
rplHandle = rplIt->second;
}
const auto res = OSDynLoad_FindExport(rplHandle, static_cast<OSDynLoad_ExportType>(isData),
functionName.c_str(), reinterpret_cast<void **>(&functionAddress));
if (res != OS_DYNLOAD_OK || functionAddress == 0) {
DEBUG_FUNCTION_LINE_ERR("Failed to find export %.*s of %.*s",
static_cast<int>(functionName.length()), functionName.c_str(),
static_cast<int>(rplName.length()), rplName.data());
return false;
}
}
if (!ElfUtils::elfLinkOne(curReloc.getType(), curReloc.getOffset(), curReloc.getAddend(),
reinterpret_cast<uint32_t>(curReloc.getDestination()),
functionAddress, trampData, RELOC_TYPE_IMPORT)) {
return false;
}
}
return true;
}
} // namespace
bool ResolveRelocations(const std::vector<ModuleContainer> &loadedModules, const ExternalRPLLoadingStrategy rplLoadingStrategy, std::map<std::string, OSDynLoad_Module, std::less<>> &usedRPls) {
PROFILE_FUNCTION();
bool wasSuccessful = true;
OSDynLoadAllocFn prevDynLoadAlloc = nullptr;
OSDynLoadFreeFn prevDynLoadFree = nullptr;
if (rplLoadingStrategy == ExternalRPLLoadingStrategy::LOAD_EXTERNAL_RPLS) {
OSDynLoad_GetAllocator(&prevDynLoadAlloc, &prevDynLoadFree);
if (gCustomRPLAllocatorAllocFn != nullptr && gCustomRPLAllocatorFreeFn != nullptr) {
OSDynLoad_SetAllocator(reinterpret_cast<OSDynLoadAllocFn>(gCustomRPLAllocatorAllocFn), gCustomRPLAllocatorFreeFn);
} else {
OSDynLoad_SetAllocator(CustomDynLoadAlloc, CustomDynLoadFree);
}
}
auto restoreOSDynLoadAllocator = onLeavingScope([rplLoadingStrategy, prevDynLoadAlloc, prevDynLoadFree]() {
if (rplLoadingStrategy == ExternalRPLLoadingStrategy::LOAD_EXTERNAL_RPLS) {
OSDynLoad_SetAllocator(prevDynLoadAlloc, prevDynLoadFree);
}
});
std::map<std::string, const ModuleContainer *, std::less<>> moduleMap;
for (const auto &mod : loadedModules) {
if (mod.isLinkedAndLoaded()) {
moduleMap[mod.getMetaInformation().getExportName()] = &mod;
}
}
for (auto &curModule : loadedModules) {
if (!curModule.isLinkedAndLoaded()) {
DEBUG_FUNCTION_LINE_VERBOSE("Skip doing relocations for %s as it's not linked properly", curModule.getMetaInformation().getExportName().c_str());
continue;
}
const auto &trampData = curModule.getLinkInformation().getTrampData();
for (auto &cur : trampData) {
if (cur.status == RELOC_TRAMP_IMPORT_DONE) {
cur.status = RELOC_TRAMP_FREE;
}
}
DEBUG_FUNCTION_LINE("Let's do the relocations for %s", curModule.getMetaInformation().getExportName().c_str());
const auto &relocData = curModule.getLinkInformation().getRelocationDataList();
if (!doRelocation(relocData,
moduleMap,
rplLoadingStrategy,
trampData,
usedRPls)) {
wasSuccessful = false;
const auto errMsg = string_format("Failed to do Relocations for %s", curModule.getMetaInformation().getExportName().c_str());
DEBUG_FUNCTION_LINE_ERR("%s", errMsg.c_str());
OSFatal(errMsg.c_str());
}
curModule.getLinkInformation().flushCache();
}
return wasSuccessful;
}
} // namespace WUMSLoader::Modules::RelocationUtils