Skip to content

Commit 731c147

Browse files
adamomainzfacebook-github-bot
authored andcommitted
weighs added to logs from production data!!
Summary: Weights are now being pulled from production data + a bit of refactoring Reviewed By: nmacchioni Differential Revision: D65637832 fbshipit-source-id: 078aaed94f340f8bc00adbe21e433bc57f810b6f
1 parent f4039ec commit 731c147

File tree

2 files changed

+11
-3
lines changed

2 files changed

+11
-3
lines changed

run.py

+4
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,10 @@ def _run(args: argparse.Namespace, extra_args: List[str]) -> BenchmarkOperatorRe
8383
"device": args.device,
8484
"logging_group": args.logging_group,
8585
}
86+
if args.production_shapes:
87+
from tritonbench.utils.fb.durin_data import productionDataLoader
88+
89+
kwargs["weights_loader"] = productionDataLoader
8690

8791
if "hardware" in args:
8892
kwargs["hardware"] = args.hardware

tritonbench/operators/flash_attention/operator.py

+7-3
Original file line numberDiff line numberDiff line change
@@ -427,7 +427,6 @@ def get_ctx_vals():
427427
shapes = self.__additional_example_input(ctx_vals)
428428
requires_grad = True
429429
for shape in shapes:
430-
print(shape)
431430
BATCH, H, N_CTX, D_HEAD = shape
432431
q = torch.randn(
433432
(BATCH, H, N_CTX, D_HEAD),
@@ -459,9 +458,14 @@ def __additional_example_input(self, standard_shapes: Generator) -> Generator:
459458
]
460459
shapes = chain(standard_shapes, llama_shapes)
461460
if self.add_production_shapes:
462-
from ...utils.fb.durin_data import get_shapes_from_frozen_durin
461+
from ...utils.fb.durin_data import productionDataLoader
463462

464-
shapes = chain(shapes, get_shapes_from_frozen_durin("attention"))
463+
shapes = chain(
464+
shapes,
465+
productionDataLoader.get_shapes_from_frozen_durin(
466+
self.name, "attention"
467+
),
468+
)
465469
return shapes
466470

467471
@register_x_val(label="(Batch, Heads, SeqLen, Dhead)")

0 commit comments

Comments
 (0)