Rate this post

Chúng tôi vẽ đồ thị mô hình tuyến tính của mình với các tham số ngẫu nhiên được gán cho nó. Chúng tôi nhận thấy rằng nó không phù hợp với dữ liệu của chúng tôi. Những gì chúng ta phải làm. Chúng ta cần huấn luyện mô hình này để mô hình có các thông số trọng lượng và độ chệch tối ưu và phù hợp với dữ liệu này.

Các bài viết liên quan:

Có các bước sau để đào tạo một mô hình:

Bước 1

Bước đầu tiên của chúng tôi là xác định hàm mất mát, mà chúng tôi dự định giảm thiểu. PyTorch cung cấp một cách rất hiệu quả để chỉ định hàm bị mất. PyTorch cung cấp hàm MSELoss (), được gọi là tổn thất bình phương trung bình, để tính toán tổn thất dưới dạng

criterion=nn.MSELoss()  

Bước 2

Bây giờ, bước tiếp theo của chúng tôi là cập nhật các thông số của chúng tôi. Với mục đích này, chúng tôi chỉ định trình tối ưu hóa sử dụng thuật toán giảm độ dốc. Chúng tôi sử dụng hàm SGD () được gọi là descent gradient ngẫu nhiên để tối ưu hóa. SGD giảm thiểu tổng hao hụt từng mẫu một và thường đạt đến sự hội tụ nhanh hơn nhiều vì nó sẽ thường xuyên cập nhật trọng lượng của mô hình của chúng tôi trong cùng một kích thước mẫu.

optimizer=torch.optim.SGD(model.parameters(),lr=0.01)   

Ở đây, lr là viết tắt của tỷ lệ học tập, ban đầu được đặt thành 0,01.

Bước 3

Chúng tôi sẽ đào tạo mô hình của chúng tôi cho một số kỷ nguyên được chỉ định (Chúng tôi đã tính toán hàm lỗi và sao chép ngược độ dốc xuống của hàm lỗi này để cập nhật trọng số).

epochs=100  

Và bây giờ, đối với mỗi kỷ nguyên, chúng tôi phải giảm thiểu lỗi của hệ thống mô hình của chúng tôi. Sai số chỉ đơn giản là sự so sánh giữa dự đoán của mô hình và các giá trị thực tế.

Losses=[]  
For i in range (epochs):  
    ypred=model.forward(x)  #Prediction of y  
    loss=criterion(ypred,y) #Find loss  
    losses.append()     # Add loss in list   
    optimizer.zero_grad() # Set the gradient to zero  
    loss.backward() #To compute derivatives   
    optimizer.step()    # Update the parameters   

Bước 4

Cuối cùng, chúng ta vẽ biểu đồ mô hình tuyến tính mới của mình bằng cách gọi phương thức plotfit ().

plotfit('Trained Model')  

Leave a Reply

Call now
%d bloggers like this: