-
Notifications
You must be signed in to change notification settings - Fork 233
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
992a462
commit 8e530e3
Showing
10 changed files
with
1,494 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,49 @@ | ||
# 🤖 BERT Base Cased Integration | ||
|
||
Integrate the BERT (Masked Language Model) to predict masked words using the cased model, seamlessly integrated with uAgent. | ||
BERT is a transformers model pretrained on a large corpus of English data in a self-supervised fashion. This means it was pretrained on the raw texts only, with no humans labeling them in any way (which is why it can use lots of publicly available data) with an automatic process to generate inputs and labels from those texts. | ||
This model is cased: it does make a difference between english and English. | ||
|
||
![BERT Logo](https://path/to/bert-logo.png) | ||
|
||
## 🛠️ Prerequisites | ||
|
||
Before getting started, make sure you have the following software installed: | ||
|
||
- Python (v3.10+ recommended) | ||
- Poetry (a Python packaging and dependency management tool) | ||
|
||
## 🚀 Setup | ||
|
||
1. **Obtain HuggingFace API Token:** | ||
|
||
- Visit [HuggingFace](https://huggingface.co/). | ||
- Sign up or log in. | ||
- Navigate to `Profile -> Settings -> Access Tokens`. | ||
- Copy an existing token or create a new one. | ||
|
||
2. **Configure Environment:** | ||
|
||
- Set your HuggingFace API Token as follows: | ||
|
||
``` | ||
export HUGGING_FACE_ACCESS_TOKEN="{Your HuggingFace API Token}" | ||
``` | ||
|
||
3. **Install Dependencies:** | ||
|
||
```bash | ||
poetry install | ||
``` | ||
|
||
## 📋 Running The Script | ||
|
||
To run the project, use the following command: | ||
|
||
``` | ||
poetry run python main.py | ||
``` | ||
You can change the input text to whatever you like. Open the `src/bert_base_user.py` file and modify the value of the `INPUT_TEXT` variable in `agents/bert_base_user.py`. | ||
The BERT Base model will suggest the most appropriate word to be replaced. |
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,16 @@ | ||
[tool.poetry] | ||
name = "bert-base-cased-uagent-integration" | ||
version = "0.0.1" | ||
description = "BERT base model (cased) integration with fetch.ai uagent" | ||
authors = ["Laxmikant Tripathi <[email protected]>"] | ||
readme = "README.md" | ||
|
||
[tool.poetry.dependencies] | ||
python = ">=3.10,<3.12" | ||
requests = "^2.31.0" | ||
uagents = "^0.6.2" | ||
|
||
|
||
[build-system] | ||
requires = ["poetry-core"] | ||
build-backend = "poetry.core.masonry.api" |
Empty file.
Empty file.
82 changes: 82 additions & 0 deletions
82
integrations/bert-base-cased/src/agents/bert_base_agent.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,82 @@ | ||
import os | ||
import requests | ||
from uagents import Agent, Context, Protocol | ||
from messages import UAResponse, UARequest, Error | ||
from uagents.setup import fund_agent_if_low | ||
|
||
# Define constants | ||
HUGGING_FACE_ACCESS_TOKEN = os.getenv("HUGGING_FACE_ACCESS_TOKEN", "") | ||
|
||
if not HUGGING_FACE_ACCESS_TOKEN: | ||
raise Exception( | ||
"You need to provide an HUGGING_FACE_ACCESS_TOKEN, by exporting env, please follow the README" | ||
) | ||
|
||
BERT_BASE_URL = "https://api-inference.huggingface.co/models/bert-base-cased" | ||
|
||
# Set headers for API requests | ||
HEADERS = {"Authorization": f"Bearer {HUGGING_FACE_ACCESS_TOKEN}"} | ||
|
||
|
||
def create_and_fund_agent() -> Agent: | ||
""" | ||
Create an agent for BERT-based text prediction and fund it if necessary. | ||
Returns: | ||
Agent: The created and funded agent. | ||
""" | ||
agent = Agent( | ||
name="bert_base_agent", | ||
seed=HUGGING_FACE_ACCESS_TOKEN, | ||
) | ||
os.environ["AI_MODEL_AGENT_ADDRESS"] = agent.address | ||
fund_agent_if_low(agent.wallet.address()) | ||
return agent | ||
|
||
|
||
agent = create_and_fund_agent() | ||
# Define a protocol for UARequests | ||
bert_prediction_protocol = Protocol("UARequest") | ||
|
||
|
||
async def predict_text(ctx: Context, sender: str, text: str) -> None: | ||
""" | ||
Predict text using the Hugging Face BERT-based model and send the result as a response. | ||
Args: | ||
ctx (Context): The context object. | ||
sender (str): The sender's identifier. | ||
text (str): The text to be classified. | ||
""" | ||
try: | ||
response = requests.post(BERT_BASE_URL, headers=HEADERS, json={"inputs": text}) | ||
|
||
if response.status_code != 200: | ||
error_message = response.json().get("error", "Unknown error") | ||
await ctx.send(sender, Error(error=f"API Error: {error_message}")) | ||
return | ||
|
||
model_result = response.json()[0] | ||
await ctx.send(sender, UAResponse(response=model_result)) | ||
except Exception as ex: | ||
await ctx.send(sender, Error(error=f"Exception: {str(ex)}")) | ||
|
||
|
||
@bert_prediction_protocol.on_message(model=UARequest, replies={UAResponse, Error}) | ||
async def handle_request(ctx: Context, sender: str, request: UARequest) -> None: | ||
""" | ||
Handle UARequest for text prediction and send the response. | ||
Args: | ||
ctx (Context): The context object. | ||
sender (str): The sender's identifier. | ||
request (UARequest): The UARequest containing the text to classify. | ||
""" | ||
ctx.logger.info(f"Got request from {sender} for text prediction: {request.text}") | ||
await predict_text(ctx, sender, request.text) | ||
|
||
|
||
agent.include(bert_prediction_protocol) | ||
|
||
if __name__ == "__main__": | ||
bert_prediction_protocol.run() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,58 @@ | ||
import os | ||
from uagents import Agent, Context, Protocol | ||
from uagents.setup import fund_agent_if_low | ||
|
||
from messages import Error, UARequest, UAResponse | ||
|
||
# Constants | ||
INPUT_TEXT = "India is a [MASK] Capital of World." | ||
AI_MODEL_AGENT_ADDRESS = os.getenv("AI_MODEL_AGENT_ADDRESS", "") | ||
|
||
# Create a user agent | ||
user = Agent( | ||
name="bert_base_user", | ||
) | ||
|
||
# Fund the user agent if necessary | ||
fund_agent_if_low(user.wallet.address()) | ||
|
||
# Define a protocol for user-agent communication | ||
bert_base_user = Protocol("Request") | ||
|
||
@bert_base_user.on_interval(360, messages=UARequest) | ||
async def predict_masking(ctx: Context): | ||
""" | ||
Periodically sends a UARequest to the AI model agent to predict a masking word. | ||
Args: | ||
ctx (Context): The context object. | ||
""" | ||
ctx.logger.info(f"Asking AI model agent to find masking word: {INPUT_TEXT}") | ||
await ctx.send(AI_MODEL_AGENT_ADDRESS, UARequest(text=INPUT_TEXT)) | ||
|
||
@bert_base_user.on_message(model=UAResponse) | ||
async def handle_data(ctx: Context, sender: str, data: UAResponse): | ||
""" | ||
Handles the response data received from the AI model agent. | ||
Args: | ||
ctx (Context): The context object. | ||
sender (str): The sender's identifier. | ||
data (UAResponse): The UAResponse containing the predicted masking word. | ||
""" | ||
ctx.logger.info(f"Got response from AI model agent: {data.response}") | ||
|
||
@bert_base_user.on_message(model=Error) | ||
async def handle_error(ctx: Context, sender: str, error: Error): | ||
""" | ||
Handles error messages received from the AI model agent. | ||
Args: | ||
ctx (Context): The context object. | ||
sender (str): The sender's identifier. | ||
error (Error): The Error message from the AI model agent. | ||
""" | ||
ctx.logger.info(f"Got error from AI model agent: {error}") | ||
|
||
# Include the user agent in the agent setup | ||
user.include(bert_base_user) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
from uagents import Bureau | ||
|
||
from agents.bert_base_agent import agent | ||
from agents.bert_base_user import user | ||
|
||
if __name__ == "__main__": | ||
bureau = Bureau(endpoint="http://127.0.0.1:8000/submit", port=8000) | ||
print(f"Adding AI model agent to Bureau: {agent.address}") | ||
bureau.add(agent) | ||
|
||
print(f"Adding user agent to Bureau: {user.address}") | ||
bureau.add(user) | ||
bureau.run() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
from .base import UAResponse, UARequest, Error |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
from uagents import Model | ||
|
||
|
||
class UARequest(Model): | ||
text: str | ||
|
||
|
||
class Error(Model): | ||
error: str | ||
|
||
|
||
class UAResponse(Model): | ||
response: dict |