Skip to content

Commit 2f7147b

Browse files
authored
Updated the Python interface after recent changes to TreePPL (#7)
* Updated the serialization and deserialization of TreePPL matrices. * Updated TreePPL examples after recent changes to the TreePPL syntax.
1 parent c780fd0 commit 2f7147b

File tree

5 files changed

+12
-40
lines changed

5 files changed

+12
-40
lines changed

examples/coin.tppl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
model function coin(outcomes: Bool[]): Real {
1+
model function coin(outcomes: Bool[]) => Real {
22
assume p ~ Uniform(0.0, 1.0);
33
for i in 1 to (length(outcomes)) {
44
observe outcomes[i] ~ Bernoulli(p);

examples/crbd.tppl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ function walk(node: Tree, time:Real, lambda: Real, mu: Real) {
3030
}
3131
}
3232

33-
model function crbd(tree: Tree): Real[] {
33+
model function crbd(tree: Tree) => Real[] {
3434
assume lambda ~ Gamma(1.0, 1.0);
3535
assume mu ~ Gamma(1.0, 0.5);
3636
walk(tree.left, tree.age, lambda, mu);

examples/generative_crbd.tppl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
model function generativeCrbd(time: Real, lambda: Real, mu: Real): Tree {
1+
model function generativeCrbd(time: Real, lambda: Real, mu: Real) => Tree {
22
assume waitingTime ~ Exponential(lambda + mu);
33
let eventTime = time - waitingTime;
44
if eventTime < 0.0 {

examples/treeppl_in_jupyter.ipynb

Lines changed: 4 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@
7373
"source": [
7474
"%%treeppl flip samples=10\n",
7575
"\n",
76-
"model function flip(): Bool {\n",
76+
"model function flip() => Bool {\n",
7777
" assume p ~ Bernoulli(0.5);\n",
7878
" return p;\n",
7979
"}"
@@ -102,35 +102,6 @@
102102
"res.samples"
103103
]
104104
},
105-
{
106-
"cell_type": "markdown",
107-
"id": "dcd09ad5",
108-
"metadata": {},
109-
"source": [
110-
"We can dynamically adjust the number of samples as follows:"
111-
]
112-
},
113-
{
114-
"cell_type": "code",
115-
"execution_count": null,
116-
"id": "b5810a7c",
117-
"metadata": {},
118-
"outputs": [],
119-
"source": [
120-
"flip.set_samples(1000)"
121-
]
122-
},
123-
{
124-
"cell_type": "code",
125-
"execution_count": null,
126-
"id": "d961ee99",
127-
"metadata": {},
128-
"outputs": [],
129-
"source": [
130-
"res = flip()\n",
131-
"sns.countplot(y=res.samples)"
132-
]
133-
},
134105
{
135106
"cell_type": "markdown",
136107
"id": "1e839e5b",
@@ -156,7 +127,7 @@
156127
"source": [
157128
"%%treeppl coin samples=100000\n",
158129
"\n",
159-
"model function coin(outcomes: Bool[]): Real {\n",
130+
"model function coin(outcomes: Bool[]) => Real {\n",
160131
" assume p ~ Uniform(0.0, 1.0);\n",
161132
" for i in 1 to (length(outcomes)) {\n",
162133
" observe outcomes[i] ~ Bernoulli(p);\n",
@@ -200,7 +171,7 @@
200171
"source": [
201172
"%%treeppl generative_crbd samples=1\n",
202173
"\n",
203-
"model function generativeCrbd(time: Real, lambda: Real, mu: Real): Tree {\n",
174+
"model function generativeCrbd(time: Real, lambda: Real, mu: Real) => Tree {\n",
204175
" assume waitingTime ~ Exponential(lambda + mu);\n",
205176
" let eventTime = time - waitingTime;\n",
206177
" if eventTime < 0.0 {\n",
@@ -291,7 +262,7 @@
291262
" }\n",
292263
"}\n",
293264
"\n",
294-
"model function crbd(tree: Tree): Real[] {\n",
265+
"model function crbd(tree: Tree) => Real[] {\n",
295266
" assume lambda ~ Gamma(1.0, 1.0);\n",
296267
" assume mu ~ Gamma(1.0, 0.5);\n",
297268
" walk(tree.left, tree.age, lambda, mu);\n",

treeppl/serialization.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,8 @@ def object_hook(dictionary):
2323
return constructor_to_class[constructor](**dictionary.get("__data__", {}))
2424
if "__float__" in dictionary:
2525
return float(dictionary["__float__"])
26-
if "__tensor__" in dictionary:
27-
return np.array(dictionary["__tensor__"]).reshape(dictionary["__tensorShape__"])
26+
if set(dictionary.keys()) == {"m", "n", "arr"}:
27+
return np.array(dictionary["arr"]).reshape(dictionary["m"], dictionary["n"])
2828
return dictionary
2929

3030

@@ -46,8 +46,9 @@ def default(self, obj):
4646
try:
4747
if isinstance(obj, np.ndarray):
4848
return {
49-
"__tensor__": obj.flatten().tolist(),
50-
"__tensorShape__": obj.shape,
49+
"m": obj.shape[0],
50+
"n": obj.shape[1],
51+
"arr": obj.flatten().tolist(),
5152
}
5253
return json.JSONEncoder.default(self, obj)
5354
except TypeError:

0 commit comments

Comments
 (0)