Commit 04bb9112 authored by Arkadiusz Hiler's avatar Arkadiusz Hiler Committed by Alexandre Julliard

msvcrt: Increase module's reference count before returning from _beginthread[ex]().

Increasing DLL's reference count from the trampoline function makes it prone to race conditions. The thread can start executing after we have already returned from _beginthread[ex]() and the DLL might have been freed. Fixes rare crash on launch with Baldur's Gate 3. Signed-off-by: 's avatarArkadiusz Hiler <ahiler@codeweavers.com> Signed-off-by: 's avatarPiotr Caban <piotr@codeweavers.com> Signed-off-by: 's avatarAlexandre Julliard <julliard@winehq.org>
parent 2bec77cb
...@@ -32,6 +32,9 @@ typedef struct { ...@@ -32,6 +32,9 @@ typedef struct {
_beginthreadex_start_routine_t start_address_ex; _beginthreadex_start_routine_t start_address_ex;
}; };
void *arglist; void *arglist;
#if _MSVCR_VER >= 140
HMODULE module;
#endif
} _beginthread_trampoline_t; } _beginthread_trampoline_t;
/********************************************************************* /*********************************************************************
...@@ -113,16 +116,10 @@ static DWORD CALLBACK _beginthread_trampoline(LPVOID arg) ...@@ -113,16 +116,10 @@ static DWORD CALLBACK _beginthread_trampoline(LPVOID arg)
thread_data_t *data = msvcrt_get_thread_data(); thread_data_t *data = msvcrt_get_thread_data();
memcpy(&local_trampoline,arg,sizeof(local_trampoline)); memcpy(&local_trampoline,arg,sizeof(local_trampoline));
data->handle = local_trampoline.thread;
free(arg); free(arg);
data->handle = local_trampoline.thread;
#if _MSVCR_VER >= 140 #if _MSVCR_VER >= 140
if (!GetModuleHandleExW(GET_MODULE_HANDLE_EX_FLAG_FROM_ADDRESS, data->module = local_trampoline.module;
(void*)local_trampoline.start_address, &data->module))
{
data->module = NULL;
WARN("failed to get module for the start_address: %d\n", GetLastError());
}
#endif #endif
local_trampoline.start_address(local_trampoline.arglist); local_trampoline.start_address(local_trampoline.arglist);
...@@ -162,7 +159,19 @@ uintptr_t CDECL _beginthread( ...@@ -162,7 +159,19 @@ uintptr_t CDECL _beginthread(
trampoline->start_address = start_address; trampoline->start_address = start_address;
trampoline->arglist = arglist; trampoline->arglist = arglist;
#if _MSVCR_VER >= 140
if(!GetModuleHandleExW(GET_MODULE_HANDLE_EX_FLAG_FROM_ADDRESS,
(void*)start_address, &trampoline->module))
{
trampoline->module = NULL;
WARN("failed to get module for the start_address: %d\n", GetLastError());
}
#endif
if(ResumeThread(thread) == -1) { if(ResumeThread(thread) == -1) {
#if _MSVCR_VER >= 140
FreeLibrary(trampoline->module);
#endif
free(trampoline); free(trampoline);
*_errno() = EAGAIN; *_errno() = EAGAIN;
return -1; return -1;
...@@ -181,19 +190,10 @@ static DWORD CALLBACK _beginthreadex_trampoline(LPVOID arg) ...@@ -181,19 +190,10 @@ static DWORD CALLBACK _beginthreadex_trampoline(LPVOID arg)
thread_data_t *data = msvcrt_get_thread_data(); thread_data_t *data = msvcrt_get_thread_data();
memcpy(&local_trampoline, arg, sizeof(local_trampoline)); memcpy(&local_trampoline, arg, sizeof(local_trampoline));
data->handle = local_trampoline.thread;
free(arg); free(arg);
data->handle = local_trampoline.thread;
#if _MSVCR_VER >= 140 #if _MSVCR_VER >= 140
{ data->module = local_trampoline.module;
thread_data_t *data = msvcrt_get_thread_data();
if (!GetModuleHandleExW(GET_MODULE_HANDLE_EX_FLAG_FROM_ADDRESS,
(void*)local_trampoline.start_address_ex, &data->module))
{
data->module = NULL;
WARN("failed to get module for the start_address: %d\n", GetLastError());
}
}
#endif #endif
retval = local_trampoline.start_address_ex(local_trampoline.arglist); retval = local_trampoline.start_address_ex(local_trampoline.arglist);
...@@ -225,9 +225,21 @@ uintptr_t CDECL _beginthreadex( ...@@ -225,9 +225,21 @@ uintptr_t CDECL _beginthreadex(
trampoline->start_address_ex = start_address; trampoline->start_address_ex = start_address;
trampoline->arglist = arglist; trampoline->arglist = arglist;
#if _MSVCR_VER >= 140
if(!GetModuleHandleExW(GET_MODULE_HANDLE_EX_FLAG_FROM_ADDRESS,
(void*)start_address, &trampoline->module))
{
trampoline->module = NULL;
WARN("failed to get module for the start_address: %d\n", GetLastError());
}
#endif
thread = CreateThread(security, stack_size, _beginthreadex_trampoline, thread = CreateThread(security, stack_size, _beginthreadex_trampoline,
trampoline, initflag, thrdaddr); trampoline, initflag, thrdaddr);
if(!thread) { if(!thread) {
#if _MSVCR_VER >= 140
FreeLibrary(trampoline->module);
#endif
free(trampoline); free(trampoline);
msvcrt_set_errno(GetLastError()); msvcrt_set_errno(GetLastError());
return 0; return 0;
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment