Skip to content

Commit

Permalink
Merge pull request #722 from Guovin/dev
Browse files Browse the repository at this point in the history
refactor:get_speed_m3u8(#719)
  • Loading branch information
Guovin authored Dec 23, 2024
2 parents c2717b5 + ab5683e commit 32efa80
Showing 1 changed file with 89 additions and 49 deletions.
138 changes: 89 additions & 49 deletions utils/speed.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,37 +6,76 @@

import m3u8
from aiohttp import ClientSession, TCPConnector
from multidict import CIMultiDictProxy

from utils.config import config
from utils.tools import is_ipv6, remove_cache_info


async def get_speed_with_download(url: str, timeout: int = config.sort_timeout) -> dict[str, float | None]:
async def get_speed_with_download(url: str, session: ClientSession = None, timeout: int = config.sort_timeout) -> dict[
str, float | None]:
"""
Get the speed of the url with a total timeout
"""
start_time = time()
total_size = 0
total_time = 0
info = {'speed': None, 'delay': None}
if session is None:
session = ClientSession(connector=TCPConnector(ssl=False), trust_env=True)
created_session = True
else:
created_session = False
try:
async with ClientSession(
connector=TCPConnector(ssl=False), trust_env=True
) as session:
async with session.get(url, timeout=timeout) as response:
if response.status == 404:
return info
info['delay'] = int(round((time() - start_time) * 1000))
async for chunk in response.content.iter_any():
if chunk:
total_size += len(chunk)
async with session.get(url, timeout=timeout) as response:
if response.status == 404:
return info
info['delay'] = int(round((time() - start_time) * 1000))
async for chunk in response.content.iter_any():
if chunk:
total_size += len(chunk)
except Exception as e:
pass
finally:
end_time = time()
total_time += end_time - start_time
info['speed'] = (total_size / total_time if total_time > 0 else 0) / 1024 / 1024
return info
if created_session:
await session.close()
end_time = time()
total_time += end_time - start_time
info['speed'] = (total_size / total_time if total_time > 0 else 0) / 1024 / 1024
return info


async def get_m3u8_headers(url: str, session: ClientSession = None, timeout: int = 5) -> CIMultiDictProxy[str] | dict[
any, any]:
"""
Get the headers of the m3u8 url
"""
if session is None:
session = ClientSession(connector=TCPConnector(ssl=False), trust_env=True)
created_session = True
else:
created_session = False
try:
async with session.head(url, timeout=timeout) as response:
return response.headers
except:
pass
finally:
if created_session:
await session.close()
return {}


def check_m3u8_valid(headers: CIMultiDictProxy[str] | dict[any, any]) -> bool:
"""
Check the m3u8 url is valid
"""
content_type = headers.get('Content-Type')
if content_type:
content_type = content_type.lower()
if 'application/vnd.apple.mpegurl' in content_type:
return True
return False


async def get_speed_m3u8(url: str, timeout: int = config.sort_timeout) -> dict[str, float | None]:
Expand All @@ -47,44 +86,45 @@ async def get_speed_m3u8(url: str, timeout: int = config.sort_timeout) -> dict[s
try:
url = quote(url, safe=':/?$&=@[]').partition('$')[0]
async with ClientSession(connector=TCPConnector(ssl=False), trust_env=True) as session:
async with session.head(url, timeout=5) as response:
content_type = response.headers.get('Content-Type')
if content_type:
content_type = content_type.lower()
location = response.headers.get('Location')
if 'application/vnd.apple.mpegurl' in content_type:
url = location or url
headers = await get_m3u8_headers(url, session)
if check_m3u8_valid(headers):
location = headers.get('Location')
if location:
info.update(await get_speed_m3u8(location, timeout))
else:
m3u8_obj = m3u8.load(url, timeout=2)
playlists = m3u8_obj.data.get('playlists')
segments = m3u8_obj.segments
if not segments and playlists:
parsed_url = urlparse(url)
url = f"{parsed_url.scheme}://{parsed_url.netloc}{parsed_url.path.rsplit('/', 1)[0]}/{playlists[0].get('uri', '')}"
uri_headers = await get_m3u8_headers(url, session)
if not check_m3u8_valid(uri_headers):
if uri_headers.get('Content-Length'):
info.update(await get_speed_with_download(url, session, timeout))
return info
m3u8_obj = m3u8.load(url, timeout=2)
playlists = m3u8_obj.data.get('playlists')
segments = m3u8_obj.segments
if not segments and playlists:
parsed_url = urlparse(url)
url = f"{parsed_url.scheme}://{parsed_url.netloc}{parsed_url.path.rsplit('/', 1)[0]}/{playlists[0].get('uri', '')}"
m3u8_obj = m3u8.load(url, timeout=2)
segments = m3u8_obj.segments
if not segments:
return info
ts_urls = [segment.absolute_uri for segment in segments]
speed_list = []
start_time = time()
for ts_url in ts_urls:
if time() - start_time > timeout:
break
download_info = await get_speed_with_download(ts_url, timeout)
speed_list.append(download_info['speed'])
if info['delay'] is None and download_info['delay'] is not None:
info['delay'] = download_info['delay']
info['speed'] = sum(speed_list) / len(speed_list) if speed_list else 0
elif location:
info.update(await get_speed_m3u8(location, timeout))
elif response.headers.get('Content-Length'):
info.update(await get_speed_with_download(url, timeout))
else:
return info
if not segments:
return info
ts_urls = [segment.absolute_uri for segment in segments]
speed_list = []
start_time = time()
for ts_url in ts_urls:
if time() - start_time > timeout:
break
download_info = await get_speed_with_download(ts_url, session, timeout)
speed_list.append(download_info['speed'])
if info['delay'] is None and download_info['delay'] is not None:
info['delay'] = download_info['delay']
info['speed'] = sum(speed_list) / len(speed_list) if speed_list else 0
elif headers.get('Content-Length'):
info.update(await get_speed_with_download(url, session, timeout))
else:
return info
except:
pass
finally:
return info
return info


async def get_delay_requests(url, timeout=config.sort_timeout, proxy=None):
Expand Down

0 comments on commit 32efa80

Please sign in to comment.