Skip to content

Commit

Permalink
Remove unused A and B computation
Browse files Browse the repository at this point in the history
  • Loading branch information
awgu committed Mar 13, 2024
1 parent a6bc165 commit 779d7d0
Showing 1 changed file with 10 additions and 15 deletions.
25 changes: 10 additions & 15 deletions galore_torch/galore_projector.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,9 @@ def __init__(self, rank, verbose=False, update_proj_gap=200, scale=1.0, proj_typ
self.scale = scale
self.ortho_matrix = None
self.proj_type = proj_type

def project(self, full_rank_grad, iter):

if self.proj_type == 'std':
if full_rank_grad.shape[0] >= full_rank_grad.shape[1]:
if self.ortho_matrix is None or iter % self.update_proj_gap == 0:
Expand Down Expand Up @@ -41,7 +41,7 @@ def project(self, full_rank_grad, iter):
if self.ortho_matrix is None or iter % self.update_proj_gap == 0:
self.ortho_matrix = self.get_orthogonal_matrix(full_rank_grad, self.rank, type='full')
low_rank_grad = torch.matmul(self.ortho_matrix[0].t(), full_rank_grad) @ self.ortho_matrix[1].t()

return low_rank_grad

def project_back(self, low_rank_grad):
Expand All @@ -62,11 +62,11 @@ def project_back(self, low_rank_grad):
full_rank_grad = torch.matmul(self.ortho_matrix, low_rank_grad)
elif self.proj_type == 'full':
full_rank_grad = torch.matmul(self.ortho_matrix[0], low_rank_grad) @ self.ortho_matrix[1]


return full_rank_grad * self.scale


# svd decomposition
def get_orthogonal_matrix(self, weights, rank, type):
module_params = weights
Expand All @@ -79,20 +79,18 @@ def get_orthogonal_matrix(self, weights, rank, type):
else:
float_data = True
matrix = module_params.data

U, s, Vh = torch.linalg.svd(matrix, full_matrices = False)

#make the smaller matrix always to be orthogonal matrix
if type=='right':
A = U[:, :rank] @ torch.diag(s[:rank])
B = Vh[:rank, :]

if not float_data:
B = B.to(original_device).type(original_type)
return B
elif type=='left':
A = U[:, :rank]
B = torch.diag(s[:rank]) @ Vh[:rank, :]
if not float_data:
A = A.to(original_device).type(original_type)
return A
Expand All @@ -105,6 +103,3 @@ def get_orthogonal_matrix(self, weights, rank, type):
return [A, B]
else:
raise ValueError('type should be left, right or full')



0 comments on commit 779d7d0

Please sign in to comment.