Descente du gradient stochastique
Remarque : Battre les cartes
Les données sont rangées dans l'ordre fixé dans les tableaux x et y. Plutôt que de parcourir les données dans l'ordre de x et y, on peut mélanger les données. L'algorithme s'appelle alors descente du gradient stochastique.
Exemple : Mélanger les données : méthode shuffle
Par rapport au programme précédent, la ligne :
dataset = dataset.repeat( num_epochs ).batch( batch_size )
a été modifiée en
dataset = dataset.shuffle(500).repeat( num_epochs ).batch( batch_size )
1
import tensorflow as tf
2
import numpy as np
3
4
class Model(object):
5
def __init__(self, a, b):
6
self.a = a
7
self.b = b
8
def __call__(self, x):
9
return self.a * x + self.b
10
11
def train(model, inputs, outputs, learning_rate):
12
with tf.GradientTape() as t:
13
t.watch([model.a, model.b])
14
current_loss = perte(model(inputs), outputs)
15
da, db = t.gradient(current_loss, [model.a, model.b])
16
model.a = tf.add(model.a,tf.constant(-learning_rate * da))
17
model.b = tf.add(model.b,tf.constant(-learning_rate * db))
18
19
def perte(predicted_y, target_y):
20
return tf.reduce_mean(tf.square(predicted_y - target_y))
21
22
23
x = np.array([1, 5, 8, 9 ,10, 15,13, 3,-2],np.float32)
24
y = np.array([-2,-5, -7, -12 ,-15, -5, -12,-10,-5],np.float32)
25
nombre_donnees = x.shape[0]
26
dataset = tf.data.Dataset.from_tensor_slices(( x , y ))
27
28
batch_size = 2
29
num_epochs = 2
30
dataset = dataset.shuffle(500).repeat( num_epochs ).batch( batch_size )
31
iterator = dataset.__iter__()
32
learning_rate = 0.01
33
nbre_lot = nombre_donnees // batch_size
34
model = Model(tf.Variable(7.3),tf.Variable(5.5))
35
36
print('(a, b) : (', model.a.numpy(),', ',model.b.numpy(),')')
37
print('Perte: ', perte(model(x), y).numpy())
38
39
40
for epoch in range(num_epochs):
41
for i in range(nbre_lot):
42
print("************** Epoque,lot : ",epoch,'//',i)
43
x_batch , y_batch = iterator.get_next()
44
print('Donnees ajustees :', x_batch,'/',y_batch)
45
train(model, x_batch, y_batch, learning_rate)
46
print('(a, b) : (', model.a.numpy(),', ',model.b.numpy(),')')
47
print('Perte: ', perte(model(x), y).numpy())
48
49
50