Skip to content

Commit

Permalink
chore: lint
Browse files Browse the repository at this point in the history
  • Loading branch information
sgalkina committed Oct 25, 2023
1 parent 31e0fde commit a12671b
Show file tree
Hide file tree
Showing 3 changed files with 54 additions and 39 deletions.
49 changes: 32 additions & 17 deletions test/test_semisupervised_encode.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@


class TestDataLoader(unittest.TestCase):

def test_permute_indices(self):
indices = vamb.semisupervised_encode.permute_indices(10, 25, seed=1)
set_10 = set(range(10))
Expand All @@ -22,38 +21,54 @@ class TestVAEVAE(unittest.TestCase):
N_contigs = 111
tnfs = np.random.random((N_contigs, 103)).astype(np.float32)
rpkms = np.random.random((N_contigs, 14)).astype(np.float32)
domain = 'd_Bacteria'
phyla = ['f_1', 'f_2', 'f_3']
domain = "d_Bacteria"
phyla = ["f_1", "f_2", "f_3"]
classes = {
'f_1': ['c_11', 'c_21', 'c_31'],
'f_2': ['c_12', 'c_22', 'c_32'],
'f_3': ['c_13', 'c_23', 'c_33'],
"f_1": ["c_11", "c_21", "c_31"],
"f_2": ["c_12", "c_22", "c_32"],
"f_3": ["c_13", "c_23", "c_33"],
}
lengths = np.random.randint(2000, 5000, size=N_contigs)

def make_random_annotation(self):
phylum = np.random.choice(self.phyla, 1)[0]
clas = np.random.choice(self.classes[phylum], 1)[0]
if np.random.random() <= 0.2:
return ';'.join([self.domain])
return ";".join([self.domain])
if 0.2 < np.random.random() <= 0.5:
return ';'.join([self.domain, phylum])
return ';'.join([self.domain, phylum, clas])
return ";".join([self.domain, phylum])
return ";".join([self.domain, phylum, clas])

def make_random_annotations(self):
return [self.make_random_annotation() for _ in range(self.N_contigs)]

def test_make_graph(self):
annotations = self.make_random_annotations()
nodes, ind_nodes, table_parent = vamb.h_loss.make_graph(annotations)
print(nodes, ind_nodes, table_parent)
self.assertTrue(set(nodes).issubset(set([
'Domain', 'd_Archaea', 'd_Bacteria',
'f_1', 'f_2', 'f_3',
'c_11', 'c_21', 'c_31',
'c_12', 'c_22', 'c_32',
'c_13', 'c_23', 'c_33',
])))
self.assertTrue(
set(nodes).issubset(
set(
[
"Domain",
"d_Archaea",
"d_Bacteria",
"f_1",
"f_2",
"f_3",
"c_11",
"c_21",
"c_31",
"c_12",
"c_22",
"c_32",
"c_13",
"c_23",
"c_33",
]
)
)
)
for p, cls in self.classes.items():
for c in cls:
for f in self.phyla:
Expand Down
40 changes: 21 additions & 19 deletions vamb/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -1521,7 +1521,9 @@ class BasicArguments(object):

def __init__(self, args):
self.args = args
self.comp_options = CompositionOptions(self.args.fasta, self.args.composition, self.args.minlength)
self.comp_options = CompositionOptions(
self.args.fasta, self.args.composition, self.args.minlength
)
self.abundance_options = AbundanceOptions(
self.args.bampaths,
self.args.abundancepath,
Expand Down Expand Up @@ -1582,8 +1584,8 @@ def __init__(self, args):
dropout=args.dropout,
)
self.vae_training_options = VAETrainingOptions(
nepochs=args.nepochs,
batchsize=args.batchsize,
nepochs=args.nepochs,
batchsize=args.batchsize,
batchsteps=args.batchsteps,
)
self.aae_options = None
Expand Down Expand Up @@ -1616,8 +1618,8 @@ def __init__(self, args):

def init_encoder_and_training(self):
self.encoder_options = EncoderOptions(
vae_options=self.vae_options,
aae_options=self.aae_options,
vae_options=self.vae_options,
aae_options=self.aae_options,
alpha=self.args.alpha,
)
self.training_options = TrainingOptions(
Expand Down Expand Up @@ -1659,8 +1661,8 @@ class VAEVAEArguments(BinnerArguments):
def __init__(self, args):
super(VAEVAEArguments, self).__init__(args)
self.encoder_options = EncoderOptions(
vae_options=self.vae_options,
aae_options=None,
vae_options=self.vae_options,
aae_options=None,
alpha=args.alpha,
)
self.training_options = TrainingOptions(
Expand Down Expand Up @@ -2213,58 +2215,58 @@ def main():
parser.print_help()
sys.exit()

subparsers = parser.add_subparsers(dest='subcommand')
subparsers = parser.add_subparsers(dest="subcommand")

predict_parser = subparsers.add_parser('predict', help='predict help')
predict_parser = subparsers.add_parser("predict", help="predict help")
add_input_output_arguments(predict_parser)
add_taxonomy_arguments(predict_parser, taxonomy_only=True)
add_predictor_arguments(predict_parser)

vaevae_parserbin_parser = subparsers.add_parser('bin', help='bin help')
subparsers_model = vaevae_parserbin_parser.add_subparsers(dest='model_subcommand')
vaevae_parserbin_parser = subparsers.add_parser("bin", help="bin help")
subparsers_model = vaevae_parserbin_parser.add_subparsers(dest="model_subcommand")

vae_parser = subparsers_model.add_parser('vae', help='vae_parser help')
vae_parser = subparsers_model.add_parser("vae", help="vae_parser help")
add_input_output_arguments(vae_parser)
add_vae_arguments(vae_parser)
add_clustering_arguments(vae_parser)
add_predictor_arguments(vae_parser)

aae_parser = subparsers_model.add_parser('aae', help='aae_parser help')
aae_parser = subparsers_model.add_parser("aae", help="aae_parser help")
add_input_output_arguments(aae_parser)
add_aae_arguments(aae_parser)
add_clustering_arguments(aae_parser)

vaeaae_parser = subparsers_model.add_parser('vaeaae', help='vaeaae_parser help')
vaeaae_parser = subparsers_model.add_parser("vaeaae", help="vaeaae_parser help")
add_input_output_arguments(vaeaae_parser)
add_vae_arguments(vaeaae_parser)
add_aae_arguments(vaeaae_parser)
add_clustering_arguments(vaeaae_parser)

vaevae_parser = subparsers_model.add_parser('vaevae', help='vaevae_parser help')
vaevae_parser = subparsers_model.add_parser("vaevae", help="vaevae_parser help")
add_input_output_arguments(vaevae_parser)
add_vae_arguments(vaevae_parser)
add_clustering_arguments(vaevae_parser)
add_predictor_arguments(vaevae_parser)
add_taxonomy_arguments(vaevae_parser)

recluster_parser = subparsers.add_parser('recluster', help='recluster help')
recluster_parser = subparsers.add_parser("recluster", help="recluster help")
add_input_output_arguments(recluster_parser)
add_reclustering_arguments(recluster_parser)
add_taxonomy_arguments(recluster_parser, predictions_only=True)

args = parser.parse_args()

if args.subcommand == 'predict':
if args.subcommand == "predict":
runner = TaxometerArguments(args)
elif args.subcommand == 'bin':
elif args.subcommand == "bin":
classes_map = dict(
vae=VAEArguments,
aae=AAEArguments,
vaeaae=VAEAAEArguments,
vaevae=VAEVAEArguments,
)
runner = classes_map[args.model_subcommand](args)
elif args.subcommand == 'recluster':
elif args.subcommand == "recluster":
runner = ReclusteringArguments(args)
runner.run()

Expand Down
4 changes: 1 addition & 3 deletions vamb/semisupervised_encode.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,9 +86,7 @@ def kld_gauss(p_mu, p_logstd, q_mu, q_logstd):

def _make_dataset(rpkm, tnf, lengths, batchsize=256, destroy=False, cuda=False):
n_workers = 4 if cuda else 1
dataloader = _encode.make_dataloader(
rpkm, tnf, lengths, batchsize, destroy, cuda
)
dataloader = _encode.make_dataloader(rpkm, tnf, lengths, batchsize, destroy, cuda)
(
depthstensor,
tnftensor,
Expand Down

0 comments on commit a12671b

Please sign in to comment.