Skip to content

Commit

Permalink
Merge pull request #17 from soo-h/master
Browse files Browse the repository at this point in the history
merge updates for v1.1.2
  • Loading branch information
cailigd committed Jul 1, 2024
2 parents 74eb723 + a0a9b76 commit e46c976
Show file tree
Hide file tree
Showing 11 changed files with 30 additions and 56 deletions.
10 changes: 5 additions & 5 deletions MuRaL/nn_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ def run_time_view_model_predict_m(model, dataloader, criterion, device, n_class,
total_loss = 0
batch_count = 0
get_batch_time_recode = 0
get_batch_train_recode = 0
get_batch_predict_recode = 0
step_time = time.time()
get_batch_time = time.time()

Expand All @@ -136,7 +136,7 @@ def run_time_view_model_predict_m(model, dataloader, criterion, device, n_class,
print("get 500 batch used time: ", get_batch_time_recode)
get_batch_time_recode = 0

batch_train_time = time.time()
batch_predict_time = time.time()
cat_x = cat_x.to(device)
cont_x = cont_x.to(device)
distal_x = distal_x.to(device)
Expand All @@ -155,10 +155,10 @@ def run_time_view_model_predict_m(model, dataloader, criterion, device, n_class,
print(f"Batch Number: {batch_count}; Mean Time of 500 batch: {(time.time()-step_time)}")
step_time = time.time()

get_batch_train_recode += time.time() - batch_train_time
get_batch_predict_recode += time.time() - batch_predict_time
if batch_count % 500 == 0:
print("training 500 batch used time:", get_batch_train_recode)
get_batch_train_recode = 0
print("training 500 batch used time:", get_batch_predict_recode)
get_batch_predict_recode = 0
get_batch_time = time.time()

if device == torch.device('cpu'):
Expand Down
8 changes: 1 addition & 7 deletions MuRaL/run_predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,12 +113,6 @@ def parse_arguments(parser):
between RAM memory and preprocessing speed. It is recommended to use 300k.
Default: 300000.""" ).strip())

optional.add_argument('--sampled_segments', type=int, metavar='INT', default=[10], nargs='+',
help=textwrap.dedent("""
Number of segments chosen for generating samples for batches in DataLoader.
Default: 10.
""" ).strip())

optional.add_argument('--pred_batch_size', metavar='INT', default=16,
help=textwrap.dedent("""
Size of mini batches for prediction. Default: 16.
Expand Down Expand Up @@ -208,7 +202,7 @@ def main():
ref_genome= args.ref_genome

pred_batch_size = args.pred_batch_size
sampled_segments = args.sampled_segments
sampled_segments = 1
# Output file path
pred_file = args.pred_file

Expand Down
4 changes: 3 additions & 1 deletion MuRaL/run_train_TL_raytune.py
Original file line number Diff line number Diff line change
Expand Up @@ -441,7 +441,9 @@ def main():

if not sampled_segments:
sampled_segments = args.sampled_segments = para_read_from_config('sampled_segments', config)

else:
sampled_segments = args.sampled_segments[0]

args.seq_only = config['seq_only']


Expand Down
2 changes: 1 addition & 1 deletion MuRaL/run_train_raytune.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,7 @@ def parse_arguments(parser):
Default: 0.25.
""" ).strip())

learn_args.add_argument('--segment_center', type=int, metavar='INT', default=300000,
learn_args.add_argument('--segment_center', type=int, metavar='INT', default=300000,
help=textwrap.dedent("""
The maximum encoding unit of the sequence. It affects trade-off
between RAM memory and preprocessing speed. It is recommended to use 300k.
Expand Down
22 changes: 0 additions & 22 deletions dirichlet.patch

This file was deleted.

10 changes: 0 additions & 10 deletions dirichlet_install.sh

This file was deleted.

2 changes: 1 addition & 1 deletion dirichlet_python/dirichletcal/calib/multinomial.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,7 +274,7 @@ def _newton_update(weights_0, X, XX_T, target, k, method_, maxiter=int(1024),
updates = gradient / hessian
else:
try:
inverse = scipy.linalg.pinv2(hessian)
inverse = scipy.linalg.pinv(hessian)
updates = np.matmul(inverse, gradient)
except (raw_np.linalg.LinAlgError, ValueError) as err:
logging.error(err)
Expand Down
10 changes: 5 additions & 5 deletions dirichlet_python/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,11 +29,11 @@
],
python_requires='>=3.6',
install_requires = [
'numpy>=1.14.2'
'scipy>=1.0.0'
'scikit-learn>=0.19.1'
'jax'
'jaxlib'
'numpy>=1.14.2',
'scipy>=1.0.0',
'scikit-learn>=0.19.1',
'jax',
'jaxlib',
'autograd'
]
)
5 changes: 4 additions & 1 deletion environment.yml
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
channels:
- pytorch
- bioconda
- conda-forge
- defaults
dependencies:
- python=3.8.5
- pip=22.0.4
- numpy=1.21.2
- numpy=1.23.5
- pandas=1.4.1
- cudatoolkit=11.3
- pytorch=1.10.2
- pybigwig=0.3.17
Expand All @@ -24,3 +26,4 @@ dependencies:
- jaxlib==0.4.13
- autograd==1.6.2
- protobuf==3.20.0
- setuptools==69.0.2
5 changes: 4 additions & 1 deletion environment_cpu.yml
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
channels:
- pytorch
- bioconda
- conda-forge
- defaults
dependencies:
- python=3.8.5
- pip=22.0.4
- numpy=1.21.2
- numpy=1.23.5
- pandas=1.4.1
- pytorch=1.10.2
- cpuonly=2.0
- pybigwig=0.3.17
Expand All @@ -24,3 +26,4 @@ dependencies:
- jaxlib==0.4.13
- autograd==1.6.2
- protobuf==3.20.0
- setuptools==69.0.2
8 changes: 6 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,12 @@

from setuptools import setup, find_packages

import os
os.system('bash dirichlet_install.sh')
import subprocess
try:
subprocess.check_call(['pip', 'install', '.'], cwd='dirichlet_python')
except subprocess.CalledProcessError as e:
print(f"Error installing dirichlet_python: {e}")
exit(1)

def get_version():
try:
Expand Down

0 comments on commit e46c976

Please sign in to comment.