1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53
| def training(model, rounds, batch_size, lr, ds,L, data_dict, C, K, E, plt_title, plt_color): global_weights = model.state_dict() train_loss = [] start = time.time() for curr_round in range(1, rounds+1): w, local_loss = [], [] m = max(int(C*K), 1) S_t = np.random.choice(range(K), m, replace=False) for k in S_t: sub_data = ds[data_dict[k]] sub_y = L[data_dict[k]] local_update = ClientUpdate(sub_data,sub_y, batch_size=batch_size, learning_rate=lr, epochs=E) weights, loss = local_update.train(model) w.append(weights) local_loss.append(loss)
weights_avg = w[0] for k in weights_avg.keys(): for i in range(1, len(w)): weights_avg[k]=weights_avg[k]+w[i][k] weights_avg[k]=weights_avg[k]/len(w) global_weights[k].set_value(weights_avg[k]) model.load_dict(global_weights)
loss_avg = sum(local_loss) / len(local_loss) if curr_round % 10 == 0: print('Round: {}... \tAverage Loss: {}'.format(curr_round, np.round(loss_avg, 5))) train_loss.append(loss_avg)
end = time.time() fig, ax = plt.subplots() x_axis = np.arange(1, rounds+1) y_axis = np.array(train_loss) ax.plot(x_axis, y_axis, 'tab:'+plt_color)
ax.set(xlabel='Number of Rounds', ylabel='Train Loss',title=plt_title) ax.grid() fig.savefig(plt_title+'.jpg', format='jpg') print("Training Done!") print("Total time taken to Train: {}".format(end-start)) return model.state_dict()
mnist_cnn = CNN() mnist_cnn_iid_trained = training(mnist_cnn, rounds, batch_size, lr, train_x,train_y, iid_dict, C, K, E, "MNIST CNN on IID Dataset", "orange")
|