-
Notifications
You must be signed in to change notification settings - Fork 3
/
928.cpp
87 lines (84 loc) · 2.53 KB
/
928.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
class DisjointSet {
private:
vector<int> parent;
vector<int> rank;
vector<int> size;
public:
DisjointSet(int n) {
parent.resize(n, 0);
rank.resize(n, 0);
size.resize(n, 1);
for (int i = 0; i < n; ++i) parent[i] = i;
}
int find(int x) {
if (x != parent[x]) {
parent[x] = find(parent[x]);
}
return parent[x];
}
void join(int x, int y) {
int pX = find(x);
int pY = find(y);
if (pX == pY) return;
if (rank[pX] > rank[pY]) {
parent[pY] = pX;
size[pX] += size[pY];
}
else if (rank[pX] < rank[pY]) {
parent[pX] = pY;
size[pY] += size[pX];
}
else {
parent[pY] = pX;
rank[pX]++;
size[pX] += size[pY];
}
}
int getSize(int x) {
return size[find(x)];
}
};
class Solution {
public:
int minMalwareSpread(vector<vector<int>>& graph, vector<int>& initial) {
int n = graph.size();
unordered_set<int> initialHash(initial.begin(), initial.end());
vector<int> naive;
for (int i = 0; i < n; ++i) {
if (initialHash.find(i) == initialHash.end()) naive.push_back(i);
}
DisjointSet* disjointSet = new DisjointSet(n);
// connect all naive nodes
for (const auto& node1 : naive) {
for (const auto& node2 : naive) {
if (graph[node1][node2]) disjointSet->join(node1, node2);
}
}
unordered_map<int, unordered_set<int>> malware2groups;
unordered_map<int, int> node2malwareCnt;
for (auto& malware : initial) {
for (auto& node : naive) {
if (graph[malware][node]) malware2groups[malware].insert(disjointSet->find(node));
}
for (auto& infectedNode : malware2groups[malware]) {
node2malwareCnt[infectedNode]++;
}
}
int maxCount = 0;
int res = initial[0];
for (auto& malware : initial) {
int count = 0;
for (auto& infectedNode : malware2groups[malware]) {
if (node2malwareCnt[infectedNode] == 1) count += disjointSet->getSize(infectedNode);
}
if (count > maxCount) {
maxCount = max(maxCount, count);
res = malware;
}
else if (count == maxCount) {
res = min(res, malware);
}
}
return res;
}
};