Skip to content

Commit 4d20a84

Browse files
committed
begin bulk fetch from cache
1 parent 4b741cc commit 4d20a84

File tree

3 files changed

+123
-42
lines changed

3 files changed

+123
-42
lines changed

conda_libmamba_solver/shards.py

Lines changed: 20 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -174,7 +174,7 @@ def build_repodata(self) -> RepodataDict:
174174
return repodata
175175

176176

177-
class Shards(ShardLike):
177+
class ShardsIndex(ShardLike):
178178
def __init__(self, shards_index: ShardsIndex, url: str, shards_cache: shards_cache.ShardCache):
179179
"""
180180
Args:
@@ -185,13 +185,16 @@ def __init__(self, shards_index: ShardsIndex, url: str, shards_cache: shards_cac
185185
self.url = url
186186
self.shards_cache = shards_cache
187187

188-
self.session = get_session(self.base_url)
188+
# can we share a session for multiple subdir's of the same channel, or
189+
# any time self.shards_base_url is similar to another Shards() instance?
190+
self.session = get_session(self.shards_base_url)
189191

190192
self.repodata_no_packages = {
191193
k: v for k, v in self.shards_index.items() if k not in ("shards",)
192194
}
193195

194196
# used to write out repodata subset
197+
# not used in traversal algorithm
195198
self.visited: dict[str, Shard | None] = {}
196199

197200
@property
@@ -203,7 +206,7 @@ def packages_index(self):
203206
return self.shards_index["shards"]
204207

205208
@property
206-
def base_url(self) -> str:
209+
def shards_base_url(self) -> str:
207210
"""
208211
Return self.url joined with shards_base_url.
209212
Note shards_base_url can be a relative or an absolute url.
@@ -218,7 +221,7 @@ def shard_url(self, package: str) -> str:
218221
"""
219222
shard_name = f"{self.packages_index[package].hex()}.msgpack.zst"
220223
# "Individual shards are stored under the URL <shards_base_url><sha256>.msgpack.zst"
221-
return urljoin(self.base_url, shard_name)
224+
return urljoin(self.shards_base_url, shard_name)
222225

223226
def fetch_shard(self, package: str) -> Shard:
224227
"""
@@ -341,42 +344,43 @@ def repodata_shards(url, cache: RepodataCache) -> bytes:
341344
return response_bytes
342345

343346

344-
def fetch_shards(sd: SubdirData) -> Shards | None:
347+
def fetch_shards(
348+
sd: SubdirData, cache: shards_cache.ShardCache | None = None
349+
) -> ShardsIndex | None:
345350
"""
346351
Check a SubdirData's URL for shards.
347352
Return shards index bytes from cache or network.
348353
Return None if not found; caller should fetch normal repodata.
349354
"""
350355

351356
fetch = sd.repo_fetch
352-
cache = fetch.repo_cache
357+
repo_cache = fetch.repo_cache
353358
# cache.load_state() will clear the file on JSONDecodeError but cache.load()
354359
# will raise the exception
355-
cache.load_state(binary=True)
356-
cache_state = cache.state
360+
repo_cache.load_state(binary=True)
361+
cache_state = repo_cache.state
362+
363+
if cache is None:
364+
cache = shards_cache.ShardCache(Path(conda.gateways.repodata.create_cache_dir()))
357365

358366
if cache_state.should_check_format("shards"):
359367
try:
360368
# look for shards index
361369
shards_index_url = f"{sd.url_w_subdir}/repodata_shards.msgpack.zst"
362-
found = repodata_shards(shards_index_url, cache)
370+
found = repodata_shards(shards_index_url, repo_cache)
363371
cache_state.set_has_format("shards", True)
364372
# this will also set state["refresh_ns"] = time.time_ns(); we could
365373
# call cache.refresh() if we got a 304 instead:
366-
cache.save(found)
374+
repo_cache.save(found)
367375

368376
# basic parse (move into caller?)
369377
shards_index: ShardsIndex = msgpack.loads(zstandard.decompress(found)) # type: ignore
370-
shards = Shards(
371-
shards_index,
372-
shards_index_url,
373-
shards_cache.ShardCache(Path(conda.gateways.repodata.create_cache_dir())),
374-
)
378+
shards = ShardsIndex(shards_index, shards_index_url, cache)
375379
return shards
376380

377381
except (HTTPError, conda.gateways.repodata.RepodataIsEmpty):
378382
# fetch repodata.json / repodata.json.zst instead
379383
cache_state.set_has_format("shards", False)
380-
cache.refresh(refresh_ns=1) # expired but not falsy
384+
repo_cache.refresh(refresh_ns=1) # expired but not falsy
381385

382386
return None

conda_libmamba_solver/shards_subset.py

Lines changed: 24 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -36,11 +36,15 @@
3636
import json
3737
import sys
3838
from dataclasses import dataclass
39+
from pathlib import Path
3940

41+
import conda.gateways.repodata
4042
from conda.base.context import context
4143
from conda.core.subdir_data import SubdirData
4244
from conda.models.channel import Channel
4345

46+
from conda_libmamba_solver import shards_cache
47+
4448
from .shards import RepodataDict, ShardLike, fetch_shards, shard_mentioned_packages
4549

4650

@@ -100,6 +104,7 @@ def shortest(self, start_packages):
100104
self.nodes = {package: Node(0, package) for package in start_packages}
101105
unvisited = [(n.distance, n) for n in self.nodes.values()]
102106
while unvisited:
107+
# parallel fetch all unvisited shards but don't mark as visited
103108
original_priority, node = heapq.heappop(unvisited)
104109
if (
105110
original_priority != node.distance
@@ -116,16 +121,7 @@ def shortest(self, start_packages):
116121

117122

118123
def build_repodata_subset(tmp_path, root_packages, channels):
119-
channel_data: dict[str, ShardLike] = {}
120-
for channel in channels:
121-
for channel_url in Channel(channel).urls(True, context.subdirs):
122-
subdir_data = SubdirData(Channel(channel_url))
123-
found = fetch_shards(subdir_data)
124-
if not found:
125-
repodata_json, _ = subdir_data.repo_fetch.fetch_latest_parsed()
126-
repodata_json = RepodataDict(repodata_json) # type: ignore
127-
found = ShardLike(repodata_json, channel_url)
128-
channel_data[channel_url] = found
124+
channel_data = fetch_channels(channels)
129125

130126
subset = RepodataSubset((*channel_data.values(),))
131127
subset.shortest(root_packages)
@@ -147,3 +143,21 @@ def build_repodata_subset(tmp_path, root_packages, channels):
147143
subset_paths[channel] = repodata_path
148144

149145
return subset_paths, repodata_size
146+
147+
148+
def fetch_channels(channels):
149+
channel_data: dict[str, ShardLike] = {}
150+
151+
# share single disk cache for all Shards() instances
152+
cache = shards_cache.ShardCache(Path(conda.gateways.repodata.create_cache_dir()))
153+
154+
for channel in channels:
155+
for channel_url in Channel(channel).urls(True, context.subdirs):
156+
subdir_data = SubdirData(Channel(channel_url))
157+
found = fetch_shards(subdir_data, cache)
158+
if not found:
159+
repodata_json, _ = subdir_data.repo_fetch.fetch_latest_parsed()
160+
repodata_json = RepodataDict(repodata_json) # type: ignore
161+
found = ShardLike(repodata_json, channel_url)
162+
channel_data[channel_url] = found
163+
return channel_data

tests/test_shards.py

Lines changed: 79 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
import random
1414
import time
1515
import urllib.parse
16+
from contextlib import contextmanager
1617
from hashlib import sha256
1718
from pathlib import Path
1819
from typing import TYPE_CHECKING
@@ -27,13 +28,12 @@
2728
from conda_libmamba_solver import shards, shards_cache
2829
from conda_libmamba_solver.index import LibMambaIndexHelper
2930
from conda_libmamba_solver.shards import (
30-
RepodataDict,
3131
ShardLike,
32-
Shards,
32+
ShardsIndex,
3333
fetch_shards,
3434
shard_mentioned_packages,
3535
)
36-
from conda_libmamba_solver.shards_subset import build_repodata_subset
36+
from conda_libmamba_solver.shards_subset import Node, build_repodata_subset, fetch_channels
3737
from tests.channel_testing.helpers import _dummy_http_server
3838

3939
if TYPE_CHECKING:
@@ -157,19 +157,10 @@ def test_fetch_shards(conda_no_token: None):
157157

158158
channels.append(Channel("conda-forge-sharded"))
159159

160-
channel_data: dict[str, ShardLike] = {}
161-
for channel in channels:
162-
for channel_url in Channel(channel).urls(True, context.subdirs):
163-
subdir_data = SubdirData(Channel(channel_url))
164-
found = fetch_shards(subdir_data)
165-
if not found:
166-
repodata_json, _ = subdir_data.repo_fetch.fetch_latest_parsed()
167-
repodata_json = RepodataDict(repodata_json) # type: ignore
168-
found = ShardLike(repodata_json, channel_url)
169-
channel_data[channel_url] = found
160+
channel_data = fetch_channels(channels)
170161

171162
# at least one should be real shards, not repodata.json presented as shards.
172-
assert any(isinstance(channel, Shards) for channel in channel_data.values())
163+
assert any(isinstance(channel, ShardsIndex) for channel in channel_data.values())
173164

174165

175166
def test_shard_cache(tmp_path: Path):
@@ -317,6 +308,9 @@ def test_shardlike():
317308

318309

319310
def test_shardlike_repr():
311+
"""
312+
Code coverage for ShardLike.__repr__()
313+
"""
320314
shardlike = ShardLike(
321315
{
322316
"packages": {},
@@ -325,7 +319,7 @@ def test_shardlike_repr():
325319
},
326320
"https://conda.anaconda.org/",
327321
)
328-
cls, url, *rest = repr(shardlike).split()
322+
cls, url, *_ = repr(shardlike).split()
329323
assert "ShardLike" in cls
330324
assert shardlike.url == url
331325

@@ -361,7 +355,8 @@ def test_shardlike_repr():
361355

362356
def test_traverse_shards_3(conda_no_token: None, tmp_path):
363357
"""
364-
Another go at the dependency traversal algorithm.
358+
Build repodata subset using the third attempt at a dependency traversal
359+
algorithm.
365360
"""
366361

367362
logging.basicConfig(level=logging.INFO)
@@ -390,6 +385,9 @@ def test_traverse_shards_3(conda_no_token: None, tmp_path):
390385

391386

392387
def test_shards_indexhelper(conda_no_token):
388+
"""
389+
Load LibMambaIndexHelper with parameters that will enable sharded repodata.
390+
"""
393391
channels = [*context.default_channels, Channel("conda-forge-sharded")]
394392

395393
class fake_in_state:
@@ -407,3 +405,68 @@ class fake_in_state:
407405
)
408406

409407
print(helper.repos)
408+
409+
410+
@contextmanager
411+
def _timer(name: str):
412+
begin = time.monotonic_ns()
413+
yield
414+
end = time.monotonic_ns()
415+
print(f"{name} took {(end - begin) / 1e9:0.6f}s")
416+
417+
418+
def test_parallel_fetcherator(conda_no_token: None):
419+
channels = [*context.default_channels, Channel("conda-forge-sharded")]
420+
roots = [
421+
Node(distance=0, package="ca-certificates", visited=False),
422+
Node(distance=0, package="icu", visited=False),
423+
Node(distance=0, package="expat", visited=False),
424+
Node(distance=0, package="libexpat", visited=False),
425+
Node(distance=0, package="libffi", visited=False),
426+
Node(distance=0, package="libmpdec", visited=False),
427+
Node(distance=0, package="libzlib", visited=False),
428+
Node(distance=0, package="openssl", visited=False),
429+
Node(distance=0, package="python", visited=False),
430+
Node(distance=0, package="readline", visited=False),
431+
Node(distance=0, package="liblzma", visited=False),
432+
Node(distance=0, package="xz", visited=False),
433+
Node(distance=0, package="libsqlite", visited=False),
434+
Node(distance=0, package="tk", visited=False),
435+
Node(distance=0, package="ncurses", visited=False),
436+
Node(distance=0, package="zlib", visited=False),
437+
Node(distance=0, package="pip", visited=False),
438+
Node(distance=0, package="twine", visited=False),
439+
Node(distance=0, package="python_abi", visited=False),
440+
Node(distance=0, package="tzdata", visited=False),
441+
]
442+
443+
with _timer("repodata.json/shards index fetch"):
444+
channel_data = fetch_channels(channels)
445+
446+
with _timer("Shard fetch"):
447+
sharded = [
448+
channel for channel in channel_data.values() if isinstance(channel, ShardsIndex)
449+
]
450+
assert sharded, "No sharded repodata found"
451+
452+
wanted = []
453+
for shard in sharded:
454+
for root in roots:
455+
if root.package in shard:
456+
wanted.append((shard, root.package, shard.shard_url(root.package)))
457+
458+
print(len(wanted), "shards to fetch")
459+
460+
shared_shard_cache = sharded[0].shards_cache
461+
from_cache = shared_shard_cache.retrieve_multiple([shard_url for *_, shard_url in wanted])
462+
463+
for url, shard_or_none in from_cache.items():
464+
if shard_or_none is not None:
465+
print(f"Cache hit for {url}")
466+
467+
# add fetched Shard objects to Shards objects visited dict
468+
for shard, package, shard_url in wanted:
469+
if from_cache_shard := from_cache.get(shard_url):
470+
shard.visited[package] = from_cache_shard
471+
472+
# XXX don't call everything Shard/Shards

0 commit comments

Comments
 (0)