网页资讯视频图片知道文库贴吧地图采购
进入贴吧全吧搜索

 
 
 
日一二三四五六
       
       
       
       
       
       

签到排名:今日本吧第个签到,

本吧因你更精彩,明天继续来努力!

本吧签到人数:0

一键签到
成为超级会员,使用一键签到
一键签到
本月漏签0次!
0
成为超级会员,赠送8张补签卡
如何使用?
点击日历上漏签日期,即可进行补签。
连续签到:天  累计签到:天
0
超级会员单次开通12个月以上,赠送连续签到卡3张
使用连续签到卡
01月08日漏签0天
python吧 关注:480,616贴子:1,983,710
  • 看贴

  • 图片

  • 吧主推荐

  • 视频

  • 游戏

  • 4回复贴,共1页
<<返回python吧
>0< 加载中...

TypeError: super(type, obj): obj must be an instance or subt

  • 只看楼主
  • 收藏

  • 回复
  • 无极玄少
  • 童生
    2
该楼层疑似违规已被系统折叠 隐藏此楼查看此楼



  • cloveses
  • 榜眼
    12
该楼层疑似违规已被系统折叠 隐藏此楼查看此楼
super()中应该是nn.Module


2026-01-08 20:37:27
广告
不感兴趣
开通SVIP免广告
  • 无极玄少
  • 童生
    2
该楼层疑似违规已被系统折叠 隐藏此楼查看此楼
def compute_energy(self, z, phi=None, mu=None, cov=None, size_average=True):
if phi is None:
phi = to_var(self.phi)
if mu is None:
mu = to_var(self.mu)
if cov is None:
cov = to_var(self.cov)
k, D, _ = cov.size()
z_mu = (z.unsqueeze(1)- mu.unsqueeze(0))
cov_inverse = []
det_cov = []
cov_diag = 0
eps = 1e-12
for i in range(k):
# K x D x D
cov_k = cov[i] + to_var(torch.eye(D)*eps)
cov_inverse.append(torch.inverse(cov_k).unsqueeze(0))
#det_cov.append(np.linalg.det(cov_k.data.cpu().numpy()* (2*np.pi)))
det_cov.append((Cholesky.apply(cov_k.cpu() * (2*np.pi)).diag().prod()).unsqueeze(0))
cov_diag = cov_diag + torch.sum(1 / cov_k.diag())
# K x D x D
cov_inverse = torch.cat(cov_inverse, dim=0)
# K
#det_cov = to_var(torch.from_numpy(np.float32(np.array(det_cov))))
# N x K
exp_term_tmp = -0.5 * torch.sum(torch.sum(z_mu.unsqueeze(-1) * cov_inverse.unsqueeze(0), dim=-2) * z_mu, dim=-1)
# for stability (logsumexp)
max_val = torch.max((exp_term_tmp).clamp(min=0), dim=1, keepdim=True)[0]
exp_term = torch.exp(exp_term_tmp - max_val)
# sample_energy = -max_val.squeeze() - torch.log(torch.sum(phi.unsqueeze(0) * exp_term / (det_cov).unsqueeze(0), dim = 1) + eps)
sample_energy = -max_val.squeeze() - torch.log(torch.sum(phi.unsqueeze(0) * exp_term / (torch.sqrt(det_cov)).unsqueeze(0), dim = 1) + eps)
# sample_energy = -max_val.squeeze() - torch.log(torch.sum(phi.unsqueeze(0) * exp_term / (torch.sqrt((2*np.pi)**D * det_cov)).unsqueeze(0), dim = 1) + eps)
在Py3.6 下我运行上面那段程序的时候报错如下:

请教该如何更改?


登录百度账号

扫二维码下载贴吧客户端

下载贴吧APP
看高清直播、视频!
  • 贴吧页面意见反馈
  • 违规贴吧举报反馈通道
  • 贴吧违规信息处理公示
  • 4回复贴,共1页
<<返回python吧
分享到:
©2026 Baidu贴吧协议|隐私政策|吧主制度|意见反馈|网络谣言警示