深度學(xué)習(xí)(16)——權(quán)重加載-創(chuàng)新互聯(lián)

深度學(xué)習(xí)(16)—— 權(quán)重加載
model.load_state_dict(torch.load('/home/yangjy/projects/Jane_git_tf/weights/con_model/best1_2022-12-02-09-36.pth', map_location=device))
Question?

情形:
新的model是需要兩個(gè)模型作前期的處理后的結(jié)果,如model1得到feature1,model2得到feature2,最終(現(xiàn)在訓(xùn)練的)model(model3)需要學(xué)習(xí)的是根據(jù)feature1和feature2進(jìn)行整合和特征學(xué)習(xí)正確分辨出最終的結(jié)果。這個(gè)時(shí)候model3在第一次訓(xùn)練做初始化的時(shí)候需要加載model1和model2的權(quán)重,但是后來(lái)訓(xùn)練的時(shí)候如果初始權(quán)重是之前訓(xùn)練好的model3的權(quán)重,就不要再加載model1和model2的權(quán)重后再加載model3的權(quán)重,機(jī)器在加載的過(guò)程中都是需要消耗時(shí)間的,一方面是資源成本的浪費(fèi),無(wú)論是時(shí)間成本還是內(nèi)存占用率都是很大的消耗;其次剛剛我發(fā)現(xiàn),這樣重復(fù)性加載時(shí)影響最終的模型訓(xùn)練效果的,模型在加載權(quán)重的過(guò)程個(gè)人建議不要寫(xiě)在模型初始化的過(guò)程中,這種不靈活的寫(xiě)法,很可能會(huì)產(chǎn)生bias??!

創(chuàng)新互聯(lián)建站擁有網(wǎng)站維護(hù)技術(shù)和項(xiàng)目管理團(tuán)隊(duì),建立的售前、實(shí)施和售后服務(wù)體系,為客戶(hù)提供定制化的成都做網(wǎng)站、成都網(wǎng)站建設(shè)、網(wǎng)站維護(hù)、電信內(nèi)江機(jī)房解決方案。為客戶(hù)網(wǎng)站安全和日常運(yùn)維提供整體管家式外包優(yōu)質(zhì)服務(wù)。我們的網(wǎng)站維護(hù)服務(wù)覆蓋集團(tuán)企業(yè)、上市公司、外企網(wǎng)站、商城網(wǎng)站開(kāi)發(fā)、政府網(wǎng)站等各類(lèi)型客戶(hù)群體,為全球成百上千家企業(yè)提供全方位網(wǎng)站維護(hù)、服務(wù)器維護(hù)解決方案。
class model3(nn.Module):
    def __init__(self, num_classes, device,model1_path, model2_path,freeeze_pretain):
        super(Conv_con, self).__init__()
        self.device = device
        self.model1= MiniConvNext(num_classes=5, depths=[3, 3, 9, 3],
                                          dims=[96, 192, 384, 768], )
        self.model2 = MiniConvNext(num_classes=1, depths=[3, 3, 9, 3],
                                          dims=[96, 192, 384, 768], )
        self.freeeze_pretain = freeeze_pretain
        self._init_weights()

        self.fctl = _FCtL(512, 512)
        self.norm = LayerNorm(512, eps=1e-6, data_format="channels_last")
        self.head = nn.Linear(512, num_classes)
        self.model1_path= model1_path
        self.model2_path= model2_path

        self.set_pretrained_weight()

    def set_pretrained_weight(self):
        if self.model1_path:
            pretrained_dict = torch.load(self.model1_path, map_location=self.device)
            model_dict = self.model1.state_dict()
            # 1. filter out unnecessary keys
            pretrained_dict_b = {k: v for k, v in pretrained_dict.items() if k in model_dict}
            # 2. overwrite entries in the existing state dict
            model_dict.update(pretrained_dict_b)
            self.model1.load_state_dict(model_dict)
            self.model1.eval() 

        if self.model2_path:
            pretrained_dict = torch.load(self.model2_path, map_location=self.device)
            model_dict = self.model2.state_dict()
            # 1. filter out unnecessary keys
            pretrained_dict_b = {k: v for k, v in pretrained_dict.items() if k in model_dict}
            # 2. overwrite entries in the existing state dict
            model_dict.update(pretrained_dict_b)
            self.model2.load_state_dict(model_dict)
            if self.freeeze_pretain: # ?????fctl????????????
                self.model2.eval()


        if self.freeeze_pretain: # ??FCTL
            for name, para in self.model2.named_parameters():
                para.requires_grad = False
            for name, para in self.model1.named_parameters():
                para.requires_grad = False
        else: # ??FCTL?global??
            for name, para in self.model1.named_parameters():
                para.requires_grad = False

    def get_pretrained_weight(self):
        for name, parm in self.model2.named_parameters():
            print(f'{name}:{parm.requires_grad}')

    def forward(self, x, y):
        if self.freeeze_pretain: # ?????????????????????
            with torch.no_grad():
                feature1 = self.model1(x)
                feature2 = self.model2(y)
        else:
            with torch.no_grad():
                feature1 = self.model1(y)
            feature2 = self.model2(x)

        features = self.fctl(feature1 ,feature2 ,) # ??global?????roi????
        features_1 = self.norm(features.mean([-2, -1]))  
        out = self.head(features_1)
        return out

    def _init_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Linear):
                nn.init.trunc_normal_(m.weight, std=0.2)
                nn.init.constant_(m.bias, 0)

勸大家不要這樣寫(xiě)!不要把權(quán)重加載的事情放在初始化里面,追悔莫及!
思考:
其實(shí)我自己剛開(kāi)始覺(jué)得這種重復(fù)加載權(quán)重應(yīng)該是沒(méi)有問(wèn)題的,因?yàn)閙odel1和model2只是model3的一部分,我先加載model1和model2的權(quán)重,最后加載model3的權(quán)重也是會(huì)覆蓋剛剛加載的model1和model2的權(quán)重的,但是結(jié)果好像并不像我想想的那么簡(jiǎn)單。因?yàn)槲矣糜?xùn)練好的權(quán)重去預(yù)測(cè),先加載model1,再加載model2,之后加載model3,之后得到的結(jié)果驚掉下巴!雖然再訓(xùn)練過(guò)程中在驗(yàn)證集上準(zhǔn)確率不低,但是…所以用驗(yàn)證集驗(yàn)證是不是我的權(quán)重保存有問(wèn)題。check后發(fā)現(xiàn)沒(méi)有問(wèn)題,之后檢查數(shù)據(jù)集也沒(méi)有問(wèn)題,代碼也沒(méi)問(wèn)題,label的錯(cuò)誤之前犯過(guò)了,也不妨再檢查一遍沒(méi)有問(wèn)題。所以我又重新定義了不加載權(quán)重的predict_model ,直接加載model3的權(quán)重,這次在驗(yàn)證集上的結(jié)果才是正常的。至于原因,個(gè)人還在探索,搞明白再和大家分享。

你是否還在尋找穩(wěn)定的海外服務(wù)器提供商?創(chuàng)新互聯(lián)www.cdcxhl.cn海外機(jī)房具備T級(jí)流量清洗系統(tǒng)配攻擊溯源,準(zhǔn)確流量調(diào)度確保服務(wù)器高可用性,企業(yè)級(jí)服務(wù)器適合批量采購(gòu),新人活動(dòng)首月15元起,快前往官網(wǎng)查看詳情吧

標(biāo)題名稱(chēng):深度學(xué)習(xí)(16)——權(quán)重加載-創(chuàng)新互聯(lián)
分享路徑:http://bm7419.com/article42/ggihc.html

成都網(wǎng)站建設(shè)公司_創(chuàng)新互聯(lián),為您提供域名注冊(cè)、企業(yè)網(wǎng)站制作定制開(kāi)發(fā)、網(wǎng)站建設(shè)動(dòng)態(tài)網(wǎng)站、網(wǎng)頁(yè)設(shè)計(jì)公司

廣告

聲明:本網(wǎng)站發(fā)布的內(nèi)容(圖片、視頻和文字)以用戶(hù)投稿、用戶(hù)轉(zhuǎn)載內(nèi)容為主,如果涉及侵權(quán)請(qǐng)盡快告知,我們將會(huì)在第一時(shí)間刪除。文章觀點(diǎn)不代表本網(wǎng)站立場(chǎng),如需處理請(qǐng)聯(lián)系客服。電話(huà):028-86922220;郵箱:631063699@qq.com。內(nèi)容未經(jīng)允許不得轉(zhuǎn)載,或轉(zhuǎn)載時(shí)需注明來(lái)源: 創(chuàng)新互聯(lián)

微信小程序開(kāi)發(fā)