removed scheduling function
This commit is contained in:
27
tools.py
27
tools.py
@@ -899,33 +899,6 @@ class Until:
|
||||
return step < self._until
|
||||
|
||||
|
||||
def schedule(string, step):
|
||||
try:
|
||||
return float(string)
|
||||
except ValueError:
|
||||
match = re.match(r"linear\((.+),(.+),(.+)\)", string)
|
||||
if match:
|
||||
initial, final, duration = [float(group) for group in match.groups()]
|
||||
mix = torch.clip(torch.Tensor([step / duration]), 0, 1)[0]
|
||||
return (1 - mix) * initial + mix * final
|
||||
match = re.match(r"warmup\((.+),(.+)\)", string)
|
||||
if match:
|
||||
warmup, value = [float(group) for group in match.groups()]
|
||||
scale = torch.clip(step / warmup, 0, 1)
|
||||
return scale * value
|
||||
match = re.match(r"exp\((.+),(.+),(.+)\)", string)
|
||||
if match:
|
||||
initial, final, halflife = [float(group) for group in match.groups()]
|
||||
return (initial - final) * 0.5 ** (step / halflife) + final
|
||||
match = re.match(r"horizon\((.+),(.+),(.+)\)", string)
|
||||
if match:
|
||||
initial, final, duration = [float(group) for group in match.groups()]
|
||||
mix = torch.clip(step / duration, 0, 1)
|
||||
horizon = (1 - mix) * initial + mix * final
|
||||
return 1 - 1 / horizon
|
||||
raise NotImplementedError(string)
|
||||
|
||||
|
||||
def weight_init(m):
|
||||
if isinstance(m, nn.Linear):
|
||||
in_num = m.in_features
|
||||
|
||||
Reference in New Issue
Block a user