Skip to content

Commit

Permalink
Added OLMoModel Class and config.architecture detection, and temporar…
Browse files Browse the repository at this point in the history
…y fake layernorm
  • Loading branch information
shobrienDMA committed Nov 12, 2024
1 parent c94ee92 commit f3e4dbf
Showing 1 changed file with 12 additions and 0 deletions.
12 changes: 12 additions & 0 deletions src/python/py/models/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -995,7 +995,13 @@ def make_layernorm(self, layer_id, layernorm, skip, simple, location):
skip_input = self.layernorm_attrs["skip_input"]

weight = f"model.layers.{layer_id}.{location}_layernorm.weight"
#ShaneTim
if layernorm.weight is None:
layernorm.weight = torch.ones(2048)
self.make_external_tensor(layernorm.weight.detach().numpy().astype(self.to_numpy_dtype[self.io_dtype]) + self.layernorm_attrs["add_offset"], weight)
#ShaneTim
if layernorm.bias is None:
layernorm.bias = torch.ones(2048)
bias = f"model.layers.{layer_id}.{location}_layernorm.bias"
if not simple:
self.make_external_tensor(layernorm.bias.detach().numpy().astype(self.to_numpy_dtype[self.io_dtype]), bias)
Expand Down Expand Up @@ -3040,6 +3046,10 @@ def make_layer(self, layer_id, layer):
layer.self_attn = layer.self_attn if hasattr(layer, 'self_attn') else layer.self_attention
super().make_layer(layer_id, layer)

class OLMoModel(Model):
def __init__(self, config, io_dtype, onnx_dtype, ep, cache_dir, extra_options):
super().__init__(config, io_dtype, onnx_dtype, ep, cache_dir, extra_options)

def check_extra_options(kv_pairs):
if "int4_op_types_to_quantize" in kv_pairs:
op_types_to_quantize = ()
Expand Down Expand Up @@ -3144,6 +3154,8 @@ def create_model(model_name, input_path, output_dir, precision, execution_provid
# Quantized ChatGLM model has ChatGLMForConditionalGeneration as architecture whereas HF model as the latter
config.hidden_act = "swiglu"
onnx_model = ChatGLMModel(config, io_dtype, precision, execution_provider, cache_dir, extra_options)
elif config.architectures[0] == "OlmoForCausalLM":
onnx_model = OLMoModel(config, io_dtype, precision, execution_provider, cache_dir, extra_options)
else:
raise NotImplementedError(f"The {hf_name} model is not currently supported.")

Expand Down

0 comments on commit f3e4dbf

Please sign in to comment.