1818
1919package org .apache .flink .api .connector .source .mocks ;
2020
21+ import org .apache .flink .api .connector .source .ReaderInfo ;
2122import org .apache .flink .api .connector .source .SourceEvent ;
2223import org .apache .flink .api .connector .source .SplitEnumerator ;
2324import org .apache .flink .api .connector .source .SplitEnumeratorContext ;
2829
2930import java .io .IOException ;
3031import java .util .ArrayList ;
32+ import java .util .Collection ;
3133import java .util .Collections ;
32- import java .util .Comparator ;
3334import java .util .HashMap ;
3435import java .util .HashSet ;
3536import java .util .List ;
3637import java .util .Map ;
3738import java .util .Set ;
38- import java .util .SortedSet ;
39- import java .util .TreeSet ;
39+ import java .util .stream .Collectors ;
4040
4141/** A mock {@link SplitEnumerator} for unit tests. */
4242public class MockSplitEnumerator
4343 implements SplitEnumerator <MockSourceSplit , Set <MockSourceSplit >>, SupportsBatchSnapshot {
44- private final SortedSet <MockSourceSplit > unassignedSplits ;
44+ private final Map <Integer , Set <MockSourceSplit >> pendingSplitAssignment ;
45+ private final Map <String , Integer > globalSplitAssignment ;
4546 private final SplitEnumeratorContext <MockSourceSplit > enumContext ;
4647 private final List <SourceEvent > handledSourceEvent ;
4748 private final List <Long > successfulCheckpoints ;
@@ -50,22 +51,24 @@ public class MockSplitEnumerator
5051
5152 public MockSplitEnumerator (int numSplits , SplitEnumeratorContext <MockSourceSplit > enumContext ) {
5253 this (new HashSet <>(), enumContext );
54+ List <MockSourceSplit > unassignedSplits = new ArrayList <>();
5355 for (int i = 0 ; i < numSplits ; i ++) {
5456 unassignedSplits .add (new MockSourceSplit (i ));
5557 }
58+ recalculateAssignments (unassignedSplits );
5659 }
5760
5861 public MockSplitEnumerator (
5962 Set <MockSourceSplit > unassignedSplits ,
6063 SplitEnumeratorContext <MockSourceSplit > enumContext ) {
61- this .unassignedSplits =
62- new TreeSet <>(Comparator .comparingInt (o -> Integer .parseInt (o .splitId ())));
63- this .unassignedSplits .addAll (unassignedSplits );
64+ this .pendingSplitAssignment = new HashMap <>();
65+ this .globalSplitAssignment = new HashMap <>();
6466 this .enumContext = enumContext ;
6567 this .handledSourceEvent = new ArrayList <>();
6668 this .successfulCheckpoints = new ArrayList <>();
6769 this .started = false ;
6870 this .closed = false ;
71+ recalculateAssignments (unassignedSplits );
6972 }
7073
7174 @ Override
@@ -83,25 +86,36 @@ public void handleSourceEvent(int subtaskId, SourceEvent sourceEvent) {
8386
8487 @ Override
8588 public void addSplitsBack (List <MockSourceSplit > splits , int subtaskId ) {
86- unassignedSplits .addAll (splits );
89+ // add back to same subtaskId.
90+ putPendingAssignments (subtaskId , splits );
8791 }
8892
8993 @ Override
9094 public void addReader (int subtaskId ) {
91- List <MockSourceSplit > assignment = new ArrayList <>();
92- for (MockSourceSplit split : unassignedSplits ) {
93- if (Integer .parseInt (split .splitId ()) % enumContext .currentParallelism () == subtaskId ) {
94- assignment .add (split );
95+ ReaderInfo readerInfo = enumContext .registeredReaders ().get (subtaskId );
96+ List <MockSourceSplit > splitsOnRecovery = readerInfo .getReportedSplitsOnRegistration ();
97+
98+ List <MockSourceSplit > redistributedSplits = new ArrayList <>();
99+ List <MockSourceSplit > addBackSplits = new ArrayList <>();
100+ for (MockSourceSplit split : splitsOnRecovery ) {
101+ if (!globalSplitAssignment .containsKey (split .splitId ())) {
102+ // if the split is not present in globalSplitAssignment, it means that this split is
103+ // being registered for the first time and is eligible for redistribution.
104+ redistributedSplits .add (split );
105+ } else if (!globalSplitAssignment .containsKey (split .splitId ())) {
106+ // if split is already assigned to other sub-task, just ignore it. Otherwise, add
107+ // back to this sub-task again.
108+ addBackSplits .add (split );
95109 }
96110 }
97- enumContext . assignSplits (
98- new SplitsAssignment <>( Collections . singletonMap ( subtaskId , assignment )) );
99- unassignedSplits . removeAll ( assignment );
111+ recalculateAssignments ( redistributedSplits );
112+ putPendingAssignments ( subtaskId , addBackSplits );
113+ assignAllSplits ( );
100114 }
101115
102116 @ Override
103117 public Set <MockSourceSplit > snapshotState (long checkpointId ) {
104- return unassignedSplits ;
118+ return getUnassignedSplits () ;
105119 }
106120
107121 @ Override
@@ -114,11 +128,6 @@ public void close() throws IOException {
114128 this .closed = true ;
115129 }
116130
117- public void addNewSplits (List <MockSourceSplit > newSplits ) {
118- unassignedSplits .addAll (newSplits );
119- assignAllSplits ();
120- }
121-
122131 // --------------------
123132
124133 public boolean started () {
@@ -130,7 +139,9 @@ public boolean closed() {
130139 }
131140
132141 public Set <MockSourceSplit > getUnassignedSplits () {
133- return unassignedSplits ;
142+ return pendingSplitAssignment .values ().stream ()
143+ .flatMap (Set ::stream )
144+ .collect (Collectors .toSet ());
134145 }
135146
136147 public List <SourceEvent > getHandledSourceEvent () {
@@ -145,17 +156,27 @@ public List<Long> getSuccessfulCheckpoints() {
145156
146157 private void assignAllSplits () {
147158 Map <Integer , List <MockSourceSplit >> assignment = new HashMap <>();
148- unassignedSplits .forEach (
149- split -> {
150- int subtaskId =
151- Integer .parseInt (split .splitId ()) % enumContext .currentParallelism ();
152- if (enumContext .registeredReaders ().containsKey (subtaskId )) {
153- assignment
154- .computeIfAbsent (subtaskId , ignored -> new ArrayList <>())
155- .add (split );
156- }
157- });
159+ for (Map .Entry <Integer , Set <MockSourceSplit >> iter : pendingSplitAssignment .entrySet ()) {
160+ Integer subtaskId = iter .getKey ();
161+ if (enumContext .registeredReaders ().containsKey (subtaskId )) {
162+ assignment .put (subtaskId , new ArrayList <>(iter .getValue ()));
163+ }
164+ }
158165 enumContext .assignSplits (new SplitsAssignment <>(assignment ));
159- assignment .values ().forEach (l -> unassignedSplits .removeAll (l ));
166+ assignment .keySet ().forEach (pendingSplitAssignment ::remove );
167+ }
168+
169+ private void recalculateAssignments (Collection <MockSourceSplit > newSplits ) {
170+ for (MockSourceSplit split : newSplits ) {
171+ int subtaskId = Integer .parseInt (split .splitId ()) % enumContext .currentParallelism ();
172+ putPendingAssignments (subtaskId , Collections .singletonList (split ));
173+ }
174+ }
175+
176+ private void putPendingAssignments (int subtaskId , Collection <MockSourceSplit > splits ) {
177+ Set <MockSourceSplit > pendingSplits =
178+ pendingSplitAssignment .computeIfAbsent (subtaskId , HashSet ::new );
179+ pendingSplits .addAll (splits );
180+ splits .forEach (split -> globalSplitAssignment .put (split .splitId (), subtaskId ));
160181 }
161182}
0 commit comments