Skip to content

Commit

Permalink
Custom headers in cors
Browse files Browse the repository at this point in the history
  • Loading branch information
getroot committed Dec 5, 2024
1 parent a57fa45 commit f97e3e4
Show file tree
Hide file tree
Showing 13 changed files with 132 additions and 71 deletions.
2 changes: 1 addition & 1 deletion src/projects/api_server/api_server.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ namespace api
bool Server::SetupCORS(const cfg::mgr::api::API &api_config)
{
bool is_cors_parsed;
auto cross_domains = api_config.GetCrossDomainList(&is_cors_parsed);
auto cross_domains = api_config.GetCrossDomains(&is_cors_parsed);

if (is_cors_parsed)
{
Expand Down
45 changes: 44 additions & 1 deletion src/projects/config/items/common/cross_domains.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,57 @@ namespace cfg
{
protected:
std::vector<ov::String> _url_list;
std::vector<cmn::Option> _custom_headers;
std::map<ov::String, ov::String> _custom_header_map;

public:
CFG_DECLARE_CONST_REF_GETTER_OF(GetUrls, _url_list)
CFG_DECLARE_CONST_REF_GETTER_OF(GetCustomHeaders, _custom_headers)

// Set * to allow all domains
void AllowAll()
{
_url_list.clear();
_url_list.emplace_back("*");
}

// Return original case of the key and value
std::optional<std::tuple<ov::String, ov::String>> GetCustomHeader(ov::String key) const
{
auto it = _custom_header_map.find(key.LowerCaseString());
if (it == _custom_header_map.end())
{
return std::nullopt;
}

auto key_value = it->second;
auto key_value_list = key_value.Split(":");
if (key_value_list.size() != 2)
{
return std::nullopt;
}

auto origin_key = key_value_list[0];
auto value = key_value_list[1];

return std::make_tuple(origin_key, value);
}

protected:
void MakeList() override
{
Register<OmitJsonName>("Url", &_url_list);
Register<Optional>("Url", &_url_list);
Register<Optional>({"Header", "headers"}, &_custom_headers, nullptr,
[=]() -> std::shared_ptr<ConfigError> {
for (auto &item : _custom_headers)
{
auto key = item.GetKey().LowerCaseString();
auto key_value = ov::String::FormatString("%s:%s", item.GetKey().CStr(), item.GetValue().CStr());
// To keep the original case of the key
_custom_header_map.emplace(key, key_value);
}
return nullptr;
});
}
};
} // namespace cmn
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
//==============================================================================
#pragma once

#include "default_query_string_item.h"
#include "../../../../common/option.h"

namespace cfg
{
Expand All @@ -21,7 +21,7 @@ namespace cfg
struct DefaultQueryString : public Item
{
protected:
std::vector<DefaultQueryStringItem> _items;
std::vector<cmn::Option> _items;
std::map<ov::String, ov::String> _item_map;

public:
Expand Down

This file was deleted.

75 changes: 67 additions & 8 deletions src/projects/modules/http/cors/cors_manager.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,12 @@

namespace http
{
void CorsManager::SetCrossDomains(const info::VHostAppName &vhost_app_name, const std::vector<ov::String> &url_list)
void CorsManager::SetCrossDomains(const info::VHostAppName &vhost_app_name, const cfg::cmn::CrossDomains &cross_domain_cfg)
{
std::lock_guard lock_guard(_cors_mutex);

auto url_list = cross_domain_cfg.GetUrls();
_cors_cfg_map[vhost_app_name] = cross_domain_cfg;
auto &cors_policy = _cors_policy_map[vhost_app_name];
auto &cors_regex_list = _cors_item_list_map[vhost_app_name];
ov::String cors_rtmp;
Expand Down Expand Up @@ -147,11 +149,14 @@ namespace http
ov::String origin_header = request->GetHeader("ORIGIN");
ov::String cors_header = "";

std::unordered_map<info::VHostAppName, cfg::cmn::CrossDomains>::const_iterator cors_cfg_iterator;

{
std::lock_guard lock_guard(_cors_mutex);

auto cors_policy_iterator = _cors_policy_map.find(vhost_app_name);
auto cors_regex_list_iterator = _cors_item_list_map.find(vhost_app_name);
cors_cfg_iterator = _cors_cfg_map.find(vhost_app_name);

if (
(cors_policy_iterator == _cors_policy_map.end()) ||
Expand Down Expand Up @@ -210,20 +215,74 @@ namespace http
response->SetHeader("Access-Control-Allow-Origin", cors_header);
response->SetHeader("Vary", "Origin");

std::vector<ov::String> method_list;
if (cors_cfg_iterator == _cors_cfg_map.end())
{
// This happens in the following situations:
//
// 1) Request to an application that doesn't exist
// 2) Request while the application is being created
return false;
}

auto &cors_cfg = _cors_cfg_map.at(vhost_app_name);

for (const auto &method : allowed_methods)
// Access-Control-Allow-Credentials
auto custom_header_opt = cors_cfg.GetCustomHeader("Access-Control-Allow-Credentials");
if (custom_header_opt.has_value())
{
response->SetHeader(std::get<0>(custom_header_opt.value()), std::get<1>(custom_header_opt.value()));
}
else
{
method_list.push_back(http::StringFromMethod(method));
response->SetHeader("Access-Control-Allow-Credentials", "true");
}

// Access-Control-Allow-Methods
custom_header_opt = cors_cfg.GetCustomHeader("Access-Control-Allow-Methods");
if (custom_header_opt.has_value())
{
response->SetHeader(std::get<0>(custom_header_opt.value()), std::get<1>(custom_header_opt.value()));
}
else
{
std::vector<ov::String> method_list;
for (const auto &method : allowed_methods)
{
method_list.push_back(http::StringFromMethod(method));
}

if (method_list.empty() == false)
{
response->SetHeader("Access-Control-Allow-Methods", ov::String::Join(method_list, ", "));
}
}

response->SetHeader("Access-Control-Allow-Credentials", "true");
// Access-Control-Allow-Headers
custom_header_opt = cors_cfg.GetCustomHeader("Access-Control-Allow-Headers");
if (custom_header_opt.has_value())
{
response->SetHeader(std::get<0>(custom_header_opt.value()), std::get<1>(custom_header_opt.value()));
}
else
{
response->SetHeader("Access-Control-Allow-Headers", "*");
}

if (method_list.empty() == false)
// Remaining custom headers
auto custom_headers = cors_cfg.GetCustomHeaders();
for (const auto &item : custom_headers)
{
response->SetHeader("Access-Control-Allow-Methods", ov::String::Join(method_list, ", "));
auto key = item.GetKey();
if (key.LowerCaseString() == "access-control-allow-origin" ||
key.LowerCaseString() == "access-control-allow-credentials" ||
key.LowerCaseString() == "access-control-allow-methods" ||
key.LowerCaseString() == "access-control-allow-headers")
{
continue;
}

response->SetHeader(key, item.GetValue());
}
response->SetHeader("Access-Control-Allow-Headers", "*");

return true;
}
Expand Down
5 changes: 4 additions & 1 deletion src/projects/modules/http/cors/cors_manager.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
#include <base/ovlibrary/ovlibrary.h>

#include "../server/http_server.h"
#include "config/items/common/cross_domains.h"

namespace http
{
Expand All @@ -28,7 +29,7 @@ namespace http
// Empty url_list means 'Do not set any CORS header'
//
// NOTE - SetCrossDomains() isn't thread-safe.
void SetCrossDomains(const info::VHostAppName &vhost_app_name, const std::vector<ov::String> &url_list);
void SetCrossDomains(const info::VHostAppName &vhost_app_name, const cfg::cmn::CrossDomains &cross_domain_cfg);

bool SetupRtmpCorsXml(const std::shared_ptr<http::svr::HttpResponse> &response) const;

Expand Down Expand Up @@ -84,6 +85,8 @@ namespace http
// key: VHostAppName, value: regex
std::unordered_map<info::VHostAppName, std::vector<CorsItem>> _cors_item_list_map;

std::unordered_map<info::VHostAppName, cfg::cmn::CrossDomains> _cors_cfg_map;

// CORS for RTMP
//
// NOTE - The RTMP CORS setting follows the first declared <CrossDomains> setting,
Expand Down
9 changes: 2 additions & 7 deletions src/projects/modules/whip/whip_server.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -261,14 +261,9 @@ bool WhipServer::RemoveCertificate(const std::shared_ptr<const info::Certificate
return true;
}

void WhipServer::SetCors(const info::VHostAppName &vhost_app_name, const std::vector<ov::String> &url_list)
void WhipServer::SetCors(const info::VHostAppName &vhost_app_name, const cfg::cmn::CrossDomains &cross_domain_cfg)
{
_cors_manager.SetCrossDomains(vhost_app_name, url_list);
}

void WhipServer::EraseCors(const info::VHostAppName &vhost_app_name)
{
_cors_manager.SetCrossDomains(vhost_app_name, {});
_cors_manager.SetCrossDomains(vhost_app_name, cross_domain_cfg);
}

ov::String WhipServer::GetIceServerLinkValue(const ov::String &URL, const ov::String &username, const ov::String &credential)
Expand Down
3 changes: 1 addition & 2 deletions src/projects/modules/whip/whip_server.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,7 @@ class WhipServer : public ov::EnableSharedFromThis<WhipServer>
bool InsertCertificate(const std::shared_ptr<const info::Certificate> &certificate);
bool RemoveCertificate(const std::shared_ptr<const info::Certificate> &certificate);

void SetCors(const info::VHostAppName &vhost_app_name, const std::vector<ov::String> &url_list);
void EraseCors(const info::VHostAppName &vhost_app_name);
void SetCors(const info::VHostAppName &vhost_app_name, const cfg::cmn::CrossDomains &cross_domain_cfg);

protected:
struct TurnIP
Expand Down
2 changes: 1 addition & 1 deletion src/projects/orchestrator/virtual_host.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ namespace ocst

// CORS
bool is_cors_parsed;
auto cross_domains = _host_info.GetCrossDomainList(&is_cors_parsed);
auto cross_domains = _host_info.GetCrossDomains(&is_cors_parsed);

if (is_cors_parsed)
{
Expand Down
7 changes: 4 additions & 3 deletions src/projects/providers/webrtc/webrtc_provider.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -247,11 +247,12 @@ namespace pvd
if (_whip_server != nullptr)
{
auto webrtc_cfg = application_info.GetConfig().GetProviders().GetWebrtcProvider();
auto cross_domains = webrtc_cfg.GetCrossDomainList();
if (cross_domains.empty())
bool is_parsed;
auto cross_domains = webrtc_cfg.GetCrossDomains(&is_parsed);
if (is_parsed == false)
{
// There is no CORS setting in the WebRTC Provider in the already deployed Server.xml. In this case, provide * to avoid confusion.
cross_domains.push_back("*");
cross_domains.AllowAll();
}
_whip_server->SetCors(application_info.GetVHostAppName(), cross_domains);
}
Expand Down
4 changes: 2 additions & 2 deletions src/projects/publishers/hls/hls_application.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,15 +23,15 @@ HlsApplication::HlsApplication(const std::shared_ptr<pub::Publisher> &publisher,
{
auto ts_config = application_info.GetConfig().GetPublishers().GetHlsPublisher();
bool is_parsed;
const auto &cross_domains = ts_config.GetCrossDomainList(&is_parsed);
const auto &cross_domains = ts_config.GetCrossDomains(&is_parsed);

if (is_parsed)
{
_cors_manager.SetCrossDomains(application_info.GetVHostAppName(), cross_domains);
}
else
{
const auto &default_cross_domains = application_info.GetHostInfo().GetCrossDomainList(&is_parsed);
const auto &default_cross_domains = application_info.GetHostInfo().GetCrossDomains(&is_parsed);
if (is_parsed)
{
_cors_manager.SetCrossDomains(application_info.GetVHostAppName(), default_cross_domains);
Expand Down
4 changes: 2 additions & 2 deletions src/projects/publishers/llhls/llhls_application.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,15 +23,15 @@ LLHlsApplication::LLHlsApplication(const std::shared_ptr<pub::Publisher> &publis
{
auto llhls_config = application_info.GetConfig().GetPublishers().GetLLHlsPublisher();
bool is_parsed;
const auto &cross_domains = llhls_config.GetCrossDomainList(&is_parsed);
const auto &cross_domains = llhls_config.GetCrossDomains(&is_parsed);

if (is_parsed)
{
_cors_manager.SetCrossDomains(application_info.GetVHostAppName(), cross_domains);
}
else
{
const auto &default_cross_domains = application_info.GetHostInfo().GetCrossDomainList(&is_parsed);
const auto &default_cross_domains = application_info.GetHostInfo().GetCrossDomains(&is_parsed);
if (is_parsed)
{
_cors_manager.SetCrossDomains(application_info.GetVHostAppName(), default_cross_domains);
Expand Down
4 changes: 2 additions & 2 deletions src/projects/publishers/thumbnail/thumbnail_application.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,15 @@ ThumbnailApplication::ThumbnailApplication(const std::shared_ptr<pub::Publisher>
auto thumbnail_config = application_info.GetConfig().GetPublishers().GetThumbnailPublisher();

bool is_parsed;
const auto &cross_domains = thumbnail_config.GetCrossDomainList(&is_parsed);
const auto &cross_domains = thumbnail_config.GetCrossDomains(&is_parsed);

if (is_parsed)
{
_cors_manager.SetCrossDomains(application_info.GetVHostAppName(), cross_domains);
}
else
{
const auto &default_cross_domains = application_info.GetHostInfo().GetCrossDomainList(&is_parsed);
const auto &default_cross_domains = application_info.GetHostInfo().GetCrossDomains(&is_parsed);
if (is_parsed)
{
_cors_manager.SetCrossDomains(application_info.GetVHostAppName(), default_cross_domains);
Expand Down

0 comments on commit f97e3e4

Please sign in to comment.