-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
4 changed files
with
194 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,46 @@ | ||
from __future__ import annotations | ||
import typing | ||
import base64 | ||
|
||
import strawberry | ||
from strawberry.scalars import JSON, Base64 | ||
from strawberry_sqlalchemy_mapper import strawberry_dataclass_from_model | ||
import orjson | ||
from redis_streamer import utils, ctx | ||
from . import streams | ||
from redis_streamer.config import * | ||
from redis_streamer.models import session, RecordingModel | ||
|
||
|
||
# ---------------------------------------------------------------------------- # | ||
# Queries # | ||
# ---------------------------------------------------------------------------- # | ||
|
||
@strawberry.type | ||
@strawberry_dataclass_from_model(RecordingModel) | ||
class Recording: | ||
pass | ||
|
||
@strawberry.type | ||
class RecordingsQuery: | ||
def recordings(self) -> typing.List[Recording]: | ||
return session.query(RecordingModel).all() | ||
|
||
@strawberry.field | ||
def recording(self, id: strawberry.ID) -> Recording: | ||
return session.query(RecordingModel).get(id) | ||
|
||
|
||
@strawberry.type | ||
class RecordingMutation: | ||
@strawberry.mutation | ||
async def start(self, device_id: str, meta: JSON) -> JSON: | ||
return | ||
|
||
@strawberry.mutation | ||
async def stop(self, device_id: str, meta: JSON) -> JSON: | ||
return | ||
|
||
@strawberry.mutation | ||
async def rename(self, device_id: str) -> JSON: | ||
return await disconnect_device(device_id) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
import os | ||
from sqlalchemy.ext.declarative import declarative_base | ||
from sqlalchemy import create_engine | ||
from sqlalchemy.orm import scoped_session, sessionmaker | ||
|
||
|
||
engine = create_engine(os.getenv("DB_CONNECTION") or "sqlite:///") | ||
session = scoped_session(sessionmaker( | ||
autocommit=False, autoflush=False, bind=engine)) | ||
|
||
|
||
Base = declarative_base() | ||
from .recording import RecordingModel |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
from sqlalchemy import Column, Integer, String, ForeignKey, DateTime | ||
from . import Base | ||
|
||
class RecordingModel(Base): | ||
__tablename__ = "recordings" | ||
id: int = Column(Integer, primary_key=True, index=True) | ||
name: str = Column(String, nullable=True) | ||
start_time: DateTime = Column(DateTime, nullable=False) | ||
end_time: DateTime = Column(DateTime, nullable=False) | ||
device_name: str = Column(String, nullable=True) | ||
|
||
def as_dict(self): | ||
return { | ||
"id": self.id, | ||
"name": self.name, | ||
"start_time": self.start_time, | ||
"end_time": self.end_time, | ||
"device_name": self.device_name, | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,116 @@ | ||
import os | ||
import tqdm | ||
import asyncio | ||
import ray | ||
import datetime | ||
from .core import ctx, Agent | ||
from .models import session, RecordingModel | ||
from .graphql_schema.streams import get_stream_ids | ||
|
||
ray.init() | ||
|
||
|
||
@ray.remote | ||
class Recorder: | ||
def __init__(self): | ||
pass | ||
|
||
async def record(self, name, prefix='', last_entry_id="$", batch=1, block=5000): | ||
self.stopped = False | ||
|
||
session.add(RecordingModel(name=name, start_time=datetime.datetime.now())) | ||
session.commit() | ||
rec_entry = session.query(RecordingModel).filter(RecordingModel.name == name).order_by(RecordingModel.start_time).last() | ||
|
||
agent = Agent() | ||
stream_ids = await get_stream_ids() | ||
cursor = agent.init_cursor({s: last_entry_id for s in stream_ids}) | ||
try: | ||
while not self.stopped: | ||
# read data from redis | ||
results, cursor = await agent.read(cursor, latest=False, count=batch or 1, block=block) | ||
for sid, xs in results: | ||
for ts, data in xs: | ||
await writer[sid].write(data, ts) | ||
finally: | ||
rec_entry.end_time = datetime.datetime.now() | ||
session.commit() | ||
|
||
def stop(self): | ||
self.stopped = True | ||
|
||
def replay(self): | ||
pass | ||
|
||
|
||
class Writers: | ||
def __init__(self, cls): | ||
self.cls = cls | ||
self.writers = {} | ||
self.is_entered = False | ||
|
||
async def __aenter__(self): | ||
self.is_entered = True | ||
await asyncio.gather(*(w.__aenter__() for w in self.writers.values())) | ||
return self | ||
|
||
async def get_writer(self, sid): | ||
if sid not in self.writers: | ||
self.writers[sid] = w = self.cls(sid) | ||
if self.is_entered: | ||
await w.__aenter__() | ||
return self.writers[sid] | ||
|
||
async def write(self, sid, data, ts): | ||
writer = await self.get_writer(sid) | ||
await writer.write(data, ts) | ||
|
||
async def __aexit__(self, *a): | ||
await asyncio.gather(*(w.__aexit__(*a) for w in self.writers.values())) | ||
self.is_entered = False | ||
|
||
|
||
|
||
class RawWriter: | ||
raw=True | ||
def __init__(self, name, store_dir='', max_len=1000, max_size=9.5*MB, **kw): | ||
super().__init__(**kw) | ||
#self.fname = os.path.join(store_dir, f'{name}.zip') | ||
self.dir = os.path.join(store_dir, name) | ||
os.makedirs(self.dir, exist_ok=True) | ||
self.name = name | ||
self.max_len = max_len | ||
self.max_size = max_size | ||
|
||
def context(self, sample=None, t_start=None): | ||
try: | ||
self.size = 0 | ||
self.buffer = [] | ||
with tqdm.tqdm(total=self.max_len, desc=self.name) as self.pbar: | ||
yield self | ||
finally: | ||
if self.buffer: | ||
self._dump(self.buffer) | ||
self.buffer.clear() | ||
|
||
def _dump(self, data): | ||
if not data: | ||
return | ||
import zipfile | ||
fname = os.path.join(self.dir, f'{data[0][1]}_{data[-1][1]}.zip') | ||
tqdm.tqdm.write(f"writing {fname}") | ||
with zipfile.ZipFile(fname, 'a', zipfile.ZIP_STORED, False) as zf: | ||
for d, ts in data: | ||
zf.writestr(ts, d) | ||
|
||
def write(self, data, ts): | ||
self.pbar.update() | ||
self.size += len(data) | ||
self.buffer.append([data, ts]) | ||
if len(self.buffer) >= self.max_len or self.size >= self.max_size: | ||
self._dump(self.buffer) | ||
self.buffer.clear() | ||
self.pbar.reset() | ||
self.size = 0 | ||
|
||
|