Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow mutation through a transform_iterator #2006

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

bernhardmgruber
Copy link
Contributor

@bernhardmgruber bernhardmgruber commented Jul 18, 2024

But only if the transform iterator's base iterator returns a true l-value reference (and not a proxy reference).

@bernhardmgruber bernhardmgruber added the thrust For all items related to Thrust. label Jul 18, 2024
@bernhardmgruber bernhardmgruber force-pushed the transform_lvalue_ref branch 3 times, most recently from ec3c5ae to 6b6d930 Compare July 18, 2024 21:33
@bernhardmgruber bernhardmgruber marked this pull request as ready for review July 18, 2024 23:46
@bernhardmgruber bernhardmgruber requested review from a team as code owners July 18, 2024 23:46
Copy link
Contributor

🟩 CI finished in 5h 20m: Pass: 100%/250 | Total: 5d 02h | Avg: 29m 26s | Max: 1h 02m | Hits: 54%/248341
  • 🟩 cub: Pass: 100%/131 | Total: 2d 18h | Avg: 30m 30s | Max: 53m 33s | Hits: 70%/109429

    🟩 cpu
      🟩 amd64              Pass: 100%/123 | Total:  2d 14h | Avg: 30m 16s | Max: 53m 33s | Hits:  71%/102597
      🟩 arm64              Pass: 100%/8   | Total:  4h 33m | Avg: 34m 10s | Max: 36m 42s | Hits:  60%/6832  
    🟩 ctk
      🟩 11.1               Pass: 100%/15  | Total:  7h 37m | Avg: 30m 28s | Max: 45m 15s | Hits:  61%/11598 
      🟩 11.8               Pass: 100%/3   | Total:  2h 15m | Avg: 45m 10s | Max: 46m 06s | Hits:  60%/2562  
      🟩 12.5               Pass: 100%/113 | Total:  2d 08h | Avg: 30m 07s | Max: 53m 33s | Hits:  71%/95269 
    🟩 cudacxx
      🟩 ClangCUDA17        Pass: 100%/2   | Total: 42m 22s | Avg: 21m 11s | Max: 22m 30s | Hits:  66%/1412  
      🟩 nvcc11.1           Pass: 100%/15  | Total:  7h 37m | Avg: 30m 28s | Max: 45m 15s | Hits:  61%/11598 
      🟩 nvcc11.8           Pass: 100%/3   | Total:  2h 15m | Avg: 45m 10s | Max: 46m 06s | Hits:  60%/2562  
      🟩 nvcc12.5           Pass: 100%/111 | Total:  2d 08h | Avg: 30m 17s | Max: 53m 33s | Hits:  72%/93857 
    🟩 cudacxx_family
      🟩 ClangCUDA          Pass: 100%/2   | Total: 42m 22s | Avg: 21m 11s | Max: 22m 30s | Hits:  66%/1412  
      🟩 nvcc               Pass: 100%/129 | Total:  2d 17h | Avg: 30m 39s | Max: 53m 33s | Hits:  70%/108017
    🟩 cxx
      🟩 Clang9             Pass: 100%/6   | Total:  3h 08m | Avg: 31m 25s | Max: 35m 17s | Hits:  61%/4902  
      🟩 Clang10            Pass: 100%/3   | Total:  1h 46m | Avg: 35m 26s | Max: 38m 35s | Hits:  61%/2568  
      🟩 Clang11            Pass: 100%/4   | Total:  2h 14m | Avg: 33m 30s | Max: 34m 56s | Hits:  61%/3424  
      🟩 Clang12            Pass: 100%/4   | Total:  2h 13m | Avg: 33m 29s | Max: 34m 59s | Hits:  61%/3424  
      🟩 Clang13            Pass: 100%/4   | Total:  2h 09m | Avg: 32m 26s | Max: 33m 25s | Hits:  61%/3424  
      🟩 Clang14            Pass: 100%/4   | Total:  2h 16m | Avg: 34m 11s | Max: 35m 34s | Hits:  61%/3424  
      🟩 Clang15            Pass: 100%/4   | Total:  2h 22m | Avg: 35m 40s | Max: 36m 10s | Hits:  61%/3416  
      🟩 Clang16            Pass: 100%/4   | Total:  2h 20m | Avg: 35m 03s | Max: 38m 40s | Hits:  61%/3416  
      🟩 Clang17            Pass: 100%/26  | Total: 10h 17m | Avg: 23m 45s | Max: 34m 09s | Hits:  85%/21908 
      🟩 GCC6               Pass: 100%/2   | Total: 59m 33s | Avg: 29m 46s | Max: 30m 01s | Hits:  60%/1556  
      🟩 GCC7               Pass: 100%/6   | Total:  3h 09m | Avg: 31m 33s | Max: 34m 36s | Hits:  60%/4905  
      🟩 GCC8               Pass: 100%/6   | Total:  3h 12m | Avg: 32m 09s | Max: 35m 47s | Hits:  60%/4905  
      🟩 GCC9               Pass: 100%/6   | Total:  3h 15m | Avg: 32m 34s | Max: 36m 57s | Hits:  60%/4905  
      🟩 GCC10              Pass: 100%/4   | Total:  2h 18m | Avg: 34m 37s | Max: 35m 37s | Hits:  60%/3424  
      🟩 GCC11              Pass: 100%/7   | Total:  4h 28m | Avg: 38m 21s | Max: 46m 06s | Hits:  60%/5978  
      🟩 GCC12              Pass: 100%/4   | Total:  2h 22m | Avg: 35m 41s | Max: 38m 55s | Hits:  60%/3416  
      🟩 GCC13              Pass: 100%/28  | Total: 11h 37m | Avg: 24m 53s | Max: 53m 33s | Hits:  81%/23912 
      🟩 Intel2023.2.0      Pass: 100%/3   | Total:  1h 58m | Avg: 39m 25s | Max: 42m 22s | Hits:  61%/2340  
      🟩 MSVC14.16          Pass: 100%/1   | Total: 45m 15s | Avg: 45m 15s | Max: 45m 15s | Hits:  65%/697   
      🟩 MSVC14.29          Pass: 100%/2   | Total:  1h 25m | Avg: 42m 56s | Max: 43m 38s | Hits:  65%/1394  
      🟩 MSVC14.39          Pass: 100%/3   | Total:  2h 13m | Avg: 44m 38s | Max: 48m 07s | Hits:  65%/2091  
    🟩 cxx_family
      🟩 Clang              Pass: 100%/59  | Total:  1d 04h | Avg: 29m 19s | Max: 38m 40s | Hits:  71%/49906 
      🟩 GCC                Pass: 100%/63  | Total:  1d 07h | Avg: 29m 54s | Max: 53m 33s | Hits:  70%/53001 
      🟩 Intel              Pass: 100%/3   | Total:  1h 58m | Avg: 39m 25s | Max: 42m 22s | Hits:  61%/2340  
      🟩 MSVC               Pass: 100%/6   | Total:  4h 25m | Avg: 44m 10s | Max: 48m 07s | Hits:  65%/4182  
    🟩 gpu
      🟩 v100               Pass: 100%/131 | Total:  2d 18h | Avg: 30m 30s | Max: 53m 33s | Hits:  70%/109429
    🟩 jobs
      🟩 Build              Pass: 100%/99  | Total:  2d 07h | Avg: 33m 49s | Max: 48m 07s | Hits:  61%/82101 
      🟩 DeviceLaunch       Pass: 100%/8   | Total:  2h 23m | Avg: 17m 52s | Max: 21m 41s | Hits:  99%/6832  
      🟩 GraphCapture       Pass: 100%/8   | Total:  2h 01m | Avg: 15m 14s | Max: 18m 03s | Hits:  99%/6832  
      🟩 HostLaunch         Pass: 100%/8   | Total:  2h 18m | Avg: 17m 18s | Max: 18m 51s | Hits:  99%/6832  
      🟩 TestGPU            Pass: 100%/8   | Total:  4h 04m | Avg: 30m 36s | Max: 53m 33s | Hits:  94%/6832  
    🟩 sm
      🟩 60;70;80;90        Pass: 100%/3   | Total:  2h 15m | Avg: 45m 10s | Max: 46m 06s | Hits:  60%/2562  
      🟩 90a                Pass: 100%/4   | Total:  1h 12m | Avg: 18m 10s | Max: 19m 09s | Hits:  60%/3416  
    🟩 std
      🟩 11                 Pass: 100%/34  | Total: 16h 45m | Avg: 29m 34s | Max: 46m 06s | Hits:  70%/28605 
      🟩 14                 Pass: 100%/37  | Total: 19h 36m | Avg: 31m 47s | Max: 53m 33s | Hits:  68%/30696 
      🟩 17                 Pass: 100%/36  | Total: 18h 34m | Avg: 30m 56s | Max: 44m 43s | Hits:  70%/29927 
      🟩 20                 Pass: 100%/24  | Total: 11h 41m | Avg: 29m 14s | Max: 48m 07s | Hits:  74%/20201 
    
  • 🟩 thrust: Pass: 100%/118 | Total: 2d 07h | Avg: 28m 23s | Max: 1h 02m | Hits: 41%/138912

    🟩 cpu
      🟩 amd64              Pass: 100%/110 | Total:  2d 03h | Avg: 28m 16s | Max:  1h 02m | Hits:  42%/129492
      🟩 arm64              Pass: 100%/8   | Total:  3h 59m | Avg: 29m 58s | Max: 34m 17s | Hits:  27%/9420  
    🟩 ctk
      🟩 11.1               Pass: 100%/15  | Total:  7h 09m | Avg: 28m 36s | Max: 55m 12s | Hits:  27%/17660 
      🟩 11.8               Pass: 100%/3   | Total:  1h 53m | Avg: 37m 53s | Max: 40m 51s | Hits:  27%/3534  
      🟩 12.5               Pass: 100%/100 | Total:  1d 22h | Avg: 28m 04s | Max:  1h 02m | Hits:  43%/117718
    🟩 cudacxx
      🟩 ClangCUDA17        Pass: 100%/2   | Total: 59m 17s | Avg: 29m 38s | Max: 30m 39s | Hits:  26%/2354  
      🟩 nvcc11.1           Pass: 100%/15  | Total:  7h 09m | Avg: 28m 36s | Max: 55m 12s | Hits:  27%/17660 
      🟩 nvcc11.8           Pass: 100%/3   | Total:  1h 53m | Avg: 37m 53s | Max: 40m 51s | Hits:  27%/3534  
      🟩 nvcc12.5           Pass: 100%/98  | Total:  1d 21h | Avg: 28m 02s | Max:  1h 02m | Hits:  44%/115364
    🟩 cudacxx_family
      🟩 ClangCUDA          Pass: 100%/2   | Total: 59m 17s | Avg: 29m 38s | Max: 30m 39s | Hits:  26%/2354  
      🟩 nvcc               Pass: 100%/116 | Total:  2d 06h | Avg: 28m 22s | Max:  1h 02m | Hits:  41%/136558
    🟩 cxx
      🟩 Clang9             Pass: 100%/6   | Total:  2h 53m | Avg: 28m 59s | Max: 35m 09s | Hits:  27%/7062  
      🟩 Clang10            Pass: 100%/3   | Total:  1h 31m | Avg: 30m 23s | Max: 32m 53s | Hits:  27%/3531  
      🟩 Clang11            Pass: 100%/4   | Total:  2h 01m | Avg: 30m 18s | Max: 34m 03s | Hits:  27%/4708  
      🟩 Clang12            Pass: 100%/4   | Total:  2h 06m | Avg: 31m 34s | Max: 36m 03s | Hits:  27%/4708  
      🟩 Clang13            Pass: 100%/4   | Total:  2h 03m | Avg: 30m 54s | Max: 34m 05s | Hits:  27%/4708  
      🟩 Clang14            Pass: 100%/4   | Total:  2h 00m | Avg: 30m 09s | Max: 33m 46s | Hits:  27%/4708  
      🟩 Clang15            Pass: 100%/4   | Total:  1h 58m | Avg: 29m 33s | Max: 31m 49s | Hits:  27%/4708  
      🟩 Clang16            Pass: 100%/4   | Total:  2h 04m | Avg: 31m 13s | Max: 34m 21s | Hits:  27%/4708  
      🟩 Clang17            Pass: 100%/18  | Total:  6h 16m | Avg: 20m 55s | Max: 33m 10s | Hits:  60%/21186 
      🟩 GCC6               Pass: 100%/2   | Total: 50m 36s | Avg: 25m 18s | Max: 27m 42s | Hits:  27%/2354  
      🟩 GCC7               Pass: 100%/6   | Total:  2h 44m | Avg: 27m 21s | Max: 33m 09s | Hits:  27%/7068  
      🟩 GCC8               Pass: 100%/6   | Total:  2h 54m | Avg: 29m 07s | Max: 31m 49s | Hits:  27%/7068  
      🟩 GCC9               Pass: 100%/6   | Total:  3h 06m | Avg: 31m 00s | Max: 36m 45s | Hits:  27%/7068  
      🟩 GCC10              Pass: 100%/4   | Total:  2h 04m | Avg: 31m 08s | Max: 33m 37s | Hits:  27%/4712  
      🟩 GCC11              Pass: 100%/7   | Total:  3h 45m | Avg: 32m 12s | Max: 40m 51s | Hits:  44%/8246  
      🟩 GCC12              Pass: 100%/4   | Total:  2h 09m | Avg: 32m 23s | Max: 35m 25s | Hits:  27%/4712  
      🟩 GCC13              Pass: 100%/20  | Total:  6h 30m | Avg: 19m 30s | Max: 34m 17s | Hits:  64%/23560 
      🟩 Intel2023.2.0      Pass: 100%/3   | Total:  2h 03m | Avg: 41m 03s | Max: 43m 39s | Hits:  27%/3540  
      🟩 MSVC14.16          Pass: 100%/1   | Total: 55m 12s | Avg: 55m 12s | Max: 55m 12s | Hits:  25%/1173  
      🟩 MSVC14.29          Pass: 100%/2   | Total:  2h 03m | Avg:  1h 01m | Max:  1h 02m | Hits:  25%/2346  
      🟩 MSVC14.39          Pass: 100%/6   | Total:  3h 46m | Avg: 37m 43s | Max:  1h 00m | Hits:  62%/7038  
    🟩 cxx_family
      🟩 Clang              Pass: 100%/51  | Total: 22h 56m | Avg: 26m 59s | Max: 36m 03s | Hits:  39%/60027 
      🟩 GCC                Pass: 100%/55  | Total:  1d 00h | Avg: 26m 16s | Max: 40m 51s | Hits:  43%/64788 
      🟩 Intel              Pass: 100%/3   | Total:  2h 03m | Avg: 41m 03s | Max: 43m 39s | Hits:  27%/3540  
      🟩 MSVC               Pass: 100%/9   | Total:  6h 44m | Avg: 44m 59s | Max:  1h 02m | Hits:  49%/10557 
    🟩 gpu
      🟩 v100               Pass: 100%/118 | Total:  2d 07h | Avg: 28m 23s | Max:  1h 02m | Hits:  41%/138912
    🟩 jobs
      🟩 Build              Pass: 100%/99  | Total:  2d 04h | Avg: 31m 43s | Max:  1h 02m | Hits:  30%/116553
      🟩 TestCPU            Pass: 100%/11  | Total:  1h 43m | Avg:  9m 26s | Max: 18m 25s | Hits:  99%/12939 
      🟩 TestGPU            Pass: 100%/8   | Total:  1h 45m | Avg: 13m 14s | Max: 15m 26s | Hits:  99%/9420  
    🟩 sm
      🟩 60;70;80;90        Pass: 100%/3   | Total:  1h 53m | Avg: 37m 53s | Max: 40m 51s | Hits:  27%/3534  
      🟩 90a                Pass: 100%/4   | Total:  1h 16m | Avg: 19m 08s | Max: 20m 08s | Hits:  27%/4712  
    🟩 std
      🟩 11                 Pass: 100%/30  | Total: 11h 51m | Avg: 23m 42s | Max: 37m 00s | Hits:  41%/35328 
      🟩 14                 Pass: 100%/34  | Total: 17h 14m | Avg: 30m 25s | Max:  1h 00m | Hits:  39%/40020 
      🟩 17                 Pass: 100%/33  | Total: 16h 48m | Avg: 30m 34s | Max:  1h 02m | Hits:  40%/38847 
      🟩 20                 Pass: 100%/21  | Total:  9h 55m | Avg: 28m 21s | Max:  1h 00m | Hits:  46%/24717 
    
  • 🟩 pycuda: Pass: 100%/1 | Total: 11m 02s | Avg: 11m 02s | Max: 11m 02s

    🟩 cpu
      🟩 amd64              Pass: 100%/1   | Total: 11m 02s | Avg: 11m 02s | Max: 11m 02s
    🟩 ctk
      🟩 12.5               Pass: 100%/1   | Total: 11m 02s | Avg: 11m 02s | Max: 11m 02s
    🟩 cudacxx
      🟩 nvcc12.5           Pass: 100%/1   | Total: 11m 02s | Avg: 11m 02s | Max: 11m 02s
    🟩 cudacxx_family
      🟩 nvcc               Pass: 100%/1   | Total: 11m 02s | Avg: 11m 02s | Max: 11m 02s
    🟩 cxx
      🟩 GCC13              Pass: 100%/1   | Total: 11m 02s | Avg: 11m 02s | Max: 11m 02s
    🟩 cxx_family
      🟩 GCC                Pass: 100%/1   | Total: 11m 02s | Avg: 11m 02s | Max: 11m 02s
    🟩 gpu
      🟩 v100               Pass: 100%/1   | Total: 11m 02s | Avg: 11m 02s | Max: 11m 02s
    🟩 jobs
      🟩 Test               Pass: 100%/1   | Total: 11m 02s | Avg: 11m 02s | Max: 11m 02s
    

👃 Inspect Changes

Modifications in project?

Project
CCCL Infrastructure
libcu++
CUB
+/- Thrust
CUDA Experimental
pycuda

Modifications in project or dependencies?

Project
CCCL Infrastructure
libcu++
+/- CUB
+/- Thrust
CUDA Experimental
+/- pycuda

🏃‍ Runner counts (total jobs: 250)

# Runner
178 linux-amd64-cpu16
41 linux-amd64-gpu-v100-latest-1
16 linux-arm64-cpu16
15 windows-amd64-cpu16

Copy link
Collaborator

@alliepiper alliepiper left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  • We should figure out and document the edge cases for how this feature behaves for cross-system vectors/references.
  • A new thrust example would be a nice-to-have.
  • The transform_iterator should document allowed and disallowed usages of this feature.

void TestTransformIteratorAsDestination()
{
constexpr auto n = 10;
thrust::host_vector<int> src(n, 1234);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggestion: Also test with device_vector. The CUDA-backed device_vector references/pointers/etc are significantly more complex than the host_vector implementations.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, that opened a can of worms

_CCCL_HOST_DEVICE typename super_t::reference dereference() const
{
// TODO(bgruber): I am not sure this is the correct check here. There is also the trait
// thrust::detail::is_wrapped_reference that sounds fitting. Only allowing to pass through l-value references
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IIRC, the device_vector references are proxy references that is_wrapped_reference is designed to detect. We may need to handle these specially.

thrust::host_vector<int> src(n, 1234);
thrust::host_vector<foo> dst(n, foo{1, 2});

thrust::copy(src.begin(), src.end(), thrust::make_transform_iterator(dst.begin(), access_x{}));
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Question: What happens if this is attempted on cross-system host<->device vectors?

@bernhardmgruber
Copy link
Contributor Author

This feature requires some more changes to improve support for proxy references (we currently force a proxy reference to decay into its value type and pass that further). Fixing that breaks some "actors", whatever that is. I'll see when I have some time to give it another go.

But only if the transform iterator's base iterator does not return a wrapped reference and is not a device_vector
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
thrust For all items related to Thrust.
Projects
Status: In Progress
Development

Successfully merging this pull request may close these issues.

2 participants