Skip to content

Commit

Permalink
[Runtime] Add a new compiled format CompiledApp (#408)
Browse files Browse the repository at this point in the history
**Motivation**
In LLM serving, there are both prefill and decode stages which have
different computation graphs. We need to save this two graphs and load
them together, and hope to share the weights of the two graphs.

**CompiledApp**
`hidet.runtime.CompiledApp` is such a runtime object that may contain
multiple compiled graphs and will deal with the weight sharing.

**Usage**
```python
import pytest
import hidet
from hidet.testing.models import resnet18
from hidet.runtime import CompiledApp, save_compiled_app, load_compiled_app

module_1 = resnet18().cuda()
module_2 = resnet18().cuda()

x1 = hidet.symbol(['batch_size', 3, 224, 224], dtype='float32', device='cuda:0')
x2 = hidet.symbol([1, 3, 224, 224], dtype='float32', device='cuda:0')

y1 = module_1(x1)
y2 = module_2(x2)

# the two compiled graphs share the weights
cgraph_1 = hidet.trace_from(y1, inputs=[x1]).build()
cgraph_2 = hidet.trace_from(y2, inputs=[x2]).build()

# we create a compiled app with two compiled graphs
app = create_compiled_app(graphs={'graph_1': cgraph_1, 'graph_2': cgraph_2}, name='demo_app')

save_compiled_app(app, 'app.hidet')

app = load_compiled_app('app.hidet')

x = hidet.randn([1, 3, 224, 224], device='cuda')
y1 = app.graphs['graph_1'](x)
y2 = app.graphs['graph_2'](x)
hidet.utils.assert_close(y1, y2)

# check if they share the weights
# this is one important feature of compiled app that share the weights of graphs if they are numerically identical
assert len(set(app.graphs['graph_1'].weights) ^ set(app.graphs['graph_2'].weights)) == 0
```
  • Loading branch information
yaoyaoding committed Jan 7, 2024
1 parent 70765da commit 2f5e3f1
Show file tree
Hide file tree
Showing 6 changed files with 319 additions and 42 deletions.
2 changes: 1 addition & 1 deletion python/hidet/graph/nn/convolutions.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,4 +33,4 @@ def extra_str(self) -> str:

def forward(self, x):
x = ops.pad(x, ops.utils.normalize_padding(self.padding))
return ops.conv2d(x, self.weight, self.stride, self.groups)
return ops.conv2d(x, self.weight, stride=self.stride, groups=self.groups)
1 change: 1 addition & 0 deletions python/hidet/runtime/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,3 +18,4 @@
from .compiled_module import CompiledModule, CompiledFunction, load_compiled_module
from .compiled_task import CompiledTask, load_compiled_task
from .compiled_graph import CompiledGraph, save_compiled_graph, load_compiled_graph
from .compiled_app import CompiledApp, load_compiled_app, save_compiled_app, create_compiled_app
171 changes: 171 additions & 0 deletions python/hidet/runtime/compiled_app.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,171 @@
from typing import Dict, List, Optional
import json
import dataclasses
import os
import zipfile
import tempfile
import hashlib
from collections import defaultdict
from dataclasses import asdict
import numpy as np

import hidet.utils
from hidet.runtime.compiled_graph import CompiledGraph, save_compiled_graph, load_compiled_graph, GraphExecution


@dataclasses.dataclass
class AppMetaData:
name: str
hidet_version: str
graphs: List[str]
app_hash: str


class CompiledApp:
def __init__(self, meta: AppMetaData, graphs: Dict[str, CompiledGraph] = None):
self.meta: AppMetaData = meta
self.graphs: Dict[str, CompiledGraph] = graphs


def create_compiled_app(graphs: Dict[str, CompiledGraph], name: Optional[str] = None) -> CompiledApp:
"""
Create a compiled app from a dict of compiled graphs.
Parameters
----------
graphs: Dict[str, CompiledGraph]
The compiled graphs.
name: Optional[str]
The name of the app. If None, the name will be set to 'app'.
Returns
-------
ret: CompiledApp
The compiled app.
"""
if name is None:
name = 'app'

hash_obj = hashlib.sha256()
hash_obj.update(name.encode())
for graph_name, graph in graphs.items():
hash_obj.update(graph_name.encode())
hash_obj.update(graph.meta.graph_hash.encode())
app_hash: str = hash_obj.hexdigest()[:16]

meta = AppMetaData(name=name, hidet_version=hidet.__version__, graphs=list(graphs.keys()), app_hash=app_hash)
return CompiledApp(meta=meta, graphs=graphs)


def save_compiled_app(app: CompiledApp, path: str):
"""
Save a compiled app to a file.
Parameters
----------
app: CompiledApp
The compiled app to save.
path: str
The path to save the compiled app.
"""
with tempfile.TemporaryDirectory() as tmp_dir:
# save the meta data
with open(os.path.join(tmp_dir, 'meta.json'), 'w') as f:
meta_bytes = json.dumps(asdict(app.meta), indent=4)
f.write(meta_bytes)

# save the kernel-only graphs to files
for name, graph in app.graphs.items():
graph_path = os.path.join(tmp_dir, '{}.hidet'.format(name))
save_compiled_graph(graph, file=graph_path, save_dispatch_table=False, save_weights=False)
with zipfile.ZipFile(graph_path, 'r') as zip_file:
graph_dir = os.path.join(tmp_dir, 'graphs', name)
os.makedirs(graph_dir)
zip_file.extractall(path=graph_dir)
os.remove(graph_path)

# save the weights
weights: List[np.ndarray] = []
weight_hash_map: Dict[str, int] = {} # the hash of the weight -> the index of the weight in the weights list
for name, graph in app.graphs.items():
with open(os.path.join(tmp_dir, 'graphs', '{}-weights-index.txt'.format(name)), 'w') as weight_index_file:
for weight in graph.weights:
weight_ndarray = weight.cpu().numpy()
hash_obj = hashlib.sha256()
hash_obj.update(weight_ndarray.tobytes())
hash_obj.update(weight.signature().encode())
weight_hash: str = hash_obj.hexdigest()
if weight_hash not in weight_hash_map:
weight_hash_map[weight_hash] = len(weights)
weights.append(weight_ndarray)
weight_index = weight_hash_map[weight_hash]
weight_index_file.write('{}\n'.format(weight_index))

np.savez(os.path.join(tmp_dir, 'weights.npz'), *weights)

# save the contents of the current dir to a zip file
with zipfile.ZipFile(path, 'w') as zip_file:
for root, _, files in os.walk(tmp_dir):
for file in files:
zip_file.write(os.path.join(root, file), arcname=os.path.relpath(os.path.join(root, file), tmp_dir))


def load_compiled_app(path: str) -> CompiledApp:
"""
Load a compiled app from a file.
Parameters
----------
path: str
The path to the compiled app file.
Returns
-------
ret: CompiledApp
The loaded compiled app.
"""
from hidet import Tensor
from hidet.utils.dataclass import from_dict

with zipfile.ZipFile(path, 'r') as zip_file:
# load the meta data
with zip_file.open('meta.json', 'r') as f:
meta_bytes = f.read()
meta: AppMetaData = from_dict(AppMetaData, json.loads(meta_bytes))

# extract the app if needed
app_dir = hidet.utils.cache_file('apps', meta.app_hash)
meta_path = os.path.join(app_dir, 'meta.json')
if not os.path.exists(meta_path):
# we only extract the app when it is not in our cache dir.
# we used 'meta.json' as the indicator whether the app is there or not.
# if the app is not there, we extract everything but the weights in the app to the cache dir
files_to_extract = [name for name in zip_file.namelist() if name != 'weights.npz']
zip_file.extractall(app_dir, files_to_extract)

# load the compiled graphs
graphs: Dict[str, CompiledGraph] = {}
for graph_name in meta.graphs:
graphs[graph_name] = load_compiled_graph(os.path.join(app_dir, 'graphs', graph_name))

# load the weights from the app file
device2weights: Dict[str, Dict[int, Tensor]] = defaultdict(dict)
with zip_file.open('weights.npz', 'r') as npz:
weights: List[np.ndarray] = list(np.load(npz).values())
for graph_name in meta.graphs:
graph: CompiledGraph = graphs[graph_name]
weight_index_file = os.path.join(app_dir, 'graphs', '{}-weights-index.txt'.format(graph_name))
graph_weights = []
with open(weight_index_file, 'r') as f:
weight_indices = [int(line.strip()) for line in f.readlines()]
for idx, weight_index in enumerate(weight_indices):
execution: GraphExecution = graph.graph_execution
device: str = execution.tensor_device[execution.weights_index[idx]]
if weight_index not in device2weights[device]:
device2weights[device][weight_index] = hidet.asarray(weights[weight_index], device=device)
graph_weights.append(device2weights[device][weight_index])
graphs[graph_name].set_weights(graph_weights)

return CompiledApp(meta=meta, graphs=graphs)
Loading

0 comments on commit 2f5e3f1

Please sign in to comment.