Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Create nougat.ipynb #85

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
325 changes: 325 additions & 0 deletions notebooks/nougat.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,325 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "3IGr7-SPuivC"
},
"source": [
"# **Nougat** : Neural Optical Understanding for Academic Documents\n",
"# **A Gradio Demo**\n",
"\n",
"## Lukas Blecher et al. [Paper](https://arxiv.org/pdf/2308.13418.pdf), [Project](https://facebookresearch.github.io/nougat/)\n",
"\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "o1psC42ludfh"
},
"source": [
"### Installing the required libraries"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "wFetOZtjXT4D",
"outputId": "e20caad0-6539-474c-a765-eaf40044e952"
},
"outputs": [],
"source": [
"!pip install gradio -U -q\n",
"import gradio as gr"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "VTVyJd43TJjh",
"outputId": "048a4a98-afe1-4505-dd47-6e81c89fe11c"
},
"outputs": [],
"source": [
"!pip install nougat-ocr -q"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "DHGxzqdmkbVf"
},
"source": [
"### Download a smaple pdf file"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "RMsn4EE1j9Gl",
"outputId": "2a085c28-d0df-4904-d4b9-5e79774d2a9f"
},
"outputs": [],
"source": [
"# Download a sample pdf file - https://arxiv.org/pdf/2308.13418.pdf (nougat paper)\n",
"import requests\n",
"import os\n",
"\n",
"# create a new input directory for pdf downloads\n",
"if not os.path.exists(\"input\"):\n",
" os.mkdir(\"input\")\n",
"def get_pdf(pdf_link):\n",
"\n",
" # Send a GET request to the PDF link\n",
" response = requests.get(pdf_link)\n",
"\n",
" if response.status_code == 200:\n",
" # Save the PDF content to a local file\n",
" with open(\"input/nougat.pdf\", 'wb') as pdf_file:\n",
" pdf_file.write(response.content)\n",
" print(\"PDF downloaded successfully.\")\n",
" else:\n",
" print(\"Failed to download the PDF.\")\n",
" return\n",
"\n",
"\n",
"get_pdf(\"https://arxiv.org/pdf/2308.13418.pdf\")"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "ghKR79CknBcP"
},
"source": [
"### Downloading model weights"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "y88ZxhOVaOXt",
"outputId": "69b653ac-d8c0-42cf-d632-0a7815753989"
},
"outputs": [],
"source": [
"from nougat.utils.checkpoint import get_checkpoint\n",
"CHECKPOINT = get_checkpoint('nougat')"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "BUG9t33OpKao"
},
"source": [
"### Writing inference functions for Gradio app"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "OGRpNZkFzyz1"
},
"outputs": [],
"source": [
"import subprocess\n",
"import uuid\n",
"import requests\n",
"import re\n",
"\n",
"# Download pdf from a given link\n",
"def get_pdf(pdf_link):\n",
" # Generate a unique filename\n",
" unique_filename = f\"input/downloaded_paper_{uuid.uuid4().hex}.pdf\"\n",
"\n",
" # Send a GET request to the PDF link\n",
" response = requests.get(pdf_link)\n",
"\n",
" if response.status_code == 200:\n",
" # Save the PDF content to a local file\n",
" with open(unique_filename, 'wb') as pdf_file:\n",
" pdf_file.write(response.content)\n",
" print(\"PDF downloaded successfully.\")\n",
" else:\n",
" print(\"Failed to download the PDF.\")\n",
" return unique_filename\n",
"\n",
"\n",
"# Run nougat on the pdf file\n",
"def nougat_ocr(file_name):\n",
"\n",
" # Command to run\n",
" cli_command = [\n",
" 'nougat',\n",
" '--out', 'output',\n",
" 'pdf', file_name,\n",
" '--checkpoint', CHECKPOINT,\n",
" '--markdown'\n",
" ]\n",
"\n",
" # Run the command\n",
" subprocess.run(cli_command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)\n",
"\n",
" return\n",
"\n",
"\n",
"# predict function / driver function\n",
"def paper_read(pdf_file, pdf_link):\n",
" if pdf_file is None:\n",
" if pdf_link == '':\n",
" print(\"No file is uploaded and No link is provided\")\n",
" return \"No data provided. Upload a pdf file or provide a pdf link and try again!\"\n",
" else:\n",
" file_name = get_pdf(pdf_link)\n",
" else:\n",
" file_name = pdf_file.name\n",
"\n",
" nougat_ocr(file_name)\n",
"\n",
" # Open the file for reading\n",
" file_name = file_name.split('/')[-1][:-4]\n",
" with open(f'output/{file_name}.mmd', 'r') as file:\n",
" content = file.read()\n",
"\n",
" return content\n",
"\n",
"\n",
"# Handling examples in Gradio app\n",
"def process_example(pdf_file,pdf_link):\n",
" ocr_content = paper_read(pdf_file,pdf_link)\n",
" return gr.update(value=ocr_content)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "IjVq6B5MGAGt"
},
"outputs": [],
"source": [
"# fixing the size of markdown component in gradio app\n",
"css = \"\"\"\n",
" #mkd {\n",
" height: 500px;\n",
" overflow: auto;\n",
" border: 1px solid #ccc;\n",
" }\n",
"\"\"\"\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "xSrKYBORuTCl"
},
"source": [
"### Building Gradio UI"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 750
},
"id": "TnvIHr5ITJbl",
"outputId": "cb27bf8d-519f-46ac-de7c-8dc1956a74d0"
},
"outputs": [],
"source": [
"# Gradio Blocks\n",
"with gr.Blocks(css =css) as demo:\n",
" with gr.Row():\n",
" mkd = gr.Markdown('<h4><center>Upload a PDF</center></h4>',scale=1)\n",
" mkd = gr.Markdown('<h4><center><i>OR</i></center></h4>',scale=1)\n",
" mkd = gr.Markdown('<h4><center>Provide a PDF link</center></h4>',scale=1)\n",
"\n",
" with gr.Row(equal_height=True):\n",
" pdf_file = gr.File(label='PDF📃', file_count='single', scale=1)\n",
" pdf_link = gr.Textbox(placeholder='Enter an arxiv link here', label='PDF link🔗🌐', scale=1)\n",
"\n",
" with gr.Row():\n",
" btn = gr.Button('Run NOUGAT🍫')\n",
" clr = gr.Button('Clear🚿')\n",
"\n",
" output_headline = gr.Markdown(\"<h3><center>PDF converted into markup language through Nougat-OCR👇:</center></h3>\")\n",
" parsed_output = gr.Markdown(r'OCR Output📃🔤',elem_id='mkd', scale=1, latex_delimiters=[{ \"left\": r\"\\(\", \"right\": r\"\\)\", \"display\": False },{ \"left\": r\"\\[\", \"right\": r\"\\]\", \"display\": True }])\n",
"\n",
" btn.click(paper_read, [pdf_file, pdf_link], parsed_output )\n",
" clr.click(lambda : (gr.update(value=None),\n",
" gr.update(value=None),\n",
" gr.update(value=None)),\n",
" [],\n",
" [pdf_file, pdf_link, parsed_output]\n",
" )\n",
"\n",
" # gr.Examples(\n",
" # [[\"nougat.pdf\", \"\"], [None, \"https://arxiv.org/pdf/2308.08316.pdf\"]],\n",
" # inputs = [pdf_file, pdf_link],\n",
" # outputs = parsed_output,\n",
" # fn=process_example,\n",
" # cache_examples=True,\n",
" # label='Click on any examples below to get Nougat OCR results quickly:'\n",
" # )\n",
"\n",
"demo.queue()\n",
"demo.launch(share=True)\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"accelerator": "GPU",
"colab": {
"gpuType": "T4",
"provenance": []
},
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.16"
}
},
"nbformat": 4,
"nbformat_minor": 4
}