Firstly, I want to hook CreateFile() and rewrite it. Then I want to recode the callstack of my new CreateFile() function. But when I use SymInitialize() to Initialize a handle, it falls into an endless loop. Through my debug, the reason is SymInitialize() invokes CreateFile(). So why does SymInitialize() involve CreateFile()? How to avoid this loop? Is there any alternative method to record callstack information to avoid this loop?
#include <windows.h>
#include <stdio.h>
#include "detours.h"
#include <fstream>
#include <io.h>
#pragma comment(lib, "detours.lib")
#include <DbgHelp.h> //SymInitialize
#pragma comment(lib,"dbghelp.lib")
#define STACK_INFO_LEN 200
struct stackInfo {
PDWORD hashValue; // hash value to identify same stack
char* szBriefInfo; // callstack info
};
stackInfo ShowTraceStack(char* szBriefInfo)
{
static const int MAX_STACK_FRAMES = 12;
void* pStack[MAX_STACK_FRAMES];
static char szStackInfo[STACK_INFO_LEN * MAX_STACK_FRAMES];
static char szFrameInfo[STACK_INFO_LEN];
HANDLE process = GetCurrentProcess(); // The handle used must be unique to avoid sharing a session with another component,
SymInitialize(process, NULL, TRUE);
PDWORD hashValue = (PDWORD)malloc(sizeof(DWORD)); // allow memory for hashVavlue, it will be rewrited in function CaptureStackBackTrace
WORD frames = CaptureStackBackTrace(0, MAX_STACK_FRAMES, pStack, hashValue);
//printf("hash value is: %ud \n", &hashValue);
if (szBriefInfo == NULL) {
strcpy_s(szStackInfo, "stack traceback:\n");
}
else {
strcpy_s(szStackInfo, szBriefInfo);
}
for (WORD i = 0; i < frames; ++i) {
DWORD64 address = (DWORD64)(pStack[i]);
DWORD64 displacementSym = 0;
char buffer[sizeof(SYMBOL_INFO) + MAX_SYM_NAME * sizeof(TCHAR)];
PSYMBOL_INFO pSymbol = (PSYMBOL_INFO)buffer;
pSymbol->SizeOfStruct = sizeof(SYMBOL_INFO);
pSymbol->MaxNameLen = MAX_SYM_NAME;
DWORD displacementLine = 0;
IMAGEHLP_LINE64 line;
line.SizeOfStruct = sizeof(IMAGEHLP_LINE64);
if (SymFromAddr(process, address, &displacementSym, pSymbol) &&
SymGetLineFromAddr64(process, address, &displacementLine, &line))
{
_snprintf_s(szFrameInfo, sizeof(szFrameInfo), "\t%s() at %s:%d(0x%x)\n",
pSymbol->Name, line.FileName, line.LineNumber, pSymbol->Address);
}
else
{
_snprintf_s(szFrameInfo, sizeof(szFrameInfo), "\terror: %d\n", GetLastError());
}
strcat_s(szStackInfo, szFrameInfo);
}
stackInfo traceStackInfo;
traceStackInfo.hashValue = hashValue;
traceStackInfo.szBriefInfo = szStackInfo;
return traceStackInfo;
//printf("%s", szStackInfo);
}
HANDLE (*__stdcall oldCreateFile)(LPCWSTR,
DWORD,
DWORD,
LPSECURITY_ATTRIBUTES,
DWORD,
DWORD,
HANDLE) = CreateFileW;
HANDLE WINAPI newCreateFile(
_In_ LPCWSTR lpFileName,
_In_ DWORD dwDesiredAccess,
_In_ DWORD dwShareMode,
_In_opt_ LPSECURITY_ATTRIBUTES lpSecurityAttributes,
_In_ DWORD dwCreationDisposition,
_In_ DWORD dwFlagsAndAttributes,
_In_opt_ HANDLE hTemplateFile
) {
ShowTraceStack((char*)"trace information.\n");
return oldCreateFile(
L".\\newFiles.txt", // L".\\NewFile.txt", // Filename
//lpFileName,
dwDesiredAccess, // Desired access
dwShareMode, // Share mode
lpSecurityAttributes, // Security attributes
dwCreationDisposition, // Creates a new file, only if it doesn't already exist
dwFlagsAndAttributes, // Flags and attributes
NULL);
}
void hook() {
DetourRestoreAfterWith();
DetourTransactionBegin();
DetourUpdateThread(GetCurrentThread());
DetourAttach(&(PVOID&)oldCreateFile, newCreateFile);
DetourTransactionCommit();
}
void unhook()
{
DetourTransactionBegin();
DetourUpdateThread(GetCurrentThread());
DetourDetach(&(PVOID&)oldCreateFile, newCreateFile);
DetourTransactionCommit();
}
void myProcess() {
HANDLE hFile = CreateFile(TEXT(".\\CreateFileDemo.txt"),
GENERIC_WRITE | GENERIC_READ,
0,
NULL,
CREATE_ALWAYS,
FILE_ATTRIBUTE_NORMAL, NULL);
if (hFile == INVALID_HANDLE_VALUE)
{
OutputDebugString(TEXT("CreateFile fail!\r\n"));
}
// write to file
const int BUFSIZE = 4096;
char chBuffer[BUFSIZE];
memcpy(chBuffer, "Test", 4);
DWORD dwWritenSize = 0;
BOOL bRet = WriteFile(hFile, chBuffer, 4, &dwWritenSize, NULL);
if (bRet) {
OutputDebugString(TEXT("WriteFile success!\r\n"));
}
}
int main(){
hook();
myProcess();
unhook();
}
The main problem is the call to SymInitialize where you pass through "TRUE" for fInvadeProcess parameter. This is causing it to SymLoadModuleEx to be called for each loaded module. This will cause a lot of file access to download / create / open PDB files for each loaded module. This is the reason for your infinite loop.
The "quick" fix for this sample is to move the call to SymInitialize into your main before the hook call as it only needs to be called once. This means all the PDB modules are loaded before the hooking / call to ShowTraceStack.
The other problems are:
dbghelp API is NOT thread safe - so this example will not work in a multi-threaded application
SymFromAddr may call CreateFile as well for the same reason to load a newly loaded module PDB information - so your hook not passing through the filename will cause PDB information to not work
If you are trying to make someone more useful I would:
Move SymInitialize to main before the hooking (called only once)
Only call CaptureStackBackTrace in the hook and queue the thread stack information to be processed at a later time
Create a separate thread the takes the CaptureStackBackTrace stack information output and convert it to a stack trace - this would is the only thread calling the dbghlp API making calls to dbghlp API thread safe
In your hook detect when being called from the dbghlp API usage thread and don't do the stack trace and don't modify the CreateFile parameters so you don't get into a infinite loop
My project launches a target process, injecting a library into it which is meant to hook Direct3D functions. I have done this successfully in the past, but decided to rewrite the library to a closer specification of the Direct3D interfaces. The idea here was to create a set of my own 1:1 wrapper classes, each inheriting a DirectX interface (ie, class CD3D8 : IDirect3D8, class CD3DDevice8 : IDirect3DDevice8 and so forth). Since all members of each underlying DirectX COM interface are pure virtual methods, I thought I could easily override them...
So when my library hooks Direct3DCreate8, I return a pointer to my CDirect3D8 instead of the standard IDirect3D8. And then my library uses the virtual table of that pointer to hook method #15, which is IDirect3D8::CreateDevice. And once that is called, I return a pointer to CDirect3DDevice8 as opposed to IDirect3DDevice8. All appears to be well, except none of the overrided functions are being called by the injected application! Somehow, the app appears to be calling the original interface functions instead of my own. Am I missing some sort of concept here? Do I need to manually remap the virtual table pointers to the ones in my custom wrapper classes, for example?
Here's what I got so far (only showing d3d8):
D3D.h
#pragma once
#define STDM(method) COM_DECLSPEC_NOTHROW HRESULT STDMETHODCALLTYPE method
#define STDM_(type,method) COM_DECLSPEC_NOTHROW type STDMETHODCALLTYPE method
template<class IDirect3D8>
class CD3D : public IDirect3D8
{
public:
CD3D();
~CD3D();
/*** IUnknown methods ***/
STDM(QueryInterface)(THIS_ REFIID riid, void** ppvObj);
STDM_(ULONG,AddRef)(THIS);
STDM_(ULONG,Release)(THIS);
/*** IDirect3D8 methods ***/
STDM(RegisterSoftwareDevice)(THIS_ void* pInitializeFunction);
STDM_(UINT, GetAdapterCount)(THIS);
STDM(GetAdapterIdentifier)(THIS_ UINT Adapter,DWORD Flags,D3DADAPTER_IDENTIFIER8* pIdentifier);
STDM_(UINT, GetAdapterModeCount)(THIS_ UINT Adapter);
STDM(EnumAdapterModes)(THIS_ UINT Adapter,UINT Mode,D3DDISPLAYMODE* pMode);
STDM(GetAdapterDisplayMode)(THIS_ UINT Adapter,D3DDISPLAYMODE* pMode);
STDM(CheckDeviceType)(THIS_ UINT Adapter,D3DDEVTYPE CheckType,D3DFORMAT DisplayFormat,D3DFORMAT BackBufferFormat,BOOL Windowed);
STDM(CheckDeviceFormat)(THIS_ UINT Adapter,D3DDEVTYPE DeviceType,D3DFORMAT AdapterFormat,DWORD Usage,D3DRESOURCETYPE RType,D3DFORMAT CheckFormat);
STDM(CheckDeviceMultiSampleType)(THIS_ UINT Adapter,D3DDEVTYPE DeviceType,D3DFORMAT SurfaceFormat,BOOL Windowed,D3DMULTISAMPLE_TYPE MultiSampleType);
STDM(CheckDepthStencilMatch)(THIS_ UINT Adapter,D3DDEVTYPE DeviceType,D3DFORMAT AdapterFormat,D3DFORMAT RenderTargetFormat,D3DFORMAT DepthStencilFormat);
STDM(GetDeviceCaps)(THIS_ UINT Adapter,D3DDEVTYPE DeviceType,D3DCAPS8* pCaps);
STDM_(HMONITOR, GetAdapterMonitor)(THIS_ UINT Adapter);
STDM(CreateDevice)(THIS_ UINT Adapter,D3DDEVTYPE DeviceType,HWND hFocusWindow,DWORD BehaviorFlags,D3DPRESENT_PARAMETERS* pPresentationParameters,IDirect3DDevice8** ppReturnedDeviceInterface);
};
D3D.cpp
#include "stdafx.h"
#define STDIMP(iface, method, ...) COM_DECLSPEC_NOTHROW HRESULT STDMETHODCALLTYPE CD3D<iface>::method(__VA_ARGS__)
#define STDIMP_(iface, type, method, ...) COM_DECLSPEC_NOTHROW type STDMETHODCALLTYPE CD3D<iface>::method(__VA_ARGS__)
#define STDIMP8(method, ...) STDIMP(IDirect3D8, method, __VA_ARGS__)
#define STDIMP8_(type, method, ...) STDIMP_(IDirect3D8, type, method, __VA_ARGS__)
CD3D<IDirect3D8>::CD3D()
{
}
CD3D<IDirect3D8>::~CD3D()
{
}
STDIMP8(QueryInterface, THIS_ REFIID riid, void** ppvObj) {
return QueryInterface( riid, ppvObj );
}
STDIMP8_(ULONG, AddRef, THIS) {
return AddRef();
}
STDIMP8_(ULONG, Release, THIS) {
return Release();
}
STDIMP8(RegisterSoftwareDevice, THIS_ void* pInitializeFunction) {
return RegisterSoftwareDevice( pInitializeFunction );
}
STDIMP8_(UINT, GetAdapterCount, THIS) {
return GetAdapterCount();
}
STDIMP8(GetAdapterIdentifier, THIS_ UINT Adapter, DWORD Flags, D3DADAPTER_IDENTIFIER8* pIdentifier) {
return GetAdapterIdentifier( Adapter, Flags, pIdentifier );
}
STDIMP8_(UINT, GetAdapterModeCount, THIS_ UINT Adapter) {
return GetAdapterModeCount( Adapter );
}
STDIMP8(EnumAdapterModes, THIS_ UINT Adapter, UINT Mode, D3DDISPLAYMODE* pMode) {
return EnumAdapterModes( Adapter, Mode, pMode );
}
STDIMP8(GetAdapterDisplayMode, THIS_ UINT Adapter, D3DDISPLAYMODE* pMode) {
return GetAdapterDisplayMode( Adapter, pMode );
}
STDIMP8(CheckDeviceType, THIS_ UINT Adapter, D3DDEVTYPE CheckType, D3DFORMAT DisplayFormat, D3DFORMAT BackBufferFormat, BOOL Windowed) {
return CheckDeviceType( Adapter, CheckType, DisplayFormat, BackBufferFormat, Windowed );
}
STDIMP8(CheckDeviceFormat, THIS_ UINT Adapter, D3DDEVTYPE DeviceType, D3DFORMAT AdapterFormat, DWORD Usage, D3DRESOURCETYPE RType, D3DFORMAT CheckFormat) {
return CheckDeviceFormat( Adapter, DeviceType, AdapterFormat, Usage, RType, CheckFormat );
}
STDIMP8(CheckDeviceMultiSampleType, THIS_ UINT Adapter, D3DDEVTYPE DeviceType, D3DFORMAT SurfaceFormat, BOOL Windowed, D3DMULTISAMPLE_TYPE MultiSampleType) {
return CheckDeviceMultiSampleType( Adapter, DeviceType, SurfaceFormat, Windowed, MultiSampleType );
}
STDIMP8(CheckDepthStencilMatch, THIS_ UINT Adapter, D3DDEVTYPE DeviceType, D3DFORMAT AdapterFormat, D3DFORMAT RenderTargetFormat, D3DFORMAT DepthStencilFormat) {
return CheckDepthStencilMatch( Adapter, DeviceType, AdapterFormat, RenderTargetFormat, DepthStencilFormat );
}
STDIMP8(GetDeviceCaps, THIS_ UINT Adapter, D3DDEVTYPE DeviceType, D3DCAPS8* pCaps) {
return GetDeviceCaps( Adapter, DeviceType, pCaps );
}
STDIMP8_(HMONITOR, GetAdapterMonitor, THIS_ UINT Adapter) {
return GetAdapterMonitor( Adapter );
}
STDIMP8(CreateDevice, THIS_ UINT Adapter, D3DDEVTYPE DeviceType, HWND hFocusWindow, DWORD BehaviorFlags, D3DPRESENT_PARAMETERS* pPresentationParameters, IDirect3DDevice8** ppReturnedDeviceInterface) {
return CreateDevice( Adapter, DeviceType, hFocusWindow, BehaviorFlags, pPresentationParameters, ppReturnedDeviceInterface );
}
Main.cpp (Using Microsoft Detours 3.0):
#include <detours.h>
#include "D3D.h"
typedef HMODULE (WINAPI * HookLoadLibraryA)( LPCSTR lpFileName );
typedef IDirect3D8 *(WINAPI * HookDirect3DCreate8)( UINT SdkVersion );
typedef HRESULT (WINAPI * HookCreateDevice8)( IDirect3DDevice8* pInterface, UINT Adapter, D3DDEVTYPE DeviceType, HWND hFocusWindow, DWORD BehaviorFlags, D3DPRESENT_PARAMETERS* pPresentationParameters, IDirect3DDevice8** ppReturnedDeviceInterface );
HookLoadLibraryA RealLoadLibraryA;
HookDirect3DCreate8 RealDirect3DCreate8;
HookCreateDevice8 RealCreateDevice8;
//...
CD3D<IDirect3D8> *m_d3d8;
CD3DDevice<IDirect3D8> *m_d3dDev8;
//...
RealLoadLibraryA = (HookLoadLibraryA)GetProcAddress(GetModuleHandleA("kernel32.dll"), "LoadLibraryA");
DetourTransactionBegin();
DetourUpdateThread(GetCurrentThread());
DetourAttach(&(PVOID&)RealLoadLibraryA, FakeLoadLibraryA);
DetourTransactionCommit();
//...
VOID VirtualHook( PVOID pInterface, PVOID pHookProc, PVOID pOldProc, int iIndex )
{
// Hook a procedure within an interface's virtual table
PDWORD pVtable = (PDWORD)*((PDWORD)pInterface);
DWORD lpflOldProtect;
VirtualProtect( (PVOID)&pVtable[iIndex], sizeof(DWORD), PAGE_READWRITE, &lpflOldProtect );
if( pOldProc ) *(DWORD*)pOldProc = pVtable[iIndex];
pVtable[iIndex] = (DWORD)pHookProc;
VirtualProtect( pVtable, sizeof(DWORD), lpflOldProtect, &lpflOldProtect );
}
HRESULT WINAPI FakeCreateDevice8( IDirect3DDevice8* pInterface, UINT Adapter, D3DDEVTYPE DeviceType, HWND hFocusWindow, DWORD BehaviorFlags, D3DPRESENT_PARAMETERS* pPresentationParameters, IDirect3DDevice8** ppReturnedDeviceInterface )
{
HRESULT ret = RealCreateDevice8( pInterface, Adapter, DeviceType, hFocusWindow, BehaviorFlags, pPresentationParameters, ppReturnedDeviceInterface );
// Save the registers
__asm pushad
if(*ppReturnedDeviceInterface != NULL)
m_d3dDev8 = reinterpret_cast<CD3DDevice<IDirect3D8> *>(*ppReturnedDeviceInterface);
// Restore the registers
__asm popad
return ret;
}
IDirect3D8 *WINAPI FakeDirect3DCreate8( UINT SdkVersion )
{
m_d3d8 = reinterpret_cast<CD3D<IDirect3D8> *>(RealDirect3DCreate8( SdkVersion ));
if( m_d3d8 ) {
// Hook CreateDevice (vftable index #15)
VirtualHook( m_d3d8, &FakeCreateDevice8, &RealCreateDevice8, 15 );
}
return m_d3d8;
}
HMODULE WINAPI FakeLoadLibraryA( LPCSTR lpFileName )
{
CStringA strFileName( lpFileName );
int i = strFileName.ReverseFind('\\');
if(i != -1) strFileName = strFileName.Right(i + 1);
if( strFileName.CompareNoCase("d3d8.dll") == 0 )
{
// Hook Direct3DCreate8
HMODULE m_hD3D = RealLoadLibraryA( lpFileName );
RealDirect3DCreate8 = (HookDirect3DCreate8)GetProcAddress(m_hD3D, "Direct3DCreate8");
DetourTransactionBegin();
DetourUpdateThread( GetCurrentThread() );
DetourAttach(&(PVOID&)RealDirect3DCreate8, FakeDirect3DCreate8);
DetourTransactionCommit();
return m_hD3D;
}
}
... The hooked functions get called, but none of my wrapper functions do, and the injected application runs just as it normally would. Why is this? Of course, I could manually set hooks for each and every function found in each and every DirectX interface (for each version of Direct3D), but the point of this was to try and prevent having to do that and to keep it a bit cleaner. So is there a way to get this to work as I had intended? Thanks!
You're using reinterpret_cast, which won't do much since the underlying object will still have a pointer to the 'real' IDirect3D8 or IDirect3DDevice8 COM vtable, which will be used for the calls.
Instead, instantiate your custom CD3D or CD3DDevice by passing it the original object instance. Then modify all your calls to call back the correct method on the original object -- your class effectively acting as a transparent proxy.
I mean something such as:
STDIMP8_(ULONG, AddRef, THIS) {
return realObject->AddRef();
}
with realObject being the original IDirect3D8*.
I am trying to hook a function to cmd.exe process
the dll is injected just fine the problem is i can't get the cmd.exe to call my function
when im trying to enter the word "dir" on the command prompt it's showing me the same results instade of changing the first name to 'dan'
what am i doing wrong?
HANDLE WINAPI newFindFirstFileA(__in LPCTSTR lpFileName, __out LPWIN32_FIND_DATA lpFindFileData)
{
WIN32_FIND_DATA FindFileData;
HANDLE hFind = FindFirstFile(lpFileName, &FindFileData);
*FindFileData.cFileName = L'Dan';
lpFindFileData = &FindFileData;
return hFind;
}
BOOL APIENTRY DllMain (HINSTANCE hInst /* Library instance handle. */ ,
DWORD reason /* Reason this function is being called. */ ,
LPVOID reserved /* Not used. */ )
{
switch (reason)
{
case DLL_PROCESS_ATTACH:
MessageBox(NULL,L"DLL Was injected!", L"Message" ,NULL);
/* Hooking function */
DWORD* dw = (DWORD*)GetProcAddress( GetModuleHandleA("kernel32.dll"), "FindFirstFileA" );
*dw = (DWORD)newFindFirstFileA;
break;
}
/* Returns TRUE on success, FALSE on failure */
return TRUE;
}
GetProcAddress does not return the pointer to IAT entry. Instead, it returns the location of the actual function. Thus, *dw = (DWORD) newFindFirstFileA would overwrite the prolog of the FindFirstFileA function, which would be disastrous. Refer to this article for detailed explanation for hooking an API
Is there a way to progammatically detect when a module - specifically a DLL - has been unloaded from a process?
I don't have the DLL source, so I can't change it's DLL entry point. Nor can I poll if the DLL is currently loaded because the DLL may be unloaded and then reloaded between polling.
RESULTS:
I ended up using jimharks solution of detouring the dll entry point and catching DLL_PROCESS_DETACH. I found detouring FreeLibrary() to work as well but code must be added to detect when the module is actually unloaded or if the reference count is just being decreased. Necrolis' link about finding the reference count was handy for on method of doing so.
I should note that I had problems with MSDetours not actually unloading the module from memory if a detour existed within it.
One very bad way(which was used by starcraft 2), is to make your program attach to itsself then monitor for the dll unload debug event(http://msdn.microsoft.com/en-us/library/ms679302(VS.85).aspx), else you'd either need to IAT hook FreeLibrary and FreeLibraryEx in the process or hotpatch the functions in kernel32 them monitor the names being passed and the global reference counts.
Try using LdrRegisterDllNotification if you're on Vista or above. It does require using GetProcAddress to find the function address from ntdll.dll, but it's the proper way of doing it.
Maybe a less bad way then Necrolis's would be to use Microsoft Research's Detours package to hook the dll's entry point to watch for DLL_PROCESS_DETACH notifications.
You can find the entry point given an HMODULE (as returned by LoadLibrary) using this function:
#include <windows.h>
#include <DelayImp.h>
PVOID GetAddressOfEntryPoint(HMODULE hmod)
{
PIMAGE_DOS_HEADER pidh = (PIMAGE_DOS_HEADER)hmod;
PIMAGE_NT_HEADERS pinth = (PIMAGE_NT_HEADERS)((PBYTE)hmod + pidh->e_lfanew);
PVOID pvEntry = (PBYTE)hmod + pinth->OptionalHeader.AddressOfEntryPoint;
return pvEntry;
}
Your entrypoint replacement could take direct action or increment a counter that you check for in your main loop or where it's important to you. (And should almost certainly call the original entrypoint.)
UPDATE: Thanks to #LeoDavidson for pointing this out in the comments below. Detours 4.0 is now licensed using the liberal MIT License.
I hope this helps.
#Necrolis, your link to “The covert way to find the Reference Count of DLL” was just too intriguing for me to ignore because it contains the technical details I needed to implement this alternate solution (that I thought of yesterday, but was lacking the Windows Internals). Thanks. I voted for your answer because of the link you shared.
The linked article shows how to get to the internal LDR_MODULE:
struct _LDR_MODULE
{
LIST_ENTRY InLoadOrderModuleList;
LIST_ENTRY InMemoryOrderModuleList;
LIST_ENTRY InInitializationOrderModuleList;
PVOID BaseAddress;
PVOID EntryPoint;
ULONG SizeOfImage;
UNICODE_STRING FullDllName;
UNICODE_STRING BaseDllName;
ULONG Flags;
USHORT LoadCount;
USHORT TlsIndex;
LIST_ENTRY HashTableEntry;
ULONG TimeDateStamp;
} LDR_MODULE, *PLDR_MODULE;
Right here we have EntryPoint, Window's internal pointer to the module’s entry point. For a dll that’s DllMain (or the language run time function that eventually calls DllMain). What if we just change that? I wrote a test and it seems to work, at least on XP. The DllMain hook gets called with reason DLL_PROCESS_DETACH just before the DLL unloads.
The BaseAddress is the same value as an HMODULE and is useful for finding the right LDR_MODULE. The LoadCount is here so we can track that. And finally FullDllName is helpful for debugging and makes it possible to search for DLL name instead of HMODULE.
This is all Windows internals. It’s (mostly) documented, but the MSDN documentation warns “ZwQueryInformationProcess may be altered or unavailable in future versions of Windows.”
Here’s a full example (but without full error checking). It seems to work but hasn’t seen much testing.
// HookDllEntryPoint.cpp by Jim Harkins (jimhark), Nov 2010
#include "stdafx.h"
#include <stdio.h>
#include <winternl.h>
#include <process.h> // for _beginthread, only needed for testing
typedef NTSTATUS(WINAPI *pfnZwQueryInformationProcess)(
__in HANDLE ProcessHandle,
__in PROCESSINFOCLASS ProcessInformationClass,
__out PVOID ProcessInformation,
__in ULONG ProcessInformationLength,
__out_opt PULONG ReturnLength);
HMODULE hmodNtdll = LoadLibrary(_T("ntdll.dll"));
// Should test pZwQueryInformationProcess for NULL if you
// might ever run in an environment where this function
// is not available (like future version of Windows).
pfnZwQueryInformationProcess pZwQueryInformationProcess =
(pfnZwQueryInformationProcess)GetProcAddress(
hmodNtdll,
"ZwQueryInformationProcess");
typedef BOOL(WINAPI *PDLLMAIN) (
__in HINSTANCE hinstDLL,
__in DWORD fdwReason,
__in LPVOID lpvReserved);
// Note: It's possible for pDllMainNew to be called before
// HookDllEntryPoint returns. If pDllMainNew calls the old
// function, it should pass a pointer to the variable used
// so we can set it here before we hook.
VOID HookDllEntryPoint(
HMODULE hmod, PDLLMAIN pDllMainNew, PDLLMAIN *ppDllMainOld)
{
PROCESS_BASIC_INFORMATION pbi = {0};
ULONG ulcbpbi = 0;
NTSTATUS nts = (*pZwQueryInformationProcess)(
GetCurrentProcess(),
ProcessBasicInformation,
&pbi,
sizeof(pbi),
&ulcbpbi);
BOOL fFoundMod = FALSE;
PLIST_ENTRY pcurModule =
pbi.PebBaseAddress->Ldr->InMemoryOrderModuleList.Flink;
while (!fFoundMod && pcurModule !=
&pbi.PebBaseAddress->Ldr->InMemoryOrderModuleList)
{
PLDR_DATA_TABLE_ENTRY pldte = (PLDR_DATA_TABLE_ENTRY)
(CONTAINING_RECORD(
pcurModule, LDR_DATA_TABLE_ENTRY, InMemoryOrderLinks));
// Note: pldte->FullDllName.Buffer is Unicode full DLL name
// *(PUSHORT)&pldte->Reserved5[1] is LoadCount
if (pldte->DllBase == hmod)
{
fFoundMod = TRUE;
*ppDllMainOld = (PDLLMAIN)pldte->Reserved3[0];
pldte->Reserved3[0] = pDllMainNew;
}
pcurModule = pcurModule->Flink;
}
return;
}
PDLLMAIN pDllMain_advapi32 = NULL;
BOOL WINAPI DllMain_advapi32(
__in HINSTANCE hinstDLL,
__in DWORD fdwReason,
__in LPVOID lpvReserved)
{
char *pszReason;
switch (fdwReason)
{
case DLL_PROCESS_ATTACH:
pszReason = "DLL_PROCESS_ATTACH";
break;
case DLL_PROCESS_DETACH:
pszReason = "DLL_PROCESS_DETACH";
break;
case DLL_THREAD_ATTACH:
pszReason = "DLL_THREAD_ATTACH";
break;
case DLL_THREAD_DETACH:
pszReason = "DLL_THREAD_DETACH";
break;
default:
pszReason = "*UNKNOWN*";
break;
}
printf("\n");
printf("DllMain(0x%.8X, %s, 0x%.8X)\n",
(int)hinstDLL, pszReason, (int)lpvReserved);
printf("\n");
if (NULL == pDllMain_advapi32)
{
return FALSE;
}
else
{
return (*pDllMain_advapi32)(
hinstDLL,
fdwReason,
lpvReserved);
}
}
void TestThread(void *)
{
// Do nothing
}
// Test HookDllEntryPoint
int _tmain(int argc, _TCHAR* argv[])
{
HMODULE hmodAdvapi = LoadLibrary(L"advapi32.dll");
printf("advapi32.dll Base Addr: 0x%.8X\n", (int)hmodAdvapi);
HookDllEntryPoint(
hmodAdvapi, DllMain_advapi32, &pDllMain_advapi32);
_beginthread(TestThread, 0, NULL);
Sleep(1000);
return 0;
}