-
Notifications
You must be signed in to change notification settings - Fork 52
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[Runtime] Add a new compiled format CompiledApp (#408)
**Motivation** In LLM serving, there are both prefill and decode stages which have different computation graphs. We need to save this two graphs and load them together, and hope to share the weights of the two graphs. **CompiledApp** `hidet.runtime.CompiledApp` is such a runtime object that may contain multiple compiled graphs and will deal with the weight sharing. **Usage** ```python import pytest import hidet from hidet.testing.models import resnet18 from hidet.runtime import CompiledApp, save_compiled_app, load_compiled_app module_1 = resnet18().cuda() module_2 = resnet18().cuda() x1 = hidet.symbol(['batch_size', 3, 224, 224], dtype='float32', device='cuda:0') x2 = hidet.symbol([1, 3, 224, 224], dtype='float32', device='cuda:0') y1 = module_1(x1) y2 = module_2(x2) # the two compiled graphs share the weights cgraph_1 = hidet.trace_from(y1, inputs=[x1]).build() cgraph_2 = hidet.trace_from(y2, inputs=[x2]).build() # we create a compiled app with two compiled graphs app = create_compiled_app(graphs={'graph_1': cgraph_1, 'graph_2': cgraph_2}, name='demo_app') save_compiled_app(app, 'app.hidet') app = load_compiled_app('app.hidet') x = hidet.randn([1, 3, 224, 224], device='cuda') y1 = app.graphs['graph_1'](x) y2 = app.graphs['graph_2'](x) hidet.utils.assert_close(y1, y2) # check if they share the weights # this is one important feature of compiled app that share the weights of graphs if they are numerically identical assert len(set(app.graphs['graph_1'].weights) ^ set(app.graphs['graph_2'].weights)) == 0 ```
- Loading branch information
1 parent
70765da
commit 2f5e3f1
Showing
6 changed files
with
319 additions
and
42 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,171 @@ | ||
from typing import Dict, List, Optional | ||
import json | ||
import dataclasses | ||
import os | ||
import zipfile | ||
import tempfile | ||
import hashlib | ||
from collections import defaultdict | ||
from dataclasses import asdict | ||
import numpy as np | ||
|
||
import hidet.utils | ||
from hidet.runtime.compiled_graph import CompiledGraph, save_compiled_graph, load_compiled_graph, GraphExecution | ||
|
||
|
||
@dataclasses.dataclass | ||
class AppMetaData: | ||
name: str | ||
hidet_version: str | ||
graphs: List[str] | ||
app_hash: str | ||
|
||
|
||
class CompiledApp: | ||
def __init__(self, meta: AppMetaData, graphs: Dict[str, CompiledGraph] = None): | ||
self.meta: AppMetaData = meta | ||
self.graphs: Dict[str, CompiledGraph] = graphs | ||
|
||
|
||
def create_compiled_app(graphs: Dict[str, CompiledGraph], name: Optional[str] = None) -> CompiledApp: | ||
""" | ||
Create a compiled app from a dict of compiled graphs. | ||
Parameters | ||
---------- | ||
graphs: Dict[str, CompiledGraph] | ||
The compiled graphs. | ||
name: Optional[str] | ||
The name of the app. If None, the name will be set to 'app'. | ||
Returns | ||
------- | ||
ret: CompiledApp | ||
The compiled app. | ||
""" | ||
if name is None: | ||
name = 'app' | ||
|
||
hash_obj = hashlib.sha256() | ||
hash_obj.update(name.encode()) | ||
for graph_name, graph in graphs.items(): | ||
hash_obj.update(graph_name.encode()) | ||
hash_obj.update(graph.meta.graph_hash.encode()) | ||
app_hash: str = hash_obj.hexdigest()[:16] | ||
|
||
meta = AppMetaData(name=name, hidet_version=hidet.__version__, graphs=list(graphs.keys()), app_hash=app_hash) | ||
return CompiledApp(meta=meta, graphs=graphs) | ||
|
||
|
||
def save_compiled_app(app: CompiledApp, path: str): | ||
""" | ||
Save a compiled app to a file. | ||
Parameters | ||
---------- | ||
app: CompiledApp | ||
The compiled app to save. | ||
path: str | ||
The path to save the compiled app. | ||
""" | ||
with tempfile.TemporaryDirectory() as tmp_dir: | ||
# save the meta data | ||
with open(os.path.join(tmp_dir, 'meta.json'), 'w') as f: | ||
meta_bytes = json.dumps(asdict(app.meta), indent=4) | ||
f.write(meta_bytes) | ||
|
||
# save the kernel-only graphs to files | ||
for name, graph in app.graphs.items(): | ||
graph_path = os.path.join(tmp_dir, '{}.hidet'.format(name)) | ||
save_compiled_graph(graph, file=graph_path, save_dispatch_table=False, save_weights=False) | ||
with zipfile.ZipFile(graph_path, 'r') as zip_file: | ||
graph_dir = os.path.join(tmp_dir, 'graphs', name) | ||
os.makedirs(graph_dir) | ||
zip_file.extractall(path=graph_dir) | ||
os.remove(graph_path) | ||
|
||
# save the weights | ||
weights: List[np.ndarray] = [] | ||
weight_hash_map: Dict[str, int] = {} # the hash of the weight -> the index of the weight in the weights list | ||
for name, graph in app.graphs.items(): | ||
with open(os.path.join(tmp_dir, 'graphs', '{}-weights-index.txt'.format(name)), 'w') as weight_index_file: | ||
for weight in graph.weights: | ||
weight_ndarray = weight.cpu().numpy() | ||
hash_obj = hashlib.sha256() | ||
hash_obj.update(weight_ndarray.tobytes()) | ||
hash_obj.update(weight.signature().encode()) | ||
weight_hash: str = hash_obj.hexdigest() | ||
if weight_hash not in weight_hash_map: | ||
weight_hash_map[weight_hash] = len(weights) | ||
weights.append(weight_ndarray) | ||
weight_index = weight_hash_map[weight_hash] | ||
weight_index_file.write('{}\n'.format(weight_index)) | ||
|
||
np.savez(os.path.join(tmp_dir, 'weights.npz'), *weights) | ||
|
||
# save the contents of the current dir to a zip file | ||
with zipfile.ZipFile(path, 'w') as zip_file: | ||
for root, _, files in os.walk(tmp_dir): | ||
for file in files: | ||
zip_file.write(os.path.join(root, file), arcname=os.path.relpath(os.path.join(root, file), tmp_dir)) | ||
|
||
|
||
def load_compiled_app(path: str) -> CompiledApp: | ||
""" | ||
Load a compiled app from a file. | ||
Parameters | ||
---------- | ||
path: str | ||
The path to the compiled app file. | ||
Returns | ||
------- | ||
ret: CompiledApp | ||
The loaded compiled app. | ||
""" | ||
from hidet import Tensor | ||
from hidet.utils.dataclass import from_dict | ||
|
||
with zipfile.ZipFile(path, 'r') as zip_file: | ||
# load the meta data | ||
with zip_file.open('meta.json', 'r') as f: | ||
meta_bytes = f.read() | ||
meta: AppMetaData = from_dict(AppMetaData, json.loads(meta_bytes)) | ||
|
||
# extract the app if needed | ||
app_dir = hidet.utils.cache_file('apps', meta.app_hash) | ||
meta_path = os.path.join(app_dir, 'meta.json') | ||
if not os.path.exists(meta_path): | ||
# we only extract the app when it is not in our cache dir. | ||
# we used 'meta.json' as the indicator whether the app is there or not. | ||
# if the app is not there, we extract everything but the weights in the app to the cache dir | ||
files_to_extract = [name for name in zip_file.namelist() if name != 'weights.npz'] | ||
zip_file.extractall(app_dir, files_to_extract) | ||
|
||
# load the compiled graphs | ||
graphs: Dict[str, CompiledGraph] = {} | ||
for graph_name in meta.graphs: | ||
graphs[graph_name] = load_compiled_graph(os.path.join(app_dir, 'graphs', graph_name)) | ||
|
||
# load the weights from the app file | ||
device2weights: Dict[str, Dict[int, Tensor]] = defaultdict(dict) | ||
with zip_file.open('weights.npz', 'r') as npz: | ||
weights: List[np.ndarray] = list(np.load(npz).values()) | ||
for graph_name in meta.graphs: | ||
graph: CompiledGraph = graphs[graph_name] | ||
weight_index_file = os.path.join(app_dir, 'graphs', '{}-weights-index.txt'.format(graph_name)) | ||
graph_weights = [] | ||
with open(weight_index_file, 'r') as f: | ||
weight_indices = [int(line.strip()) for line in f.readlines()] | ||
for idx, weight_index in enumerate(weight_indices): | ||
execution: GraphExecution = graph.graph_execution | ||
device: str = execution.tensor_device[execution.weights_index[idx]] | ||
if weight_index not in device2weights[device]: | ||
device2weights[device][weight_index] = hidet.asarray(weights[weight_index], device=device) | ||
graph_weights.append(device2weights[device][weight_index]) | ||
graphs[graph_name].set_weights(graph_weights) | ||
|
||
return CompiledApp(meta=meta, graphs=graphs) |
Oops, something went wrong.