Thank you for your great work!
I have a question about a conditional loss in FactorCL-SSL case.
In IRFL_model.py Line 311, conditional CLUB loss is computed as follows:
self.club_x1x2_cond(torch.cat([self.linears_club_x1x2_cond[0](x1_embed),
self.linears_club_x1x2_cond[0](x1_embed)], dim=1),
torch.cat([self.linears_club_x1x2_cond[1](x2_embed),
self.linears_club_x1x2_cond[1](x2_embed)], dim=1))
However, I think that "embeds" should be concatenated with "aug_embeds" following Eq(8) in the paper, like:
self.club_x1x2_cond(torch.cat([self.linears_club_x1x2_cond[0](x1_embed),
self.linears_club_x1x2_cond[0](x1_aug_embed)], dim=1),
torch.cat([self.linears_club_x1x2_cond[1](x2_embed),
self.linears_club_x1x2_cond[1](x2_aug_embed)], dim=1))
Since I'm a beginner in this field, I might have misunderstood something.
Is there a chance I might have misunderstood something?
Your response would be really helpful for me! Thank you.
Thank you for your great work!
I have a question about a conditional loss in FactorCL-SSL case.
In IRFL_model.py Line 311, conditional CLUB loss is computed as follows:
However, I think that "embeds" should be concatenated with "aug_embeds" following Eq(8) in the paper, like:
Since I'm a beginner in this field, I might have misunderstood something.
Is there a chance I might have misunderstood something?
Your response would be really helpful for me! Thank you.