Ver código fonte

Removed ensure_session to clean stuff up. May revisit later

Nel Nibcord 4 meses atrás
pai
commit
0673d6452c

+ 0 - 6
exo/inference/inference_engine.py

@@ -33,12 +33,6 @@ class InferenceEngine(ABC):
   async def save_session(self, key, value):
   async def save_session(self, key, value):
     self.session[key] = value
     self.session[key] = value
   
   
-  async def ensure_session(self, key, check, value_gen, hook=None):
-    if key not in self.session or not check(self.session[key]):
-      await self.save_session(key, value_gen())
-      if hook is not None:
-        hook()
-  
   async def clear_session(self):
   async def clear_session(self):
     self.session.empty()
     self.session.empty()
   
   

+ 0 - 5
exo/inference/tinygrad/inference.py

@@ -111,8 +111,6 @@ class TinygradDynamicShardInferenceEngine(InferenceEngine):
       Tensor.training = False
       Tensor.training = False
       return self.session['loss'](self.model, x, y, l)
       return self.session['loss'](self.model, x, y, l)
     await self.ensure_shard(shard)
     await self.ensure_shard(shard)
-    await self.ensure_session('loss', lambda: loss)
-    await self.ensure_session('jit', lambda: TinyJit(step)) 
     score = await asyncio.get_running_loop().run_in_executor(self.executor, lambda: self.session['jit'](Tensor(inputs), targets, lengths))
     score = await asyncio.get_running_loop().run_in_executor(self.executor, lambda: self.session['jit'](Tensor(inputs), targets, lengths))
     out = score.numpy()
     out = score.numpy()
     return out
     return out
@@ -126,9 +124,6 @@ class TinygradDynamicShardInferenceEngine(InferenceEngine):
       self.session['opt'].step()
       self.session['opt'].step()
       return score
       return score
     await self.ensure_shard(shard)
     await self.ensure_shard(shard)
-    await self.ensure_session('loss', lambda: loss)
-    await self.ensure_session('opt', lambda: opt(nn.state.get_parameters(self.model.model), lr=lr))
-    await self.ensure_session('jit', lambda: TinyJit(step)) 
       
       
     score = await asyncio.get_running_loop().run_in_executor(self.executor, lambda: self.session['jit'](Tensor(inputs), targets, lengths).realize())
     score = await asyncio.get_running_loop().run_in_executor(self.executor, lambda: self.session['jit'](Tensor(inputs), targets, lengths).realize())