From a659bf56136148b1dc3fbaf2ca105600b0a9884d Mon Sep 17 00:00:00 2001 From: Jasper Date: Sun, 24 Sep 2023 22:32:57 -0700 Subject: [PATCH 1/2] add graph-building capabilities, minor fixes --- hfppl/distributions/bernoulli.py | 1 + hfppl/distributions/geometric.py | 1 + hfppl/distributions/logcategorical.py | 1 + hfppl/util.py | 38 ++++++++++++++++++++++++++- setup.py | 5 +++- 5 files changed, 44 insertions(+), 2 deletions(-) diff --git a/hfppl/distributions/bernoulli.py b/hfppl/distributions/bernoulli.py index 0d9028a..bf2dfc5 100644 --- a/hfppl/distributions/bernoulli.py +++ b/hfppl/distributions/bernoulli.py @@ -1,3 +1,4 @@ +import numpy as np from .distribution import Distribution import numpy as np diff --git a/hfppl/distributions/geometric.py b/hfppl/distributions/geometric.py index aa2be13..e63e07d 100644 --- a/hfppl/distributions/geometric.py +++ b/hfppl/distributions/geometric.py @@ -1,3 +1,4 @@ +import numpy as np from .distribution import Distribution class Geometric(Distribution): diff --git a/hfppl/distributions/logcategorical.py b/hfppl/distributions/logcategorical.py index 5dcd16c..24d165f 100644 --- a/hfppl/distributions/logcategorical.py +++ b/hfppl/distributions/logcategorical.py @@ -1,3 +1,4 @@ +import numpy as np from .distribution import Distribution class LogCategorical(Distribution): diff --git a/hfppl/util.py b/hfppl/util.py index ee01d7e..389c1e5 100644 --- a/hfppl/util.py +++ b/hfppl/util.py @@ -1,6 +1,8 @@ """Utility functions""" import numpy as np +import networkx as nx +import matplotlib.pyplot as plt def logsumexp(nums): m = np.max(nums) @@ -18,4 +20,38 @@ def log_softmax(nums): return nums - logsumexp(nums) def softmax(nums): - return np.exp(log_softmax(nums)) \ No newline at end of file + return np.exp(log_softmax(nums)) + +def build_graph(LLM, node=None, path=None, graph=None, level=0): + if graph is None: + graph = nx.DiGraph() + if node is None: + node = LLM.cache + if path is None: + path = [] + + # Add node to graph + node_id = '->'.join([str(token_id) for token_id in path]) + node_label = LLM.tokenizer.decode([path[-1]]) if path else 'ROOT' + graph.add_node(node_id, label=node_label, level=level) + + # Add edge to graph + if path: + parent_id = '->'.join([str(token_id) for token_id in path[:-1]]) + graph.add_edge(parent_id, node_id) + + # Recurse on children + for token_id, child in node.children.items(): + build_graph(LLM, child, path + [token_id], graph, level + 1) + + return graph + +def draw_graph(graph): + pos = nx.multipartite_layout(graph, subset_key="level") # Position nodes at different levels + labels = {node: data['label'] for node, data in graph.nodes(data=True)} + nx.draw(graph, pos, labels=labels, arrows=True) + plt.show() + +def show_graph(LLM): + graph = build_graph(LLM) + draw_graph(graph) diff --git a/setup.py b/setup.py index 58f0c49..3149347 100644 --- a/setup.py +++ b/setup.py @@ -1,3 +1,4 @@ +import setuptools from setuptools import setup setup( @@ -14,7 +15,9 @@ 'transformers==4.30', 'bitsandbytes', 'accelerate', - 'sentencepiece' + 'sentencepiece', + 'networkx', + 'matplotlib' ], classifiers=[ From ea5c2aef64bddc567e112afbffcebb15a3e81f61 Mon Sep 17 00:00:00 2001 From: Jasper Date: Sun, 24 Sep 2023 22:44:32 -0700 Subject: [PATCH 2/2] add graph showing to default example --- examples/hard_constraints.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/examples/hard_constraints.py b/examples/hard_constraints.py index e5c5431..96dfa20 100644 --- a/examples/hard_constraints.py +++ b/examples/hard_constraints.py @@ -1,7 +1,7 @@ import string import asyncio from hfppl import Model, CachedCausalLM, Token, LMContext, smc_standard - +from hfppl.util import show_graph # Load the language model. # Vicuna is an open model; to use a model with restricted access, like LLaMA 2, @@ -77,4 +77,6 @@ async def main(): for p in particles: print(f"{p.context}") + show_graph(LLM) + asyncio.run(main()) \ No newline at end of file