diff --git a/airsenal/scripts/fill_predictedscore_table.py b/airsenal/scripts/fill_predictedscore_table.py index 408f0e76..63df054d 100644 --- a/airsenal/scripts/fill_predictedscore_table.py +++ b/airsenal/scripts/fill_predictedscore_table.py @@ -171,6 +171,11 @@ def calc_all_predicted_points( dbsession=dbsession, ) for p in predictions: + # check if db uri contains postgresql + if "postgresql" in dbsession.bind.url.drivername: + # check if the predicted_points is a float or jaxlib ArrayImpl + if hasattr(p.predicted_points, "shape"): + p.predicted_points = p.predicted_points.tolist() dbsession.add(p) dbsession.commit() print("Finished adding predictions to db")