From 1d34a38c06c786e487fad57a887bc4e9ab25053a Mon Sep 17 00:00:00 2001 From: zzl Date: Mon, 23 May 2022 01:45:47 +0800 Subject: [PATCH] scope and ole implementations revised --- com/impl.go | 7 ++++ com/scope.go | 70 ++++++++++++++++++++++---------------- ole/oleimpl/boundmethod.go | 40 ++++++++++++++++++++++ ole/oleimpl/funcmap.go | 47 ++++++++++++++++++------- ole/oleimpl/reflect.go | 18 ++++++++-- ole/variant.go | 23 ++++++++++--- 6 files changed, 155 insertions(+), 50 deletions(-) create mode 100644 ole/oleimpl/boundmethod.go diff --git a/com/impl.go b/com/impl.go index af02ea5..174ba84 100644 --- a/com/impl.go +++ b/com/impl.go @@ -37,6 +37,10 @@ type Initializable interface { Initialize() } +type ComObjFreeListener interface { + OnComObjFree() +} + func NewComObj[T any, PT ComObjConstraint[T]](impl win32.IUnknownInterface) *T { if roa, ok := impl.(RealObjectAware); ok { roa.SetRealObject(impl) @@ -135,6 +139,9 @@ func (this *IUnknownComObj) Impl() win32.IUnknownInterface { } func (this *IUnknownComObj) free() { + if lsnr, ok := this.Impl().(ComObjFreeListener); ok { + lsnr.OnComObjFree() + } impls[this.implSlot] = nil freeImplSlots = append(freeImplSlots, this.implSlot) bOk, err := win32.HeapFree(hHeap, 0, unsafe.Pointer(this)) diff --git a/com/scope.go b/com/scope.go index 0a28192..35623fe 100644 --- a/com/scope.go +++ b/com/scope.go @@ -3,6 +3,7 @@ package com import ( "github.com/zzl/go-win32api/win32" "unsafe" + "log" ) type scopedObject struct { @@ -10,18 +11,15 @@ type scopedObject struct { Type int //0:com interface, 1:bstr, 2:*variant, 3:safearray } -var scopedObjs []scopedObject - var CurrentScope *Scope type Scope struct { - index int + scopedObjs []scopedObject ParentScope *Scope } func NewScope() *Scope { scope := &Scope{ - index: len(scopedObjs), ParentScope: CurrentScope, } CurrentScope = scope @@ -29,15 +27,23 @@ func NewScope() *Scope { } func (this *Scope) Add(pUnknown unsafe.Pointer) { - scopedObjs = append(scopedObjs, scopedObject{Ptr: pUnknown}) + this.scopedObjs = append(this.scopedObjs, scopedObject{Ptr: pUnknown}) +} + +func (this *Scope) AddComPtr(iunknownObj win32.IUnknownObject, addRef ...bool) { + pUnknown := iunknownObj.GetIUnknown() + this.scopedObjs = append(this.scopedObjs, scopedObject{Ptr: unsafe.Pointer(pUnknown), Type: 0}) + if len(addRef) == 1 && addRef[0] { + pUnknown.AddRef() + } } func (this *Scope) AddBstr(bstr win32.BSTR) { - scopedObjs = append(scopedObjs, scopedObject{Ptr: unsafe.Pointer(bstr), Type: 1}) + this.scopedObjs = append(this.scopedObjs, scopedObject{Ptr: unsafe.Pointer(bstr), Type: 1}) } func (this *Scope) AddVar(pVar *win32.VARIANT) { - scopedObjs = append(scopedObjs, scopedObject{Ptr: unsafe.Pointer(pVar), Type: 2}) + this.scopedObjs = append(this.scopedObjs, scopedObject{Ptr: unsafe.Pointer(pVar), Type: 2}) } func (this *Scope) AddVarIfNeeded(pVar *win32.VARIANT) { @@ -47,26 +53,32 @@ func (this *Scope) AddVarIfNeeded(pVar *win32.VARIANT) { default: return } - scopedObjs = append(scopedObjs, scopedObject{Ptr: unsafe.Pointer(pVar), Type: 2}) + this.scopedObjs = append(this.scopedObjs, scopedObject{Ptr: unsafe.Pointer(pVar), Type: 2}) } func (this *Scope) AddArray(psa *win32.SAFEARRAY) { - scopedObjs = append(scopedObjs, scopedObject{Ptr: unsafe.Pointer(psa), Type: 3}) + this.scopedObjs = append(this.scopedObjs, scopedObject{Ptr: unsafe.Pointer(psa), Type: 3}) } -func AddToScope(value interface{}) { - if CurrentScope == nil { - panic("no current scope") //? +func AddToScope(value interface{}, scope ...*Scope) { + var s *Scope + if len(scope) != 0 { + s = scope[0] + } else { + if CurrentScope == nil { + log.Panic("no current scope") + } + s = CurrentScope } switch v := value.(type) { case win32.IUnknownObject: - CurrentScope.Add(unsafe.Pointer(v.GetIUnknown())) + s.Add(unsafe.Pointer(v.GetIUnknown())) case win32.BSTR: - CurrentScope.AddBstr(v) + s.AddBstr(v) case win32.VARIANT: - CurrentScope.AddVarIfNeeded(&v) + s.AddVarIfNeeded(&v) case *win32.SAFEARRAY: - CurrentScope.AddArray(v) + s.AddArray(v) default: println("?") } @@ -84,20 +96,18 @@ func (this *Scope) Leave() { } func (this *Scope) Clear() { - count := len(scopedObjs) - for n := this.index; n < count; n++ { - obj := scopedObjs[n] - if obj.Ptr != nil { - if obj.Type == 0 { - (*win32.IUnknown)(obj.Ptr).Release() - } else if obj.Type == 1 { - win32.SysFreeString((win32.BSTR)(obj.Ptr)) - } else if obj.Type == 2 { - win32.VariantClear((*win32.VARIANT)(obj.Ptr)) - } else if obj.Type == 3 { - win32.SafeArrayDestroy((*win32.SAFEARRAY)(obj.Ptr)) - } + count := len(this.scopedObjs) + for n := 0; n < count; n++ { + obj := this.scopedObjs[n] + if obj.Type == 0 { + (*win32.IUnknown)(obj.Ptr).Release() + } else if obj.Type == 1 { + win32.SysFreeString((win32.BSTR)(obj.Ptr)) + } else if obj.Type == 2 { + win32.VariantClear((*win32.VARIANT)(obj.Ptr)) + } else if obj.Type == 3 { + win32.SafeArrayDestroy((*win32.SAFEARRAY)(obj.Ptr)) } } - scopedObjs = scopedObjs[:this.index] + this.scopedObjs = nil } diff --git a/ole/oleimpl/boundmethod.go b/ole/oleimpl/boundmethod.go new file mode 100644 index 0000000..40745be --- /dev/null +++ b/ole/oleimpl/boundmethod.go @@ -0,0 +1,40 @@ +package oleimpl + +import ( + "github.com/zzl/go-com/ole" + "github.com/zzl/go-win32api/win32" + "syscall" +) + +type BoundMethodDispImpl struct { + ole.IDispatchImpl + ownerDispObj *win32.IDispatch + memberDispId int32 +} + +func (this *BoundMethodDispImpl) GetIDsOfNames(riid *syscall.GUID, rgszNames *win32.PWSTR, + cNames uint32, lcid uint32, rgDispId *int32) win32.HRESULT { + return win32.DISP_E_UNKNOWNNAME +} + +func (this *BoundMethodDispImpl) Invoke(dispIdMember int32, riid *syscall.GUID, + lcid uint32, wFlags uint16, pDispParams *win32.DISPPARAMS, pVarResult *win32.VARIANT, + pExcepInfo *win32.EXCEPINFO, puArgErr *uint32) win32.HRESULT { + if dispIdMember == int32(win32.DISPID_VALUE) { + return this.ownerDispObj.Invoke(this.memberDispId, riid, lcid, + wFlags, pDispParams, pVarResult, pExcepInfo, puArgErr) + } + return win32.E_NOTIMPL +} + +func (this *BoundMethodDispImpl) OnComObjFree() { + this.ownerDispObj.Release() +} + +func NewBoundMethodDispatch(pDispOwner *win32.IDispatch, memberDispId int32) *win32.IDispatch { + pDisp := ole.NewIDispatchComObject(&BoundMethodDispImpl{ + ownerDispObj: pDispOwner, + memberDispId: memberDispId, + }).IDispatch() + return pDisp +} diff --git a/ole/oleimpl/funcmap.go b/ole/oleimpl/funcmap.go index b399aef..40e5816 100644 --- a/ole/oleimpl/funcmap.go +++ b/ole/oleimpl/funcmap.go @@ -22,6 +22,14 @@ type FuncMapDispImpl struct { funcs []VariantFunc pNames []string props []VariantProp + + OnFinalize func() +} + +func (this *FuncMapDispImpl) OnComObjFree() { + if this.OnFinalize != nil { + this.OnFinalize() + } } func (this *FuncMapDispImpl) GetIDsOfNames(riid *syscall.GUID, rgszNames *win32.PWSTR, @@ -33,13 +41,13 @@ func (this *FuncMapDispImpl) GetIDsOfNames(riid *syscall.GUID, rgszNames *win32. name = strings.ToLower(name) for n, fName := range this.fNames { if fName == name { - *rgDispId = int32(n) + *rgDispId = int32(n) + 1 return win32.S_OK } } for n, pName := range this.pNames { if pName == name { - *rgDispId = int32(n + len(this.fNames)) + *rgDispId = int32(n+len(this.fNames)) + 1 return win32.S_OK } } @@ -51,21 +59,32 @@ func (this *FuncMapDispImpl) Invoke(dispIdMember int32, riid *syscall.GUID, pExcepInfo *win32.EXCEPINFO, puArgErr *uint32) win32.HRESULT { vArgs, _ := ole.ProcessInvokeArgs(pDispParams, 9) - if wFlags == uint16(win32.DISPATCH_METHOD) { - funcIdx := int(dispIdMember) - if funcIdx < len(this.funcs) { - pvRet := this.funcs[funcIdx](vArgs...) - if pVarResult != nil && pvRet != nil { - *pVarResult = win32.VARIANT(*pvRet) + + funcIdx := int(dispIdMember) - 1 + if funcIdx >= 0 && funcIdx < len(this.funcs) { + if wFlags == uint16(win32.DISPATCH_PROPERTYGET) { // + if pDispParams.CArgs == 0 { + pDispThis := (*win32.IDispatch)(this.ComObject.Pointer()) + pDisp := NewBoundMethodDispatch(pDispThis, dispIdMember) + *(*ole.Variant)(pVarResult) = *ole.NewVariantDispatch(pDisp) + return win32.S_OK + } else { + return win32.E_NOTIMPL } - return win32.S_OK } - } else if propIdx := int(dispIdMember) - len(this.funcs); propIdx >= 0 && propIdx < len(this.props) { + + pvRet := this.funcs[funcIdx](vArgs...) + if pVarResult != nil && pvRet != nil { + *pVarResult = win32.VARIANT(*pvRet) + } + return win32.S_OK + } else if propIdx := funcIdx - len(this.funcs); propIdx >= 0 && propIdx < len(this.props) { prop := this.props[propIdx] var f VariantFunc - if wFlags == uint16(win32.DISPATCH_PROPERTYGET) && prop.Get != nil { + if wFlags == uint16(win32.DISPATCH_PROPERTYGET) { f = prop.Get - } else { + } else if wFlags == uint16(win32.DISPATCH_PROPERTYPUT) || + wFlags == uint16(win32.DISPATCH_PROPERTYPUTREF) { f = prop.Set } if f == nil { @@ -81,8 +100,9 @@ func (this *FuncMapDispImpl) Invoke(dispIdMember int32, riid *syscall.GUID, return win32.E_NOTIMPL } +// func NewFuncMapDispatch(funcMap map[string]VariantFunc, - propMap map[string]VariantProp) *win32.IDispatch { + propMap map[string]VariantProp, onFinalize func()) *win32.IDispatch { var fNames []string var funcs []VariantFunc for name, f := range funcMap { @@ -100,6 +120,7 @@ func NewFuncMapDispatch(funcMap map[string]VariantFunc, funcs: funcs, pNames: pNames, props: props, + OnFinalize: onFinalize, }) return pDisp } diff --git a/ole/oleimpl/reflect.go b/ole/oleimpl/reflect.go index 2f5b0dc..7aeb917 100644 --- a/ole/oleimpl/reflect.go +++ b/ole/oleimpl/reflect.go @@ -92,7 +92,7 @@ func (this *ReflectDispImpl) Invoke(dispIdMember int32, riid *syscall.GUID, pExcepInfo *win32.EXCEPINFO, puArgErr *uint32) win32.HRESULT { dispId := int(dispIdMember) if dispId == 0 { - //? + return win32.E_NOTIMPL //? } else if dispId > len(this.members) { return win32.E_INVALIDARG } @@ -102,6 +102,16 @@ func (this *ReflectDispImpl) Invoke(dispIdMember int32, riid *syscall.GUID, funcValue = member.CallFuncValue } else if wFlags == uint16(win32.DISPATCH_PROPERTYGET) { funcValue = member.GetFuncValue + if funcValue == nil { + if member.CallFuncValue != nil && pDispParams.CArgs == 0 { + pDispThis := (*win32.IDispatch)(this.ComObject.Pointer()) + pDisp := NewBoundMethodDispatch(pDispThis, dispIdMember) + *(*ole.Variant)(pVarResult) = *ole.NewVariantDispatch(pDisp) + return win32.S_OK + } else { + return win32.E_NOTIMPL + } + } } else { funcValue = member.SetFuncValue } @@ -134,7 +144,11 @@ func (this *ReflectDispImpl) Invoke(dispIdMember int32, riid *syscall.GUID, if numOut == 0 { // } else if numOut == 1 { - ole.SetVariantParam((*ole.Variant)(pVarResult), retVals[0].Interface(), nil) + var vResult ole.Variant + var unwrapActions ole.Actions + ole.SetVariantParam(&vResult, retVals[0].Interface(), &unwrapActions) + *(*ole.Variant)(pVarResult) = *vResult.Copy() + unwrapActions.Execute() } println("?") diff --git a/ole/variant.go b/ole/variant.go index 748999e..2e95af8 100644 --- a/ole/variant.go +++ b/ole/variant.go @@ -28,6 +28,7 @@ func NewVariant(value interface{}) *Variant { case *win32.IDispatch: return NewVariantDispatch(val) case *win32.IUnknown: + val.AddRef() //ref added v := &Variant{} v.Vt = uint16(win32.VT_UNKNOWN) *v.PunkVal() = val @@ -37,7 +38,7 @@ func NewVariant(value interface{}) *Variant { case int: v := &Variant{} if val >= math.MinInt32 && val <= math.MaxInt32 { - v.Vt = uint16(win32.VT_INT) + v.Vt = uint16(win32.VT_I4) *v.IntVal() = int32(val) } else { v.Vt = uint16(win32.VT_I8) @@ -47,7 +48,7 @@ func NewVariant(value interface{}) *Variant { case uint: v := &Variant{} if val <= math.MaxUint32 { - v.Vt = uint16(win32.VT_UINT) + v.Vt = uint16(win32.VT_UI4) *v.UintVal() = uint32(val) } else { v.Vt = uint16(win32.VT_UI8) @@ -107,7 +108,7 @@ func NewVariant(value interface{}) *Variant { func (this *Variant) Copy() *Variant { var v2 win32.VARIANT - win32.VariantCopy((*win32.VARIANT)(this), &v2) + win32.VariantCopy(&v2, (*win32.VARIANT)(this)) return (*Variant)(&v2) } @@ -855,12 +856,13 @@ func (this *Variant) ToIDispatch() (*win32.IDispatch, error) { return v.PdispValVal(), nil } -func (this *Variant) ToIUnknonw() (*win32.IUnknown, error) { +func (this *Variant) ToIUnknown() (*win32.IUnknown, error) { v, vt := this, uint16(win32.VT_UNKNOWN) if v.Vt != vt { if v.Vt == vt|uint16(win32.VT_BYREF) { return *v.PpunkValVal(), nil } + //? v = &Variant{} hr := win32.VariantChangeType((*win32.VARIANT)(v), (*win32.VARIANT)(this), 0, vt) //defer v.Clear() @@ -1330,8 +1332,9 @@ type VariantDispatch struct { //24 _pad3 int64 //8@16 } -//addref? +//ref added func NewVariantDispatch(pDisp *win32.IDispatch) *Variant { + pDisp.AddRef() return (*Variant)(unsafe.Pointer(&VariantDispatch{ Vt: win32.VT_DISPATCH, Value: pDisp})) } @@ -1341,12 +1344,22 @@ func Var(value interface{}) Variant { return *v } +func NewVar(value interface{}) *Variant { + return NewVariant(value) +} + func VarScoped(value interface{}) Variant { v := NewVariant(value) com.CurrentScope.AddVarIfNeeded((*win32.VARIANT)(v)) return *v } +func NewVarScoped(value interface{}) *Variant { + v := NewVariant(value) + com.CurrentScope.AddVarIfNeeded((*win32.VARIANT)(v)) + return v +} + func CheckVarType(value interface{}) win32.VARENUM { switch value.(type) { case win32.VARIANT, Variant: