Skip to content

Commit

Permalink
Added --reference (#15)
Browse files Browse the repository at this point in the history
* added reference

* smaller changes

* text mistake fix

* subset tsv df

* docu update
  • Loading branch information
jonas-fuchs committed Aug 6, 2024
1 parent 07e59b3 commit d5763bd
Show file tree
Hide file tree
Showing 4 changed files with 40 additions and 25 deletions.
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,8 @@ positional arguments:
options:
-h, --help show this help message and exit
-r ref_id, --reference ref_id
reference identifier
--name virHEAT_plot.pdf
plot name and file type (pdf, png, svg, jpg). Default: virHEAT_plot.pdf
-l None, --genome-length None
Expand Down
2 changes: 1 addition & 1 deletion virheat/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
"""plot vcf data as a heatmap mapped to a virus genome"""
_program = "virheat"
__version__ = "0.7"
__version__ = "0.7.1"
34 changes: 21 additions & 13 deletions virheat/command.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,15 @@ def get_args(sysargs):
nargs=2,
help="folder containing input files and output folder"
)
parser.add_argument(
"-r",
"--reference",
type=str,
metavar="ref_id",
required=True,
default=None,
help="reference identifier"
)
parser.add_argument(
"--name",
type=str,
Expand Down Expand Up @@ -147,23 +156,25 @@ def main(sysargs=sys.argv[1:]):
sys.exit("\033[31m\033[1mERROR:\033[0m No VCF files provided")
else:
if args.scores:
reference_name, frequency_lists, unique_mutations, file_names = data_prep.extract_vcf_data(vcf_files, threshold=args.threshold, scores=True)
frequency_lists, unique_mutations, file_names = data_prep.extract_vcf_data(vcf_files, args.reference, threshold=args.threshold, scores=True)
n_scores = len(args.scores)
else:
reference_name, frequency_lists, unique_mutations, file_names = data_prep.extract_vcf_data(vcf_files, threshold=args.threshold)
frequency_lists, unique_mutations, file_names = data_prep.extract_vcf_data(vcf_files, args.reference, threshold=args.threshold)

if args.zoom:
unique_mutations = data_prep.zoom_to_genomic_regions(unique_mutations, args.zoom)
frequency_array = data_prep.create_freq_array(unique_mutations, frequency_lists)

# user specified delete options (removes mutations based on various rationales)
if args.delete:
frequency_array = data_prep.delete_common_mutations(frequency_array, unique_mutations)
if args.delete_n is not None:
frequency_array = data_prep.delete_n_mutations(frequency_array, unique_mutations, args.delete_n)
# enables the deletion option only if more than 1 vcf file is provided
if len(vcf_files) > 1:
# user specified delete options (removes mutations based on various rationales)
if args.delete:
frequency_array = data_prep.delete_common_mutations(frequency_array, unique_mutations)
if args.delete_n is not None:
frequency_array = data_prep.delete_n_mutations(frequency_array, unique_mutations, args.delete_n)

# annotate low coverage if per base coverage from qualimap was provided
data_prep.annotate_non_covered_regions(args.input[0], args.min_cov, frequency_array, file_names, unique_mutations)
data_prep.annotate_non_covered_regions(args.input[0], args.min_cov, frequency_array, file_names, unique_mutations, args.reference)

# define relative locations of all items in the plot
n_samples, n_mutations = len(frequency_array), len(frequency_array[0])
Expand All @@ -178,10 +189,7 @@ def main(sysargs=sys.argv[1:]):
if args.gff3_path is not None and args.genome_length is not None:
sys.exit("\033[31m\033[1mERROR:\033[0m Do not provide the -g and -l argument simultaneously!")
elif args.gff3_path is not None:
gff3_info, gff3_ref_name = data_prep.parse_gff3(args.gff3_path)
# issue a warning if #CHROM and gff3 do not match
if gff3_ref_name not in reference_name and reference_name not in gff3_ref_name:
print("\033[31m\033[1mWARNING:\033[0m gff3 reference does not match the vcf reference!")
gff3_info = data_prep.parse_gff3(args.gff3_path, args.reference)
genome_end = data_prep.get_genome_end(gff3_info)
genes_with_mutations, n_tracks = data_prep.create_track_dict(unique_mutations, gff3_info, args.gff3_annotations)
elif args.genome_length is not None:
Expand Down Expand Up @@ -225,7 +233,7 @@ def main(sysargs=sys.argv[1:]):
plotting.create_heatmap(ax, frequency_array, cmap_cells)
mutation_set = plotting.create_genome_vis(ax, genome_y_location, n_mutations, unique_mutations, start, stop)
plotting.create_axis(ax, n_mutations, min_y_location, n_samples, file_names, start, stop, genome_y_location,
unique_mutations, reference_name)
unique_mutations, args.reference)
plotting.create_mutation_legend(mutation_set, min_y_location, n_samples, n_scores)
plotting.create_colorbar(args.threshold, cmap_cells, min_y_location, n_samples, ax)
# plot gene track
Expand Down
27 changes: 16 additions & 11 deletions virheat/scripts/data_prep.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def convert_string(string):
return string


def read_vcf(vcf_file):
def read_vcf(vcf_file, reference):
"""
parse vcf files to dictionary
"""
Expand All @@ -62,10 +62,10 @@ def read_vcf(vcf_file):
header = header_lines[0]
# get each line as frequency_lists
with open(vcf_file, "r") as f:
lines = [l.split("\t") for l in f if not l.startswith('#')]
lines = [l.split("\t") for l in f if l.startswith(reference)]
# check if vcf is empty
if not lines:
print(f"\033[31m\033[1mWARNING:\033[0m {vcf_file} is empty!")
print(f"\033[31m\033[1mWARNING:\033[0m {vcf_file} has no variants to {reference}!")
# get standard headers as keys
for key in header[0:6]:
vcf_dict[key] = []
Expand Down Expand Up @@ -104,7 +104,7 @@ def read_vcf(vcf_file):
return vcf_dict


def extract_vcf_data(vcf_files, threshold=0, scores=False):
def extract_vcf_data(vcf_files, reference, threshold=0, scores=False):
"""
extract relevant vcf data
"""
Expand All @@ -114,7 +114,7 @@ def extract_vcf_data(vcf_files, threshold=0, scores=False):

for file in vcf_files:
file_names.append(os.path.splitext(os.path.basename(file))[0])
vcf_dict = read_vcf(file)
vcf_dict = read_vcf(file, reference)
frequency_list = []
# write all mutation info in a '_' sep string
for idx in range(0, len(vcf_dict["#CHROM"])):
Expand All @@ -138,8 +138,10 @@ def extract_vcf_data(vcf_files, threshold=0, scores=False):
unique_mutations = sorted(
{x[0] for li in frequency_lists for x in li}, key=lambda x: int(x.split("_")[0])
)
if not unique_mutations:
sys.exit(f"\033[31m\033[1mERROR:\033[0m No variants to {reference} in all vcf files!")

return vcf_dict["#CHROM"][0], frequency_lists, unique_mutations, file_names
return frequency_lists, unique_mutations, file_names


def extract_scores(unique_mutations, scores_file, aa_pos_col, score_col):
Expand Down Expand Up @@ -185,7 +187,7 @@ def create_freq_array(unique_mutations, frequency_lists):
return np.array(frequency_array)


def annotate_non_covered_regions(coverage_dir, min_coverage, frequency_array, file_names, unique_mutations):
def annotate_non_covered_regions(coverage_dir, min_coverage, frequency_array, file_names, unique_mutations, reference):
"""
Insert nan values into np array if position is not covered. Needs
per base coverage tsv files created by bamqc
Expand All @@ -200,6 +202,7 @@ def annotate_non_covered_regions(coverage_dir, min_coverage, frequency_array, fi
continue
tsv_file = [file for file in per_base_coverage_files if os.path.splitext(os.path.basename(file))[0] == file_name][0]
coverage = pd.read_csv(tsv_file, sep="\t")
coverage = coverage[coverage["#chr"] == reference]
for j, (mutation, frequency) in enumerate(zip(unique_mutations, array)):
mut_pos = int(mutation.split("_")[0])
if coverage[coverage["pos"] == mut_pos].empty or all([frequency == 0, coverage[coverage["pos"] == mut_pos]["coverage"].iloc[0] <= min_coverage]):
Expand Down Expand Up @@ -267,7 +270,7 @@ def zoom_to_genomic_regions(unique_mutations, start_stop):
return zoomed_unique


def parse_gff3(file):
def parse_gff3(file, reference):
"""
parse gff3 to dictionary
"""
Expand All @@ -277,10 +280,9 @@ def parse_gff3(file):
with open(file, "r") as gff3_file:
for line in gff3_file:
# ignore comments and last line
if line.startswith("#") or line == "\n":
if not line.startswith(reference):
continue
gff_values = line.split("\t")
gff3_ref_name = gff_values[0]
# create keys
if gff_values[2] not in gff3_dict:
gff3_dict[gff_values[2]] = {}
Expand All @@ -300,7 +302,10 @@ def parse_gff3(file):

gff3_file.close()

return gff3_dict, gff3_ref_name
if not gff3_dict:
sys.exit(f"\033[31m\033[1mERROR:\033[0m {reference} not found in gff3 file.")

return gff3_dict


def get_genome_end(gff3_dict):
Expand Down

0 comments on commit d5763bd

Please sign in to comment.