|
@@ -26,7 +26,16 @@ class RealESRGANer():
|
|
|
half (float): Whether to use half precision during inference. Default: False.
|
|
|
"""
|
|
|
|
|
|
- def __init__(self, scale, model_path, model=None, tile=0, tile_pad=10, pre_pad=10, half=False, device=None):
|
|
|
+ def __init__(self,
|
|
|
+ scale,
|
|
|
+ model_path,
|
|
|
+ model=None,
|
|
|
+ tile=0,
|
|
|
+ tile_pad=10,
|
|
|
+ pre_pad=10,
|
|
|
+ half=False,
|
|
|
+ device=None,
|
|
|
+ gpu_id=None):
|
|
|
self.scale = scale
|
|
|
self.tile_size = tile
|
|
|
self.tile_pad = tile_pad
|
|
@@ -35,7 +44,11 @@ class RealESRGANer():
|
|
|
self.half = half
|
|
|
|
|
|
# initialize model
|
|
|
- self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') if device is None else device
|
|
|
+ if gpu_id:
|
|
|
+ self.device = torch.device(
|
|
|
+ f'cuda:{gpu_id}' if torch.cuda.is_available() else 'cpu') if device is None else device
|
|
|
+ else:
|
|
|
+ self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') if device is None else device
|
|
|
# if the model_path starts with https, it will first download models to the folder: realesrgan/weights
|
|
|
if model_path.startswith('https://'):
|
|
|
model_path = load_file_from_url(
|