Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: refactor stream variants #91

Merged
merged 1 commit into from
May 29, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion backend/api/src/database/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@ pub mod channel_role;
pub mod channel_role_grant;
pub mod global_role;
pub mod global_role_grant;
pub mod protobuf;
pub mod session;
pub mod stream;
pub mod stream_bitrate_update;
pub mod stream_event;
pub mod stream_variant;
pub mod user;
64 changes: 64 additions & 0 deletions backend/api/src/database/protobuf.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
#[derive(Debug, Clone, Default)]
pub enum ProtobufValue<T: prost::Message + std::default::Default> {
#[default]
None,
Some(T),
Err(prost::DecodeError),
}

impl<T: prost::Message + std::default::Default> ProtobufValue<T> {
#[allow(dead_code)]
pub fn unwrap(self) -> Option<T> {
match self {
Self::Some(data) => Some(data),
Self::None => None,
Self::Err(err) => panic!(
"called `ProtobufValue::unwrap()` on a `Err` value: {:?}",
err
),
}
}
}

impl<T: prost::Message + std::default::Default, F> From<Option<F>> for ProtobufValue<T>
where
ProtobufValue<T>: From<F>,
{
fn from(data: Option<F>) -> Self {
match data {
Some(data) => Self::from(data),
None => Self::None,
}
}
}

impl<T: prost::Message + std::default::Default> From<Vec<u8>> for ProtobufValue<T> {
fn from(data: Vec<u8>) -> Self {
match T::decode(data.as_slice()) {
Ok(variants) => Self::Some(variants),
Err(e) => Self::Err(e),
}
}
}

impl<T: prost::Message + std::default::Default + PartialEq> PartialEq for ProtobufValue<T> {
fn eq(&self, other: &Self) -> bool {
match (self, other) {
(Self::None, Self::None) => true,
(Self::Some(a), Self::Some(b)) => a == b,
_ => false,
}
}
}

impl<T: prost::Message + std::default::Default + PartialEq> PartialEq<Option<T>>
for ProtobufValue<T>
{
fn eq(&self, other: &Option<T>) -> bool {
match (self, other) {
(Self::None, None) => true,
(Self::Some(a), Some(b)) => a == b,
_ => false,
}
}
}
5 changes: 5 additions & 0 deletions backend/api/src/database/stream.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
use crate::pb::scuffle::types::StreamVariants;
use chrono::{DateTime, Utc};
use uuid::Uuid;

use super::protobuf::ProtobufValue;

#[derive(Debug, Clone, Default, Copy, Eq, PartialEq)]
#[repr(i64)]
pub enum State {
Expand Down Expand Up @@ -62,6 +65,8 @@ pub struct Model {
pub ingest_address: String,
/// The connection which owns the stream.
pub connection_id: Uuid,
/// The Stream Variants
pub variants: ProtobufValue<StreamVariants>,
/// The time the stream was created.
pub created_at: DateTime<Utc>,
/// The time the stream was last updated.
Expand Down
35 changes: 0 additions & 35 deletions backend/api/src/database/stream_variant.rs

This file was deleted.

1 change: 0 additions & 1 deletion backend/api/src/dataloader/mod.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
pub mod session;
pub mod stream;
pub mod stream_variant;
pub mod user;
pub mod user_permissions;
47 changes: 0 additions & 47 deletions backend/api/src/dataloader/stream_variant.rs

This file was deleted.

3 changes: 0 additions & 3 deletions backend/api/src/global/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ use common::context::Context;

use crate::config::AppConfig;
use crate::dataloader::stream::StreamByIdLoader;
use crate::dataloader::stream_variant::StreamVariantsByStreamIdLoader;
use crate::dataloader::user_permissions::UserPermissionsByIdLoader;
use crate::dataloader::{
session::SessionByIdLoader, user::UserByIdLoader, user::UserByUsernameLoader,
Expand All @@ -22,7 +21,6 @@ pub struct GlobalState {
pub session_by_id_loader: DataLoader<SessionByIdLoader>,
pub user_permisions_by_id_loader: DataLoader<UserPermissionsByIdLoader>,
pub stream_by_id_loader: DataLoader<StreamByIdLoader>,
pub stream_variants_by_stream_id_loader: DataLoader<StreamVariantsByStreamIdLoader>,
pub rmq: common::rmq::ConnectionPool,
}

Expand All @@ -41,7 +39,6 @@ impl GlobalState {
session_by_id_loader: SessionByIdLoader::new(db.clone()),
user_permisions_by_id_loader: UserPermissionsByIdLoader::new(db.clone()),
stream_by_id_loader: StreamByIdLoader::new(db.clone()),
stream_variants_by_stream_id_loader: StreamVariantsByStreamIdLoader::new(db.clone()),
db,
rmq,
}
Expand Down
1 change: 1 addition & 0 deletions backend/api/src/gql.nocov.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ mod config;
mod database;
mod dataloader;
mod global;
mod pb;

use api::v1::gql::schema;
use async_graphql::SDLExportOptions;
Expand Down
113 changes: 12 additions & 101 deletions backend/api/src/grpc/api.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,22 +4,18 @@ use std::sync::{Arc, Weak};
use crate::database::{
global_role,
stream::{self, State},
stream_event, stream_variant,
stream_event,
};
use chrono::{Duration, TimeZone, Utc};
use sqlx::{Executor, Postgres, QueryBuilder};
use prost::Message;
use tonic::{async_trait, Request, Response, Status};
use uuid::Uuid;

use super::pb::scuffle::{
backend::{
api_server,
update_live_stream_request::{event::Level, update::Update},
AuthenticateLiveStreamRequest, AuthenticateLiveStreamResponse, LiveStreamState,
NewLiveStreamRequest, NewLiveStreamResponse, UpdateLiveStreamRequest,
UpdateLiveStreamResponse,
},
types::StreamVariant,
use crate::pb::scuffle::backend::{
api_server,
update_live_stream_request::{event::Level, update::Update},
AuthenticateLiveStreamRequest, AuthenticateLiveStreamResponse, LiveStreamState,
NewLiveStreamRequest, NewLiveStreamResponse, UpdateLiveStreamRequest, UpdateLiveStreamResponse,
};

type Result<T> = std::result::Result<T, Status>;
Expand All @@ -38,88 +34,6 @@ impl ApiServer {
pub fn into_service(self) -> api_server::ApiServer<Self> {
api_server::ApiServer::new(self)
}

async fn insert_stream_variants<'c, T: Executor<'c, Database = Postgres>>(
tx: T,
stream_id: Uuid,
variants: &Vec<StreamVariant>,
) -> Result<()> {
// Insert the new stream variants
let mut values = Vec::new();

// Unfortunately, we can't use the `sqlx::query!` macro here because it doesn't support
// batch inserts. So we have to build the query manually. This is a bit of a pain, because
// the query is not compile time checked, but it's better than nothing.
let mut query_builder = QueryBuilder::new(
"
INSERT INTO stream_variants (
id,
stream_id,
name,
video_framerate,
video_height,
video_width,
video_bitrate,
video_codec,
audio_bitrate,
audio_channels,
audio_sample_rate,
audio_codec,
metadata,
created_at
) ",
);

for variant in variants {
let variant_id = variant.id.parse::<Uuid>().map_err(|_| {
Status::invalid_argument("invalid variant ID: must be a valid UUID")
})?;

values.push(stream_variant::Model {
id: variant_id,
stream_id,
name: variant.name.clone(),
video_framerate: variant.video_settings.as_ref().map(|v| v.framerate as i64),
video_height: variant.video_settings.as_ref().map(|v| v.height as i64),
video_width: variant.video_settings.as_ref().map(|v| v.width as i64),
video_bitrate: variant.video_settings.as_ref().map(|v| v.bitrate as i64),
video_codec: variant.video_settings.as_ref().map(|v| v.codec.clone()),
audio_bitrate: variant.audio_settings.as_ref().map(|a| a.bitrate as i64),
audio_channels: variant.audio_settings.as_ref().map(|a| a.channels as i64),
audio_sample_rate: variant
.audio_settings
.as_ref()
.map(|a| a.sample_rate as i64),
audio_codec: variant.audio_settings.as_ref().map(|a| a.codec.clone()),
metadata: serde_json::from_str(&variant.metadata).unwrap_or_default(),
created_at: Utc::now(),
})
}

query_builder.push_values(values, |mut b, variant| {
b.push_bind(variant.id)
.push_bind(variant.stream_id)
.push_bind(variant.name)
.push_bind(variant.video_framerate)
.push_bind(variant.video_height)
.push_bind(variant.video_width)
.push_bind(variant.video_bitrate)
.push_bind(variant.video_codec)
.push_bind(variant.audio_bitrate)
.push_bind(variant.audio_channels)
.push_bind(variant.audio_sample_rate)
.push_bind(variant.audio_codec)
.push_bind(variant.metadata)
.push_bind(variant.created_at);
});

query_builder.build().execute(tx).await.map_err(|e| {
tracing::error!("failed to insert stream variants: {}", e);
Status::internal("internal server error")
})?;

Ok(())
}
}
#[async_trait]
impl api_server::Api for ApiServer {
Expand Down Expand Up @@ -230,8 +144,7 @@ impl api_server::Api for ApiServer {
stream_id: stream.id.to_string(),
record,
transcode,
try_resume: false,
variants: vec![],
variants: None,
}))
}

Expand Down Expand Up @@ -396,11 +309,10 @@ impl api_server::Api for ApiServer {
})?;
}
Update::Variants(v) => {
ApiServer::insert_stream_variants(&mut *tx, stream_id, &v.variants).await?;

sqlx::query!(
"UPDATE streams SET updated_at = NOW() WHERE id = $1",
"UPDATE streams SET updated_at = NOW(), variants = $2 WHERE id = $1",
stream_id,
v.encode_to_vec(),
)
.execute(&mut *tx)
.await
Expand Down Expand Up @@ -485,11 +397,10 @@ impl api_server::Api for ApiServer {
Status::internal("internal server error")
})?;

ApiServer::insert_stream_variants(&mut *tx, stream_id, &request.variants).await?;

sqlx::query!(
"UPDATE streams SET updated_at = NOW() WHERE id = $1",
"UPDATE streams SET updated_at = NOW(), variants = $2 WHERE id = $1",
stream_id,
request.variants.unwrap_or_default().encode_to_vec(),
)
.execute(&mut *tx)
.await
Expand Down
Loading