Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Fixbug] Fix for softmmax cpu causing issues #437

Open
wants to merge 81 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
81 commits
Select commit Hold shift + click to select a range
14cbb3b
initial commit
fishingguy456 Jul 6, 2023
7896c45
works on multidimensional, axis=-1
fishingguy456 Jul 25, 2023
ff90ed5
initial commit
fishingguy456 Jul 6, 2023
fc61204
change imports
fishingguy456 Jul 20, 2023
f84201f
fix for diff size, compiledmodule error fix
fishingguy456 Jul 21, 2023
6f2e43c
works on multidimensional, axis=-1
fishingguy456 Jul 25, 2023
25f22cf
initial commit
fishingguy456 Jul 6, 2023
aafbb0f
initial commit
fishingguy456 Jul 6, 2023
44993e2
change imports
fishingguy456 Jul 20, 2023
a86d866
fix for diff size, compiledmodule error fix
fishingguy456 Jul 21, 2023
b59ffa2
works on multidimensional, axis=-1
fishingguy456 Jul 25, 2023
7edf0eb
wrap up softmax, starting layernorm
fishingguy456 Jul 28, 2023
44c04b3
layernorm kinda works but not rly
fishingguy456 Jul 31, 2023
2ccc4b6
better code for softmax
fishingguy456 Jul 31, 2023
13ea5dc
layernorm works for last layer
fishingguy456 Aug 1, 2023
d89036d
move find sum and find max to registered function
fishingguy456 Aug 1, 2023
b0659f6
find max in registered func
fishingguy456 Aug 1, 2023
904760b
not working softmax on not last dim, minor changes
fishingguy456 Aug 3, 2023
29b7ba7
layernorm works for any dims
fishingguy456 Aug 3, 2023
0c8dc3a
comments
fishingguy456 Aug 4, 2023
77fe8d9
tuning, fix for flowgraph operator resolve
fishingguy456 Aug 4, 2023
ac40695
softmax works
fishingguy456 Aug 5, 2023
4938a1f
commented tensors dont work, i.e. axis is not last 2 AND not multiple…
fishingguy456 Aug 5, 2023
1d447cf
actually works rn frfr so fast :100:
fishingguy456 Aug 8, 2023
30224ce
cleanup
fishingguy456 Aug 8, 2023
67d4d56
more cleanup
fishingguy456 Aug 9, 2023
09ca2f8
random testing stuff
fishingguy456 Aug 11, 2023
8352dd8
allow epilogue
fishingguy456 Aug 18, 2023
27f6cbb
better epiloguing
fishingguy456 Aug 18, 2023
cce1d42
janky matmul resolve
fishingguy456 Aug 25, 2023
f92de53
still epilogue problem?
fishingguy456 Aug 25, 2023
63dfed4
initial commit
fishingguy456 Jul 6, 2023
73a063a
works on multidimensional, axis=-1
fishingguy456 Jul 25, 2023
1c129c0
initial commit
fishingguy456 Jul 6, 2023
bf8a5b5
change imports
fishingguy456 Jul 20, 2023
3aa5cb6
fix for diff size, compiledmodule error fix
fishingguy456 Jul 21, 2023
b849ebf
works on multidimensional, axis=-1
fishingguy456 Jul 25, 2023
12fdbd1
initial commit
fishingguy456 Jul 6, 2023
9c7ecd0
initial commit
fishingguy456 Jul 6, 2023
b155bbd
change imports
fishingguy456 Jul 20, 2023
de72bc6
fix for diff size, compiledmodule error fix
fishingguy456 Jul 21, 2023
17b8d76
works on multidimensional, axis=-1
fishingguy456 Jul 25, 2023
1b52167
wrap up softmax, starting layernorm
fishingguy456 Jul 28, 2023
e479db7
layernorm kinda works but not rly
fishingguy456 Jul 31, 2023
c623630
better code for softmax
fishingguy456 Jul 31, 2023
b44b69e
layernorm works for last layer
fishingguy456 Aug 1, 2023
29ea558
move find sum and find max to registered function
fishingguy456 Aug 1, 2023
339e549
find max in registered func
fishingguy456 Aug 1, 2023
88c423c
not working softmax on not last dim, minor changes
fishingguy456 Aug 3, 2023
9c91875
layernorm works for any dims
fishingguy456 Aug 3, 2023
6e0d8e5
comments
fishingguy456 Aug 4, 2023
552aebb
tuning, fix for flowgraph operator resolve
fishingguy456 Aug 4, 2023
dc258e3
softmax works
fishingguy456 Aug 5, 2023
95f6be7
commented tensors dont work, i.e. axis is not last 2 AND not multiple…
fishingguy456 Aug 5, 2023
d0b99a4
actually works rn frfr so fast :100:
fishingguy456 Aug 8, 2023
67a43a5
cleanup
fishingguy456 Aug 8, 2023
4443780
more cleanup
fishingguy456 Aug 9, 2023
4088fc6
random testing stuff
fishingguy456 Aug 11, 2023
7430696
allow epilogue
fishingguy456 Aug 18, 2023
8a1167e
better epiloguing
fishingguy456 Aug 18, 2023
0f4876f
janky matmul resolve
fishingguy456 Aug 25, 2023
49c072f
still epilogue problem?
fishingguy456 Aug 25, 2023
0bd13d8
Merge remote-tracking branch 'origin/main'
fishingguy456 Sep 14, 2023
de74231
clean up for pr
fishingguy456 Sep 14, 2023
9ab0bac
fix test
fishingguy456 Sep 18, 2023
f779a1d
lint
fishingguy456 Sep 18, 2023
124fb09
minor pr edits
fishingguy456 Sep 19, 2023
6c4efd9
pytests, cpu child class
fishingguy456 Sep 19, 2023
40fd71f
potential fix for failing tests? but prob not will have to investigat…
fishingguy456 Sep 19, 2023
90c4ffb
weird diff
fishingguy456 Jan 4, 2024
587ba64
merge conflict resolve build.py
fishingguy456 Jan 4, 2024
89d5646
remove shady batch mat mul
fishingguy456 Jan 4, 2024
a3a4b03
lint thing
fishingguy456 Jan 5, 2024
aec95d2
move helpers to new file
fishingguy456 Jan 8, 2024
7a41b5c
lint
fishingguy456 Jan 8, 2024
dcc6a45
change tolerance for flaky test for test_dynamic_shape
fishingguy456 Jan 8, 2024
dbfcc56
Merge branch 'upstream/main'
fishingguy456 Jan 16, 2024
703f49e
Merge branch 'upstream/main'
fishingguy456 Jan 22, 2024
8f73d4c
fused scheduler task name
fishingguy456 Mar 9, 2024
713b016
Merge branch 'upstream/main'
fishingguy456 Mar 9, 2024
eafa10f
implement cpu issue softmax
fishingguy456 Mar 9, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions python/hidet/graph/ops/softmax.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,11 +157,6 @@ def softmax_kernel(xs: xdtype[shape], ys: xdtype[shape]):

return ir_module

def implement_cpu(self, working_dir: str) -> Union[IRModule, List[IRModule]]:
if self.inputs[0].type.dtype != float32:
return NotImplemented # use auto-scheduler
return tune.extract_ir_modules(self.schedule_softmax_cpu)


class CPUSoftmaxTask(SoftmaxTask):
def allow_epilogue(self) -> bool:
Expand All @@ -170,6 +165,11 @@ def allow_epilogue(self) -> bool:
def allow_prologue(self) -> bool:
return False

def implement_cpu(self, working_dir: str) -> Union[IRModule, List[IRModule]]:
if self.inputs[0].type.dtype != float32:
return NotImplemented # use auto-scheduler
return tune.extract_ir_modules(self.schedule_softmax_cpu)

@tune.space(2, nthreads=['', 4, 8, 16, 32, 64, 96])
@tune.space(1, nthreads=['', 8, 16])
def schedule_softmax_cpu(self, nthreads='') -> IRModule:
Expand Down
8 changes: 7 additions & 1 deletion python/hidet/ir/schedulers/cpu/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,13 @@ def schedule_grid_compute(self, node: GridCompute, tensor_map: Dict[TensorNode,
params, param_map, call_args = self.grid_compute_params_and_args(node, tensor_map)

if self.task is not None:
name = f'{self.task.name}_compute_{node.name}'
from hidet.graph.ops.fusion.fused_operator import FusedTask

if isinstance(self.task, FusedTask):
fused_name = self.task.attrs['fused_ops'].replace(' ', '_')
name = f'fused_{fused_name}_{node.name}'
else:
name = f'{self.task.name}_{node.name}'
else:
name = f'compute_{node.name}'

Expand Down
Loading