diff --git a/core/runtime/src/factory/EggContainerFactory.ts b/core/runtime/src/factory/EggContainerFactory.ts index de5f0e9b..2e2765fe 100644 --- a/core/runtime/src/factory/EggContainerFactory.ts +++ b/core/runtime/src/factory/EggContainerFactory.ts @@ -1,7 +1,7 @@ import { EggPrototype } from '@eggjs/tegg-metadata'; import { EggContainer } from '../model/EggContainer'; import { LifecycleContext } from '@eggjs/tegg-lifecycle'; -import { EggObjectName, ObjectInitTypeLike } from '@eggjs/core-decorator'; +import { EggObjectName, ObjectInitTypeLike, PrototypeUtil, EggProtoImplClass } from '@eggjs/core-decorator'; import { EggObject } from '../model/EggObject'; import { ContextHandler } from '../model/ContextHandler'; import { ContextInitiator } from '../impl/ContextInitiator'; @@ -40,6 +40,19 @@ export class EggContainerFactory { return obj; } + /** + * get or create egg object from the Class + * If get singleton egg object in context, + * will create context egg object for it. + */ + static async getOrCreateEggObjectFromClazz(clazz: EggProtoImplClass, name?: EggObjectName): Promise { + const proto = PrototypeUtil.getClazzProto(clazz) as EggPrototype; + if (!proto) { + throw new Error(`can not get proto for clazz ${clazz.name}`); + } + return await this.getOrCreateEggObject(proto, name); + } + static getEggObject(proto: EggPrototype, name?: EggObjectName): EggObject { const container = this.getContainer(proto); name = name || proto.name; diff --git a/core/runtime/test/EggObject.test.ts b/core/runtime/test/EggObject.test.ts index 6ce2f003..44a95613 100644 --- a/core/runtime/test/EggObject.test.ts +++ b/core/runtime/test/EggObject.test.ts @@ -69,6 +69,11 @@ describe('test/EggObject.test.ts', () => { const barProto = EggPrototypeFactory.instance.getPrototype('bar'); const barObj = await EggContainerFactory.getOrCreateEggObject(barProto, barProto.name); const bar = barObj.obj as Bar; + // get obj from class + const barObj2 = await EggContainerFactory.getOrCreateEggObjectFromClazz((barProto as any).clazz, barProto.name); + assert.equal(barObj2, barObj); + assert.equal(barObj2.obj, barObj.obj); + await TestUtil.destroyLoadUnitInstance(instance); const called = bar.getLifecycleCalled(); assert.deepStrictEqual(called, [