Skip to content

Commit

Permalink
Optimize model library (#1739)
Browse files Browse the repository at this point in the history
* Adapt experiment models api

* Add model preview
  • Loading branch information
hayden-fr authored Dec 16, 2024
1 parent 601b739 commit 5b72fc7
Show file tree
Hide file tree
Showing 4 changed files with 44 additions and 21 deletions.
26 changes: 18 additions & 8 deletions src/components/sidebar/tabs/modelLibrary/ModelTreeLeaf.vue
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,10 @@
<div ref="container" class="model-lib-node-container h-full w-full">
<TreeExplorerTreeNode :node="node">
<template #before-label>
<span
v-if="modelDef && modelDef.image"
class="model-lib-model-icon-container"
>
<span v-if="modelPreviewUrl" class="model-lib-model-icon-container">
<span
class="model-lib-model-icon"
:style="{ backgroundImage: `url(${modelDef.image})` }"
:style="{ backgroundImage: `url(${modelPreviewUrl})` }"
>
</span>
</span>
Expand Down Expand Up @@ -44,6 +41,18 @@ const props = defineProps<{
const modelDef = computed(() => props.node.data)
const modelPreviewUrl = computed(() => {
if (modelDef.value?.image) {
return modelDef.value.image
}
const folder = modelDef.value.directory
const path_index = modelDef.value.path_index
const extension = modelDef.value.file_name.split('.').pop()
const filename = modelDef.value.file_name.replace(`.${extension}`, '.webp')
const encodedFilename = encodeURIComponent(filename).replace(/%2F/g, '/')
return `/api/experiment/models/preview/${folder}/${path_index}/${encodedFilename}`
})
const previewRef = ref<InstanceType<typeof ModelPreview> | null>(null)
const modelPreviewStyle = ref<CSSProperties>({
position: 'absolute',
Expand Down Expand Up @@ -129,9 +138,10 @@ onUnmounted(() => {
background-position: center;
display: inline-block;
position: relative;
left: -2.5rem;
height: 2rem;
width: 2rem;
left: -2.2rem;
top: -0.1rem;
height: 1.7rem;
width: 1.7rem;
vertical-align: top;
}
</style>
10 changes: 6 additions & 4 deletions src/scripts/api.ts
Original file line number Diff line number Diff line change
Expand Up @@ -481,8 +481,8 @@ export class ComfyApi extends EventTarget {
* Gets a list of model folder keys (eg ['checkpoints', 'loras', ...])
* @returns The list of model folder keys
*/
async getModelFolders(): Promise<string[]> {
const res = await this.fetchApi(`/models`)
async getModelFolders(): Promise<{ name: string; folders: string[] }[]> {
const res = await this.fetchApi(`/experiment/models`)
if (res.status === 404) {
return []
}
Expand All @@ -497,8 +497,10 @@ export class ComfyApi extends EventTarget {
* @param {string} folder The folder to list models from, such as 'checkpoints'
* @returns The list of model filenames within the specified folder
*/
async getModels(folder: string): Promise<string[]> {
const res = await this.fetchApi(`/models/${folder}`)
async getModels(
folder: string
): Promise<{ name: string; pathIndex: number }[]> {
const res = await this.fetchApi(`/experiment/models/${folder}`)
if (res.status === 404) {
return []
}
Expand Down
14 changes: 11 additions & 3 deletions src/stores/modelStore.ts
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ function _findInMetadata(metadata: any, ...keys: string[]): string | null {

/** Defines and holds metadata for a model */
export class ComfyModelDef {
/** Path to the model */
readonly path_index: number
/** Proper filename of the model */
readonly file_name: string
/** Normalized filename of the model, with all backslashes replaced with forward slashes */
Expand Down Expand Up @@ -54,7 +56,8 @@ export class ComfyModelDef {
/** A string full of auto-computed lowercase-only searchable text for this model */
searchable: string = ''

constructor(name: string, directory: string) {
constructor(name: string, directory: string, pathIndex: number) {
this.path_index = pathIndex
this.file_name = name
this.normalized_file_name = name.replaceAll('\\', '/')
this.simplified_file_name = this.normalized_file_name.split('/').pop() ?? ''
Expand Down Expand Up @@ -165,7 +168,11 @@ export class ModelFolder {
this.state = ResourceState.Loading
const models = await api.getModels(this.directory)
for (const model of models) {
this.models[model] = new ComfyModelDef(model, this.directory)
this.models[`${model.pathIndex}/${model.name}`] = new ComfyModelDef(
model.name,
this.directory,
model.pathIndex
)
}
this.state = ResourceState.Loaded
return this
Expand All @@ -189,7 +196,8 @@ export const useModelStore = defineStore('models', () => {
* Loads the model folders from the server
*/
async function loadModelFolders() {
modelFolderNames.value = await api.getModelFolders()
const resData = await api.getModelFolders()
modelFolderNames.value = resData.map((folder) => folder.name)
modelFolderByName.value = {}
for (const folderName of modelFolderNames.value) {
modelFolderByName.value[folderName] = new ModelFolder(folderName)
Expand Down
15 changes: 9 additions & 6 deletions tests-ui/tests/fast/store/modelStore.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,14 @@ jest.mock('@/scripts/api', () => ({

function enableMocks() {
;(api.getModels as jest.Mock).mockResolvedValue([
'sdxl.safetensors',
'sdv15.safetensors',
'noinfo.safetensors'
{ name: 'sdxl.safetensors', pathIndex: 0 },
{ name: 'sdv15.safetensors', pathIndex: 0 },
{ name: 'noinfo.safetensors', pathIndex: 0 }
])
;(api.getModelFolders as jest.Mock).mockResolvedValue([
{ name: 'checkpoints', folders: ['/path/to/checkpoints'] },
{ name: 'vae', folders: ['/path/to/vae'] }
])
;(api.getModelFolders as jest.Mock).mockResolvedValue(['checkpoints', 'vae'])
;(api.viewMetadata as jest.Mock).mockImplementation((_, model) => {
if (model === 'noinfo.safetensors') {
return Promise.resolve({})
Expand Down Expand Up @@ -59,7 +62,7 @@ describe('useModelStore', () => {
const folderStore = await store.getLoadedModelFolder('checkpoints')
expect(folderStore).not.toBeNull()
if (!folderStore) return
const model = folderStore.models['sdxl.safetensors']
const model = folderStore.models['0/sdxl.safetensors']
await model.load()
expect(model.title).toBe('Title of sdxl.safetensors')
expect(model.architecture_id).toBe('stable-diffusion-xl-base-v1')
Expand All @@ -77,7 +80,7 @@ describe('useModelStore', () => {
const folderStore = await store.getLoadedModelFolder('checkpoints')
expect(folderStore).not.toBeNull()
if (!folderStore) return
const model = folderStore.models['noinfo.safetensors']
const model = folderStore.models['0/noinfo.safetensors']
await model.load()
expect(model.file_name).toBe('noinfo.safetensors')
expect(model.title).toBe('noinfo')
Expand Down

0 comments on commit 5b72fc7

Please sign in to comment.