Skip to content

Commit 095a124

Browse files
committed
Safer memory access
1 parent b83fa91 commit 095a124

File tree

3 files changed

+72
-15
lines changed

3 files changed

+72
-15
lines changed

Client/multiplayer_sa/StdInc.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
// StdInc.h
22
#include "StdInc.h"
33
#define ALLOC_STATS_MODULE_NAME "multiplayer_sa"
4+
#define MTASA_EXPORT_SHARED_UTIL
45
#include "SharedUtil.hpp"
56
#include "SharedUtil.MemAccess.hpp"

Shared/sdk/SharedUtil.MemAccess.h

Lines changed: 25 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,16 @@
1111

1212
#include <SharedUtil.IntTypes.h>
1313

14+
#ifdef _WIN32
15+
#ifdef MTASA_EXPORT_SHARED_UTIL
16+
#define SHARED_UTIL_API __declspec(dllexport)
17+
#else
18+
#define SHARED_UTIL_API __declspec(dllimport)
19+
#endif
20+
#else
21+
#define SHARED_UTIL_API
22+
#endif
23+
1424
namespace SharedUtil
1525
{
1626
struct SMemWrite
@@ -20,12 +30,20 @@ namespace SharedUtil
2030
DWORD oldProt;
2131
};
2232

23-
void SetInitialVirtualProtect();
24-
bool IsSlowMem(const void* pAddr, uint uiAmount);
25-
SMemWrite OpenMemWrite(const void* pAddr, uint uiAmount);
26-
void CloseMemWrite(SMemWrite& hMem);
27-
bool ismemset(const void* pAddr, int cValue, uint uiAmount);
33+
SHARED_UTIL_API void SetInitialVirtualProtect();
34+
SHARED_UTIL_API bool IsSlowMem(const void* pAddr, uint uiAmount);
35+
SHARED_UTIL_API SMemWrite OpenMemWrite(const void* pAddr, uint uiAmount);
36+
SHARED_UTIL_API void CloseMemWrite(SMemWrite& hMem);
37+
SHARED_UTIL_API bool ismemset(const void* pAddr, int cValue, uint uiAmount);
38+
39+
bool IsProtectedSlowMem(const void* pAddr);
2840

29-
#define DEBUG_CHECK_IS_FAST_MEM(addr,size) { dassert( !IsSlowMem( (const void*)(addr), size ) ); }
30-
#define DEBUG_CHECK_IS_SLOW_MEM(addr,size) { dassert( IsSlowMem( (const void*)(addr), size ) ); }
41+
#define DEBUG_CHECK_IS_FAST_MEM(addr, size) \
42+
{ \
43+
dassert(!IsSlowMem((const void*)(addr), size)); \
44+
}
45+
#define DEBUG_CHECK_IS_SLOW_MEM(addr, size) \
46+
{ \
47+
dassert(IsSlowMem((const void*)(addr), size)); \
48+
}
3149
} // namespace SharedUtil

Shared/sdk/SharedUtil.MemAccess.hpp

Lines changed: 46 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,13 @@
88
*
99
*****************************************************************************/
1010

11+
#pragma once
12+
#include <windows.h>
13+
#include <cassert>
14+
#include "SharedUtil.MemAccess.h"
15+
#include "SharedUtil.Logging.h"
16+
#include "SString.h"
17+
1118
namespace SharedUtil
1219
{
1320
// Returns true if matching memset would have no affect
@@ -74,13 +81,12 @@ namespace SharedUtil
7481
// Temporarily unprotect slow mem area
7582
SMemWrite OpenMemWrite(const void* pAddr, uint uiAmount)
7683
{
77-
SMemWrite hMem;
84+
SMemWrite hMem{};
7885

7986
// Check for incorrect use of function
8087
if (!IsSlowMem(pAddr, uiAmount))
8188
{
8289
dassert(0 && "Should use Mem*Fast function");
83-
hMem.dwFirstPage = 0;
8490
return hMem;
8591
}
8692

@@ -91,10 +97,18 @@ namespace SharedUtil
9197
hMem.dwFirstPage = ((DWORD)pAddr) & ~0xFFF;
9298
DWORD dwLastPage = (((DWORD)pAddr) + uiAmount - 1) & ~0xFFF;
9399
hMem.dwSize = dwLastPage - hMem.dwFirstPage + 0x1000;
94-
VirtualProtect((LPVOID)hMem.dwFirstPage, 0x1000, PAGE_EXECUTE_READWRITE, &hMem.oldProt);
100+
101+
if (!VirtualProtect((LPVOID)hMem.dwFirstPage, 0x1000, PAGE_EXECUTE_READWRITE, &hMem.oldProt))
102+
{
103+
DWORD error = GetLastError();
104+
OutputDebugLine(SString("MemAccess::OpenMemWrite: VirtualProtect failed at %08x, error: %d", hMem.dwFirstPage, error));
105+
hMem = {};
106+
assert(!"Failed to unprotect memory");
107+
return hMem;
108+
}
95109

96110
// Make sure not using this slow function too much
97-
OutputDebugLine(SString("[Mem] OpenMemWrite at %08x for %d bytes (oldProt:%04x)", pAddr, uiAmount, hMem.oldProt));
111+
OutputDebugLine(SString("[MemAccess] OpenMemWrite at %08x for %d bytes (oldProt:%04x)", pAddr, uiAmount, hMem.oldProt));
98112

99113
#ifdef MTA_DEBUG
100114
#if 0 // Annoying
@@ -103,14 +117,25 @@ namespace SharedUtil
103117
assert( hMem.oldProt == PAGE_EXECUTE_READ || hMem.oldProt == PAGE_READONLY );
104118
else
105119
assert( hMem.oldProt == PAGE_EXECUTE_READWRITE || hMem.oldProt == PAGE_EXECUTE_WRITECOPY );
106-
#endif
120+
#endif
107121
#endif
108122

109123
// Extra if more than one page
110124
for (uint i = 0x1000; i < hMem.dwSize; i += 0x1000)
111125
{
112126
DWORD oldProtCheck;
113-
VirtualProtect((LPVOID)(hMem.dwFirstPage + i), 0x1000, PAGE_EXECUTE_READWRITE, &oldProtCheck);
127+
if (!VirtualProtect((LPVOID)(hMem.dwFirstPage + i), 0x1000, PAGE_EXECUTE_READWRITE, &oldProtCheck))
128+
{
129+
// Try to rollback
130+
DWORD temp;
131+
VirtualProtect((LPVOID)hMem.dwFirstPage, i, hMem.oldProt, &temp);
132+
133+
DWORD error = GetLastError();
134+
OutputDebugLine(SString("[MemAccess] OpenMemWrite VirtualProtect failed at %08x, error: %d", hMem.dwFirstPage + i, error));
135+
hMem = {};
136+
assert(!"Failed to unprotect multi-page memory region");
137+
return hMem;
138+
}
114139
dassert(hMem.oldProt == oldProtCheck);
115140
}
116141

@@ -122,9 +147,22 @@ namespace SharedUtil
122147
{
123148
if (hMem.dwFirstPage == 0)
124149
return;
150+
125151
DWORD oldProt;
126-
VirtualProtect((LPVOID)hMem.dwFirstPage, hMem.dwSize, hMem.oldProt, &oldProt);
127-
dassert(oldProt == PAGE_EXECUTE_READWRITE);
152+
BOOL result = VirtualProtect((LPVOID)hMem.dwFirstPage, hMem.dwSize, hMem.oldProt, &oldProt);
153+
154+
if (!result)
155+
{
156+
DWORD error = GetLastError();
157+
OutputDebugLine(SString("MemAccess::CloseMemWrite: VirtualProtect failed at %08x, size %08x, error: %d", hMem.dwFirstPage, hMem.dwSize, error));
158+
assert(!"Failed to restore memory protection - critical");
159+
}
160+
else
161+
{
162+
dassert(oldProt == PAGE_EXECUTE_READWRITE);
163+
}
164+
165+
hMem.dwFirstPage = 0; // Invalidate handle
128166
}
129167

130168
} // namespace SharedUtil

0 commit comments

Comments
 (0)