Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,27 @@ public class AzPubSubConfig extends AbstractConfig {
"",
Importance.MEDIUM,
DSTS_METADATA_FILE_DOC)
;
.define("azpubsub.topic.max.qps",
Type.INT,
1000,
Importance.MEDIUM,
"Topic Qps")
.define("azpubsub.qps.throttling.level",
Type.INT,
0,
Importance.MEDIUM,
"Topic Qps throttling level. 0: throttling is disabled; 1: throttling at topic level; 2: throttling at clientId + topic level.")
.define("azpubsub.clientid.topic.max.qps",
Type.INT,
1000,
Importance.MEDIUM,
"Topic Qps")
.define("azpubsub.timer.task.execution.interval.in.ms",
Type.LONG,
300000,
Importance.MEDIUM,
"The interval of timer background task executions, in milliseconds")
;
}

public static AzPubSubConfig fromProps(Map<String, ?> configProviderProps) {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
package com.microsoft.azpubsub.security.auth;

import java.util.TimerTask;

public class ThreadCounterTimerTask extends TimerTask {
private String topicName = null;
private String clientId = null;
private int threadCount = 0;
private Long ioThreadId = 0L;
private int throttlingLevel = 0;
private Long intervalInMs = 300000L;
private TopicThreadCounter topicThreadCounterInstance = null;

public ThreadCounterTimerTask(long interval, int level) {
intervalInMs = interval;
throttlingLevel = level;
topicThreadCounterInstance = TopicThreadCounter.getInstance(this.intervalInMs, this.throttlingLevel);
}

public void setTopicName(String topic) {
topicName = topic;
}

public void setClientId(String client) { clientId = client; }

public void setIoThreadId(Long threadId) {
ioThreadId = threadId;
}

public int getThreadCount() {
return threadCount;
}

public void setThrottlingLevel (int level) { throttlingLevel = level; }
public int getThrottlingLevel () { return throttlingLevel; }

@Override
public void run() {
if(null != topicName) {
threadCount = topicThreadCounterInstance.add(topicName, System.currentTimeMillis(), ioThreadId, clientId);
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
package com.microsoft.azpubsub.security.auth;

import java.util.*;
import java.util.concurrent.ConcurrentHashMap;

public class TopicThreadCounter {
private static TopicThreadCounter instance = null;
private static Object lock = new Object();
private Long interval = null;
private int throttlingLevel = 0;
ConcurrentHashMap<String, TreeMap<Long, Long>> topicThreadMap = new ConcurrentHashMap<>();

public TopicThreadCounter(Long intvl, int level) {
interval = intvl;
throttlingLevel = level;
}

static TopicThreadCounter getInstance(Long interval, int level) {
synchronized (lock) {
if(null == instance) {
instance = new TopicThreadCounter(interval, level);
}
}
return instance;
}

public int add(String topic, Long currentTimeInMs, Long threadId, String clientId) {
String key = TopicThreadCounter.makeKey(this.throttlingLevel, topic, clientId, threadId);
if(!topicThreadMap.containsKey(key)) {
topicThreadMap.put(key, new TreeMap<>());
}
topicThreadMap.get(key).put(currentTimeInMs, threadId);
NavigableMap<Long, Long> subMap= topicThreadMap.get(key).tailMap(currentTimeInMs - interval, false);
HashSet<Long> hs = new HashSet<>();
for(Map.Entry element: subMap.entrySet()) {
hs.add((Long)element.getValue());
}
topicThreadMap.put(key, new TreeMap<>(subMap));
return hs.size();
}

public static String makeKey(int throttlingLevel, String topic, String clientId, Long threaId) {
if(1 == throttlingLevel) return String.format("ClientId:%s|ThreadId:%d|Topic:%s", clientId, threaId, topic);
else if(2 == throttlingLevel) return String.format("ClientId:%s|ThreadId:%d", clientId, threaId);
else if(3 == throttlingLevel) return String.format("Topic:%s|ThreadId:%d", topic, threaId);
return topic;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
package com.microsoft.azpubsub.security.auth

import java.net.InetAddress
import java.util
import java.util.Timer
import java.util.concurrent._

import com.yammer.metrics.core.{Meter, MetricName}
import kafka.metrics.KafkaMetricsGroup
import kafka.security.authorizer.AclAuthorizer
import kafka.utils.Logging
import org.apache.kafka.common.security.auth.{KafkaPrincipal, SecurityProtocol}
import org.apache.kafka.server.authorizer.{Action, AuthorizableRequestContext, AuthorizationResult}

import scala.collection.JavaConverters.asScalaSetConverter
import scala.collection.mutable

/*
* AzPubSub ACL Authorizer to handle the certificate & role based principal type
*/

object AzPubSubAclAuthorizerV2 {
val configThrottlingLevel = "azpubsub.qps.throttling.level"
val configTopicThrottlingQps = "azpubsub.topic.max.qps"
val configClientidTopicThrottlingQps = "azpubsub.clientid.topic.max.qps"
val configMetterSuccessRatePerSec = "AuthorizerSuccessPerSec"
val configMetterFailureRatePerSec = "AuthorizerFailurePerSec"
val configTimerTaskExecutionInterval = "azpubsub.timer.task.execution.interval.in.ms"
}


class AzPubSubAclAuthorizerV2 extends AclAuthorizer with Logging with KafkaMetricsGroup {
override def metricName(name: String, metricTags: scala.collection.Map[String, String]): MetricName = {
explicitMetricName("azpubsub.security", "AuthorizerMetrics", name, metricTags)
}

override def configure(javaConfigs: util.Map[String, _]): Unit = {
super.configure(javaConfigs)
config = AzPubSubConfig.fromProps(javaConfigs)
throttlingLevel = config.getInt(AzPubSubAclAuthorizerV2.configThrottlingLevel)
throttlingTopicQps = config.getInt(AzPubSubAclAuthorizerV2.configTopicThrottlingQps)
throttlingClientIdTopicQps = config.getInt(AzPubSubAclAuthorizerV2.configTopicThrottlingQps)
timerTaskExecutionInterval = config.getLong(AzPubSubAclAuthorizerV2.configTimerTaskExecutionInterval)
topicThreadCounterTimerTask = new ThreadCounterTimerTask(timerTaskExecutionInterval, throttlingLevel)
trigger.scheduleAtFixedRate(topicThreadCounterTimerTask, 2, timerTaskExecutionInterval)
topicThreadCounterTimerTask.setIoThreadId(Thread.currentThread().getId)
}

private var config: AzPubSubConfig = null
private var throttlingLevel: Int = 0
private var throttlingTopicQps: Int = 0
private var throttlingClientIdTopicQps: Int = 0
private var timerTaskExecutionInterval: Long = 0
private val successRate: Meter = newMeter(AzPubSubAclAuthorizerV2.configMetterSuccessRatePerSec, "success", TimeUnit.SECONDS)
private val failureRate: Meter = newMeter(AzPubSubAclAuthorizerV2.configMetterFailureRatePerSec, "failure", TimeUnit.SECONDS)
private var topicThreadCounterTimerTask : ThreadCounterTimerTask = null
private val trigger: Timer = new Timer(true)

var ints = new mutable.HashMap[String, mutable.TreeSet[Long]]

override def authorize(requestContext: AuthorizableRequestContext, actions: util.List[Action]): util.List[AuthorizationResult] = {
if(throttlingLevel > 0) {
actions.forEach( a => if( a.resourcePattern().resourceType() == org.apache.kafka.common.resource.ResourceType.TOPIC) {
topicThreadCounterTimerTask.setTopicName(a.resourcePattern.name())
topicThreadCounterTimerTask.setClientId(requestContext.clientId())

val key = makeKey(a.resourcePattern().name(), requestContext.clientId())
if(!ints.contains(key)) ints.put(key, new mutable.TreeSet[Long]())

val threadCount = Math.max(topicThreadCounterTimerTask.getThreadCount, 1)

var count = 1
while (throttlingLevel == 1 && ints.get(key).size * threadCount > throttlingTopicQps || throttlingLevel == 2 && ints.get(key).size * threadCount > throttlingClientIdTopicQps) {
val pivot = ints.get(key).get.minBy(x => x > System.currentTimeMillis - 1000)
val (_, after) = ints.get(key).get.partition(x => x > pivot)
ints.put(key, after)
Thread.sleep(count)
count *= 2
}

ints.get(key).get += System.currentTimeMillis
} )
}

var res : util.List[AuthorizationResult] = null

if (requestContext.principal().getClass == classOf[AzPubSubPrincipal]) {
val tmpPrincipal = requestContext.principal().asInstanceOf[AzPubSubPrincipal]
for(role <- tmpPrincipal.getRoles.asScala) {
val context : AuthorizableRequestContext = new AuthorizableRequestContext {
override def listenerName(): String = requestContext.listenerName()

override def securityProtocol(): SecurityProtocol = requestContext.securityProtocol()

override def principal(): KafkaPrincipal = {
new KafkaPrincipal(tmpPrincipal.getPrincipalType, role)
}

override def clientAddress(): InetAddress = requestContext.clientAddress()

override def requestType(): Int = requestContext.requestType()

override def requestVersion(): Int = requestContext.requestVersion()

override def clientId(): String = requestContext.clientId()

override def correlationId(): Int = requestContext.correlationId()
}
res = super.authorize(context, actions)
if (res.contains(AuthorizationResult.ALLOWED) ) {
successRate.mark()
res
}
}
failureRate.mark()
return res
}
res = super.authorize(requestContext, actions)
if(null != res && res.contains(AuthorizationResult.ALLOWED)) {
successRate.mark()
}
else {
failureRate.mark()
}
return res
}

private def makeKey(topic: String, clientId: String) : String = {
throttlingLevel match {
case 1 => String.format("ClientId:%s|Topic:%s", topic, clientId)
case 2 => String.format("ClientId:%s", clientId)
case _ => String.format("Topic:%s", topic)
}
}
}
19 changes: 10 additions & 9 deletions config/log4j.properties
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@
# limitations under the License.

# Unspecified loggers and loggers with additivity=true output to server.log and stdout
# Note that INFO only applies to unspecified loggers, the log level of the child logger is used otherwise
log4j.rootLogger=INFO, stdout, kafkaAppender
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This will be reverted.

# Note that DEBUG only applies to unspecified loggers, the log level of the child logger is used otherwise
log4j.rootLogger=DEBUG, stdout, kafkaAppender

log4j.appender.stdout=org.apache.log4j.ConsoleAppender
log4j.appender.stdout.layout=org.apache.log4j.PatternLayout
Expand Down Expand Up @@ -57,12 +57,13 @@ log4j.appender.authorizerAppender.File=${kafka.logs.dir}/kafka-authorizer.log
log4j.appender.authorizerAppender.layout=org.apache.log4j.PatternLayout
log4j.appender.authorizerAppender.layout.ConversionPattern=[%d] %p %m (%c)%n

# Change the line below to adjust ZK client logging
log4j.logger.org.apache.zookeeper=INFO
# Change the two lines below to adjust ZK client logging
log4j.logger.org.I0Itec.zkclient.ZkClient=DEBUG
log4j.logger.org.apache.zookeeper=DEBUG

# Change the two lines below to adjust the general broker logging level (output to server.log and stdout)
log4j.logger.kafka=INFO
log4j.logger.org.apache.kafka=INFO
log4j.logger.kafka=DEBUG
log4j.logger.org.apache.kafka=DEBUG

# Change to DEBUG or TRACE to enable request logging
log4j.logger.kafka.request.logger=WARN, requestAppender
Expand All @@ -79,7 +80,7 @@ log4j.additivity.kafka.network.RequestChannel$=false
log4j.logger.kafka.controller=TRACE, controllerAppender
log4j.additivity.kafka.controller=false

log4j.logger.kafka.log.LogCleaner=INFO, cleanerAppender
log4j.logger.kafka.log.LogCleaner=DEBUG, cleanerAppender
log4j.additivity.kafka.log.LogCleaner=false

log4j.logger.kafka.log.SkimpyOffsetMap=INFO, cleanerAppender
Expand All @@ -88,7 +89,7 @@ log4j.additivity.kafka.log.SkimpyOffsetMap=false
log4j.logger.state.change.logger=TRACE, stateChangeAppender
log4j.additivity.state.change.logger=false

# Access denials are logged at INFO level, change to DEBUG to also log allowed accesses
log4j.logger.kafka.authorizer.logger=INFO, authorizerAppender
# Access denials are logged at DEBUG level, change to DEBUG to also log allowed accesses
log4j.logger.kafka.authorizer.logger=DEBUG, authorizerAppender
log4j.additivity.kafka.authorizer.logger=false

Loading