diff --git a/com/context.go b/com/context.go index bfdcc98..3e9df19 100644 --- a/com/context.go +++ b/com/context.go @@ -12,12 +12,22 @@ import ( var tlsIndex uint32 type Context struct { - ID int32 //could be reused - TID uint32 - CurrentScope *Scope + ID int32 //could be reused + TID uint32 + + currentScope unsafe.Pointer // *Scope LastError *ErrorInfo } +func (this *Context) GetCurrentScope() *Scope { + p := atomic.LoadPointer(&this.currentScope) + return (*Scope)(p) +} + +func (this *Context) SetCurrentScope(s *Scope) { + atomic.StorePointer(&this.currentScope, unsafe.Pointer(s)) +} + var contexts []*Context var muContext sync.Mutex @@ -32,7 +42,6 @@ func init() { } func InitializeContext() { - //win32.TlsSetValue() index := -1 context := &Context{} context.TID = win32.GetCurrentThreadId() @@ -110,9 +119,15 @@ func Initialize() Initialized { InitializeContext() win32.CoInitialize(nil) - //atomic.LoadInt32() - //tId := win32.GetCurrentThreadId() - //comThreadIds.Store(tId, true) + return Initialized{} +} + +func InitializeMt() Initialized { + runtime.LockOSThread() + + InitializeContext() + win32.CoInitializeEx(nil, win32.COINIT_MULTITHREADED) + return Initialized{} } diff --git a/com/error.go b/com/error.go index 3025d08..b1391db 100644 --- a/com/error.go +++ b/com/error.go @@ -29,6 +29,9 @@ func NewError(hr win32.HRESULT) Error { return Error(hr) } +const OK = Error(win32.S_OK) +const FAIL = Error(win32.E_FAIL) + func NewErrorOrNil(hr win32.HRESULT) error { if win32.SUCCEEDED(hr) { return nil diff --git a/com/impl.go b/com/impl.go index 30bade6..1e567de 100644 --- a/com/impl.go +++ b/com/impl.go @@ -176,6 +176,10 @@ type IUnknownComObj struct { Parent *IUnknownComObj } +func (this *IUnknownComObj) GetIUnknown() *win32.IUnknown { + return this.IUnknown() +} + func (this *IUnknownComObj) AssignPpvObject(ppvObject unsafe.Pointer) { *(*unsafe.Pointer)(ppvObject) = unsafe.Pointer(this) } diff --git a/com/scope.go b/com/scope.go index de6531e..44b63d7 100644 --- a/com/scope.go +++ b/com/scope.go @@ -3,6 +3,7 @@ package com import ( "github.com/zzl/go-win32api/win32" "log" + "sync" "unsafe" ) @@ -11,7 +12,7 @@ type scopedObject struct { Type int //0:com interface, 1:bstr, 2:*variant, 3:safearray } -//var CurrentScope *Scope +var muScope sync.Mutex type Scope struct { scopedObjs []scopedObject @@ -21,30 +22,36 @@ type Scope struct { func NewScope() *Scope { context := GetContext() scope := &Scope{ - ParentScope: context.CurrentScope, + ParentScope: context.GetCurrentScope(), } - context.CurrentScope = scope + context.SetCurrentScope(scope) return scope } +func (this *Scope) _add(so scopedObject) { + muScope.Lock() + this.scopedObjs = append(this.scopedObjs, so) + muScope.Unlock() +} + func (this *Scope) Add(pUnknown unsafe.Pointer) { - this.scopedObjs = append(this.scopedObjs, scopedObject{Ptr: pUnknown}) + this._add(scopedObject{Ptr: pUnknown}) } func (this *Scope) AddComPtr(iunknownObj win32.IUnknownObject, addRef ...bool) { pUnknown := iunknownObj.GetIUnknown() - this.scopedObjs = append(this.scopedObjs, scopedObject{Ptr: unsafe.Pointer(pUnknown), Type: 0}) if len(addRef) == 1 && addRef[0] { pUnknown.AddRef() } + this._add(scopedObject{Ptr: unsafe.Pointer(pUnknown), Type: 0}) } func (this *Scope) AddBstr(bstr win32.BSTR) { - this.scopedObjs = append(this.scopedObjs, scopedObject{Ptr: unsafe.Pointer(bstr), Type: 1}) + this._add(scopedObject{Ptr: unsafe.Pointer(bstr), Type: 1}) } func (this *Scope) AddVar(pVar *win32.VARIANT) { - this.scopedObjs = append(this.scopedObjs, scopedObject{Ptr: unsafe.Pointer(pVar), Type: 2}) + this._add(scopedObject{Ptr: unsafe.Pointer(pVar), Type: 2}) } func (this *Scope) AddVarIfNeeded(pVar *win32.VARIANT) { @@ -54,11 +61,11 @@ func (this *Scope) AddVarIfNeeded(pVar *win32.VARIANT) { default: return } - this.scopedObjs = append(this.scopedObjs, scopedObject{Ptr: unsafe.Pointer(pVar), Type: 2}) + this._add(scopedObject{Ptr: unsafe.Pointer(pVar), Type: 2}) } func (this *Scope) AddArray(psa *win32.SAFEARRAY) { - this.scopedObjs = append(this.scopedObjs, scopedObject{Ptr: unsafe.Pointer(psa), Type: 3}) + this._add(scopedObject{Ptr: unsafe.Pointer(psa), Type: 3}) } type VariantCompatible interface { @@ -70,7 +77,7 @@ func AddToScope(value interface{}, scope ...*Scope) { if len(scope) != 0 { s = scope[0] } else { - currentScope := GetContext().CurrentScope + currentScope := GetContext().GetCurrentScope() if currentScope == nil { log.Panic("no current scope") } @@ -104,13 +111,19 @@ func WithScope(action func()) { func (this *Scope) Leave() { this.Clear() - GetContext().CurrentScope = this.ParentScope + GetContext().SetCurrentScope(this.ParentScope) } func (this *Scope) Clear() { + muScope.Lock() count := len(this.scopedObjs) - for n := 0; n < count; n++ { - obj := this.scopedObjs[n] + scopedObjs := make([]scopedObject, count) + copy(scopedObjs, this.scopedObjs) + this.scopedObjs = nil + muScope.Unlock() + + for n := count - 1; n >= 0; n-- { + obj := scopedObjs[n] if obj.Type == 0 { (*win32.IUnknown)(obj.Ptr).Release() } else if obj.Type == 1 { @@ -121,5 +134,5 @@ func (this *Scope) Clear() { win32.SafeArrayDestroy((*win32.SAFEARRAY)(obj.Ptr)) } } - this.scopedObjs = nil + } diff --git a/go.mod b/go.mod index 7e2d87d..c81c862 100644 --- a/go.mod +++ b/go.mod @@ -2,8 +2,6 @@ module github.com/zzl/go-com go 1.18 -replace github.com/zzl/go-win32api v1.1.3 => C:\Users\zzl\Documents\GitHub\go-win32api - -require github.com/zzl/go-win32api v1.1.3 +require github.com/zzl/go-win32api v1.2.0 require golang.org/x/sys v0.0.0-20220330033206-e17cdc41300f // indirect diff --git a/go.sum b/go.sum index 669aa19..857f8c3 100644 --- a/go.sum +++ b/go.sum @@ -1,2 +1,4 @@ +github.com/zzl/go-win32api v1.2.0 h1:8cUPn/io9S0H/Ah4pnuM7xmElL2MDR9bKgl8AREG/AE= +github.com/zzl/go-win32api v1.2.0/go.mod h1:iWVjU/KzuwzqGpgBZdQ6Z4JqFXeSPIzantVIkcyD4b4= golang.org/x/sys v0.0.0-20220330033206-e17cdc41300f h1:rlezHXNlxYWvBCzNses9Dlc7nGFaNMJeqLolcmQSSZY= golang.org/x/sys v0.0.0-20220330033206-e17cdc41300f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= diff --git a/ole/oleclient.go b/ole/oleclient.go index 90c37e0..df0c98d 100644 --- a/ole/oleclient.go +++ b/ole/oleclient.go @@ -161,8 +161,7 @@ func SetVariantParam(v *Variant, value interface{}, unwrapActions *Actions) { *v.PvarVal() = (*win32.VARIANT)(val) case int8: v.Vt = uint16(win32.VT_I1) - //*v.CVal() = win32.CHAR(val) - *v.CVal() = val + *v.CVal() = win32.CHAR(val) case uint8: v.Vt = uint16(win32.VT_UI1) *v.BVal() = val @@ -202,8 +201,7 @@ func SetVariantParam(v *Variant, value interface{}, unwrapActions *Actions) { }) case *int8: v.Vt = uint16(win32.VT_I1 | win32.VT_BYREF) - //*v.PcVal() = (*win32.CHAR)(unsafe.Pointer(val)) - *v.PcVal() = val + *v.PcVal() = (*win32.CHAR)(unsafe.Pointer(val)) case *uint8: v.Vt = uint16(win32.VT_UI1 | win32.VT_BYREF) *v.PbVal() = val diff --git a/ole/variant.go b/ole/variant.go index d6ca0b0..b74d060 100644 --- a/ole/variant.go +++ b/ole/variant.go @@ -58,8 +58,7 @@ func NewVariant(value interface{}) *Variant { case int8: v := &Variant{} v.Vt = uint16(win32.VT_I1) - //*v.CVal() = win32.CHAR(val) - *v.CVal() = val + *v.CVal() = win32.CHAR(val) return v case uint8: v := &Variant{}