Browse Source

Added GPU selection feature to python inference (#321)

* Added GPU selection feature to python inference

* pylint pep8 fixes

* pep8 fixes
Mert Cobanov 3 years ago
parent
commit
6b15fc6936
2 changed files with 20 additions and 3 deletions
  1. 5 1
      inference_realesrgan.py
  2. 15 2
      realesrgan/utils.py

+ 5 - 1
inference_realesrgan.py

@@ -39,6 +39,9 @@ def main():
         type=str,
         default='auto',
         help='Image extension. Options: auto | jpg | png, auto means using the same extension as inputs')
+    parser.add_argument(
+        '-g', '--gpu-id', type=int, default=None, help='gpu device to use (default=None) can be 0,1,2 for multi-gpu')
+
     args = parser.parse_args()
 
     # determine models according to model names
@@ -71,7 +74,8 @@ def main():
         tile=args.tile,
         tile_pad=args.tile_pad,
         pre_pad=args.pre_pad,
-        half=not args.fp32)
+        half=not args.fp32,
+        gpu_id=args.gpu_id)
 
     if args.face_enhance:  # Use GFPGAN for face enhancement
         from gfpgan import GFPGANer

+ 15 - 2
realesrgan/utils.py

@@ -26,7 +26,16 @@ class RealESRGANer():
         half (float): Whether to use half precision during inference. Default: False.
     """
 
-    def __init__(self, scale, model_path, model=None, tile=0, tile_pad=10, pre_pad=10, half=False, device=None):
+    def __init__(self,
+                 scale,
+                 model_path,
+                 model=None,
+                 tile=0,
+                 tile_pad=10,
+                 pre_pad=10,
+                 half=False,
+                 device=None,
+                 gpu_id=None):
         self.scale = scale
         self.tile_size = tile
         self.tile_pad = tile_pad
@@ -35,7 +44,11 @@ class RealESRGANer():
         self.half = half
 
         # initialize model
-        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') if device is None else device
+        if gpu_id:
+            self.device = torch.device(
+                f'cuda:{gpu_id}' if torch.cuda.is_available() else 'cpu') if device is None else device
+        else:
+            self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') if device is None else device
         # if the model_path starts with https, it will first download models to the folder: realesrgan/weights
         if model_path.startswith('https://'):
             model_path = load_file_from_url(