8
8
*
9
9
*****************************************************************************/
10
10
11
+ #pragma once
12
+ #include < windows.h>
13
+ #include < cassert>
14
+ #include " SharedUtil.MemAccess.h"
15
+ #include " SharedUtil.Logging.h"
16
+ #include " SString.h"
17
+
11
18
namespace SharedUtil
12
19
{
13
20
// Returns true if matching memset would have no affect
@@ -74,13 +81,12 @@ namespace SharedUtil
74
81
// Temporarily unprotect slow mem area
75
82
SMemWrite OpenMemWrite (const void * pAddr, uint uiAmount)
76
83
{
77
- SMemWrite hMem;
84
+ SMemWrite hMem{} ;
78
85
79
86
// Check for incorrect use of function
80
87
if (!IsSlowMem (pAddr, uiAmount))
81
88
{
82
89
dassert (0 && " Should use Mem*Fast function" );
83
- hMem.dwFirstPage = 0 ;
84
90
return hMem;
85
91
}
86
92
@@ -91,10 +97,18 @@ namespace SharedUtil
91
97
hMem.dwFirstPage = ((DWORD)pAddr) & ~0xFFF ;
92
98
DWORD dwLastPage = (((DWORD)pAddr) + uiAmount - 1 ) & ~0xFFF ;
93
99
hMem.dwSize = dwLastPage - hMem.dwFirstPage + 0x1000 ;
94
- VirtualProtect ((LPVOID)hMem.dwFirstPage , 0x1000 , PAGE_EXECUTE_READWRITE, &hMem.oldProt );
100
+
101
+ if (!VirtualProtect ((LPVOID)hMem.dwFirstPage , 0x1000 , PAGE_EXECUTE_READWRITE, &hMem.oldProt ))
102
+ {
103
+ DWORD error = GetLastError ();
104
+ OutputDebugLine (SString (" MemAccess::OpenMemWrite: VirtualProtect failed at %08x, error: %d" , hMem.dwFirstPage , error));
105
+ hMem = {};
106
+ assert (!" Failed to unprotect memory" );
107
+ return hMem;
108
+ }
95
109
96
110
// Make sure not using this slow function too much
97
- OutputDebugLine (SString (" [Mem ] OpenMemWrite at %08x for %d bytes (oldProt:%04x)" , pAddr, uiAmount, hMem.oldProt ));
111
+ OutputDebugLine (SString (" [MemAccess ] OpenMemWrite at %08x for %d bytes (oldProt:%04x)" , pAddr, uiAmount, hMem.oldProt ));
98
112
99
113
#ifdef MTA_DEBUG
100
114
#if 0 // Annoying
@@ -103,14 +117,25 @@ namespace SharedUtil
103
117
assert( hMem.oldProt == PAGE_EXECUTE_READ || hMem.oldProt == PAGE_READONLY );
104
118
else
105
119
assert( hMem.oldProt == PAGE_EXECUTE_READWRITE || hMem.oldProt == PAGE_EXECUTE_WRITECOPY );
106
- #endif
120
+ #endif
107
121
#endif
108
122
109
123
// Extra if more than one page
110
124
for (uint i = 0x1000 ; i < hMem.dwSize ; i += 0x1000 )
111
125
{
112
126
DWORD oldProtCheck;
113
- VirtualProtect ((LPVOID)(hMem.dwFirstPage + i), 0x1000 , PAGE_EXECUTE_READWRITE, &oldProtCheck);
127
+ if (!VirtualProtect ((LPVOID)(hMem.dwFirstPage + i), 0x1000 , PAGE_EXECUTE_READWRITE, &oldProtCheck))
128
+ {
129
+ // Try to rollback
130
+ DWORD temp;
131
+ VirtualProtect ((LPVOID)hMem.dwFirstPage , i, hMem.oldProt , &temp);
132
+
133
+ DWORD error = GetLastError ();
134
+ OutputDebugLine (SString (" [MemAccess] OpenMemWrite VirtualProtect failed at %08x, error: %d" , hMem.dwFirstPage + i, error));
135
+ hMem = {};
136
+ assert (!" Failed to unprotect multi-page memory region" );
137
+ return hMem;
138
+ }
114
139
dassert (hMem.oldProt == oldProtCheck);
115
140
}
116
141
@@ -122,9 +147,22 @@ namespace SharedUtil
122
147
{
123
148
if (hMem.dwFirstPage == 0 )
124
149
return ;
150
+
125
151
DWORD oldProt;
126
- VirtualProtect ((LPVOID)hMem.dwFirstPage , hMem.dwSize , hMem.oldProt , &oldProt);
127
- dassert (oldProt == PAGE_EXECUTE_READWRITE);
152
+ BOOL result = VirtualProtect ((LPVOID)hMem.dwFirstPage , hMem.dwSize , hMem.oldProt , &oldProt);
153
+
154
+ if (!result)
155
+ {
156
+ DWORD error = GetLastError ();
157
+ OutputDebugLine (SString (" MemAccess::CloseMemWrite: VirtualProtect failed at %08x, size %08x, error: %d" , hMem.dwFirstPage , hMem.dwSize , error));
158
+ assert (!" Failed to restore memory protection - critical" );
159
+ }
160
+ else
161
+ {
162
+ dassert (oldProt == PAGE_EXECUTE_READWRITE);
163
+ }
164
+
165
+ hMem.dwFirstPage = 0 ; // Invalidate handle
128
166
}
129
167
130
168
} // namespace SharedUtil
0 commit comments