diff --git a/aspen/model.py b/aspen/model.py index 3f5a75f5..6c3dff03 100644 --- a/aspen/model.py +++ b/aspen/model.py @@ -97,7 +97,7 @@ def forward(self, data: torch.Tensor) -> torch.Tensor: class Linear(): def __init__(self, weight: torch.Tensor, load_in_8bit: bool = True, device: str = None): - if device == None: + if device is None: device = weight.device row, col = weight.shape if load_in_8bit: diff --git a/mlora.py b/mlora.py index 1fdd3f5d..a6d71146 100644 --- a/mlora.py +++ b/mlora.py @@ -44,7 +44,7 @@ def log(msg: str): exit(-1) -if args.model_name_or_path == None: +if args.model_name_or_path is None: print('error: Argument --model_name_or_path are required.') parser.print_help() exit(-1)