Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 11 additions & 7 deletions Dockerfile
Original file line number Diff line number Diff line change
@@ -1,11 +1,15 @@
FROM python:3.11
FROM python:3.13

RUN pip install gunicorn uvicorn[standard]
WORKDIR /app

RUN mkdir /app
COPY pyproject.toml LICENSE README.md /app/
COPY src /app/src/
COPY pyproject.toml LICENSE README.md ./

RUN pip install /app
COPY src ./src

RUN pip install uv
RUN uv pip install --system --compile-bytecode .

RUN ls -la /app/src/model_fit_api/app.py

CMD ["gunicorn", "--bind", "0.0.0.0:80", "--workers", "4", "--worker-class", "uvicorn.workers.UvicornWorker", "model_fit_api.app:app"]

CMD ["gunicorn", "--bind", "0.0.0.0:80", "--workers", "4", "--worker-class", "uvicorn.workers.UvicornWorker", "model_fit_api.app:app"]
18 changes: 18 additions & 0 deletions docker-compose.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
services:
model_fit_api:
build: .
ports:
- "80"
networks:
- proxy
environment:
VIRTUAL_HOST: fit.lc.snad.space
HTTPS_METHOD: noredirect
DYNDNS_HOST: fit.lc.snad.space
LETSENCRYPT_HOST: fit.lc.snad.space
LETSENCRYPT_EMAIL: [email protected]
restart: always

networks:
proxy:
external: true
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,8 @@ dependencies = [
"pandas",
"pydantic",
"sncosmo",
"uvicorn[standard]", # Used to run the debug server
"gunicorn",
"uvicorn[standart]",
]

# On a mac, install optional dependencies with `pip install '.[dev]'` (include the single quotes)
Expand Down
126 changes: 80 additions & 46 deletions src/model_fit_api/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,104 +4,138 @@
import numpy as np
import pandas as pd
import sncosmo
import math
from astropy.table import Table
from fastapi import FastAPI
from pydantic import BaseModel

models = [
"nugent-sn1a",
"nugent-sn91t",
"nugent-sn91bg",
"nugent-sn1bc",
"nugent-hyper",
"nugent-sn2n",
"nugent-sn2p",
"nugent-sn2l",
"salt2",
"salt3-nir",
"salt3",
"snf-2011fe",
"v19-1993j",
"v19-1998bw",
"v19-1999em",
"v19-2009ip",
]

app = FastAPI()


class Observation(BaseModel):
mjd: float
band: str
flux: float
fluxerr: float
zp: float = 8.9
zpsys: Literal["ab", "vega"] = "ab"
band: str


class Target(BaseModel):
light_curve: List[Observation]
ebv: float
t_min: float
t_max: float
count: int
name_model: str
redshift: List[float]


class Model_data(BaseModel):
parameters: Dict[str, float]
name_model: str
zp: float
zpsys: str
band_list: List[str]
t_min: float
t_max: float
count: int
brightness_type: str
band_ref: Dict[str, float]

class Point(BaseModel):
time: float
flux: float
bright: float
band: str


class Result(BaseModel):
flux_jansky: List[Point]
class Parameters(BaseModel):
degrees_of_freedom: int
covariance: List[List[float]]
chi2: float
parameters: Dict[str, float]


class Bright(BaseModel):
bright: List[Point]


def fit(data, name_model, ebv, redshift):
dust = sncosmo.CCM89Dust()
model = sncosmo.Model(source=name_model, effects=[dust], effect_names=["mw"], effect_frames=["obs"])
model.set(mwebv=ebv)
fit_params = model.param_names
fit_params.remove('mwr_v')
fit_params.remove('mwebv')
summary, fitted_model = sncosmo.fit_lc(
data, model, model.param_names, bounds={"z": (redshift[0], redshift[1])}
data, model, fit_params, bounds={"z": (redshift[0], redshift[1])}
)
return summary, fitted_model


def get_flux_and_params(summary, data, fitted_model, t_min, t_max, count):
segment = np.linspace(t_min, t_max, count)
df = data.to_pandas()
def get_bright(data: Model_data):
dust = sncosmo.CCM89Dust()
fitted_model = sncosmo.Model(source=data.name_model, effects=[dust], effect_names=["mw"], effect_frames=["obs"])
fitted_model.set(**data.parameters)
segment = np.linspace(data.t_min, data.t_max, data.count)
points = []
for band in df["band"].unique():
predicts = fitted_model.bandflux(band, segment, df["zp"][0], df["zpsys"][0])
points += [Point(time=time, flux=flux, band=band) for time, flux in zip(segment, predicts)]
return Result(
flux_jansky=points,
for band in data.band_list:
predicts = fitted_model.bandflux(band, segment, data.zp, data.zpsys)
if data.brightness_type == 'flux':
predicts = [f + data.band_ref[band[0]+band[-1]] for f in predicts]
elif data.brightness_type == 'diffmag':
predicts = [math.log(f)/math.log(10)*(-2.5) + 8.9 if f > 0 else None for f in predicts]
elif data.brightness_type == "mag":
predicts = [math.log(f + data.band_ref[band[0]+band[-1]])/math.log(10)*(-2.5) + 8.9 if f + data.band_ref[band[0]+band[-1]] > 0 else None for f in predicts]
points += [Point(time=time, bright=flux, band=band) for time, flux in zip(segment, predicts)]
return Bright(
bright=points
)


def get_params(data: Target):
df = pd.DataFrame([dict(obs) for obs in data.light_curve])
table = Table.from_pandas(df)
summary, fitted_model = fit(table, data.name_model, data.ebv, data.redshift)
try: cov=summary.covariance.tolist()
except:
cov=[[]]
print('covariance is none')
return Parameters(
parameters=dict(zip(summary.param_names, summary.parameters)),
degrees_of_freedom=summary.ndof,
covariance=summary.covariance.tolist(),
covariance=cov,
chi2=summary.chisq,
)


def approximate(data: Target):
df = pd.DataFrame([obs.model_dump() for obs in data.light_curve])
table = Table.from_pandas(df)
summary, fitted_model = fit(table, data.name_model, data.ebv, data.redshift)
result = get_flux_and_params(summary, table, fitted_model, data.t_min, data.t_max, data.count)
return result

@app.post("/api/v1/sncosmo/fit")
async def sn_cosmo_fit(data: Target):
"""Fit light curve with sncosmo."""
return get_params(data)


@app.post("/api/v1/sncosmo")
async def sn_cosmo(data: Target):
@app.post("/api/v1/sncosmo/get_curve")
async def sn_cosmo_get_curve(data: Model_data):
"""Fit light curve with sncosmo."""
return approximate(data)
return get_bright(data)


@app.get("/api/v1/models")
async def models(data: Target):
async def models():
models = [
"nugent-sn1a",
"nugent-sn91t",
"nugent-sn91bg",
"nugent-sn1bc",
"nugent-hyper",
"nugent-sn2n",
"nugent-sn2p",
"nugent-sn2l",
"salt2",
"salt3-nir",
"salt3",
"v19-1993j",
"v19-1998bw",
"v19-1999em",
"v19-2009ip",
]
return {"models": models}
Loading