Skip to content

Commit

Permalink
refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
scrungus committed Jul 16, 2024
1 parent 2eddd7e commit b50b40d
Show file tree
Hide file tree
Showing 3 changed files with 309 additions and 181 deletions.
178 changes: 178 additions & 0 deletions coral_credits/api/db_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,178 @@
from django.db.models import F
from django.shortcuts import get_object_or_404
from django.utils import timezone

from coral_credits.api import models, db_exceptions


def get_current_lease(current_lease):
current_consumer = get_object_or_404(
models.Consumer, consumer_uuid=current_lease.lease_id
)
current_resource_requests = models.CreditAllocationResource.objects.filter(
consumer=current_consumer,
)
return current_consumer, current_resource_requests


def get_resource_provider_account(project_id):
resource_provider_account = models.ResourceProviderAccount.objects.get(
project_id=project_id
)
return resource_provider_account


def get_credit_allocations(resource_provider_account):
# Find all associated active CreditAllocations
# Make sure we only look for CreditAllocations valid for the current time
now = timezone.now()
credit_allocations = models.CreditAllocation.objects.filter(
account=resource_provider_account.account, start__lte=now, end__gte=now
).order_by("-start")

return credit_allocations


def get_credit_allocation_resources(credit_allocations, resource_classes):
"""
Returns a dictionary of the form:
{
"resource_class": "credit_resource_allocation"
}
"""
resource_allocations = {}
for credit_allocation in credit_allocations:
for resource_class in resource_classes:
credit_allocation_resource = models.CreditAllocationResource.objects.filter(
allocation=credit_allocation, resource_class=resource_class
).first()
if not credit_allocation_resource:
raise db_exceptions.NoCreditAllocation(
f"No credit allocated for resource_type {resource_class}"
)
resource_allocations[resource_class] = credit_allocation_resource
return resource_allocations


def get_resource_requests(lease, current_resource_requests=None):
"""
Returns a dictionary of the form:
{
"resource_class": "resource_hours"
}
"""
resource_requests = {}

for reservation in lease.reservations:
for (
resource_type,
amount,
) in reservation.resource_requests.inventories.data.items():
resource_class = get_object_or_404(models.ResourceClass, name=resource_type)
try:
# Keep it simple, ust take min for now
# TODO(tylerchristie): check we can allocate max
# CreditAllocationResource is a record of the number of resource_hours
# available for one unit of a ResourceClass, so we multiply
# lease_duration by units required.
requested_resource_hours = round(
float(amount["total"]) * reservation.min * lease.duration,
1,
)
if current_resource_requests:
delta_resource_hours = calculate_delta_resource_hours(
requested_resource_hours,
current_resource_requests,
resource_class,
)
else:
delta_resource_hours = requested_resource_hours

resource_requests[resource_class] = delta_resource_hours

except KeyError:
raise db_exceptions.ResourceRequestFormatError(
f"Unable to recognize {resource_type} format {amount}"
)

return resource_requests


def calculate_delta_resource_hours(
requested_resource_hours, current_resource_requests, resource_class
):
# Case: user requests the same resource
current_resource_request = current_resource_requests.filter(
resource_class=resource_class
).first()
if current_resource_request:
current_resource_hours = current_resource_request.resource_hours
return requested_resource_hours - current_resource_hours
# Case: user requests a new resource
return requested_resource_hours


def check_credit_allocations(resource_requests, credit_allocations):
"""
Subtracts resources requested from credit allocations and ensures all results are non-negative.
"""
result = {}
for resource_class in credit_allocations:
result[resource_class] = (
credit_allocations[resource_class].resource_hours
- resource_requests[resource_class]
)

if result[resource_class] < 0:
raise db_exceptions.InsufficientCredits(
f"Insufficient {resource_class.name} credits available. "
f"Requested:{resource_requests[resource_class]}, "
f"Available:{credit_allocations[resource_class]}"
)

return result


def check_credit_balance(credit_allocations, resource_requests):
# TODO(tylerchristie) Fresh DB query
credit_allocation_resources = get_credit_allocation_resources(
credit_allocations, resource_requests.keys()
)
for allocation in credit_allocation_resources.values():

if allocation.resource_hours < 0:
# We raise an exception so the rollback is handled
raise db_exceptions.InsufficientCredits(
(
f"Insufficient "
f"{allocation.resource_class.name} "
f"credits after allocation."
)
)


def spend_credits(
lease, resource_provider_account, context, resource_requests, credit_allocations
):

consumer = models.Consumer.objects.create(
consumer_ref=lease.lease_name,
consumer_uuid=lease.lease_id,
resource_provider_account=resource_provider_account,
user_ref=context.user_id,
start=lease.start_date,
end=lease.end_time,
)

for resource_class in resource_requests:
models.ResourceConsumptionRecord.objects.create(
consumer=consumer,
resource_class=resource_class,
resource_hours=resource_requests[resource_class],
)
# Subtract expenditure from CreditAllocationResource
credit_allocations[resource_class].resource_hours = (
credit_allocations[resource_class].resource_hours
- resource_requests[resource_class]
)
credit_allocations[resource_class].save()
84 changes: 73 additions & 11 deletions coral_credits/api/serializers.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from rest_framework import serializers

from coral_credits.api import models
from coral_credits.api.business_objects import *


class ResourceClassSerializer(serializers.HyperlinkedModelSerializer):
Expand Down Expand Up @@ -60,38 +61,45 @@ class Meta:
fields = ["consumer_ref", "resource_provider", "start", "end", "resources"]


class ContextSerializer(serializers.Serializer):
user_id = serializers.UUIDField()
project_id = serializers.UUIDField()
auth_url = serializers.URLField()
region_name = serializers.CharField()


class InventorySerializer(serializers.Serializer):
def to_representation(self, instance):
return {key: value for key, value in instance.items()}
return instance.data

def to_internal_value(self, data):
return data

def create(self, validated_data):
return Inventory(data=validated_data)


class ResourceRequestSerializer(serializers.Serializer):
inventories = InventorySerializer()
# TODO(tylerchristie)
# resource_provider_generation = serializers.IntegerField(required=False)

def to_representation(self, instance):
return {key: value for key, value in instance.items()}
return {key: value for key, value in instance.__dict__.items()}

def to_internal_value(self, data):
return data

def create(self, validated_data):
inventories = InventorySerializer().create(validated_data.pop("inventories"))
return ResourceRequest(inventories=inventories)


class AllocationSerializer(serializers.Serializer):
id = serializers.CharField()
hypervisor_hostname = serializers.UUIDField()
extra = serializers.DictField()

def create(self, validated_data):
return Allocation(
id=validated_data["id"],
hypervisor_hostname=validated_data["hypervisor_hostname"],
extra=validated_data.get("extra", {}),
)


class ReservationSerializer(serializers.Serializer):
resource_type = serializers.CharField()
Expand All @@ -104,6 +112,20 @@ class ReservationSerializer(serializers.Serializer):
)
resource_requests = ResourceRequestSerializer()

def create(self, validated_data):
allocations = [
AllocationSerializer().create(alloc)
for alloc in validated_data.get("allocations", [])
]
resource_requests = ResourceRequestSerializer().create(
validated_data.pop("resource_requests")
)
return Reservation(
**validated_data,
allocations=allocations,
resource_requests=resource_requests,
)


class LeaseSerializer(serializers.Serializer):
lease_id = serializers.UUIDField()
Expand All @@ -112,18 +134,58 @@ class LeaseSerializer(serializers.Serializer):
end_time = serializers.DateTimeField()
reservations = serializers.ListField(child=ReservationSerializer())

def create(self, validated_data):
reservations = [
ReservationSerializer().create(res)
for res in validated_data.pop("reservations")
]
return Lease(
lease_id=validated_data["lease_id"],
lease_name=validated_data["lease_name"],
start_date=validated_data["start_date"],
end_time=validated_data["end_time"],
reservations=reservations,
)


class ContextSerializer(serializers.Serializer):
user_id = serializers.UUIDField()
project_id = serializers.UUIDField()
auth_url = serializers.URLField()
region_name = serializers.CharField()

def create(self, validated_data):
return Context(
user_id=validated_data["user_id"],
project_id=validated_data["project_id"],
auth_url=validated_data["auth_url"],
region_name=validated_data["region_name"],
)


class ConsumerRequest(serializers.Serializer):
class ConsumerRequestSerializer(serializers.Serializer):
def __init__(self, *args, current_lease_required=False, **kwargs):
super().__init__(*args, **kwargs)
# current_lease required on update but not create
# Optional field current_lease
self.fields["current_lease"] = LeaseSerializer(
required=current_lease_required, allow_null=(not current_lease_required)
)

context = ContextSerializer()
lease = LeaseSerializer()

def create(self, validated_data):
context = ContextSerializer().create(validated_data["context"])
lease = LeaseSerializer().create(validated_data["lease"])
current_lease = (
LeaseSerializer().create(validated_data["current_lease"])
if "current_lease" in validated_data
else None
)
return ConsumerRequest(
context=context, lease=lease, current_lease=current_lease
)

def to_internal_value(self, data):
# Custom validation or processing can be added here if needed
return super().to_internal_value(data)
Expand Down
Loading

0 comments on commit b50b40d

Please sign in to comment.