Sample code for computing the sequence
Sequences of this form are ubiquitous in science and engineering. For example, in the natural sciences, these sequences model quantities or populations that decay or grow by a varying rate
It's common to see code that computes sequences of this form one element at a time. Read on to find out how to compute them efficiently in parallel!
All snippets of code assume you have a working installation of Python 3.8+ with PyTorch 2.1+.
The following snippet of code computes 10,000,000 elements, one element at a time. Copy and paste it to execute it. Warning: It will be painfully slow. Execution time is linear in the number of elements,
import torch
def naively_compute_sequentially(coeffs, values):
x = [values[0]] # x_0
for a, b in zip(coeffs, values[1:]):
x.append(a * x[-1] + b)
return torch.stack(x)
device = 'cuda:0' # change as necessary
seq_len = 10_000_000 # change as you wish
# Generate some random input data:
coeffs = torch.randn(seq_len, device=device)
values = torch.randn(1 + seq_len, device=device) # includes initial value
# Compute the sequence:
x = naively_compute_sequentially(coeffs, values) # includes initial value
Note: Even if you rewrite the above snippet of interpreted Python code as efficient GPU code (say, with Triton), execution will still be slow, because all elements are computed sequentially, which is inefficient in a GPU.
The snippets of code below execute the same computations in parallel -- or more precisely, as a composition of two prefix sums, each of which is executable in parallel. (See also this post on implementing parallel prefix sum in CUDA.) The first snippet is for the general case in which
If any inputs are negative, their logarithms are complex numbers:
import torch
import torch.nn.functional as F
def complex_log(float_input, eps=1e-6):
eps = float_input.new_tensor(eps)
real = float_input.abs().maximum(eps).log()
imag = (float_input < 0).to(float_input.dtype) * torch.pi
return torch.complex(real, imag)
def compute_in_parallel(coeffs, values):
log_coeffs = complex_log(coeffs)
log_values = complex_log(values)
a_star = F.pad(torch.cumsum(log_coeffs, dim=-1), (1, 0)) # eq (2) in paper
log_x0_plus_b_star = torch.logcumsumexp(log_values - a_star, dim=-1) # eq (7) in paper
log_x = a_star + log_x0_plus_b_star # eq (1) in paper
return torch.exp(log_x).real
device = 'cuda:0' # change as necessary
seq_len = 10_000_000 # change as you wish
# Generate some random input data:
coeffs = torch.randn(seq_len, device=device); # positive or negative values
values = torch.randn(1 + seq_len, device=device) # negative or positive values
# Compute the sequence:
x = compute_in_parallel(coeffs, values) # includes initial value
If no inputs are negative, their logarithms are floats:
import torch
import torch.nn.functional as F
def compute_in_parallel_special_case(coeffs, values):
log_coeffs = torch.log(coeffs)
log_values = torch.log(values)
a_star = F.pad(torch.cumsum(log_coeffs, dim=-1), (1, 0)) # eq (2) in paper
log_x0_plus_b_star = torch.logcumsumexp(log_values - a_star, dim=-1) # eq (7) in paper
log_x = a_star + log_x0_plus_b_star # eq (1) in paper
return torch.exp(log_x) # already a float
device = 'cuda:0' # change as necessary
seq_len = 10_000_000 # change as you wish
# Generate some random input data:
coeffs = torch.rand(seq_len, device=device) + 1e-6 # eps for numerical stability
values = torch.rand(1 + seq_len, device=device) * 3 # all values >= 0
# Compute the sequence:
x = compute_in_parallel_special_case(coeffs, values) # includes initial value
For computing the sequence incrementally, chunk by chunk, we recommend you cache each chunk's final log-element, log_x[-1]
, and use it as the subsequent chunk's initial log-value, log_values[0]
. Caching state in the domain of logarithms is more numerically stable.
The snippets of code above are meant to be easy-to-follow recipes. For use in production, make sure to compute all logarithms with the most efficient and numericaly stable methods available. For example, if the coefficients are gating probabilities computed from given logits, you should use F.logsigmoid(logits)
instead of torch.log(F.sigmoid(logits))
to compute the log-coefficients. If one of the input sequences has no negative numbers, don't cast it to complex in advance; instead, wait to cast its logarithms until after they have been summed. If you are using lower precision, don't assume numerical stability; instead, make sure both input sequences will always be within acceptable bounds. Use your common sense.
In certain production environments, it may be more efficient to represent each complex number as a (float, int) tuple to take advantage of the fact that all sums of imaginary components in our proposed method are multiples of
As always, you should test all available options to find out which one will work best for your use case.
See this thread.
For computing non-diagonal recurrences of the form
@misc{heinsen2023parallelization,
title={Efficient Parallelization of a Ubiquitous Sequential Computation},
author={Franz A. Heinsen},
year={2023},
eprint={2311.06281},
archivePrefix={arXiv},
primaryClass={cs.DS}
}
We originally conceived and implemented these methods as part of our AI software, nicknamed Graham. Most of the original work we do at GlassRoom tends to be either proprietary in nature or tightly coupled to internal code, so we cannot share it with outsiders. In this case, however, we were able to isolate our code, clean it up, and release it as stand-alone open-source software without having to disclose any key intellectual property.
We hope others find our work and our code useful.