Browse Source

modify weight path

Xintao 2 years ago
parent
commit
0ac8d66d39
8 changed files with 23 additions and 28 deletions
  1. 1 1
      .gitignore
  2. 1 1
      MANIFEST.in
  3. 15 17
      cog_predict.py
  4. 1 1
      docs/anime_video_model.md
  5. 2 2
      inference_realesrgan.py
  6. 1 1
      inference_realesrgan_video.py
  7. 2 5
      realesrgan/utils.py
  8. 0 0
      weights/README.md

+ 1 - 1
.gitignore

@@ -5,7 +5,7 @@ results/*
 tb_logger/*
 wandb/*
 tmp/*
-realesrgan/weights/*
+weights/*
 
 version.py
 

+ 1 - 1
MANIFEST.in

@@ -5,4 +5,4 @@ include inference_realesrgan.py
 include VERSION
 include LICENSE
 include requirements.txt
-include realesrgan/weights/README.md
+include weights/README.md

+ 15 - 17
cog_predict.py

@@ -29,52 +29,50 @@ class Predictor(BasePredictor):
     def setup(self):
         os.makedirs('output', exist_ok=True)
         # download weights
-        if not os.path.exists('realesrgan/weights/realesr-general-x4v3.pth'):
+        if not os.path.exists('weights/realesr-general-x4v3.pth'):
             os.system(
-                'wget https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-x4v3.pth -P ./realesrgan/weights'
+                'wget https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-x4v3.pth -P ./weights'
             )
-        if not os.path.exists('realesrgan/weights/GFPGANv1.4.pth'):
+        if not os.path.exists('weights/GFPGANv1.4.pth'):
+            os.system('wget https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.4.pth -P ./weights')
+        if not os.path.exists('weights/RealESRGAN_x4plus.pth'):
             os.system(
-                'wget https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.4.pth -P ./realesrgan/weights'
+                'wget https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth -P ./weights'
             )
-        if not os.path.exists('realesrgan/weights/RealESRGAN_x4plus.pth'):
+        if not os.path.exists('weights/RealESRGAN_x4plus_anime_6B.pth'):
             os.system(
-                'wget https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth -P ./realesrgan/weights'
+                'wget https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.4/RealESRGAN_x4plus_anime_6B.pth -P ./weights'
             )
-        if not os.path.exists('realesrgan/weights/RealESRGAN_x4plus_anime_6B.pth'):
+        if not os.path.exists('weights/realesr-animevideov3.pth'):
             os.system(
-                'wget https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.4/RealESRGAN_x4plus_anime_6B.pth -P ./realesrgan/weights'
-            )
-        if not os.path.exists('realesrgan/weights/realesr-animevideov3.pth'):
-            os.system(
-                'wget https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-animevideov3.pth -P ./realesrgan/weights'
+                'wget https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-animevideov3.pth -P ./weights'
             )
 
     def choose_model(self, scale, version, tile=0):
         half = True if torch.cuda.is_available() else False
         if version == 'General - RealESRGANplus':
             model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=4)
-            model_path = 'realesrgan/weights/RealESRGAN_x4plus.pth'
+            model_path = 'weights/RealESRGAN_x4plus.pth'
             self.upsampler = RealESRGANer(
                 scale=4, model_path=model_path, model=model, tile=tile, tile_pad=10, pre_pad=0, half=half)
         elif version == 'General - v3':
             model = SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=32, upscale=4, act_type='prelu')
-            model_path = 'realesrgan/weights/realesr-general-x4v3.pth'
+            model_path = 'weights/realesr-general-x4v3.pth'
             self.upsampler = RealESRGANer(
                 scale=4, model_path=model_path, model=model, tile=tile, tile_pad=10, pre_pad=0, half=half)
         elif version == 'Anime - anime6B':
             model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=6, num_grow_ch=32, scale=4)
-            model_path = 'realesrgan/weights/RealESRGAN_x4plus_anime_6B.pth'
+            model_path = 'weights/RealESRGAN_x4plus_anime_6B.pth'
             self.upsampler = RealESRGANer(
                 scale=4, model_path=model_path, model=model, tile=tile, tile_pad=10, pre_pad=0, half=half)
         elif version == 'AnimeVideo - v3':
             model = SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=16, upscale=4, act_type='prelu')
-            model_path = 'realesrgan/weights/realesr-animevideov3.pth'
+            model_path = 'weights/realesr-animevideov3.pth'
             self.upsampler = RealESRGANer(
                 scale=4, model_path=model_path, model=model, tile=tile, tile_pad=10, pre_pad=0, half=half)
 
         self.face_enhancer = GFPGANer(
-            model_path='realesrgan/weights/GFPGANv1.4.pth',
+            model_path='weights/GFPGANv1.4.pth',
             upscale=scale,
             arch='clean',
             channel_multiplier=2,

+ 1 - 1
docs/anime_video_model.md

@@ -34,7 +34,7 @@ The following are some demos (best view in the full screen mode).
 
 ```bash
 # download model
-wget https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-animevideov3.pth -P realesrgan/weights
+wget https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-animevideov3.pth -P weights
 # single gpu and single process inference
 CUDA_VISIBLE_DEVICES=0 python inference_realesrgan_video.py -i inputs/video/onepiece_demo.mp4 -n realesr-animevideov3 -s 2 --suffix outx2
 # single gpu and multi process inference (you can use multi-processing to improve GPU utilization)

+ 2 - 2
inference_realesrgan.py

@@ -88,13 +88,13 @@ def main():
     if args.model_path is not None:
         model_path = args.model_path
     else:
-        model_path = os.path.join('realesrgan/weights', args.model_name + '.pth')
+        model_path = os.path.join('weights', args.model_name + '.pth')
         if not os.path.isfile(model_path):
             ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
             for url in file_url:
                 # model_path will be updated
                 model_path = load_file_from_url(
-                    url=url, model_dir=os.path.join(ROOT_DIR, 'realesrgan/weights'), progress=True, file_name=None)
+                    url=url, model_dir=os.path.join(ROOT_DIR, 'weights'), progress=True, file_name=None)
 
     # use dni to control the denoise strength
     dni_weight = None

+ 1 - 1
inference_realesrgan_video.py

@@ -190,7 +190,7 @@ def inference_video(args, video_save_path, device=None, total_workers=1, worker_
     # ---------------------- determine model paths ---------------------- #
     model_path = os.path.join('experiments/pretrained_models', args.model_name + '.pth')
     if not os.path.isfile(model_path):
-        model_path = os.path.join('realesrgan/weights', args.model_name + '.pth')
+        model_path = os.path.join('weights', args.model_name + '.pth')
     if not os.path.isfile(model_path):
         raise ValueError(f'Model {args.model_name} does not exist.')
 

+ 2 - 5
realesrgan/utils.py

@@ -56,13 +56,10 @@ class RealESRGANer():
             assert len(model_path) == len(dni_weight), 'model_path and dni_weight should have the save length.'
             loadnet = self.dni(model_path[0], model_path[1], dni_weight)
         else:
-            # if the model_path starts with https, it will first download models to the folder: realesrgan/weights
+            # if the model_path starts with https, it will first download models to the folder: weights
             if model_path.startswith('https://'):
                 model_path = load_file_from_url(
-                    url=model_path,
-                    model_dir=os.path.join(ROOT_DIR, 'realesrgan/weights'),
-                    progress=True,
-                    file_name=None)
+                    url=model_path, model_dir=os.path.join(ROOT_DIR, 'weights'), progress=True, file_name=None)
             loadnet = torch.load(model_path, map_location=torch.device('cpu'))
 
         # prefer to use params_ema

+ 0 - 0
realesrgan/weights/README.md → weights/README.md