-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathutils.py
43 lines (33 loc) · 1.03 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
from torch.utils.data import random_split
import numpy as np
import os
import re
import torch
def torch_random_split_frac(dataset, fracs):
assert sum(fracs)==1
n = len(dataset)
return random_split(dataset, [round(n*f) for f in fracs])
def to_rgb(hl):
raise NotImplementedError
def read_pgm(filename, byteorder='>'):
"""Return image data from a raw PGM file as numpy array.
Format specification: http://netpbm.sourceforge.net/doc/pgm.html
"""
with open(filename, 'rb') as f:
first = f.readline()
if first != b'P5\n':
raise ValueError("pgm mode not supported {} in file {}".format(first, filename))
(width, height) = [int(i) for i in f.readline().split()]
depth = int(f.readline())
assert depth <= 65535
return np.frombuffer(f.read(),
dtype='u2' if depth > 255 else 'u1',
count=int(width)*int(height),
offset=0
).reshape((int(height), int(width)))
def correct_loss(mask):
shape = mask.shape
total = 1
for s in shape:
total = total * s
return total / torch.sum(torch.sum(mask)).item()