Skip to content

Commit

Permalink
scope and ole implementations revised
Browse files Browse the repository at this point in the history
  • Loading branch information
zzl committed May 22, 2022
1 parent 10c4b69 commit 1d34a38
Show file tree
Hide file tree
Showing 6 changed files with 155 additions and 50 deletions.
7 changes: 7 additions & 0 deletions com/impl.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,10 @@ type Initializable interface {
Initialize()
}

type ComObjFreeListener interface {
OnComObjFree()
}

func NewComObj[T any, PT ComObjConstraint[T]](impl win32.IUnknownInterface) *T {
if roa, ok := impl.(RealObjectAware); ok {
roa.SetRealObject(impl)
Expand Down Expand Up @@ -135,6 +139,9 @@ func (this *IUnknownComObj) Impl() win32.IUnknownInterface {
}

func (this *IUnknownComObj) free() {
if lsnr, ok := this.Impl().(ComObjFreeListener); ok {
lsnr.OnComObjFree()
}
impls[this.implSlot] = nil
freeImplSlots = append(freeImplSlots, this.implSlot)
bOk, err := win32.HeapFree(hHeap, 0, unsafe.Pointer(this))
Expand Down
70 changes: 40 additions & 30 deletions com/scope.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,41 +3,47 @@ package com
import (
"github.com/zzl/go-win32api/win32"
"unsafe"
"log"
)

type scopedObject struct {
Ptr unsafe.Pointer
Type int //0:com interface, 1:bstr, 2:*variant, 3:safearray
}

var scopedObjs []scopedObject

var CurrentScope *Scope

type Scope struct {
index int
scopedObjs []scopedObject
ParentScope *Scope
}

func NewScope() *Scope {
scope := &Scope{
index: len(scopedObjs),
ParentScope: CurrentScope,
}
CurrentScope = scope
return scope
}

func (this *Scope) Add(pUnknown unsafe.Pointer) {
scopedObjs = append(scopedObjs, scopedObject{Ptr: pUnknown})
this.scopedObjs = append(this.scopedObjs, scopedObject{Ptr: pUnknown})
}

func (this *Scope) AddComPtr(iunknownObj win32.IUnknownObject, addRef ...bool) {
pUnknown := iunknownObj.GetIUnknown()
this.scopedObjs = append(this.scopedObjs, scopedObject{Ptr: unsafe.Pointer(pUnknown), Type: 0})
if len(addRef) == 1 && addRef[0] {
pUnknown.AddRef()
}
}

func (this *Scope) AddBstr(bstr win32.BSTR) {
scopedObjs = append(scopedObjs, scopedObject{Ptr: unsafe.Pointer(bstr), Type: 1})
this.scopedObjs = append(this.scopedObjs, scopedObject{Ptr: unsafe.Pointer(bstr), Type: 1})
}

func (this *Scope) AddVar(pVar *win32.VARIANT) {
scopedObjs = append(scopedObjs, scopedObject{Ptr: unsafe.Pointer(pVar), Type: 2})
this.scopedObjs = append(this.scopedObjs, scopedObject{Ptr: unsafe.Pointer(pVar), Type: 2})
}

func (this *Scope) AddVarIfNeeded(pVar *win32.VARIANT) {
Expand All @@ -47,26 +53,32 @@ func (this *Scope) AddVarIfNeeded(pVar *win32.VARIANT) {
default:
return
}
scopedObjs = append(scopedObjs, scopedObject{Ptr: unsafe.Pointer(pVar), Type: 2})
this.scopedObjs = append(this.scopedObjs, scopedObject{Ptr: unsafe.Pointer(pVar), Type: 2})
}

func (this *Scope) AddArray(psa *win32.SAFEARRAY) {
scopedObjs = append(scopedObjs, scopedObject{Ptr: unsafe.Pointer(psa), Type: 3})
this.scopedObjs = append(this.scopedObjs, scopedObject{Ptr: unsafe.Pointer(psa), Type: 3})
}

func AddToScope(value interface{}) {
if CurrentScope == nil {
panic("no current scope") //?
func AddToScope(value interface{}, scope ...*Scope) {
var s *Scope
if len(scope) != 0 {
s = scope[0]
} else {
if CurrentScope == nil {
log.Panic("no current scope")
}
s = CurrentScope
}
switch v := value.(type) {
case win32.IUnknownObject:
CurrentScope.Add(unsafe.Pointer(v.GetIUnknown()))
s.Add(unsafe.Pointer(v.GetIUnknown()))
case win32.BSTR:
CurrentScope.AddBstr(v)
s.AddBstr(v)
case win32.VARIANT:
CurrentScope.AddVarIfNeeded(&v)
s.AddVarIfNeeded(&v)
case *win32.SAFEARRAY:
CurrentScope.AddArray(v)
s.AddArray(v)
default:
println("?")
}
Expand All @@ -84,20 +96,18 @@ func (this *Scope) Leave() {
}

func (this *Scope) Clear() {
count := len(scopedObjs)
for n := this.index; n < count; n++ {
obj := scopedObjs[n]
if obj.Ptr != nil {
if obj.Type == 0 {
(*win32.IUnknown)(obj.Ptr).Release()
} else if obj.Type == 1 {
win32.SysFreeString((win32.BSTR)(obj.Ptr))
} else if obj.Type == 2 {
win32.VariantClear((*win32.VARIANT)(obj.Ptr))
} else if obj.Type == 3 {
win32.SafeArrayDestroy((*win32.SAFEARRAY)(obj.Ptr))
}
count := len(this.scopedObjs)
for n := 0; n < count; n++ {
obj := this.scopedObjs[n]
if obj.Type == 0 {
(*win32.IUnknown)(obj.Ptr).Release()
} else if obj.Type == 1 {
win32.SysFreeString((win32.BSTR)(obj.Ptr))
} else if obj.Type == 2 {
win32.VariantClear((*win32.VARIANT)(obj.Ptr))
} else if obj.Type == 3 {
win32.SafeArrayDestroy((*win32.SAFEARRAY)(obj.Ptr))
}
}
scopedObjs = scopedObjs[:this.index]
this.scopedObjs = nil
}
40 changes: 40 additions & 0 deletions ole/oleimpl/boundmethod.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
package oleimpl

import (
"github.com/zzl/go-com/ole"
"github.com/zzl/go-win32api/win32"
"syscall"
)

type BoundMethodDispImpl struct {
ole.IDispatchImpl
ownerDispObj *win32.IDispatch
memberDispId int32
}

func (this *BoundMethodDispImpl) GetIDsOfNames(riid *syscall.GUID, rgszNames *win32.PWSTR,
cNames uint32, lcid uint32, rgDispId *int32) win32.HRESULT {
return win32.DISP_E_UNKNOWNNAME
}

func (this *BoundMethodDispImpl) Invoke(dispIdMember int32, riid *syscall.GUID,
lcid uint32, wFlags uint16, pDispParams *win32.DISPPARAMS, pVarResult *win32.VARIANT,
pExcepInfo *win32.EXCEPINFO, puArgErr *uint32) win32.HRESULT {
if dispIdMember == int32(win32.DISPID_VALUE) {
return this.ownerDispObj.Invoke(this.memberDispId, riid, lcid,
wFlags, pDispParams, pVarResult, pExcepInfo, puArgErr)
}
return win32.E_NOTIMPL
}

func (this *BoundMethodDispImpl) OnComObjFree() {
this.ownerDispObj.Release()
}

func NewBoundMethodDispatch(pDispOwner *win32.IDispatch, memberDispId int32) *win32.IDispatch {
pDisp := ole.NewIDispatchComObject(&BoundMethodDispImpl{
ownerDispObj: pDispOwner,
memberDispId: memberDispId,
}).IDispatch()
return pDisp
}
47 changes: 34 additions & 13 deletions ole/oleimpl/funcmap.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,14 @@ type FuncMapDispImpl struct {
funcs []VariantFunc
pNames []string
props []VariantProp

OnFinalize func()
}

func (this *FuncMapDispImpl) OnComObjFree() {
if this.OnFinalize != nil {
this.OnFinalize()
}
}

func (this *FuncMapDispImpl) GetIDsOfNames(riid *syscall.GUID, rgszNames *win32.PWSTR,
Expand All @@ -33,13 +41,13 @@ func (this *FuncMapDispImpl) GetIDsOfNames(riid *syscall.GUID, rgszNames *win32.
name = strings.ToLower(name)
for n, fName := range this.fNames {
if fName == name {
*rgDispId = int32(n)
*rgDispId = int32(n) + 1
return win32.S_OK
}
}
for n, pName := range this.pNames {
if pName == name {
*rgDispId = int32(n + len(this.fNames))
*rgDispId = int32(n+len(this.fNames)) + 1
return win32.S_OK
}
}
Expand All @@ -51,21 +59,32 @@ func (this *FuncMapDispImpl) Invoke(dispIdMember int32, riid *syscall.GUID,
pExcepInfo *win32.EXCEPINFO, puArgErr *uint32) win32.HRESULT {

vArgs, _ := ole.ProcessInvokeArgs(pDispParams, 9)
if wFlags == uint16(win32.DISPATCH_METHOD) {
funcIdx := int(dispIdMember)
if funcIdx < len(this.funcs) {
pvRet := this.funcs[funcIdx](vArgs...)
if pVarResult != nil && pvRet != nil {
*pVarResult = win32.VARIANT(*pvRet)

funcIdx := int(dispIdMember) - 1
if funcIdx >= 0 && funcIdx < len(this.funcs) {
if wFlags == uint16(win32.DISPATCH_PROPERTYGET) { //
if pDispParams.CArgs == 0 {
pDispThis := (*win32.IDispatch)(this.ComObject.Pointer())
pDisp := NewBoundMethodDispatch(pDispThis, dispIdMember)
*(*ole.Variant)(pVarResult) = *ole.NewVariantDispatch(pDisp)
return win32.S_OK
} else {
return win32.E_NOTIMPL
}
return win32.S_OK
}
} else if propIdx := int(dispIdMember) - len(this.funcs); propIdx >= 0 && propIdx < len(this.props) {

pvRet := this.funcs[funcIdx](vArgs...)
if pVarResult != nil && pvRet != nil {
*pVarResult = win32.VARIANT(*pvRet)
}
return win32.S_OK
} else if propIdx := funcIdx - len(this.funcs); propIdx >= 0 && propIdx < len(this.props) {
prop := this.props[propIdx]
var f VariantFunc
if wFlags == uint16(win32.DISPATCH_PROPERTYGET) && prop.Get != nil {
if wFlags == uint16(win32.DISPATCH_PROPERTYGET) {
f = prop.Get
} else {
} else if wFlags == uint16(win32.DISPATCH_PROPERTYPUT) ||
wFlags == uint16(win32.DISPATCH_PROPERTYPUTREF) {
f = prop.Set
}
if f == nil {
Expand All @@ -81,8 +100,9 @@ func (this *FuncMapDispImpl) Invoke(dispIdMember int32, riid *syscall.GUID,
return win32.E_NOTIMPL
}

//
func NewFuncMapDispatch(funcMap map[string]VariantFunc,
propMap map[string]VariantProp) *win32.IDispatch {
propMap map[string]VariantProp, onFinalize func()) *win32.IDispatch {
var fNames []string
var funcs []VariantFunc
for name, f := range funcMap {
Expand All @@ -100,6 +120,7 @@ func NewFuncMapDispatch(funcMap map[string]VariantFunc,
funcs: funcs,
pNames: pNames,
props: props,
OnFinalize: onFinalize,
})
return pDisp
}
18 changes: 16 additions & 2 deletions ole/oleimpl/reflect.go
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ func (this *ReflectDispImpl) Invoke(dispIdMember int32, riid *syscall.GUID,
pExcepInfo *win32.EXCEPINFO, puArgErr *uint32) win32.HRESULT {
dispId := int(dispIdMember)
if dispId == 0 {
//?
return win32.E_NOTIMPL //?
} else if dispId > len(this.members) {
return win32.E_INVALIDARG
}
Expand All @@ -102,6 +102,16 @@ func (this *ReflectDispImpl) Invoke(dispIdMember int32, riid *syscall.GUID,
funcValue = member.CallFuncValue
} else if wFlags == uint16(win32.DISPATCH_PROPERTYGET) {
funcValue = member.GetFuncValue
if funcValue == nil {
if member.CallFuncValue != nil && pDispParams.CArgs == 0 {
pDispThis := (*win32.IDispatch)(this.ComObject.Pointer())
pDisp := NewBoundMethodDispatch(pDispThis, dispIdMember)
*(*ole.Variant)(pVarResult) = *ole.NewVariantDispatch(pDisp)
return win32.S_OK
} else {
return win32.E_NOTIMPL
}
}
} else {
funcValue = member.SetFuncValue
}
Expand Down Expand Up @@ -134,7 +144,11 @@ func (this *ReflectDispImpl) Invoke(dispIdMember int32, riid *syscall.GUID,
if numOut == 0 {
//
} else if numOut == 1 {
ole.SetVariantParam((*ole.Variant)(pVarResult), retVals[0].Interface(), nil)
var vResult ole.Variant
var unwrapActions ole.Actions
ole.SetVariantParam(&vResult, retVals[0].Interface(), &unwrapActions)
*(*ole.Variant)(pVarResult) = *vResult.Copy()
unwrapActions.Execute()
}
println("?")

Expand Down
Loading

0 comments on commit 1d34a38

Please sign in to comment.