Skip to content

Commit

Permalink
Integrate new model data dir in cli driver
Browse files Browse the repository at this point in the history
  • Loading branch information
mittagessen committed Jan 5, 2025
1 parent 77e1d44 commit c4b26b6
Show file tree
Hide file tree
Showing 4 changed files with 33 additions and 34 deletions.
1 change: 1 addition & 0 deletions environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -33,4 +33,5 @@ dependencies:
- pip:
- coremltools~=8.1
- htrmopo
- platformdirs
- file:.
1 change: 1 addition & 0 deletions environment_cuda.yml
Original file line number Diff line number Diff line change
Expand Up @@ -34,4 +34,5 @@ dependencies:
- pip:
- coremltools~=8.1
- htrmopo
- platformdirs
- file:.
64 changes: 30 additions & 34 deletions kraken/kraken.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,28 +18,28 @@
Command line drivers for recognition functionality.
"""
import dataclasses
import logging
import os
import uuid
import click
import shlex
import logging
import warnings
from functools import partial
from pathlib import Path
from typing import IO, Any, Callable, Dict, List, Union, cast
import dataclasses

import click
from PIL import Image
from pathlib import Path
from itertools import chain
from functools import partial
from importlib import resources
from platformdirs import user_data_dir
from typing import IO, Any, Callable, Dict, List, Union, cast

from rich import print
from rich.tree import Tree
from rich.table import Table
from rich.console import Group
from rich.traceback import install
from rich.logging import RichHandler
from rich.markdown import Markdown
from rich.progress import Progress

from kraken.lib import log

Expand Down Expand Up @@ -107,7 +107,7 @@ def binarizer(threshold, zoom, escale, border, perc, range, low, high, input, ou
processing_steps=ctx.meta['steps']))
else:
form = None
ext = os.path.splitext(output)[1]
ext = Path(output).suffix
if ext in ['.jpg', '.jpeg', '.JPG', '.JPEG', '']:
form = 'png'
if ext:
Expand Down Expand Up @@ -359,7 +359,6 @@ def process_pipeline(subcommands, input, batch_input, suffix, verbose, format_ty
placing their respective outputs in temporary files.
"""
import glob
import os.path
import tempfile

from threadpoolctl import threadpool_limits
Expand All @@ -373,9 +372,8 @@ def process_pipeline(subcommands, input, batch_input, suffix, verbose, format_ty
# expand batch inputs
if batch_input and suffix:
for batch_expr in batch_input:
for in_file in glob.glob(os.path.expanduser(batch_expr), recursive=True):

input.append((in_file, '{}{}'.format(os.path.splitext(in_file)[0], suffix)))
for in_file in glob.glob(str(Path(batch_expr).expanduser()), recursive=True):
input.append(Path(in_file).with_suffix(suffix))

# parse pdfs
if format_type == 'pdf':
Expand Down Expand Up @@ -515,29 +513,31 @@ def segment(ctx, model, boxes, text_direction, scale, maxcolseps,
logger.warning(f'Baseline model ({model}) given but legacy segmenter selected. Forcing to -bl.')
boxes = False

model = [Path(m) for m in model]
if boxes is False:
if not model:
model = [SEGMENTATION_DEFAULT_MODEL]
ctx.meta['steps'].append(ProcessingStep(id=str(uuid.uuid4()),
category='processing',
description='Baseline and region segmentation',
settings={'model': [os.path.basename(m) for m in model],
settings={'model': [m.name for m in model],
'text_direction': text_direction}))

# first try to find the segmentation models by their given names, then
# look in the kraken config folder
locations = []
for m in model:
location = None
search = [m, os.path.join(click.get_app_dir(APP_NAME), m)]
search = chain([m],
Path(user_data_dir('htrmopo')).rglob(str(m)),
Path(click.get_app_dir('kraken')).rglob(str(m)))
for loc in search:
if os.path.isfile(loc):
if loc.is_file():
location = loc
locations.append(loc)
break
if not location:
raise click.BadParameter(f'No model for {m} found')

raise click.BadParameter(f'No model for {str(m)} found')

from kraken.lib.vgsl import TorchVGSLModel
model = []
Expand Down Expand Up @@ -638,11 +638,12 @@ def ocr(ctx, model, pad, reorder, base_dir, no_segmentation, text_direction):
nm: Dict[str, models.TorchSeqRecognizer] = {}
ign_tags = model.pop('ignore')
for k, v in model.items():
search = [v,
os.path.join(click.get_app_dir(APP_NAME), v)]
search = chain([Path(v)],
Path(user_data_dir('htrmopo')).rglob(v),
Path(click.get_app_dir('kraken')).rglob(v))
location = None
for loc in search:
if os.path.isfile(loc):
if loc.is_file():
location = loc
break
if not location:
Expand All @@ -669,7 +670,7 @@ def ocr(ctx, model, pad, reorder, base_dir, no_segmentation, text_direction):
category='processing',
description='Text line recognition',
settings={'text_direction': text_direction,
'models': ' '.join(os.path.basename(v) for v in model.values()),
'models': ' '.join(Path(v).name for v in model.values()),
'pad': pad,
'bidi_reordering': reorder}))

Expand Down Expand Up @@ -807,29 +808,24 @@ def get(ctx, model_id):
"""
Retrieves a model from the repository.
"""
import glob

from htrmopo import get_model, get_description
from htrmopo import get_model

from kraken.repo import get_description
from kraken.lib.progress import KrakenDownloadProgressBar

try:
desc = get_description(model_id)
get_description(model_id,
filter_fn=lambda record: getattr(record, 'software_name', None) == 'kraken' or 'kraken_pytorch' in record.keywords)
except ValueError as e:
logger.error(e)
ctx.exit(1)

print(desc)
if getattr(desc, 'software_name', None) != 'kraken' or 'kraken_pytorch' not in desc.keywords:
logger.error('Record exists but is not a kraken-compatible model')
ctx.exit(1)

with KrakenDownloadProgressBar() as progress:
download_task = progress.add_task('Processing', total=0, visible=True if not ctx.meta['verbose'] else False)
model_dir = get_model(model_id,
lambda total, advance: progress.update(download_task, total=total, advance=advance))
model_candidates = list(filter(lambda x: x.suffix == '.mlmodel', model_dir.iter_dir()))
message(f'Model dir: {model_dir} (model files: {model_candidates})')
callback=lambda total, advance: progress.update(download_task, total=total, advance=advance))
model_candidates = list(filter(lambda x: x.suffix == '.mlmodel', model_dir.iterdir()))
message(f'Model dir: {model_dir} (model files: {", ".join(x.name for x in model_candidates)})')


if __name__ == '__main__':
Expand Down
1 change: 1 addition & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ install_requires =
lightning~=2.4.0
torchmetrics>=1.1.0
threadpoolctl~=3.5.0
platformdirs
rich

[options.extras_require]
Expand Down

0 comments on commit c4b26b6

Please sign in to comment.