Skip to content

Commit c599c7e

Browse files
committed
add readme
1 parent 94c39c2 commit c599c7e

File tree

3 files changed

+48
-9
lines changed

3 files changed

+48
-9
lines changed

README.md

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1 +1,13 @@
11
# RegionalAdaptiveSampling
2+
3+
[Regional Adaptive Sampling](https://github.com/microsoft/RAS) is a new technique for accelerating the inference of diffusion transformers.
4+
It essentially works as a KV Cache inside the model, picking regions that are likely to be updated by each diffusion step and passing in only those tokens.
5+
6+
This implementation is simple to use, and compatible with Flux (dev & schnell) at HunYuanVideo. I may add support for other models in the future.
7+
8+
## Usage
9+
Apply the `Regional Adaptive Sampling` node to the desired model. It has the following parameters:
10+
- **sample_ratio**: The percent of tokens to keep in the model on a RAS pass. Anything below 0.3 is usually very bad quality.
11+
- **warmup_steps**: The number of steps to do without RAS at the beginning. Setting higher will decrease the speedup, and setting it lower will degrade the composition.
12+
- **hydrate_every**: Every `hydrate_every` steps, we do a full run through the model with all tokens, to refresh the stale cache. Set to 0 to disable and do full RAS.
13+
- **starvation_scale**: Controls how the model decides which part of the image to focus on. Increasing it will probably shift quality from the main subject to the background. The default of 0.1 is what's used in the paper, and I haven't tried anything else.

__init__.py

Lines changed: 27 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,19 @@ def INPUT_TYPES(s):
1010
"model": ("MODEL",),
1111
"sample_ratio": (
1212
"FLOAT",
13-
{"default": 0.5, "min": 0.0, "max": 1.0, "step": 0.05},
13+
{"default": 0.5, "min": 0.05, "max": 1.0, "step": 0.05},
14+
),
15+
"warmup_steps": (
16+
"INT",
17+
{"default": 4, "min": 0, "max": 100},
18+
),
19+
"hydrate_every": (
20+
"INT",
21+
{"default": 4, "min": 0, "max": 100},
22+
),
23+
"starvation_scale": (
24+
"FLOAT",
25+
{"default": 0.1, "min": 0.01, "max": 1.0, "step": 0.01},
1426
),
1527
}
1628
}
@@ -19,13 +31,25 @@ def INPUT_TYPES(s):
1931
FUNCTION = "apply_ras"
2032
CATEGORY = "ras"
2133

22-
def apply_ras(self, model: ModelPatcher, sample_ratio: float):
34+
def apply_ras(
35+
self,
36+
model: ModelPatcher,
37+
sample_ratio: float,
38+
warmup_steps: int,
39+
hydrate_every: int,
40+
starvation_scale: float,
41+
):
2342
model = model.clone()
2443
# unpatch the model
2544
# this makes sure that we're wrapping the model "in a pure state"
2645
# the model will repatch itself later
2746
model.unpatch_model()
28-
config = RASConfig(sample_ratio=sample_ratio)
47+
config = RASConfig(
48+
sample_ratio=sample_ratio,
49+
warmup_steps=warmup_steps,
50+
hydrate_every=hydrate_every,
51+
starvation_scale=starvation_scale,
52+
)
2953
manager = RASManager(config)
3054
manager.wrap_model(model)
3155
return (model,)

ras.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ def take_attributes_from(source, target, keys):
3333

3434
@dataclass
3535
class RASConfig:
36-
start_step: int = 4
36+
warmup_steps: int = 4
3737
hydrate_every: int = 5
3838
sample_ratio: float = 0.5
3939
starvation_scale: float = 0.1
@@ -131,10 +131,13 @@ def timestep_from_sigmas(sigmas: Tensor, sample_sigmas: Tensor):
131131
return int(i.item())
132132

133133
def skip_ratio(self, timestep: int) -> float:
134-
if timestep < self.config.start_step or (
135-
timestep % self.config.hydrate_every == 0
136-
):
137-
return 1
134+
if timestep < self.config.warmup_steps:
135+
return 0
136+
if self.config.hydrate_every:
137+
if (
138+
1 + timestep - self.config.warmup_steps
139+
) % self.config.hydrate_every == 0:
140+
return 0
138141
return 1.0 - self.config.sample_ratio
139142

140143
def select_indices(self, diff: Tensor, timestep: int):
@@ -172,7 +175,7 @@ def select_indices(self, diff: Tensor, timestep: int):
172175
metric *= torch.exp(self.config.starvation_scale * self.drop_count)
173176
indices = torch.sort(metric, dim=-1, descending=False).indices
174177
skip_ratio = self.skip_ratio(timestep)
175-
if skip_ratio >= 0.99:
178+
if skip_ratio <= 0.01:
176179
# we're not dropping anything -- remove the live_indices
177180
# we use the value None to indicate a full hydrate
178181
self.live_img_indices = None

0 commit comments

Comments
 (0)