-
Notifications
You must be signed in to change notification settings - Fork 5
Commit
Kafka services and mlflow logs added
- Loading branch information
There are no files selected for viewing
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,29 @@ | ||
from airflow import DAG | ||
from airflow.operators.python_operator import PythonOperator | ||
from datetime import datetime, timedelta | ||
from app.services.backtest_service import run_backtest_by_id | ||
|
||
default_args = { | ||
'owner': 'airflow', | ||
'depends_on_past': False, | ||
'start_date': datetime(2023, 1, 1), | ||
'retries': 1, | ||
'retry_delay': timedelta(minutes=5), | ||
} | ||
|
||
dag = DAG( | ||
'backtest_dag', | ||
default_args=default_args, | ||
description='DAG for running backtests', | ||
schedule_interval=timedelta(days=1), | ||
) | ||
|
||
def run_backtest(task_id, *args, **kwargs): | ||
run_backtest_by_id(task_id) | ||
|
||
run_backtest_task = PythonOperator( | ||
task_id='run_backtest', | ||
python_callable=run_backtest, | ||
op_args=['{{ task_instance.task_id }}'], | ||
dag=dag, | ||
) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,4 @@ | ||
class Config: | ||
SQLALCHEMY_DATABASE_URI = 'db_utl' | ||
SQLALCHEMY_DATABASE_URI = 'postgresql://test_user:password@localhost/test_db' | ||
SQLALCHEMY_TRACK_MODIFICATIONS = False | ||
JWT_SECRET_KEY = 'your_secret_key' |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,39 @@ | ||
from app.models.backtest import Backtest, Result | ||
from app import db | ||
from app.services.kafka_service import kafka_service | ||
from app.services.mlflow_service import mlflow_service | ||
|
||
def run_backtest_by_id(backtest_id): | ||
backtest = Backtest.query.get(backtest_id) | ||
if not backtest: | ||
return | ||
|
||
# Simulate backtest processing | ||
result = Result( | ||
backtest_id=backtest_id, | ||
total_return=10.5, | ||
number_of_trades=20, | ||
winning_trades=15, | ||
losing_trades=5, | ||
max_drawdown=3.5, | ||
sharpe_ratio=1.8 | ||
) | ||
db.session.add(result) | ||
db.session.commit() | ||
|
||
# Log metrics to MLflow | ||
metrics = { | ||
"total_return": result.total_return, | ||
"number_of_trades": result.number_of_trades, | ||
"winning_trades": result.winning_trades, | ||
"losing_trades": result.losing_trades, | ||
"max_drawdown": result.max_drawdown, | ||
"sharpe_ratio": result.sharpe_ratio | ||
} | ||
mlflow_service.log_metrics(run_name=f"Backtest_{backtest_id}", metrics=metrics) | ||
|
||
# Publish result to Kafka | ||
kafka_service.produce('backtest_results', { | ||
"backtest_id": backtest_id, | ||
"metrics": metrics | ||
}) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,30 @@ | ||
from confluent_kafka import Producer, Consumer, KafkaException | ||
import json | ||
|
||
class KafkaService: | ||
def __init__(self, brokers): | ||
self.producer = Producer({'bootstrap.servers': brokers}) | ||
self.consumer = Consumer({ | ||
'bootstrap.servers': brokers, | ||
'group.id': 'backtest_group', | ||
'auto.offset.reset': 'earliest' | ||
}) | ||
|
||
def produce(self, topic, message): | ||
self.producer.produce(topic, key=None, value=json.dumps(message)) | ||
self.producer.flush() | ||
|
||
def consume(self, topic, callback): | ||
self.consumer.subscribe([topic]) | ||
while True: | ||
msg = self.consumer.poll(timeout=1.0) | ||
if msg is None: | ||
continue | ||
if msg.error(): | ||
if msg.error().code() == KafkaError._PARTITION_EOF: | ||
continue | ||
else: | ||
raise KafkaException(msg.error()) | ||
callback(json.loads(msg.value())) | ||
|
||
kafka_service = KafkaService(brokers='localhost:9092') |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
import mlflow | ||
import mlflow.sklearn | ||
|
||
class MLflowService: | ||
def __init__(self, tracking_uri): | ||
mlflow.set_tracking_uri(tracking_uri) | ||
|
||
def log_metrics(self, run_name, metrics): | ||
with mlflow.start_run(run_name=run_name): | ||
for key, value in metrics.items(): | ||
mlflow.log_metric(key, value) | ||
|
||
mlflow_service = MLflowService(tracking_uri='http://localhost:5000') |